├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── NOTICE ├── README.md ├── analysis ├── __init__.py ├── analysis.ipynb └── utils.py ├── assets ├── demo_overview.png ├── patchsae.gif ├── ref_1.gif ├── ref_2.gif ├── ref_3.gif ├── ref_imgs_munich.gif └── sae_arch.gif ├── configs ├── classnames │ ├── caltech101_classnames.txt │ ├── cifar100_cat_to_name.json │ ├── food101_cat_to_name.json │ ├── imagenet-sketch_classnames.txt │ ├── imagenet_classnames.txt │ ├── oxford_flowers_classnames.json │ └── pet_cat_to_name.json └── models │ └── maple │ └── vit_b16_c2_ep5_batch4_2ctx.yaml ├── demo.ipynb ├── requirements.txt ├── scripts ├── 01_run_train.sh ├── 02_run_compute_feature_data.sh ├── 03_run_class_level.sh └── 04_run_topk_eval.sh ├── src ├── demo │ ├── app.py │ ├── core.py │ └── utils.py ├── models │ ├── architecture │ │ └── maple.py │ ├── clip │ │ ├── __init__.py │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ ├── clip.py │ │ ├── model.py │ │ └── simple_tokenizer.py │ ├── config │ │ ├── __init__.py │ │ ├── default_config.py │ │ └── maple.py │ ├── templates │ │ └── openai_imagenet_templates.py │ └── utils.py └── sae_training │ ├── config.py │ ├── hooked_vit.py │ ├── sae_trainer.py │ ├── sparse_autoencoder.py │ ├── utils.py │ └── vit_activations_store.py └── tasks ├── README.md ├── classification_with_top_k_masking.py ├── compute_class_wise_sae_activation.py ├── compute_sae_feature_data.py ├── train_sae_vit.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | data/ 3 | .vscode/ 4 | out/ 5 | .DS_Store 6 | *.zip 7 | wandb/ 8 | logs/ 9 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: end-of-file-fixer 6 | - id: trailing-whitespace 7 | - id: check-yaml 8 | - id: check-executables-have-shebangs 9 | - id: check-toml 10 | - id: check-added-large-files 11 | args: [--maxkb=15000] 12 | - id: check-case-conflict 13 | - id: check-merge-conflict 14 | - id: check-symlinks 15 | - id: debug-statements 16 | - id: detect-private-key 17 | - id: mixed-line-ending 18 | args: [--fix=lf] 19 | - id: requirements-txt-fixer 20 | - repo: https://github.com/astral-sh/ruff-pre-commit 21 | rev: v0.11.6 22 | hooks: 23 | - id: ruff 24 | args: [ --fix, --select=I ] 25 | - id: ruff-format 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024-2025 Hyesu Lim, Jinho Choi, Jaegul Choo, Steffen Schneider 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Code from the following third-party repositories was used. 2 | Please include the contents of the LICENSE as well this NOTICE file in all 3 | re-distributions of this code. 4 | 5 | ================================================================================= 6 | 7 | https://github.com/muzairkhattak/multimodal-prompt-learning 8 | 9 | MIT License 10 | 11 | Copyright (c) 2022 Muhammad Uzair Khattak 12 | Copyright (c) 2021 Kaiyang Zhou 13 | 14 | Permission is hereby granted, free of charge, to any person obtaining a copy 15 | of this software and associated documentation files (the "Software"), to deal 16 | in the Software without restriction, including without limitation the rights 17 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 18 | copies of the Software, and to permit persons to whom the Software is 19 | furnished to do so, subject to the following conditions: 20 | 21 | The above copyright notice and this permission notice shall be included in all 22 | copies or substantial portions of the Software. 23 | 24 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 25 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 26 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 27 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 28 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 29 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 30 | SOFTWARE. 31 | 32 | ================================================================================= 33 | 34 | https://github.com/muzairkhattak/PromptSRC 35 | 36 | MIT License 37 | Copyright (c) 2023 Muhammad Uzair Khattak 38 | Copyright (c) 2022 Muhammad Uzair Khattak 39 | Copyright (c) 2021 Kaiyang Zhou 40 | 41 | Permission is hereby granted, free of charge, to any person obtaining a copy 42 | of this software and associated documentation files (the "Software"), to deal 43 | in the Software without restriction, including without limitation the rights 44 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 45 | copies of the Software, and to permit persons to whom the Software is 46 | furnished to do so, subject to the following conditions: 47 | 48 | The above copyright notice and this permission notice shall be included in all 49 | copies or substantial portions of the Software. 50 | 51 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 52 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 53 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 54 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 55 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 56 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 57 | SOFTWARE. 58 | 59 | ================================================================================= 60 | 61 | https://github.com/HugoFry/mats_sae_training_for_ViTs 62 | 63 | MIT License 64 | 65 | Copyright (c) 2023 Joseph Bloom 66 | 67 | Permission is hereby granted, free of charge, to any person obtaining a copy 68 | of this software and associated documentation files (the "Software"), to deal 69 | in the Software without restriction, including without limitation the rights 70 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 71 | copies of the Software, and to permit persons to whom the Software is 72 | furnished to do so, subject to the following conditions: 73 | 74 | The above copyright notice and this permission notice shall be included in all 75 | copies or substantial portions of the Software. 76 | 77 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 78 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 79 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 80 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 81 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 82 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 83 | SOFTWARE. 84 | 85 | ================================================================================= 86 | 87 | https://github.com/jbloomAus/SAELens 88 | 89 | MIT License 90 | 91 | Copyright (c) 2023 Joseph Bloom 92 | 93 | Permission is hereby granted, free of charge, to any person obtaining a copy 94 | of this software and associated documentation files (the "Software"), to deal 95 | in the Software without restriction, including without limitation the rights 96 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 97 | copies of the Software, and to permit persons to whom the Software is 98 | furnished to do so, subject to the following conditions: 99 | 100 | The above copyright notice and this permission notice shall be included in all 101 | copies or substantial portions of the Software. 102 | 103 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 104 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 105 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 106 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 107 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 108 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 109 | SOFTWARE. 110 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PatchSAE: Sparse Autoencoders Reveal Selective Remapping of Visual Concepts During Adaptation 2 | 3 | [![Website & Demo](https://img.shields.io/badge/🔗_Website_&_Demo-blue)](https://dynamical-inference.ai/patchsae/) 4 | [![Paper](https://img.shields.io/badge/📑_Paper-arXiv-red)](https://arxiv.org/abs/2412.05276) 5 | [![OpenReview](https://img.shields.io/badge/OpenReview-ICLR_2025-green)](https://openreview.net/forum?id=imT03YXlG2) 6 | [![Hugging Face Demo](https://img.shields.io/badge/🤗_Hugging_Face-Demo-yellow)](https://huggingface.co/spaces/dynamical-inference/patchsae-demo) 7 | 8 |
9 | PatchSAE visualization 10 |
11 | 12 | ## 🚀 Quick Navigation 13 | 14 | - [Getting Started](#-getting-started) 15 | - [Interactive Demo](#-interactive-demo) 16 | - [Training & Analysis](#-patchsae-training-and-analysis) 17 | - [Status Updates](#-status-updates) 18 | - [License & Credits](#-license--credits) 19 | 20 | ## 🛠 Getting Started 21 | 22 | Set up your environment with these simple steps: 23 | 24 | ```bash 25 | # Create and activate environment 26 | conda create --name patchsae python=3.12 27 | conda activate patchsae 28 | 29 | # Install dependencies 30 | pip install -r requirements.txt 31 | 32 | # Always set PYTHONPATH before running any scripts 33 | cd patchsae 34 | PYTHONPATH=./ python src/demo/app.py 35 | ``` 36 | 37 | ## 🎮 Interactive Demo 38 | 39 | ### Online Demo on Hugging Face 🤗 [![Website & Demo](https://img.shields.io/badge/🔗_Website_&_Demo-blue)](https://dynamical-inference.ai/patchsae/) 40 | 41 | Explore our pre-computed images and SAE latents without any installation! 42 | > 💡 The demo may experience slowdowns due to network constraints. For optimal performance, consider disabling your VPN if you encounter any delays. 43 | 44 |
45 | Demo interface 46 |
47 | 48 | 49 | ### Local Demo: Try Your Own Images 50 | 51 | Want to experiment with your own images? Follow these steps: 52 | 53 | #### 1. Setup Local Demo 54 | 55 | First, download the necessary files: 56 | 57 | - [out.zip](https://drive.google.com/file/d/1NJzF8PriKz_mopBY4l8_44R0FVi2uw2g/edit) 58 | - [data.zip](https://drive.google.com/file/d/1reuDjXsiMkntf1JJPLC5a3CcWuJ6Ji3Z/edit) 59 | 60 | You can download the files using `gdown` as follows: 61 | 62 | ```bash 63 | # Activate environment first (see Getting Started) 64 | 65 | # Download necessary files (35MB + 513MB) 66 | gdown --id 1NJzF8PriKz_mopBY4l8_44R0FVi2uw2g # out.zip 67 | gdown --id 1reuDjXsiMkntf1JJPLC5a3CcWuJ6Ji3Z # data.zip 68 | 69 | # Extract files 70 | unzip data.zip 71 | unzip out.zip 72 | ``` 73 | 74 | > 💡 Need `gdown`? Install it with: `conda install conda-forge::gdown` 75 | 76 | Your folder structure should look like: 77 | 78 | ``` 79 | patchsae/ 80 | ├── configs/ 81 | ├── data/ # From data.zip 82 | ├── out/ # From out.zip 83 | ├── src/ 84 | │ └── demo/ 85 | │ └── app.py 86 | ├── tasks/ 87 | ├── requirements.txt 88 | └── ... (other files) 89 | ``` 90 | 91 | #### 2. Launch the Demo 92 | 93 | ```bash 94 | PYTHONPATH=./ python src/demo/app.py 95 | ``` 96 | 97 | ⚠️ **Note**: 98 | - First run will download datasets from HuggingFace automatically (About 30GB in total) 99 | - Demo runs on CPU by default 100 | - Access the interface at http://127.0.0.1:7860 (or the URL shown in terminal) 101 | 102 | ## 📊 PatchSAE Training and Analysis 103 | 104 | - **Training Instructions**: See [tasks/README.md](./tasks/README.md) 105 | - **Analysis Notebooks**: 106 | - [demo.ipynb](./demo.ipynb) 107 | - [analysis.ipynb](./analysis/analysis.ipynb) 108 | 109 | ## 📝 Status Updates 110 | 111 | - **Jan 13, 2025**: Training & Analysis code work properly. Minor error in data loading by class when using ImageNet. 112 | - **Jan 09, 2025**: Analysis code works. Updated training with evaluation during training, fixed optimizer bug. 113 | - **Jan 07, 2025**: Added analysis code. Reproducibility tests completed (trained on ImageNet, tested on Oxford-Flowers). 114 | - **Jan 06, 2025**: Training code updated. Reproducibility testing in progress. 115 | - **Jan 02, 2025**: Training code incomplete in this version. Updates coming soon. 116 | 117 | ## 📜 License & Credits 118 | 119 | ### Reference Implementations 120 | 121 | - [SAE for ViT](https://github.com/HugoFry/mats_sae_training_for_ViTs) 122 | - [SAELens](https://github.com/jbloomAus/SAELens) 123 | - [Differentiable and Fast Geometric Median in NumPy and PyTorch](https://github.com/krishnap25/geom_median) 124 | - [Self-regulating Prompts: Foundational Model Adaptation without Forgetting [ICCV 2023]](https://github.com/muzairkhattak/PromptSRC) 125 | - Used in: `configs/` and `msrc/models/` 126 | - [MaPLe: Multi-modal Prompt Learning CVPR 2023](https://github.com/muzairkhattak/multimodal-prompt-learning) 127 | - Used in: `configs/models/maple/...yaml` and `data/clip/maple/imagenet/model.pth.tar-2` 128 | 129 | ### License Notice 130 | 131 | Our code is distributed under an MIT license, please see the [LICENSE](LICENSE) file for details. 132 | The [NOTICE](NOTICE) file lists license for all third-party code included in this repository. 133 | Please include the contents of the LICENSE and NOTICE files in all re-distributions of this code. 134 | 135 | --- 136 | 137 | ### Citation 138 | 139 | If you find our code or models useful in your work, please cite our [paper](https://arxiv.org/abs/2412.05276): 140 | 141 | ``` 142 | @inproceedings{ 143 | lim2025patchsae, 144 | title={Sparse autoencoders reveal selective remapping of visual concepts during adaptation}, 145 | author={Hyesu Lim and Jinho Choi and Jaegul Choo and Steffen Schneider}, 146 | booktitle={The Thirteenth International Conference on Learning Representations}, 147 | year={2025}, 148 | url={https://openreview.net/forum?id=imT03YXlG2} 149 | } 150 | ``` 151 | -------------------------------------------------------------------------------- /analysis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamical-inference/patchsae/6b468b36fe4003724e1c5445180313c79b7a2d0c/analysis/__init__.py -------------------------------------------------------------------------------- /analysis/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import pandas as pd 5 | import plotly.express as px 6 | import torch 7 | 8 | DATASET_INFO = { 9 | "imagenet": { 10 | "path": "evanarlian/imagenet_1k_resized_256", 11 | "split": "train", 12 | }, 13 | "imagenet-sketch": { 14 | "path": "clip-benchmark/wds_imagenet_sketch", 15 | "split": "train", 16 | }, 17 | "oxford_flowers": { 18 | "path": "nelorth/oxford-flowers", 19 | "split": "train", 20 | }, 21 | "caltech101": { 22 | "path": "HuggingFaceM4/Caltech-101", 23 | "split": "train", 24 | "name": "with_background_category", 25 | }, 26 | } 27 | 28 | 29 | def calculate_entropy(top_val, top_label, ignore_label_idx: int = None, eps=1e-9): 30 | dict_size = top_label.shape[0] 31 | entropy = torch.zeros(dict_size) 32 | 33 | for i in range(dict_size): 34 | unique_labels, counts = top_label[i].unique(return_counts=True) 35 | if ignore_label_idx is not None: 36 | counts = counts[unique_labels != ignore_label_idx] 37 | unique_labels = unique_labels[unique_labels != ignore_label_idx] 38 | if len(unique_labels) != 0: 39 | if counts.sum().item() < 10: 40 | entropy[i] = -1 # discount as too few datapoints! 41 | else: 42 | summed_probs = torch.zeros_like(unique_labels, dtype=top_val.dtype) 43 | for j, label in enumerate(unique_labels): 44 | summed_probs[j] = top_val[i][top_label[i] == label].sum().item() 45 | summed_probs = summed_probs / summed_probs.sum() 46 | entropy[i] = -torch.sum(summed_probs * torch.log(summed_probs + eps)) 47 | else: 48 | entropy[i] = -1 49 | return entropy 50 | 51 | 52 | def load_stats(save_directory: str, device: torch.device): 53 | mean_acts = torch.load( 54 | os.path.join(save_directory, "sae_mean_acts.pt"), 55 | map_location=torch.device(device), 56 | ) 57 | sparsity = torch.load( 58 | os.path.join(save_directory, "sae_sparsity.pt"), 59 | map_location=torch.device(device), 60 | ) 61 | top_val = torch.load( 62 | os.path.join(save_directory, "max_activating_image_values.pt"), 63 | map_location=torch.device(device), 64 | ) 65 | top_idx = torch.load( 66 | os.path.join(save_directory, "max_activating_image_indices.pt"), 67 | map_location=torch.device(device), 68 | ) 69 | top_label = torch.load( 70 | os.path.join(save_directory, "max_activating_image_label_indices.pt"), 71 | map_location=torch.device(device), 72 | ) 73 | try: 74 | top_entropy = torch.load( 75 | os.path.join(save_directory, "top_entropy.pt"), 76 | map_location=torch.device(device), 77 | ) 78 | except: # noqa: E722 79 | print("Calculating top entropy") 80 | top_entropy = calculate_entropy(top_val, top_label) 81 | torch.save(top_entropy, os.path.join(save_directory, "top_entropy.pt")) 82 | 83 | print(f"Stats loaded from {save_directory}") 84 | 85 | stats = { 86 | "mean_acts": mean_acts.to(device), 87 | "sparsity": sparsity.to(device), 88 | "top_val": top_val.to(device), 89 | "top_idx": top_idx.to(device).to(torch.int64), 90 | "top_label": top_label.to(device).to(torch.int64), 91 | "top_entropy": top_entropy.to(device), 92 | } 93 | return stats 94 | 95 | 96 | def get_stats_scatter_plot(stats, mask=None, save_directory: str = None, eps=1e-9): 97 | if mask is None: 98 | mask = torch.ones_like( 99 | stats["sparsity"], dtype=torch.bool, device=stats["sparsity"].device 100 | ) 101 | 102 | indices = torch.where(mask)[0] 103 | plotting_data = torch.stack( 104 | [ 105 | torch.log10(stats["sparsity"][mask] + eps), 106 | torch.log10(stats["mean_acts"][mask] + eps), 107 | stats["top_entropy"][mask], 108 | indices, 109 | ], 110 | dim=0, 111 | ) 112 | plotting_data = plotting_data.transpose(0, 1) 113 | 114 | x_label = "log10(sparsity)" 115 | y_label = "log10(mean_acts)" 116 | color_label = "entropy" 117 | hover_label = "index" 118 | 119 | df = pd.DataFrame( 120 | plotting_data.numpy(), columns=[x_label, y_label, color_label, hover_label] 121 | ) 122 | fig = px.scatter( 123 | df, 124 | x=x_label, 125 | y=y_label, 126 | color=color_label, 127 | marginal_x="histogram", 128 | marginal_y="histogram", 129 | opacity=0.5, 130 | hover_data=[hover_label], 131 | ) 132 | 133 | if save_directory is not None: 134 | fig.write_image(os.path.join(save_directory, "scatter_plot.png")) 135 | fig.show() 136 | 137 | 138 | def plot_ref_images(stats, dataset, latent_idx: int, plot_top_k: int = 10, eps=1e-9): 139 | resize_size = 224 140 | num_cols = 5 141 | num_rows = plot_top_k // num_cols 142 | 143 | images = [] 144 | labels = [] 145 | 146 | for i, idx in enumerate(stats["top_idx"][latent_idx][:plot_top_k]): 147 | img = dataset[idx.item()]["image"] 148 | images.append(img.resize((resize_size, resize_size))) 149 | assert dataset[idx.item()]["label"] == stats["top_label"][latent_idx][i], ( 150 | "label mismatch, try matching dataset shuffle seed" 151 | ) 152 | labels.append(dataset[idx.item()]["label"]) 153 | 154 | _, axes = plt.subplots(num_rows, num_cols, figsize=(4.5 * num_cols, 6 * num_rows)) 155 | 156 | for i, (image, label) in enumerate(zip(images, labels)): 157 | ax = axes[i // num_cols, i % num_cols] 158 | ax.imshow(image) 159 | top_val_i = torch.log10(stats["top_val"][latent_idx][i] + eps).item() 160 | ax.set_title(f"{label}(a: {top_val_i:.2f})", fontsize=35) 161 | ax.axis("off") 162 | 163 | mean_acts = torch.log10(stats["mean_acts"][latent_idx] + eps).item() 164 | sparsity = torch.log10(stats["sparsity"][latent_idx] + eps).item() 165 | plt.suptitle( 166 | f"Index {latent_idx} (f: {sparsity:.2f}, a: {mean_acts:.2f}, e: {stats['top_entropy'][latent_idx]:.2f})\n", 167 | fontsize=35, 168 | ) 169 | 170 | plt.tight_layout() 171 | plt.show() 172 | -------------------------------------------------------------------------------- /assets/demo_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamical-inference/patchsae/6b468b36fe4003724e1c5445180313c79b7a2d0c/assets/demo_overview.png -------------------------------------------------------------------------------- /assets/patchsae.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamical-inference/patchsae/6b468b36fe4003724e1c5445180313c79b7a2d0c/assets/patchsae.gif -------------------------------------------------------------------------------- /assets/ref_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamical-inference/patchsae/6b468b36fe4003724e1c5445180313c79b7a2d0c/assets/ref_1.gif -------------------------------------------------------------------------------- /assets/ref_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamical-inference/patchsae/6b468b36fe4003724e1c5445180313c79b7a2d0c/assets/ref_2.gif -------------------------------------------------------------------------------- /assets/ref_3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamical-inference/patchsae/6b468b36fe4003724e1c5445180313c79b7a2d0c/assets/ref_3.gif -------------------------------------------------------------------------------- /assets/ref_imgs_munich.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamical-inference/patchsae/6b468b36fe4003724e1c5445180313c79b7a2d0c/assets/ref_imgs_munich.gif -------------------------------------------------------------------------------- /assets/sae_arch.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamical-inference/patchsae/6b468b36fe4003724e1c5445180313c79b7a2d0c/assets/sae_arch.gif -------------------------------------------------------------------------------- /configs/classnames/caltech101_classnames.txt: -------------------------------------------------------------------------------- 1 | accordion 2 | airplanes 3 | anchor 4 | ant 5 | background_google 6 | barrel 7 | bass 8 | beaver 9 | binocular 10 | bonsai 11 | brain 12 | brontosaurus 13 | buddha 14 | butterfly 15 | camera 16 | cannon 17 | car_side 18 | ceiling_fan 19 | cellphone 20 | chair 21 | chandelier 22 | cougar_body 23 | cougar_face 24 | crab 25 | crayfish 26 | crocodile 27 | crocodile_head 28 | cup 29 | dalmatian 30 | dollar_bill 31 | dolphin 32 | dragonfly 33 | electric_guitar 34 | elephant 35 | emu 36 | euphonium 37 | ewer 38 | faces 39 | faces_easy 40 | ferry 41 | flamingo 42 | flamingo_head 43 | garfield 44 | gerenuk 45 | gramophone 46 | grand_piano 47 | hawksbill 48 | headphone 49 | hedgehog 50 | helicopter 51 | ibis 52 | inline_skate 53 | joshua_tree 54 | kangaroo 55 | ketch 56 | lamp 57 | laptop 58 | leopards 59 | llama 60 | lobster 61 | lotus 62 | mandolin 63 | mayfly 64 | menorah 65 | metronome 66 | minaret 67 | motorbikes 68 | nautilus 69 | octopus 70 | okapi 71 | pagoda 72 | panda 73 | pigeon 74 | pizza 75 | platypus 76 | pyramid 77 | revolver 78 | rhino 79 | rooster 80 | saxophone 81 | schooner 82 | scissors 83 | scorpion 84 | sea_horse 85 | snoopy 86 | soccer_ball 87 | stapler 88 | starfish 89 | stegosaurus 90 | stop_sign 91 | strawberry 92 | sunflower 93 | tick 94 | trilobite 95 | umbrella 96 | watch 97 | water_lilly 98 | wheelchair 99 | wild_cat 100 | windsor_chair 101 | wrench 102 | yin_yang 103 | -------------------------------------------------------------------------------- /configs/classnames/cifar100_cat_to_name.json: -------------------------------------------------------------------------------- 1 | {"0": "apple", "1": "aquarium_fish", "2": "baby", "3": "bear", "4": "beaver", "5": "bed", "6": "bee", "7": "beetle", "8": "bicycle", "9": "bottle", "10": "bowl", "11": "boy", "12": "bridge", "13": "bus", "14": "butterfly", "15": "camel", "16": "can", "17": "castle", "18": "caterpillar", "19": "cattle", "20": "chair", "21": "chimpanzee", "22": "clock", "23": "cloud", "24": "cockroach", "25": "couch", "26": "cra", "27": "crocodile", "28": "cup", "29": "dinosaur", "30": "dolphin", "31": "elephant", "32": "flatfish", "33": "forest", "34": "fox", "35": "girl", "36": "hamster", "37": "house", "38": "kangaroo", "39": "keyboard", "40": "lamp", "41": "lawn_mower", "42": "leopard", "43": "lion", "44": "lizard", "45": "lobster", "46": "man", "47": "maple_tree", "48": "motorcycle", "49": "mountain", "50": "mouse", "51": "mushroom", "52": "oak_tree", "53": "orange", "54": "orchid", "55": "otter", "56": "palm_tree", "57": "pear", "58": "pickup_truck", "59": "pine_tree", "60": "plain", "61": "plate", "62": "poppy", "63": "porcupine", "64": "possum", "65": "rabbit", "66": "raccoon", "67": "ray", "68": "road", "69": "rocket", "70": "rose", "71": "sea", "72": "seal", "73": "shark", "74": "shrew", "75": "skunk", "76": "skyscraper", "77": "snail", "78": "snake", "79": "spider", "80": "squirrel", "81": "streetcar", "82": "sunflower", "83": "sweet_pepper", "84": "table", "85": "tank", "86": "telephone", "87": "television", "88": "tiger", "89": "tractor", "90": "train", "91": "trout", "92": "tulip", "93": "turtle", "94": "wardrobe", "95": "whale", "96": "willow_tree", "97": "wolf", "98": "woman", "99": "worm"} 2 | -------------------------------------------------------------------------------- /configs/classnames/food101_cat_to_name.json: -------------------------------------------------------------------------------- 1 | {"0": "apple_pie", "1": "baby_back_ribs", "2": "baklava", "3": "beef_carpaccio", "4": "beef_tartare", "5": "beet_salad", "6": "beignets", "7": "bibimbap", "8": "bread_pudding", "9": "breakfast_burrito", "10": "bruschetta", "11": "caesar_salad", "12": "cannoli", "13": "caprese_salad", "14": "carrot_cake", "15": "ceviche", "16": "cheesecake", "17": "cheese_plate", "18": "chicken_curry", "19": "chicken_quesadilla", "20": "chicken_wings", "21": "chocolate_cake", "22": "chocolate_mousse", "23": "churros", "24": "clam_chowder", "25": "club_sandwich", "26": "crab_cakes", "27": "creme_brulee", "28": "croque_madame", "29": "cup_cakes", "30": "deviled_eggs", "31": "donuts", "32": "dumplings", "33": "edamame", "34": "eggs_benedict", "35": "escargots", "36": "falafel", "37": "filet_mignon", "38": "fish_and_chips", "39": "foie_gras", "40": "french_fries", "41": "french_onion_soup", "42": "french_toast", "43": "fried_calamari", "44": "fried_rice", "45": "frozen_yogurt", "46": "garlic_bread", "47": "gnocchi", "48": "greek_salad", "49": "grilled_cheese_sandwich", "50": "grilled_salmon", "51": "guacamole", "52": "gyoza", "53": "hamburger", "54": "hot_and_sour_soup", "55": "hot_dog", "56": "huevos_rancheros", "57": "hummus", "58": "ice_cream", "59": "lasagna", "60": "lobster_bisque", "61": "lobster_roll_sandwich", "62": "macaroni_and_cheese", "63": "macarons", "64": "miso_soup", "65": "mussels", "66": "nachos", "67": "omelette", "68": "onion_rings", "69": "oysters", "70": "pad_thai", "71": "paella", "72": "pancakes", "73": "panna_cotta", "74": "peking_duck", "75": "pho", "76": "pizza", "77": "pork_chop", "78": "poutine", "79": "prime_rib", "80": "pulled_pork_sandwich", "81": "ramen", "82": "ravioli", "83": "red_velvet_cake", "84": "risotto", "85": "samosa", "86": "sashimi", "87": "scallops", "88": "seaweed_salad", "89": "shrimp_and_grits", "90": "spaghetti_bolognese", "91": "spaghetti_carbonara", "92": "spring_rolls", "93": "steak", "94": "strawberry_shortcake", "95": "sushi", "96": "tacos", "97": "takoyaki", "98": "tiramisu", "99": "tuna_tartare", "100": "waffles"} 2 | -------------------------------------------------------------------------------- /configs/classnames/imagenet_classnames.txt: -------------------------------------------------------------------------------- 1 | n01440764 tench 2 | n01443537 goldfish 3 | n01484850 great white shark 4 | n01491361 tiger shark 5 | n01494475 hammerhead shark 6 | n01496331 electric ray 7 | n01498041 stingray 8 | n01514668 rooster 9 | n01514859 hen 10 | n01518878 ostrich 11 | n01530575 brambling 12 | n01531178 goldfinch 13 | n01532829 house finch 14 | n01534433 junco 15 | n01537544 indigo bunting 16 | n01558993 American robin 17 | n01560419 bulbul 18 | n01580077 jay 19 | n01582220 magpie 20 | n01592084 chickadee 21 | n01601694 American dipper 22 | n01608432 kite (bird of prey) 23 | n01614925 bald eagle 24 | n01616318 vulture 25 | n01622779 great grey owl 26 | n01629819 fire salamander 27 | n01630670 smooth newt 28 | n01631663 newt 29 | n01632458 spotted salamander 30 | n01632777 axolotl 31 | n01641577 American bullfrog 32 | n01644373 tree frog 33 | n01644900 tailed frog 34 | n01664065 loggerhead sea turtle 35 | n01665541 leatherback sea turtle 36 | n01667114 mud turtle 37 | n01667778 terrapin 38 | n01669191 box turtle 39 | n01675722 banded gecko 40 | n01677366 green iguana 41 | n01682714 Carolina anole 42 | n01685808 desert grassland whiptail lizard 43 | n01687978 agama 44 | n01688243 frilled-necked lizard 45 | n01689811 alligator lizard 46 | n01692333 Gila monster 47 | n01693334 European green lizard 48 | n01694178 chameleon 49 | n01695060 Komodo dragon 50 | n01697457 Nile crocodile 51 | n01698640 American alligator 52 | n01704323 triceratops 53 | n01728572 worm snake 54 | n01728920 ring-necked snake 55 | n01729322 eastern hog-nosed snake 56 | n01729977 smooth green snake 57 | n01734418 kingsnake 58 | n01735189 garter snake 59 | n01737021 water snake 60 | n01739381 vine snake 61 | n01740131 night snake 62 | n01742172 boa constrictor 63 | n01744401 African rock python 64 | n01748264 Indian cobra 65 | n01749939 green mamba 66 | n01751748 sea snake 67 | n01753488 Saharan horned viper 68 | n01755581 eastern diamondback rattlesnake 69 | n01756291 sidewinder rattlesnake 70 | n01768244 trilobite 71 | n01770081 harvestman 72 | n01770393 scorpion 73 | n01773157 yellow garden spider 74 | n01773549 barn spider 75 | n01773797 European garden spider 76 | n01774384 southern black widow 77 | n01774750 tarantula 78 | n01775062 wolf spider 79 | n01776313 tick 80 | n01784675 centipede 81 | n01795545 black grouse 82 | n01796340 ptarmigan 83 | n01797886 ruffed grouse 84 | n01798484 prairie grouse 85 | n01806143 peafowl 86 | n01806567 quail 87 | n01807496 partridge 88 | n01817953 african grey parrot 89 | n01818515 macaw 90 | n01819313 sulphur-crested cockatoo 91 | n01820546 lorikeet 92 | n01824575 coucal 93 | n01828970 bee eater 94 | n01829413 hornbill 95 | n01833805 hummingbird 96 | n01843065 jacamar 97 | n01843383 toucan 98 | n01847000 duck 99 | n01855032 red-breasted merganser 100 | n01855672 goose 101 | n01860187 black swan 102 | n01871265 tusker 103 | n01872401 echidna 104 | n01873310 platypus 105 | n01877812 wallaby 106 | n01882714 koala 107 | n01883070 wombat 108 | n01910747 jellyfish 109 | n01914609 sea anemone 110 | n01917289 brain coral 111 | n01924916 flatworm 112 | n01930112 nematode 113 | n01943899 conch 114 | n01944390 snail 115 | n01945685 slug 116 | n01950731 sea slug 117 | n01955084 chiton 118 | n01968897 chambered nautilus 119 | n01978287 Dungeness crab 120 | n01978455 rock crab 121 | n01980166 fiddler crab 122 | n01981276 red king crab 123 | n01983481 American lobster 124 | n01984695 spiny lobster 125 | n01985128 crayfish 126 | n01986214 hermit crab 127 | n01990800 isopod 128 | n02002556 white stork 129 | n02002724 black stork 130 | n02006656 spoonbill 131 | n02007558 flamingo 132 | n02009229 little blue heron 133 | n02009912 great egret 134 | n02011460 bittern bird 135 | n02012849 crane bird 136 | n02013706 limpkin 137 | n02017213 common gallinule 138 | n02018207 American coot 139 | n02018795 bustard 140 | n02025239 ruddy turnstone 141 | n02027492 dunlin 142 | n02028035 common redshank 143 | n02033041 dowitcher 144 | n02037110 oystercatcher 145 | n02051845 pelican 146 | n02056570 king penguin 147 | n02058221 albatross 148 | n02066245 grey whale 149 | n02071294 killer whale 150 | n02074367 dugong 151 | n02077923 sea lion 152 | n02085620 Chihuahua 153 | n02085782 Japanese Chin 154 | n02085936 Maltese 155 | n02086079 Pekingese 156 | n02086240 Shih Tzu 157 | n02086646 King Charles Spaniel 158 | n02086910 Papillon 159 | n02087046 toy terrier 160 | n02087394 Rhodesian Ridgeback 161 | n02088094 Afghan Hound 162 | n02088238 Basset Hound 163 | n02088364 Beagle 164 | n02088466 Bloodhound 165 | n02088632 Bluetick Coonhound 166 | n02089078 Black and Tan Coonhound 167 | n02089867 Treeing Walker Coonhound 168 | n02089973 English foxhound 169 | n02090379 Redbone Coonhound 170 | n02090622 borzoi 171 | n02090721 Irish Wolfhound 172 | n02091032 Italian Greyhound 173 | n02091134 Whippet 174 | n02091244 Ibizan Hound 175 | n02091467 Norwegian Elkhound 176 | n02091635 Otterhound 177 | n02091831 Saluki 178 | n02092002 Scottish Deerhound 179 | n02092339 Weimaraner 180 | n02093256 Staffordshire Bull Terrier 181 | n02093428 American Staffordshire Terrier 182 | n02093647 Bedlington Terrier 183 | n02093754 Border Terrier 184 | n02093859 Kerry Blue Terrier 185 | n02093991 Irish Terrier 186 | n02094114 Norfolk Terrier 187 | n02094258 Norwich Terrier 188 | n02094433 Yorkshire Terrier 189 | n02095314 Wire Fox Terrier 190 | n02095570 Lakeland Terrier 191 | n02095889 Sealyham Terrier 192 | n02096051 Airedale Terrier 193 | n02096177 Cairn Terrier 194 | n02096294 Australian Terrier 195 | n02096437 Dandie Dinmont Terrier 196 | n02096585 Boston Terrier 197 | n02097047 Miniature Schnauzer 198 | n02097130 Giant Schnauzer 199 | n02097209 Standard Schnauzer 200 | n02097298 Scottish Terrier 201 | n02097474 Tibetan Terrier 202 | n02097658 Australian Silky Terrier 203 | n02098105 Soft-coated Wheaten Terrier 204 | n02098286 West Highland White Terrier 205 | n02098413 Lhasa Apso 206 | n02099267 Flat-Coated Retriever 207 | n02099429 Curly-coated Retriever 208 | n02099601 Golden Retriever 209 | n02099712 Labrador Retriever 210 | n02099849 Chesapeake Bay Retriever 211 | n02100236 German Shorthaired Pointer 212 | n02100583 Vizsla 213 | n02100735 English Setter 214 | n02100877 Irish Setter 215 | n02101006 Gordon Setter 216 | n02101388 Brittany dog 217 | n02101556 Clumber Spaniel 218 | n02102040 English Springer Spaniel 219 | n02102177 Welsh Springer Spaniel 220 | n02102318 Cocker Spaniel 221 | n02102480 Sussex Spaniel 222 | n02102973 Irish Water Spaniel 223 | n02104029 Kuvasz 224 | n02104365 Schipperke 225 | n02105056 Groenendael dog 226 | n02105162 Malinois 227 | n02105251 Briard 228 | n02105412 Australian Kelpie 229 | n02105505 Komondor 230 | n02105641 Old English Sheepdog 231 | n02105855 Shetland Sheepdog 232 | n02106030 collie 233 | n02106166 Border Collie 234 | n02106382 Bouvier des Flandres dog 235 | n02106550 Rottweiler 236 | n02106662 German Shepherd Dog 237 | n02107142 Dobermann 238 | n02107312 Miniature Pinscher 239 | n02107574 Greater Swiss Mountain Dog 240 | n02107683 Bernese Mountain Dog 241 | n02107908 Appenzeller Sennenhund 242 | n02108000 Entlebucher Sennenhund 243 | n02108089 Boxer 244 | n02108422 Bullmastiff 245 | n02108551 Tibetan Mastiff 246 | n02108915 French Bulldog 247 | n02109047 Great Dane 248 | n02109525 St. Bernard 249 | n02109961 husky 250 | n02110063 Alaskan Malamute 251 | n02110185 Siberian Husky 252 | n02110341 Dalmatian 253 | n02110627 Affenpinscher 254 | n02110806 Basenji 255 | n02110958 pug 256 | n02111129 Leonberger 257 | n02111277 Newfoundland dog 258 | n02111500 Great Pyrenees dog 259 | n02111889 Samoyed 260 | n02112018 Pomeranian 261 | n02112137 Chow Chow 262 | n02112350 Keeshond 263 | n02112706 brussels griffon 264 | n02113023 Pembroke Welsh Corgi 265 | n02113186 Cardigan Welsh Corgi 266 | n02113624 Toy Poodle 267 | n02113712 Miniature Poodle 268 | n02113799 Standard Poodle 269 | n02113978 Mexican hairless dog (xoloitzcuintli) 270 | n02114367 grey wolf 271 | n02114548 Alaskan tundra wolf 272 | n02114712 red wolf or maned wolf 273 | n02114855 coyote 274 | n02115641 dingo 275 | n02115913 dhole 276 | n02116738 African wild dog 277 | n02117135 hyena 278 | n02119022 red fox 279 | n02119789 kit fox 280 | n02120079 Arctic fox 281 | n02120505 grey fox 282 | n02123045 tabby cat 283 | n02123159 tiger cat 284 | n02123394 Persian cat 285 | n02123597 Siamese cat 286 | n02124075 Egyptian Mau 287 | n02125311 cougar 288 | n02127052 lynx 289 | n02128385 leopard 290 | n02128757 snow leopard 291 | n02128925 jaguar 292 | n02129165 lion 293 | n02129604 tiger 294 | n02130308 cheetah 295 | n02132136 brown bear 296 | n02133161 American black bear 297 | n02134084 polar bear 298 | n02134418 sloth bear 299 | n02137549 mongoose 300 | n02138441 meerkat 301 | n02165105 tiger beetle 302 | n02165456 ladybug 303 | n02167151 ground beetle 304 | n02168699 longhorn beetle 305 | n02169497 leaf beetle 306 | n02172182 dung beetle 307 | n02174001 rhinoceros beetle 308 | n02177972 weevil 309 | n02190166 fly 310 | n02206856 bee 311 | n02219486 ant 312 | n02226429 grasshopper 313 | n02229544 cricket insect 314 | n02231487 stick insect 315 | n02233338 cockroach 316 | n02236044 praying mantis 317 | n02256656 cicada 318 | n02259212 leafhopper 319 | n02264363 lacewing 320 | n02268443 dragonfly 321 | n02268853 damselfly 322 | n02276258 red admiral butterfly 323 | n02277742 ringlet butterfly 324 | n02279972 monarch butterfly 325 | n02280649 small white butterfly 326 | n02281406 sulphur butterfly 327 | n02281787 gossamer-winged butterfly 328 | n02317335 starfish 329 | n02319095 sea urchin 330 | n02321529 sea cucumber 331 | n02325366 cottontail rabbit 332 | n02326432 hare 333 | n02328150 Angora rabbit 334 | n02342885 hamster 335 | n02346627 porcupine 336 | n02356798 fox squirrel 337 | n02361337 marmot 338 | n02363005 beaver 339 | n02364673 guinea pig 340 | n02389026 common sorrel horse 341 | n02391049 zebra 342 | n02395406 pig 343 | n02396427 wild boar 344 | n02397096 warthog 345 | n02398521 hippopotamus 346 | n02403003 ox 347 | n02408429 water buffalo 348 | n02410509 bison 349 | n02412080 ram (adult male sheep) 350 | n02415577 bighorn sheep 351 | n02417914 Alpine ibex 352 | n02422106 hartebeest 353 | n02422699 impala (antelope) 354 | n02423022 gazelle 355 | n02437312 arabian camel 356 | n02437616 llama 357 | n02441942 weasel 358 | n02442845 mink 359 | n02443114 European polecat 360 | n02443484 black-footed ferret 361 | n02444819 otter 362 | n02445715 skunk 363 | n02447366 badger 364 | n02454379 armadillo 365 | n02457408 three-toed sloth 366 | n02480495 orangutan 367 | n02480855 gorilla 368 | n02481823 chimpanzee 369 | n02483362 gibbon 370 | n02483708 siamang 371 | n02484975 guenon 372 | n02486261 patas monkey 373 | n02486410 baboon 374 | n02487347 macaque 375 | n02488291 langur 376 | n02488702 black-and-white colobus 377 | n02489166 proboscis monkey 378 | n02490219 marmoset 379 | n02492035 white-headed capuchin 380 | n02492660 howler monkey 381 | n02493509 titi monkey 382 | n02493793 Geoffroy's spider monkey 383 | n02494079 common squirrel monkey 384 | n02497673 ring-tailed lemur 385 | n02500267 indri 386 | n02504013 Asian elephant 387 | n02504458 African bush elephant 388 | n02509815 red panda 389 | n02510455 giant panda 390 | n02514041 snoek fish 391 | n02526121 eel 392 | n02536864 silver salmon 393 | n02606052 rock beauty fish 394 | n02607072 clownfish 395 | n02640242 sturgeon 396 | n02641379 gar fish 397 | n02643566 lionfish 398 | n02655020 pufferfish 399 | n02666196 abacus 400 | n02667093 abaya 401 | n02669723 academic gown 402 | n02672831 accordion 403 | n02676566 acoustic guitar 404 | n02687172 aircraft carrier 405 | n02690373 airliner 406 | n02692877 airship 407 | n02699494 altar 408 | n02701002 ambulance 409 | n02704792 amphibious vehicle 410 | n02708093 analog clock 411 | n02727426 apiary 412 | n02730930 apron 413 | n02747177 trash can 414 | n02749479 assault rifle 415 | n02769748 backpack 416 | n02776631 bakery 417 | n02777292 balance beam 418 | n02782093 balloon 419 | n02783161 ballpoint pen 420 | n02786058 Band-Aid 421 | n02787622 banjo 422 | n02788148 baluster / handrail 423 | n02790996 barbell 424 | n02791124 barber chair 425 | n02791270 barbershop 426 | n02793495 barn 427 | n02794156 barometer 428 | n02795169 barrel 429 | n02797295 wheelbarrow 430 | n02799071 baseball 431 | n02802426 basketball 432 | n02804414 bassinet 433 | n02804610 bassoon 434 | n02807133 swimming cap 435 | n02808304 bath towel 436 | n02808440 bathtub 437 | n02814533 station wagon 438 | n02814860 lighthouse 439 | n02815834 beaker 440 | n02817516 military hat (bearskin or shako) 441 | n02823428 beer bottle 442 | n02823750 beer glass 443 | n02825657 bell tower 444 | n02834397 baby bib 445 | n02835271 tandem bicycle 446 | n02837789 bikini 447 | n02840245 ring binder 448 | n02841315 binoculars 449 | n02843684 birdhouse 450 | n02859443 boathouse 451 | n02860847 bobsleigh 452 | n02865351 bolo tie 453 | n02869837 poke bonnet 454 | n02870880 bookcase 455 | n02871525 bookstore 456 | n02877765 bottle cap 457 | n02879718 hunting bow 458 | n02883205 bow tie 459 | n02892201 brass memorial plaque 460 | n02892767 bra 461 | n02894605 breakwater 462 | n02895154 breastplate 463 | n02906734 broom 464 | n02909870 bucket 465 | n02910353 buckle 466 | n02916936 bulletproof vest 467 | n02917067 high-speed train 468 | n02927161 butcher shop 469 | n02930766 taxicab 470 | n02939185 cauldron 471 | n02948072 candle 472 | n02950826 cannon 473 | n02951358 canoe 474 | n02951585 can opener 475 | n02963159 cardigan 476 | n02965783 car mirror 477 | n02966193 carousel 478 | n02966687 tool kit 479 | n02971356 cardboard box / carton 480 | n02974003 car wheel 481 | n02977058 automated teller machine 482 | n02978881 cassette 483 | n02979186 cassette player 484 | n02980441 castle 485 | n02981792 catamaran 486 | n02988304 CD player 487 | n02992211 cello 488 | n02992529 mobile phone 489 | n02999410 chain 490 | n03000134 chain-link fence 491 | n03000247 chain mail 492 | n03000684 chainsaw 493 | n03014705 storage chest 494 | n03016953 chiffonier 495 | n03017168 bell or wind chime 496 | n03018349 china cabinet 497 | n03026506 Christmas stocking 498 | n03028079 church 499 | n03032252 movie theater 500 | n03041632 cleaver 501 | n03042490 cliff dwelling 502 | n03045698 cloak 503 | n03047690 clogs 504 | n03062245 cocktail shaker 505 | n03063599 coffee mug 506 | n03063689 coffeemaker 507 | n03065424 spiral or coil 508 | n03075370 combination lock 509 | n03085013 computer keyboard 510 | n03089624 candy store 511 | n03095699 container ship 512 | n03100240 convertible 513 | n03109150 corkscrew 514 | n03110669 cornet 515 | n03124043 cowboy boot 516 | n03124170 cowboy hat 517 | n03125729 cradle 518 | n03126707 construction crane 519 | n03127747 crash helmet 520 | n03127925 crate 521 | n03131574 infant bed 522 | n03133878 Crock Pot 523 | n03134739 croquet ball 524 | n03141823 crutch 525 | n03146219 cuirass 526 | n03160309 dam 527 | n03179701 desk 528 | n03180011 desktop computer 529 | n03187595 rotary dial telephone 530 | n03188531 diaper 531 | n03196217 digital clock 532 | n03197337 digital watch 533 | n03201208 dining table 534 | n03207743 dishcloth 535 | n03207941 dishwasher 536 | n03208938 disc brake 537 | n03216828 dock 538 | n03218198 dog sled 539 | n03220513 dome 540 | n03223299 doormat 541 | n03240683 drilling rig 542 | n03249569 drum 543 | n03250847 drumstick 544 | n03255030 dumbbell 545 | n03259280 Dutch oven 546 | n03271574 electric fan 547 | n03272010 electric guitar 548 | n03272562 electric locomotive 549 | n03290653 entertainment center 550 | n03291819 envelope 551 | n03297495 espresso machine 552 | n03314780 face powder 553 | n03325584 feather boa 554 | n03337140 filing cabinet 555 | n03344393 fireboat 556 | n03345487 fire truck 557 | n03347037 fire screen 558 | n03355925 flagpole 559 | n03372029 flute 560 | n03376595 folding chair 561 | n03379051 football helmet 562 | n03384352 forklift 563 | n03388043 fountain 564 | n03388183 fountain pen 565 | n03388549 four-poster bed 566 | n03393912 freight car 567 | n03394916 French horn 568 | n03400231 frying pan 569 | n03404251 fur coat 570 | n03417042 garbage truck 571 | n03424325 gas mask or respirator 572 | n03425413 gas pump 573 | n03443371 goblet 574 | n03444034 go-kart 575 | n03445777 golf ball 576 | n03445924 golf cart 577 | n03447447 gondola 578 | n03447721 gong 579 | n03450230 gown 580 | n03452741 grand piano 581 | n03457902 greenhouse 582 | n03459775 radiator grille 583 | n03461385 grocery store 584 | n03467068 guillotine 585 | n03476684 hair clip 586 | n03476991 hair spray 587 | n03478589 half-track 588 | n03481172 hammer 589 | n03482405 hamper 590 | n03483316 hair dryer 591 | n03485407 hand-held computer 592 | n03485794 handkerchief 593 | n03492542 hard disk drive 594 | n03494278 harmonica 595 | n03495258 harp 596 | n03496892 combine harvester 597 | n03498962 hatchet 598 | n03527444 holster 599 | n03529860 home theater 600 | n03530642 honeycomb 601 | n03532672 hook 602 | n03534580 hoop skirt 603 | n03535780 gymnastic horizontal bar 604 | n03538406 horse-drawn vehicle 605 | n03544143 hourglass 606 | n03584254 iPod 607 | n03584829 clothes iron 608 | n03590841 carved pumpkin 609 | n03594734 jeans 610 | n03594945 jeep 611 | n03595614 T-shirt 612 | n03598930 jigsaw puzzle 613 | n03599486 rickshaw 614 | n03602883 joystick 615 | n03617480 kimono 616 | n03623198 knee pad 617 | n03627232 knot 618 | n03630383 lab coat 619 | n03633091 ladle 620 | n03637318 lampshade 621 | n03642806 laptop computer 622 | n03649909 lawn mower 623 | n03657121 lens cap 624 | n03658185 letter opener 625 | n03661043 library 626 | n03662601 lifeboat 627 | n03666591 lighter 628 | n03670208 limousine 629 | n03673027 ocean liner 630 | n03676483 lipstick 631 | n03680355 slip-on shoe 632 | n03690938 lotion 633 | n03691459 music speaker 634 | n03692522 loupe magnifying glass 635 | n03697007 sawmill 636 | n03706229 magnetic compass 637 | n03709823 messenger bag 638 | n03710193 mailbox 639 | n03710637 tights 640 | n03710721 one-piece bathing suit 641 | n03717622 manhole cover 642 | n03720891 maraca 643 | n03721384 marimba 644 | n03724870 mask 645 | n03729826 matchstick 646 | n03733131 maypole 647 | n03733281 maze 648 | n03733805 measuring cup 649 | n03742115 medicine cabinet 650 | n03743016 megalith 651 | n03759954 microphone 652 | n03761084 microwave oven 653 | n03763968 military uniform 654 | n03764736 milk can 655 | n03769881 minibus 656 | n03770439 miniskirt 657 | n03770679 minivan 658 | n03773504 missile 659 | n03775071 mitten 660 | n03775546 mixing bowl 661 | n03776460 mobile home 662 | n03777568 ford model t 663 | n03777754 modem 664 | n03781244 monastery 665 | n03782006 monitor 666 | n03785016 moped 667 | n03786901 mortar and pestle 668 | n03787032 graduation cap 669 | n03788195 mosque 670 | n03788365 mosquito net 671 | n03791053 vespa 672 | n03792782 mountain bike 673 | n03792972 tent 674 | n03793489 computer mouse 675 | n03794056 mousetrap 676 | n03796401 moving van 677 | n03803284 muzzle 678 | n03804744 metal nail 679 | n03814639 neck brace 680 | n03814906 necklace 681 | n03825788 baby pacifier 682 | n03832673 notebook computer 683 | n03837869 obelisk 684 | n03838899 oboe 685 | n03840681 ocarina 686 | n03841143 odometer 687 | n03843555 oil filter 688 | n03854065 pipe organ 689 | n03857828 oscilloscope 690 | n03866082 overskirt 691 | n03868242 bullock cart 692 | n03868863 oxygen mask 693 | n03871628 product packet / packaging 694 | n03873416 paddle 695 | n03874293 paddle wheel 696 | n03874599 padlock 697 | n03876231 paintbrush 698 | n03877472 pajamas 699 | n03877845 palace 700 | n03884397 pan flute 701 | n03887697 paper towel 702 | n03888257 parachute 703 | n03888605 parallel bars 704 | n03891251 park bench 705 | n03891332 parking meter 706 | n03895866 railroad car 707 | n03899768 patio 708 | n03902125 payphone 709 | n03903868 pedestal 710 | n03908618 pencil case 711 | n03908714 pencil sharpener 712 | n03916031 perfume 713 | n03920288 Petri dish 714 | n03924679 photocopier 715 | n03929660 plectrum 716 | n03929855 Pickelhaube 717 | n03930313 picket fence 718 | n03930630 pickup truck 719 | n03933933 pier 720 | n03935335 piggy bank 721 | n03937543 pill bottle 722 | n03938244 pillow 723 | n03942813 ping-pong ball 724 | n03944341 pinwheel 725 | n03947888 pirate ship 726 | n03950228 drink pitcher 727 | n03954731 block plane 728 | n03956157 planetarium 729 | n03958227 plastic bag 730 | n03961711 plate rack 731 | n03967562 farm plow 732 | n03970156 plunger 733 | n03976467 Polaroid camera 734 | n03976657 pole 735 | n03977966 police van 736 | n03980874 poncho 737 | n03982430 pool table 738 | n03983396 soda bottle 739 | n03991062 plant pot 740 | n03992509 potter's wheel 741 | n03995372 power drill 742 | n03998194 prayer rug 743 | n04004767 printer 744 | n04005630 prison 745 | n04008634 missile 746 | n04009552 projector 747 | n04019541 hockey puck 748 | n04023962 punching bag 749 | n04026417 purse 750 | n04033901 quill 751 | n04033995 quilt 752 | n04037443 race car 753 | n04039381 racket 754 | n04040759 radiator 755 | n04041544 radio 756 | n04044716 radio telescope 757 | n04049303 rain barrel 758 | n04065272 recreational vehicle 759 | n04067472 fishing casting reel 760 | n04069434 reflex camera 761 | n04070727 refrigerator 762 | n04074963 remote control 763 | n04081281 restaurant 764 | n04086273 revolver 765 | n04090263 rifle 766 | n04099969 rocking chair 767 | n04111531 rotisserie 768 | n04116512 eraser 769 | n04118538 rugby ball 770 | n04118776 ruler measuring stick 771 | n04120489 sneaker 772 | n04125021 safe 773 | n04127249 safety pin 774 | n04131690 salt shaker 775 | n04133789 sandal 776 | n04136333 sarong 777 | n04141076 saxophone 778 | n04141327 scabbard 779 | n04141975 weighing scale 780 | n04146614 school bus 781 | n04147183 schooner 782 | n04149813 scoreboard 783 | n04152593 CRT monitor 784 | n04153751 screw 785 | n04154565 screwdriver 786 | n04162706 seat belt 787 | n04179913 sewing machine 788 | n04192698 shield 789 | n04200800 shoe store 790 | n04201297 shoji screen / room divider 791 | n04204238 shopping basket 792 | n04204347 shopping cart 793 | n04208210 shovel 794 | n04209133 shower cap 795 | n04209239 shower curtain 796 | n04228054 ski 797 | n04229816 balaclava ski mask 798 | n04235860 sleeping bag 799 | n04238763 slide rule 800 | n04239074 sliding door 801 | n04243546 slot machine 802 | n04251144 snorkel 803 | n04252077 snowmobile 804 | n04252225 snowplow 805 | n04254120 soap dispenser 806 | n04254680 soccer ball 807 | n04254777 sock 808 | n04258138 solar thermal collector 809 | n04259630 sombrero 810 | n04263257 soup bowl 811 | n04264628 keyboard space bar 812 | n04265275 space heater 813 | n04266014 space shuttle 814 | n04270147 spatula 815 | n04273569 motorboat 816 | n04275548 spider web 817 | n04277352 spindle 818 | n04285008 sports car 819 | n04286575 spotlight 820 | n04296562 stage 821 | n04310018 steam locomotive 822 | n04311004 through arch bridge 823 | n04311174 steel drum 824 | n04317175 stethoscope 825 | n04325704 scarf 826 | n04326547 stone wall 827 | n04328186 stopwatch 828 | n04330267 stove 829 | n04332243 strainer 830 | n04335435 tram 831 | n04336792 stretcher 832 | n04344873 couch 833 | n04346328 stupa 834 | n04347754 submarine 835 | n04350905 suit 836 | n04355338 sundial 837 | n04355933 sunglasses 838 | n04356056 sunglasses 839 | n04357314 sunscreen 840 | n04366367 suspension bridge 841 | n04367480 mop 842 | n04370456 sweatshirt 843 | n04371430 swim trunks / shorts 844 | n04371774 swing 845 | n04372370 electrical switch 846 | n04376876 syringe 847 | n04380533 table lamp 848 | n04389033 tank 849 | n04392985 tape player 850 | n04398044 teapot 851 | n04399382 teddy bear 852 | n04404412 television 853 | n04409515 tennis ball 854 | n04417672 thatched roof 855 | n04418357 front curtain 856 | n04423845 thimble 857 | n04428191 threshing machine 858 | n04429376 throne 859 | n04435653 tile roof 860 | n04442312 toaster 861 | n04443257 tobacco shop 862 | n04447861 toilet seat 863 | n04456115 torch 864 | n04458633 totem pole 865 | n04461696 tow truck 866 | n04462240 toy store 867 | n04465501 tractor 868 | n04467665 semi-trailer truck 869 | n04476259 tray 870 | n04479046 trench coat 871 | n04482393 tricycle 872 | n04483307 trimaran 873 | n04485082 tripod 874 | n04486054 triumphal arch 875 | n04487081 trolleybus 876 | n04487394 trombone 877 | n04493381 hot tub 878 | n04501370 turnstile 879 | n04505470 typewriter keyboard 880 | n04507155 umbrella 881 | n04509417 unicycle 882 | n04515003 upright piano 883 | n04517823 vacuum cleaner 884 | n04522168 vase 885 | n04523525 vaulted or arched ceiling 886 | n04525038 velvet fabric 887 | n04525305 vending machine 888 | n04532106 vestment 889 | n04532670 viaduct 890 | n04536866 violin 891 | n04540053 volleyball 892 | n04542943 waffle iron 893 | n04548280 wall clock 894 | n04548362 wallet 895 | n04550184 wardrobe 896 | n04552348 military aircraft 897 | n04553703 sink 898 | n04554684 washing machine 899 | n04557648 water bottle 900 | n04560804 water jug 901 | n04562935 water tower 902 | n04579145 whiskey jug 903 | n04579432 whistle 904 | n04584207 hair wig 905 | n04589890 window screen 906 | n04590129 window shade 907 | n04591157 Windsor tie 908 | n04591713 wine bottle 909 | n04592741 airplane wing 910 | n04596742 wok 911 | n04597913 wooden spoon 912 | n04599235 wool 913 | n04604644 split-rail fence 914 | n04606251 shipwreck 915 | n04612504 sailboat 916 | n04613696 yurt 917 | n06359193 website 918 | n06596364 comic book 919 | n06785654 crossword 920 | n06794110 traffic or street sign 921 | n06874185 traffic light 922 | n07248320 dust jacket 923 | n07565083 menu 924 | n07579787 plate 925 | n07583066 guacamole 926 | n07584110 consomme 927 | n07590611 hot pot 928 | n07613480 trifle 929 | n07614500 ice cream 930 | n07615774 popsicle 931 | n07684084 baguette 932 | n07693725 bagel 933 | n07695742 pretzel 934 | n07697313 cheeseburger 935 | n07697537 hot dog 936 | n07711569 mashed potatoes 937 | n07714571 cabbage 938 | n07714990 broccoli 939 | n07715103 cauliflower 940 | n07716358 zucchini 941 | n07716906 spaghetti squash 942 | n07717410 acorn squash 943 | n07717556 butternut squash 944 | n07718472 cucumber 945 | n07718747 artichoke 946 | n07720875 bell pepper 947 | n07730033 cardoon 948 | n07734744 mushroom 949 | n07742313 Granny Smith apple 950 | n07745940 strawberry 951 | n07747607 orange 952 | n07749582 lemon 953 | n07753113 fig 954 | n07753275 pineapple 955 | n07753592 banana 956 | n07754684 jackfruit 957 | n07760859 cherimoya (custard apple) 958 | n07768694 pomegranate 959 | n07802026 hay 960 | n07831146 carbonara 961 | n07836838 chocolate syrup 962 | n07860988 dough 963 | n07871810 meatloaf 964 | n07873807 pizza 965 | n07875152 pot pie 966 | n07880968 burrito 967 | n07892512 red wine 968 | n07920052 espresso 969 | n07930864 tea cup 970 | n07932039 eggnog 971 | n09193705 mountain 972 | n09229709 bubble 973 | n09246464 cliff 974 | n09256479 coral reef 975 | n09288635 geyser 976 | n09332890 lakeshore 977 | n09399592 promontory 978 | n09421951 sandbar 979 | n09428293 beach 980 | n09468604 valley 981 | n09472597 volcano 982 | n09835506 baseball player 983 | n10148035 bridegroom 984 | n10565667 scuba diver 985 | n11879895 rapeseed 986 | n11939491 daisy 987 | n12057211 yellow lady's slipper 988 | n12144580 corn 989 | n12267677 acorn 990 | n12620546 rose hip 991 | n12768682 horse chestnut seed 992 | n12985857 coral fungus 993 | n12998815 agaric 994 | n13037406 gyromitra 995 | n13040303 stinkhorn mushroom 996 | n13044778 earth star fungus 997 | n13052670 hen of the woods mushroom 998 | n13054560 bolete 999 | n13133613 corn cob 1000 | n15075141 toilet paper 1001 | -------------------------------------------------------------------------------- /configs/classnames/oxford_flowers_classnames.json: -------------------------------------------------------------------------------- 1 | {"21": "fire lily", "3": "canterbury bells", "45": "bolero deep blue", "1": "pink primrose", "34": "mexican aster", "27": "prince of wales feathers", "7": "moon orchid", "16": "globe-flower", "25": "grape hyacinth", "26": "corn poppy", "79": "toad lily", "39": "siam tulip", "24": "red ginger", "67": "spring crocus", "35": "alpine sea holly", "32": "garden phlox", "10": "globe thistle", "6": "tiger lily", "93": "ball moss", "33": "love in the mist", "9": "monkshood", "102": "blackberry lily", "14": "spear thistle", "19": "balloon flower", "100": "blanket flower", "13": "king protea", "49": "oxeye daisy", "15": "yellow iris", "61": "cautleya spicata", "31": "carnation", "64": "silverbush", "68": "bearded iris", "63": "black-eyed susan", "69": "windflower", "62": "japanese anemone", "20": "giant white arum lily", "38": "great masterwort", "4": "sweet pea", "86": "tree mallow", "101": "trumpet creeper", "42": "daffodil", "22": "pincushion flower", "2": "hard-leaved pocket orchid", "54": "sunflower", "66": "osteospermum", "70": "tree poppy", "85": "desert-rose", "99": "bromelia", "87": "magnolia", "5": "english marigold", "92": "bee balm", "28": "stemless gentian", "97": "mallow", "57": "gaura", "40": "lenten rose", "47": "marigold", "59": "orange dahlia", "48": "buttercup", "55": "pelargonium", "36": "ruby-lipped cattleya", "91": "hippeastrum", "29": "artichoke", "71": "gazania", "90": "canna lily", "18": "peruvian lily", "98": "mexican petunia", "8": "bird of paradise", "30": "sweet william", "17": "purple coneflower", "52": "wild pansy", "84": "columbine", "12": "colt's foot", "11": "snapdragon", "96": "camellia", "23": "fritillary", "50": "common dandelion", "44": "poinsettia", "53": "primula", "72": "azalea", "65": "californian poppy", "80": "anthurium", "76": "morning glory", "37": "cape flower", "56": "bishop of llandaff", "60": "pink-yellow dahlia", "82": "clematis", "58": "geranium", "75": "thorn apple", "41": "barbeton daisy", "95": "bougainvillea", "43": "sword lily", "83": "hibiscus", "78": "lotus", "88": "cyclamen", "94": "foxglove", "81": "frangipani", "74": "rose", "89": "watercress", "73": "water lily", "46": "wallflower", "77": "passion flower", "51": "petunia"} 2 | -------------------------------------------------------------------------------- /configs/classnames/pet_cat_to_name.json: -------------------------------------------------------------------------------- 1 | {"0": "abyssinian", "1": "american_bulldog", "2": "american_pit_bull_terrier", "3": "basset_hound", "4": "beagle", "5": "bengal", "6": "birman", "7": "bombay", "8": "boxer", "9": "british_shorthair", "10": "chihuahua", "11": "egyptian_mau", "12": "english_cocker_spaniel", "13": "english_setter", "14": "german_shorthaired", "15": "great_pyrenees", "16": "havanese", "17": "japanese_chin", "18": "keeshond", "19": "leonberger", "20": "maine_coon", "21": "miniature_pinscher", "22": "newfoundland", "23": "persian", "24": "pomeranian", "25": "pug", "26": "ragdoll", "27": "russian_blue", "28": "saint_bernard", "29": "samoyed", "30": "scottish_terrier", "31": "shiba_inu", "32": "siamese", "33": "sphynx", "34": "staffordshire_bull_terrier", "35": "wheaten_terrier", "36": "yorkshire_terrier"} 2 | -------------------------------------------------------------------------------- /configs/models/maple/vit_b16_c2_ep5_batch4_2ctx.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 4 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.0035 18 | MAX_EPOCH: 5 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | MAPLE: 33 | N_CTX: 2 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" 36 | PROMPT_DEPTH: 9 37 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | datasets==3.1.0 2 | einops==0.8.0 3 | ftfy==6.3.1 4 | geom-median==0.1.0 5 | gradio==5.8.0 6 | jaxtyping==0.2.36 7 | matplotlib==3.9.3 8 | nbformat 9 | numpy==1.26.4 10 | pandas==2.2.3 11 | Pillow==11.0.0 12 | plotly==5.24.1 13 | regex==2024.11.6 14 | Requests==2.32.3 15 | scipy==1.14.1 16 | setuptools==75.1.0 17 | torchvision==0.17.2 18 | tqdm==4.67.1 19 | transformer_lens==2.9.1 20 | transformers==4.46.3 21 | wandb==0.19.0 22 | yacs==0.1.8 23 | -------------------------------------------------------------------------------- /scripts/01_run_train.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | PYTHONPATH=./ nohup python -u tasks/train_sae_vit.py \ 4 | --batch_size 128 \ 5 | --checkpoint_path patchsae_checkpoints \ 6 | --n_checkpoints 10 \ 7 | --use_ghost_grads \ 8 | --log_to_wandb --wandb_project patchsae_test --wandb_entity hyesulim-hs \ 9 | > logs/01_test_training.txt 10 | -------------------------------------------------------------------------------- /scripts/02_run_compute_feature_data.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | PYTHONPATH=./ nohup python -u tasks/compute_sae_feature_data.py \ 4 | --root_dir ./ \ 5 | --dataset_name imagenet \ 6 | --sae_path patchsae_checkpoints/YOUR-OWN-PATH/clip-vit-base-patch16_-2_resid_49152.pt \ 7 | --vit_type base > logs/02_test_extract.txt 8 | -------------------------------------------------------------------------------- /scripts/03_run_class_level.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | PYTHONPATH=./ nohup python -u tasks/compute_class_wise_sae_activation.py \ 4 | --root_dir ./ \ 5 | --dataset_name imagenet \ 6 | --threshold 0.2 \ 7 | --sae_path /patchsae_checkpoints/YOUR-OWN-PATH/clip-vit-base-patch16_-2_resid_49152.pt \ 8 | --vit_type base > logs/03_test_class_level.txt 2> logs/03_test_class_level.err & 9 | -------------------------------------------------------------------------------- /scripts/04_run_topk_eval.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | PYTHONPATH=./ python tasks/classification_with_top_k_masking.py \ 4 | --root_dir ./ \ 5 | --dataset_name imagenet \ 6 | --sae_path patchsae_checkpoints/YOUR-OWN-PATH/clip-vit-base-patch16_-2_resid_49152.pt \ 7 | --cls_wise_sae_activation_path ./out/feature_data/sae_openai/base/imagenet/cls_sae_cnt.npy \ 8 | --vit_type base \ 9 | > logs/04_test_topk_eval.txt 10 | -------------------------------------------------------------------------------- /src/demo/app.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gradio as gr 4 | import numpy as np 5 | import plotly.graph_objects as go 6 | import torch 7 | from PIL import Image, ImageDraw 8 | from plotly.subplots import make_subplots 9 | 10 | from src.demo.utils import load_sae_tester 11 | 12 | IMAGE_SIZE = 400 13 | DATASET_LIST = ["imagenet"] 14 | GRID_NUM = 14 15 | 16 | 17 | def get_grid_loc(evt, image): 18 | # Get click coordinates 19 | x, y = evt._data["index"][0], evt._data["index"][1] 20 | 21 | cell_width = image.width // GRID_NUM 22 | cell_height = image.height // GRID_NUM 23 | 24 | grid_x = x // cell_width 25 | grid_y = y // cell_height 26 | return grid_x, grid_y, cell_width, cell_height 27 | 28 | 29 | def plot_activation( 30 | evt: gr.EventData, 31 | current_image, 32 | activation, 33 | model_name: str, 34 | colors: tuple[str, str], 35 | ): 36 | """Plot activation distribution for the full image and optionally a selected tile""" 37 | mean_activation = activation.mean(0) 38 | 39 | tile_activation = None 40 | tile_x = None 41 | tile_y = None 42 | 43 | if evt is not None and evt._data is not None: 44 | tile_x, tile_y, _, _ = get_grid_loc(evt, current_image) 45 | token_idx = tile_y * GRID_NUM + tile_x + 1 46 | tile_activation = activation[token_idx] 47 | 48 | fig = create_activation_plot( 49 | mean_activation, 50 | tile_activation, 51 | tile_x, 52 | tile_y, 53 | model_name=model_name, 54 | colors=colors, 55 | ) 56 | 57 | return fig 58 | 59 | 60 | def create_activation_plot( 61 | mean_activation, 62 | tile_activation=None, 63 | tile_x=None, 64 | tile_y=None, 65 | top_k=5, 66 | colors=("blue", "cyan"), 67 | model_name="CLIP", 68 | ): 69 | """Create plotly figure with activation traces and annotations""" 70 | fig = go.Figure() 71 | 72 | # Add trace for mean activation across full image 73 | model_label = model_name.split("-")[0] 74 | add_activation_trace( 75 | fig, mean_activation, f"{model_label} Image-level", colors[0], top_k 76 | ) 77 | 78 | # Add trace for tile activation if provided 79 | if tile_activation is not None: 80 | add_activation_trace( 81 | fig, 82 | tile_activation, 83 | f"{model_label} Tile ({tile_x}, {tile_y})", 84 | colors[1], 85 | top_k, 86 | ) 87 | 88 | # Update layout 89 | fig.update_layout( 90 | title="Activation Distribution", 91 | xaxis_title="SAE latent index", 92 | yaxis_title="Activation Value", 93 | template="plotly_white", 94 | legend=dict(orientation="h", yanchor="middle", y=0.5, xanchor="center", x=0.5), 95 | ) 96 | 97 | return fig 98 | 99 | 100 | def add_activation_trace(fig, activation, label, color, top_k): 101 | """Add a single activation trace with annotations to the figure""" 102 | # Add line trace 103 | fig.add_trace( 104 | go.Scatter( 105 | x=np.arange(len(activation)), 106 | y=activation, 107 | mode="lines", 108 | name=label, 109 | line=dict(color=color, dash="solid"), 110 | showlegend=True, 111 | ) 112 | ) 113 | 114 | # Add annotations for top activations 115 | top_indices = np.argsort(activation)[::-1][:top_k] 116 | for idx in top_indices: 117 | fig.add_annotation( 118 | x=idx, 119 | y=activation[idx], 120 | text=str(idx), 121 | showarrow=True, 122 | arrowhead=2, 123 | ax=0, 124 | ay=-15, 125 | arrowcolor=color, 126 | opacity=0.7, 127 | ) 128 | 129 | 130 | def plot_activation_distribution( 131 | evt: gr.EventData, current_image, clip_act, maple_act, model_name: str 132 | ): 133 | fig = make_subplots( 134 | rows=2, 135 | cols=1, 136 | shared_xaxes=True, 137 | subplot_titles=["CLIP Activation", f"{model_name} Activation"], 138 | ) 139 | 140 | fig_clip = plot_activation( 141 | evt, current_image, clip_act, "CLIP", colors=("#00b4d8", "#90e0ef") 142 | ) 143 | fig_maple = plot_activation( 144 | evt, current_image, maple_act, model_name, colors=("#ff5a5f", "#ffcad4") 145 | ) 146 | 147 | def _attach_fig(fig, sub_fig, row, col, yref): 148 | for trace in sub_fig.data: 149 | fig.add_trace(trace, row=row, col=col) 150 | 151 | for annotation in sub_fig.layout.annotations: 152 | annotation.update(yref=yref) 153 | fig.add_annotation(annotation) 154 | return fig 155 | 156 | fig = _attach_fig(fig, fig_clip, row=1, col=1, yref="y1") 157 | fig = _attach_fig(fig, fig_maple, row=2, col=1, yref="y2") 158 | 159 | fig.update_xaxes(title_text="SAE Latent Index", row=2, col=1) 160 | fig.update_xaxes(title_text="SAE Latent Index", row=1, col=1) 161 | fig.update_yaxes(title_text="Activation Value", row=1, col=1) 162 | fig.update_yaxes(title_text="Activation Value", row=2, col=1) 163 | fig.update_layout( 164 | template="plotly_white", 165 | showlegend=True, 166 | legend=dict(orientation="h", yanchor="bottom", y=-0.2, xanchor="center", x=0.5), 167 | margin=dict(l=20, r=20, t=40, b=20), 168 | ) 169 | 170 | return fig 171 | 172 | 173 | def get_top_images(model_type, slider_value, toggle_btn): 174 | out_top_images = sae_tester[model_type].get_top_images( 175 | slider_value, top_k=5, show_seg_mask=toggle_btn 176 | ) 177 | 178 | out_top_images = [plt_to_pil_direct(img) for img in out_top_images] 179 | return out_top_images 180 | 181 | 182 | def get_segmask(image, sae_act, slider_value): 183 | temp = sae_act[:, slider_value] 184 | mask = torch.Tensor(temp[1:,].reshape(14, 14)).view(1, 1, 14, 14) 185 | mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][ 186 | 0 187 | ].numpy() 188 | mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10) 189 | 190 | base_opacity = 30 191 | image_array = np.array(image)[..., :3] 192 | rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8) 193 | rgba_overlay[..., :3] = image_array[..., :3] 194 | 195 | darkened_image = (image_array[..., :3] * (base_opacity / 255)).astype(np.uint8) 196 | rgba_overlay[mask == 0, :3] = darkened_image[mask == 0] 197 | rgba_overlay[..., 3] = 255 # Fully opaque 198 | 199 | return rgba_overlay 200 | 201 | 202 | def plt_to_pil_direct(fig): 203 | # Draw the canvas to render the figure 204 | fig.canvas.draw() 205 | 206 | # Convert the figure to a NumPy array 207 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 208 | width, height = fig.canvas.get_width_height() 209 | image = data.reshape((height, width, 3)) 210 | 211 | # Create a PIL Image from the NumPy array 212 | return Image.fromarray(image) 213 | 214 | 215 | def show_segmentation_masks( 216 | selected_image, slider_value, sae_act, model_type, toggle_btn=False 217 | ): 218 | slider_value = int(slider_value.split("-")[-1]) 219 | rgba_overlay = get_segmask(selected_image, sae_act, slider_value) 220 | top_images = get_top_images(model_type, slider_value, toggle_btn) 221 | 222 | act_values = [] 223 | for dataset in REF_DATASET_LIST: 224 | act_value = sae_data_dict["mean_act_values"][dataset][slider_value, :5] 225 | act_value = [str(round(value.item(), 3)) for value in act_value] 226 | act_value = " | ".join(act_value) 227 | out = f"#### Activation values: {act_value}" 228 | act_values.append(out) 229 | 230 | return rgba_overlay, top_images, act_values 231 | 232 | 233 | def load_results(resized_image, radio_choice, clip_act, maple_act, toggle_btn): 234 | if clip_act is None: 235 | return None, None, None, None, None, None, None 236 | 237 | init_seg, init_tops, init_values = show_segmentation_masks( 238 | resized_image, radio_choice, clip_act, "CLIP", toggle_btn 239 | ) 240 | 241 | slider_value = int(radio_choice.split("-")[-1]) 242 | maple_init_seg = get_segmask(resized_image, maple_act, slider_value) 243 | 244 | out = (init_seg, maple_init_seg) 245 | out += tuple(init_tops) 246 | out += tuple(init_values) 247 | return out 248 | 249 | 250 | def load_image_and_act(image, clip_act, maple_act, model_name): 251 | resized_image = image.resize((IMAGE_SIZE, IMAGE_SIZE)) 252 | sae_tester["CLIP"].register_image(resized_image) 253 | clip_act = sae_tester["CLIP"].get_activation_distribution() 254 | 255 | sae_tester[model_name].register_image(resized_image) 256 | maple_act = sae_tester[model_name].get_activation_distribution() 257 | 258 | neuron_plot = plot_activation_distribution( 259 | None, resized_image, clip_act, maple_act, model_name 260 | ) 261 | 262 | radio_names = get_init_radio_options(clip_act, maple_act) 263 | radio_choices = gr.Radio( 264 | choices=radio_names, label="Top activating SAE latent", value=radio_names[0] 265 | ) 266 | feautre_idx = radio_names[0].split("-")[-1] 267 | markdown_display = ( 268 | f"## Segmentation mask for the selected SAE latent - {feautre_idx}" 269 | ) 270 | 271 | return ( 272 | resized_image, 273 | resized_image, 274 | neuron_plot, 275 | clip_act, 276 | maple_act, 277 | radio_choices, 278 | markdown_display, 279 | ) 280 | 281 | 282 | def highlight_grid(evt: gr.EventData, image, clip_act, maple_act, model_name): 283 | grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image) 284 | 285 | highlighted_image = image.copy() 286 | draw = ImageDraw.Draw(highlighted_image) 287 | box = [ 288 | grid_x * cell_width, 289 | grid_y * cell_height, 290 | (grid_x + 1) * cell_width, 291 | (grid_y + 1) * cell_height, 292 | ] 293 | draw.rectangle(box, outline="red", width=3) 294 | 295 | neuron_plot = plot_activation_distribution( 296 | evt, image, clip_act, maple_act, model_name 297 | ) 298 | 299 | radio, choices = update_radio_options(clip_act, maple_act, grid_x, grid_y) 300 | feautre_idx = choices[0].split("-")[-1] 301 | markdown_display = ( 302 | f"## Segmentation mask for the selected SAE latent - {feautre_idx}" 303 | ) 304 | 305 | return (highlighted_image, neuron_plot, radio, markdown_display) 306 | 307 | 308 | def get_init_radio_options(clip_act, maple_act): 309 | clip_neuron_dict = {} 310 | maple_neuron_dict = {} 311 | 312 | def _get_top_actvation(activations, neuron_dict, top_k=5): 313 | activations = activations.mean(0) 314 | top_neurons = list(np.argsort(activations)[::-1][:top_k]) 315 | for top_neuron in top_neurons: 316 | neuron_dict[top_neuron] = activations[top_neuron] 317 | sorted_dict = dict( 318 | sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True) 319 | ) 320 | return sorted_dict 321 | 322 | clip_neuron_dict = _get_top_actvation(clip_act, clip_neuron_dict) 323 | maple_neuron_dict = _get_top_actvation(maple_act, maple_neuron_dict) 324 | 325 | radio_choices = get_radio_names(clip_neuron_dict, maple_neuron_dict) 326 | 327 | return radio_choices 328 | 329 | 330 | def update_radio_options(clip_act, maple_act, grid_x, grid_y): 331 | def _sort_and_save_top_k(activations, neuron_dict, top_k=5): 332 | top_neurons = list(np.argsort(activations)[::-1][:top_k]) 333 | for top_neuron in top_neurons: 334 | neuron_dict[top_neuron] = activations[top_neuron] 335 | 336 | def _get_top_actvation(activations, neuron_dict, token_idx): 337 | image_activation = activations.mean(0) 338 | _sort_and_save_top_k(image_activation, neuron_dict) 339 | 340 | tile_activations = activations[token_idx] 341 | _sort_and_save_top_k(tile_activations, neuron_dict) 342 | 343 | sorted_dict = dict( 344 | sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True) 345 | ) 346 | return sorted_dict 347 | 348 | token_idx = grid_y * GRID_NUM + grid_x + 1 349 | clip_neuron_dict = {} 350 | maple_neuron_dict = {} 351 | clip_neuron_dict = _get_top_actvation(clip_act, clip_neuron_dict, token_idx) 352 | maple_neuron_dict = _get_top_actvation(maple_act, maple_neuron_dict, token_idx) 353 | 354 | clip_keys = list(clip_neuron_dict.keys()) 355 | maple_keys = list(maple_neuron_dict.keys()) 356 | 357 | common_keys = list(set(clip_keys).intersection(set(maple_keys))) 358 | clip_only_keys = list(set(clip_keys) - (set(maple_keys))) 359 | maple_only_keys = list(set(maple_keys) - (set(clip_keys))) 360 | 361 | common_keys.sort( 362 | key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True 363 | ) 364 | clip_only_keys.sort(reverse=True) 365 | maple_only_keys.sort(reverse=True) 366 | 367 | out = [] 368 | out.extend([f"common-{i}" for i in common_keys[:5]]) 369 | out.extend([f"CLIP-{i}" for i in clip_only_keys[:5]]) 370 | out.extend([f"MaPLE-{i}" for i in maple_only_keys[:5]]) 371 | 372 | radio_choices = gr.Radio( 373 | choices=out, label="Top activating SAE latent", value=out[0] 374 | ) 375 | return radio_choices, out 376 | 377 | 378 | def get_radio_names(clip_neuron_dict, maple_neuron_dict): 379 | clip_keys = list(clip_neuron_dict.keys()) 380 | maple_keys = list(maple_neuron_dict.keys()) 381 | 382 | common_keys = list(set(clip_keys).intersection(set(maple_keys))) 383 | clip_only_keys = list(set(clip_keys) - (set(maple_keys))) 384 | maple_only_keys = list(set(maple_keys) - (set(clip_keys))) 385 | 386 | common_keys.sort( 387 | key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True 388 | ) 389 | clip_only_keys.sort(reverse=True) 390 | maple_only_keys.sort(reverse=True) 391 | 392 | out = [] 393 | out.extend([f"common-{i}" for i in common_keys[:5]]) 394 | out.extend([f"CLIP-{i}" for i in clip_only_keys[:5]]) 395 | out.extend([f"MaPLE-{i}" for i in maple_only_keys[:5]]) 396 | 397 | return out 398 | 399 | 400 | if __name__ == "__main__": 401 | parser = argparse.ArgumentParser() 402 | parser.add_argument( 403 | "--include-imagenet", 404 | action="store_true", 405 | default=False, 406 | help="Include ImageNet in the Demo", 407 | ) 408 | args = parser.parse_args() 409 | 410 | sae_tester = load_sae_tester("./data/sae_weight/base/out.pt", args.include_imagenet) 411 | sae_data_dict = {"mean_act_values": {}} 412 | if args.include_imagenet: 413 | REF_DATASET_LIST = ["imagenet", "imagenet-sketch", "caltech101"] 414 | else: 415 | REF_DATASET_LIST = ["imagenet-sketch", "caltech101"] 416 | for dataset in ["imagenet", "imagenet-sketch", "caltech101"]: 417 | data = torch.load( 418 | f"./out/feature_data/sae_base/base/{dataset}/max_activating_image_values.pt", 419 | map_location="cpu", 420 | ) 421 | sae_data_dict["mean_act_values"][dataset] = data 422 | 423 | with gr.Blocks( 424 | theme=gr.themes.Citrus(), 425 | css=""" 426 | .image-row .gr-image { margin: 0 !important; padding: 0 !important; } 427 | .image-row img { width: auto; height: 50px; } /* Set a uniform height for all images */ 428 | """, 429 | ) as demo: 430 | with gr.Row(): 431 | with gr.Column(): 432 | # Left View: Image selection and click handling 433 | gr.Markdown("## Select input image and patch on the image") 434 | 435 | current_image = gr.State() 436 | clip_act = gr.State() 437 | maple_act = gr.State() 438 | 439 | image_display = gr.Image(type="pil", interactive=True) 440 | 441 | with gr.Column(): 442 | gr.Markdown("## SAE latent activations of CLIP and MaPLE") 443 | model_options = [ 444 | f"MaPLE-{dataset_name}" for dataset_name in DATASET_LIST 445 | ] 446 | model_selector = gr.Dropdown( 447 | choices=model_options, 448 | value=model_options[0], 449 | label="Select adapted model (MaPLe)", 450 | ) 451 | 452 | neuron_plot = gr.Plot(label="Neuron Activation", show_label=False) 453 | 454 | with gr.Row(): 455 | with gr.Column(): 456 | markdown_display = gr.Markdown( 457 | "## Segmentation mask for the selected SAE latent - " 458 | ) 459 | gr.Markdown("### Localize SAE latent activation using CLIP") 460 | seg_mask_display = gr.Image(type="pil", show_label=False) 461 | 462 | gr.Markdown("### Localize SAE latent activation using MaPLE") 463 | seg_mask_display_maple = gr.Image(type="pil", show_label=False) 464 | 465 | with gr.Column(): 466 | radio_choices = gr.Radio( 467 | choices=[], 468 | label="Top activating SAE latent", 469 | ) 470 | toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False) 471 | 472 | image_display_dict = {} 473 | activation_dict = {} 474 | for dataset in REF_DATASET_LIST: 475 | image_display_dict[dataset] = gr.Image( 476 | type="pil", label=dataset, show_label=False 477 | ) 478 | activation_dict[dataset] = gr.Markdown("") 479 | 480 | image_display.upload( 481 | fn=load_image_and_act, 482 | inputs=[image_display, clip_act, maple_act, model_selector], 483 | outputs=[ 484 | image_display, 485 | current_image, 486 | neuron_plot, 487 | clip_act, 488 | maple_act, 489 | radio_choices, 490 | markdown_display, 491 | ], 492 | ) 493 | 494 | outputs = [seg_mask_display, seg_mask_display_maple] 495 | outputs += list(image_display_dict.values()) 496 | outputs += list(activation_dict.values()) 497 | 498 | radio_choices.change( 499 | fn=load_results, 500 | inputs=[current_image, radio_choices, clip_act, maple_act, toggle_btn], 501 | outputs=outputs, 502 | ) 503 | 504 | toggle_btn.change( 505 | fn=load_results, 506 | inputs=[current_image, radio_choices, clip_act, maple_act, toggle_btn], 507 | outputs=outputs, 508 | ) 509 | 510 | image_display.select( 511 | fn=highlight_grid, 512 | inputs=[current_image, clip_act, maple_act, model_selector], 513 | outputs=[ 514 | image_display, 515 | neuron_plot, 516 | radio_choices, 517 | markdown_display, 518 | ], 519 | ) 520 | 521 | radio_choices.change( 522 | fn=load_results, 523 | inputs=[current_image, radio_choices, clip_act, maple_act, toggle_btn], 524 | outputs=outputs, 525 | ) 526 | 527 | demo.launch() 528 | -------------------------------------------------------------------------------- /src/demo/core.py: -------------------------------------------------------------------------------- 1 | import os 2 | from copy import deepcopy 3 | from io import BytesIO 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import requests 8 | import torch 9 | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas 10 | from PIL import Image 11 | 12 | 13 | class UtilMixin: 14 | def _get_max_activating_images_and_labels( 15 | self, neuron_idx, dataset, max_activating_image_indices 16 | ): 17 | img_list = max_activating_image_indices[neuron_idx] 18 | images = [] 19 | labels = [] 20 | for i in img_list: 21 | try: 22 | images.append(dataset[i.item()]["image"]) 23 | labels.append(dataset[i.item()]["label"]) 24 | except Exception: 25 | images.append(dataset[i.item()]["jpg"]) 26 | labels.append(dataset[i.item()]["cls"]) 27 | return images, labels 28 | 29 | def _create_patches(self, patch=256): 30 | temp = self.processed_image["pixel_values"].clone() 31 | patches = temp[0].data.unfold(0, 3, 3) 32 | patches = patches.unfold(1, patch, patch) 33 | patches = patches.unfold(2, patch, patch) 34 | return patches 35 | 36 | 37 | class VisualizeMixin: 38 | def _plot_input_image(self): 39 | plt.imshow(self.input_image) 40 | 41 | def _plot_feature_mask(self, patches, feat_idx, mask=None, plot=True): 42 | if mask is None: 43 | mask = self.sae_act[0, :, feat_idx].cpu() 44 | 45 | fig, axs = plt.subplots(patches.size(1), patches.size(2), figsize=(6, 6)) 46 | plt.subplots_adjust(wspace=0.01, hspace=0.01) 47 | 48 | for i in range(patches.size(1)): 49 | for j in range(patches.size(2)): 50 | patch = patches[0, i, j].permute(1, 2, 0) 51 | patch *= torch.tensor(self.vit.processor.image_processor.image_std) 52 | patch += torch.tensor(self.vit.processor.image_processor.image_mean) 53 | masked_patch = patch * mask[i * patches.size(2) + j + 1] 54 | masked_patch = (masked_patch - masked_patch.min()) / ( 55 | masked_patch.max() - masked_patch.min() + 1e-8 56 | ) 57 | axs[i, j].imshow(masked_patch) 58 | axs[i, j].axis("off") 59 | 60 | fig.suptitle(feat_idx) 61 | plt.close() 62 | return fig 63 | 64 | def _plot_patches(self, patches, highlight_patch_idx=None): 65 | fig, axs = plt.subplots(patches.size(1), patches.size(2), figsize=(6, 6)) 66 | plt.subplots_adjust(wspace=0.01, hspace=0.01) 67 | for i in range(patches.size(1)): 68 | for j in range(patches.size(2)): 69 | patch = patches[0, i, j].permute(1, 2, 0) 70 | patch *= torch.tensor(self.vit.processor.image_processor.image_std) 71 | patch += torch.tensor(self.vit.processor.image_processor.image_mean) 72 | axs[i, j].imshow(patch) 73 | if i * patches.size(2) + j == highlight_patch_idx: 74 | for spine in axs[i, j].spines.values(): 75 | spine.set_edgecolor("red") 76 | spine.set_linewidth(3) 77 | axs[i, j].set_xticks([]) 78 | axs[i, j].set_yticks([]) 79 | else: 80 | axs[i, j].axis("off") 81 | plt.show() 82 | 83 | def _plot_union_top_neruons( 84 | self, top_k, union_top_neurons, token_idx, token_act, save=False 85 | ): 86 | print(f"Union of top {top_k} neurons: {union_top_neurons}") 87 | 88 | plt.figure(figsize=(10, 5)) 89 | plt.plot(token_act) 90 | plt.plot( 91 | union_top_neurons, 92 | token_act[union_top_neurons], 93 | "ro", 94 | label="Top neurons", 95 | markersize=5, 96 | ) 97 | 98 | # Annotate feature indices 99 | for idx in union_top_neurons: 100 | plt.text( 101 | idx, token_act[idx] + 0.05, str(idx), fontsize=9, ha="center" 102 | ) # Adjust the 0.05 value as needed for spacing 103 | 104 | plt.legend() 105 | plt.title(f"token {token_idx} activation") 106 | 107 | if save: 108 | img_name = os.path.basename(self.img_url).replace(".jpg", "") 109 | save_name = f"{self.save_dir}/{img_name}/activation/{token_idx}.jpg" 110 | os.makedirs(os.path.dirname(save_name), exist_ok=True) 111 | plt.savefig(save_name) 112 | 113 | plt.show() 114 | plt.close() 115 | 116 | def _plot_images( 117 | self, 118 | dataset_name, 119 | images, 120 | neuron_idx, 121 | labels=None, 122 | suptitle=None, 123 | top_k=5, 124 | save=False, 125 | ): 126 | images = [img.resize((224, 224)) for img in images] 127 | num_cols = min(top_k, 5) 128 | num_rows = (top_k + num_cols - 1) // num_cols 129 | fig, axes = plt.subplots( 130 | num_rows, num_cols, figsize=(4.5 * num_cols, 5 * num_rows) 131 | ) 132 | axes = axes.flatten() # Flatten the 2D array of axes 133 | 134 | for i in range(top_k): 135 | axes[i].imshow(images[i]) # Display the image 136 | axes[i].axis("off") # Hide axes 137 | if labels is not None: 138 | class_name = self.class_names[dataset_name][int(labels[i])] 139 | axes[i].set_title(f"{labels[i]} {class_name}", fontsize=25) 140 | # plt.suptitle(suptitle) 141 | plt.tight_layout() 142 | 143 | if save: 144 | img_name = os.path.basename(self.img_url).replace(".jpg", "") 145 | save_name = ( 146 | f"{self.save_dir}/{img_name}/top_images/{dataset_name}/{neuron_idx}.jpg" 147 | ) 148 | os.makedirs(os.path.dirname(save_name), exist_ok=True) 149 | plt.savefig(save_name) 150 | 151 | plt.close() 152 | 153 | return fig 154 | 155 | def _fig_to_img(self, fig): 156 | canvas = FigureCanvas(fig) 157 | canvas.draw() 158 | img = np.frombuffer(canvas.tostring_rgb(), dtype="uint8") 159 | img = img.reshape(canvas.get_width_height()[::-1] + (3,)) 160 | return img 161 | 162 | def _plot_multiple_images(self, figs, neuron_idx, top_k=5, save=False): 163 | # Create a new figure to hold all subplots 164 | num_plots = len(figs) 165 | cols = 1 # Number of columns in the subplot grid 166 | rows = (num_plots + cols - 1) // cols # Calculate rows required 167 | 168 | combined_fig = plt.figure(figsize=(20, 12)) # Adjust figsize as needed 169 | 170 | for i, fig in enumerate(figs): 171 | ax = combined_fig.add_subplot(rows, cols, i + 1) 172 | img = self._fig_to_img(fig) 173 | ax.imshow(img) 174 | ax.axis("off") 175 | 176 | if save: 177 | img_name = os.path.basename(self.img_url).replace(".jpg", "") 178 | save_name = f"{self.save_dir}/{img_name}/top_images/{neuron_idx}.jpg" 179 | os.makedirs(os.path.dirname(save_name), exist_ok=True) 180 | plt.savefig(save_name) 181 | 182 | combined_fig.show() 183 | # plt.close(combined_fig) 184 | 185 | 186 | class SAETester(VisualizeMixin, UtilMixin): 187 | def __init__( 188 | self, 189 | vit, 190 | cfg, 191 | sae, 192 | mean_acts, 193 | max_act_images, 194 | datasets, 195 | class_names, 196 | noisy_threshold=0.1, 197 | device="cpu", 198 | ): 199 | self.vit = vit 200 | self.cfg = cfg 201 | self.sae = sae 202 | self.mean_acts = mean_acts 203 | self.max_act_images = max_act_images 204 | self.datasets = datasets 205 | self.class_names = class_names 206 | self.noisy_threshold = noisy_threshold 207 | self.device = device 208 | 209 | def register_image(self, img_url: str) -> None: 210 | """Load and process an image from a URL or local path.""" 211 | if isinstance(img_url, str): 212 | image = self._load_image(img_url) 213 | else: 214 | image = img_url 215 | self.input_image = image 216 | self.processed_image = self.vit.processor( 217 | images=image, text="", return_tensors="pt", padding=True 218 | ) 219 | 220 | def _load_image(self, img_url: str) -> Image.Image: 221 | """Helper method to load image from URL or local path.""" 222 | if "http" in img_url: 223 | response = requests.get(img_url) 224 | response.raise_for_status() 225 | return Image.open(BytesIO(response.content)) 226 | return Image.open(img_url) 227 | 228 | @property 229 | def processed_image(self): 230 | return self._processed_image 231 | 232 | @processed_image.setter 233 | def processed_image(self, value): 234 | self._processed_image = value 235 | 236 | @property 237 | def input_image(self): 238 | return self._input_image 239 | 240 | @input_image.setter 241 | def input_image(self, value): 242 | self._input_image = value 243 | 244 | def show_input_image(self): 245 | self._plot_input_image() 246 | 247 | def run( 248 | self, highlight_patch_idx, patch_size=16, top_k=5, num_images=5, seg_mask=True 249 | ): 250 | # idx = 0 is cls token 251 | self.show_patches( 252 | highlight_patch_idx=highlight_patch_idx - 1, patch_size=patch_size 253 | ) 254 | top_neurons = self.get_top_neurons(highlight_patch_idx, top_k=top_k) 255 | self.show_ref_images_of_neuron_indices( 256 | top_neurons, top_k=num_images, seg_mask=True 257 | ) 258 | 259 | def show_patches(self, highlight_patch_idx=None, patch_size=16): 260 | if not hasattr(self, "input_image"): 261 | assert not hasattr(self, "input_image"), "register image first" 262 | 263 | patches = self._create_patches(patch=patch_size) 264 | self._plot_patches(patches.cpu().data, highlight_patch_idx=highlight_patch_idx) 265 | 266 | def show_segmentation_mask(self, feat_idx, patch_size=16, mask=None, plot=True): 267 | patches = self._create_patches(patch=patch_size) 268 | fig = self._plot_feature_mask( 269 | patches.cpu().data, feat_idx, mask=None, plot=plot 270 | ) 271 | return fig 272 | 273 | def get_segmentation_mask(self, image, feat_idx: int): 274 | if image.mode == "L": 275 | image = image.convert("RGB") 276 | 277 | vit_act = self._run_vit_hook(image) 278 | sae_act = self._run_sae_hook(vit_act) 279 | token_act = sae_act[0].detach().cpu().numpy() 280 | filtered_mean_act = self._filter_out_nosiy_activation(token_act) 281 | 282 | temp = filtered_mean_act[:, feat_idx] 283 | mask = torch.Tensor(temp[1:,].reshape(14, 14)).view(1, 1, 14, 14) 284 | mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][ 285 | 0 286 | ].numpy() 287 | mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10) 288 | 289 | base_opacity = 30 290 | image_array = np.array(image)[..., :3] 291 | rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8) 292 | rgba_overlay[..., :3] = image_array[..., :3] 293 | 294 | darkened_image = (image_array[..., :3] * (base_opacity / 255)).astype(np.uint8) 295 | rgba_overlay[mask == 0, :3] = darkened_image[mask == 0] 296 | rgba_overlay[..., 3] = 255 # Fully opaque 297 | 298 | return Image.fromarray(rgba_overlay) 299 | 300 | def get_top_neurons(self, token_idx=None, top_k=5, plot=True): 301 | if token_idx is None: 302 | token_acts, top_neurons, self.sae_act = self._get_img_acts_and_top_neurons( 303 | top_k=top_k 304 | ) 305 | else: 306 | token_acts, top_neurons, self.sae_act = ( 307 | self._get_token_acts_and_top_neurons(token_idx=token_idx, top_k=top_k) 308 | ) 309 | if plot: 310 | self._plot_union_top_neruons(top_k, top_neurons, token_idx, token_acts) 311 | return top_neurons 312 | 313 | def get_top_images(self, neuron_idx: int, top_k=5, show_seg_mask=False): 314 | out_top_images = [] 315 | for dataset_name in self.datasets.keys(): 316 | if self.datasets[dataset_name] is None: 317 | continue 318 | images, labels = self._get_max_activating_images_and_labels( 319 | neuron_idx, 320 | self.datasets[dataset_name], 321 | self.max_act_images[dataset_name], 322 | ) 323 | 324 | if show_seg_mask: 325 | images = [ 326 | self.get_segmentation_mask(img, neuron_idx) 327 | for img in images[:top_k] 328 | ] 329 | suptitle = f"{dataset_name} - {neuron_idx}" 330 | fig = self._plot_images( 331 | dataset_name, 332 | images, 333 | neuron_idx, 334 | labels, 335 | suptitle=suptitle, 336 | top_k=top_k, 337 | save=False, 338 | ) 339 | out_top_images.append(fig) 340 | 341 | return out_top_images 342 | 343 | def show_ref_images_of_neuron_indices( 344 | self, neuron_indices: list[int], top_k=5, save=False, seg_mask=False 345 | ): 346 | for neuron_idx in neuron_indices: 347 | figs = self.get_top_images(neuron_idx, top_k=top_k, show_seg_mask=False) 348 | self._plot_multiple_images(figs, neuron_indices, top_k=top_k, save=False) 349 | 350 | if seg_mask: 351 | figs = self.get_top_images(neuron_idx, top_k=top_k, show_seg_mask=True) 352 | self._plot_multiple_images( 353 | figs, neuron_indices, top_k=top_k, save=False 354 | ) 355 | 356 | def get_activation_distribution(self): 357 | vit_act = self._run_vit_hook() 358 | sae_act = self._run_sae_hook(vit_act) 359 | token_act = sae_act[0].detach().cpu().numpy() 360 | filtered_mean_act = self._filter_out_nosiy_activation(token_act) 361 | self.sae_act = sae_act 362 | return filtered_mean_act 363 | 364 | def _get_img_acts_and_top_neurons(self, top_k=5, threshold=0.2): 365 | vit_act = self._run_vit_hook() 366 | sae_act = self._run_sae_hook(vit_act) 367 | 368 | token_act = sae_act[0].detach().cpu().numpy() 369 | filtered_mean_act = self._filter_out_nosiy_activation(token_act) 370 | token_act = (filtered_mean_act > threshold).sum(0) 371 | filtered_mean_act = filtered_mean_act.sum(0) 372 | top_neurons = np.argsort(filtered_mean_act)[::-1][:top_k] 373 | 374 | return token_act, top_neurons, sae_act 375 | 376 | def _get_token_acts_and_top_neurons(self, token_idx, top_k=5): 377 | vit_act = self._run_vit_hook() 378 | sae_act = self._run_sae_hook(vit_act) 379 | 380 | token_act = sae_act[0, token_idx, :].detach().cpu().numpy() 381 | filtered_mean_act = self._filter_out_nosiy_activation(token_act) 382 | top_neurons = np.argsort(filtered_mean_act)[::-1][:top_k] 383 | 384 | return token_act, top_neurons, sae_act 385 | 386 | def _run_vit_hook(self, image=None): 387 | if image is None: 388 | inputs = self.processed_image.to(self.device) 389 | else: 390 | inputs = self.vit.processor( 391 | images=image, text="", return_tensors="pt", padding=True 392 | ) 393 | list_of_hook_locations = [(self.cfg.block_layer, self.cfg.module_name)] 394 | vit_out, vit_cache_dict = self.vit.run_with_cache( 395 | list_of_hook_locations, **inputs 396 | ) 397 | vit_act = vit_cache_dict[(self.cfg.block_layer, self.cfg.module_name)] 398 | return vit_act 399 | 400 | def _run_sae_hook(self, vit_act): 401 | sae_out, sae_cache_dict = self.sae.run_with_cache(vit_act) 402 | sae_act = sae_cache_dict["hook_hidden_post"] 403 | if sae_act.shape[0] != 1: 404 | sae_act = sae_act.permute(1, 0, 2) 405 | return sae_act[:, :197, :] 406 | 407 | def _filter_out_nosiy_activation(self, features): 408 | noisy_features_indices = ( 409 | (self.mean_acts["imagenet"] > self.noisy_threshold).nonzero()[0].tolist() 410 | ) 411 | features_copy = deepcopy(features) 412 | if len(features_copy.shape) == 1: 413 | features_copy[noisy_features_indices] = 0 414 | elif len(features_copy.shape) == 2: 415 | features_copy[:, noisy_features_indices] = 0 416 | return features_copy 417 | -------------------------------------------------------------------------------- /src/demo/utils.py: -------------------------------------------------------------------------------- 1 | from src.demo.core import SAETester 2 | from tasks.utils import ( 3 | get_all_classnames, 4 | get_max_acts_and_images, 5 | get_sae_and_vit, 6 | load_datasets, 7 | ) 8 | 9 | 10 | def load_sae_tester(sae_path, include_imagenet=False): 11 | datasets = load_datasets(include_imagenet=include_imagenet) 12 | classnames = get_all_classnames(datasets, data_root="./configs/classnames") 13 | 14 | root = "./out/feature_data" 15 | sae_runname = "sae_base" 16 | vit_name = "base" 17 | 18 | if include_imagenet is False: 19 | datasets["imagenet"] = None 20 | 21 | max_act_imgs, mean_acts = get_max_acts_and_images( 22 | datasets, root, sae_runname, vit_name 23 | ) 24 | 25 | sae_tester = {} 26 | 27 | sae, vit, cfg = get_sae_and_vit( 28 | sae_path=sae_path, 29 | vit_type="base", 30 | device="cpu", 31 | backbone="openai/clip-vit-base-patch16", 32 | model_path=None, 33 | classnames=None, 34 | ) 35 | sae_clip = SAETester(vit, cfg, sae, mean_acts, max_act_imgs, datasets, classnames) 36 | 37 | sae, vit, cfg = get_sae_and_vit( 38 | sae_path=sae_path, 39 | vit_type="maple", 40 | device="cpu", 41 | model_path="./data/clip/maple/imagenet/model.pth.tar-2", 42 | config_path="./configs/models/maple/vit_b16_c2_ep5_batch4_2ctx.yaml", 43 | backbone="openai/clip-vit-base-patch16", 44 | classnames=classnames["imagenet"], 45 | ) 46 | sae_maple = SAETester(vit, cfg, sae, mean_acts, max_act_imgs, datasets, classnames) 47 | sae_tester["CLIP"] = sae_clip 48 | sae_tester["MaPLE-imagenet"] = sae_maple 49 | return sae_tester 50 | -------------------------------------------------------------------------------- /src/models/architecture/maple.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Portions of this file are based on code from the “multimodal-prompt-learning” repository (MIT-licensed): 3 | # https://github.com/muzairkhattak/multimodal-prompt-learning/blob/main/trainers/maple.py 4 | """ 5 | 6 | import copy 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from src.models.clip.clip import tokenize 12 | from src.models.clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 13 | 14 | _tokenizer = _Tokenizer() 15 | 16 | 17 | class TextEncoder(nn.Module): 18 | def __init__(self, clip_model): 19 | super().__init__() 20 | self.transformer = clip_model.transformer 21 | self.positional_embedding = clip_model.positional_embedding 22 | self.ln_final = clip_model.ln_final 23 | self.text_projection = clip_model.text_projection 24 | self.dtype = clip_model.dtype 25 | 26 | def forward(self, prompts, tokenized_prompts, compound_prompts_deeper_text): 27 | x = prompts + self.positional_embedding.type(self.dtype) 28 | x = x.permute(1, 0, 2) # NLD -> LND 29 | # Pass as the list, as nn.sequential cannot process multiple arguments in the forward pass 30 | combined = [ 31 | x, 32 | compound_prompts_deeper_text, 33 | 0, 34 | ] # third argument is the counter which denotes depth of prompt 35 | outputs = self.transformer(combined) 36 | x = outputs[0] # extract the x back from here 37 | x = x.permute(1, 0, 2) # LND -> NLD 38 | x = self.ln_final(x).type(self.dtype) 39 | 40 | # x.shape = [batch_size, n_ctx, transformer.width] 41 | # take features from the eot embedding (eot_token is the highest number in each sequence) 42 | x = ( 43 | x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] 44 | @ self.text_projection 45 | ) 46 | 47 | return x 48 | 49 | 50 | class MultiModalPromptLearner(nn.Module): 51 | def __init__(self, cfg, classnames, clip_model): 52 | super().__init__() 53 | n_cls = len(classnames) 54 | n_ctx = cfg.TRAINER.MAPLE.N_CTX 55 | ctx_init = cfg.TRAINER.MAPLE.CTX_INIT 56 | dtype = clip_model.dtype 57 | ctx_dim = clip_model.ln_final.weight.shape[0] 58 | clip_imsize = clip_model.visual.input_resolution 59 | cfg_imsize = cfg.INPUT.SIZE[0] 60 | # Default is 1, which is compound shallow prompting 61 | assert cfg.TRAINER.MAPLE.PROMPT_DEPTH >= 1, ( 62 | "For MaPLe, PROMPT_DEPTH should be >= 1" 63 | ) 64 | self.compound_prompts_depth = ( 65 | cfg.TRAINER.MAPLE.PROMPT_DEPTH 66 | ) # max=12, but will create 11 such shared prompts 67 | assert cfg_imsize == clip_imsize, ( 68 | f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})" 69 | ) 70 | 71 | if ctx_init and (n_ctx) <= 4: 72 | # use given words to initialize context vectors 73 | ctx_init = ctx_init.replace("_", " ") 74 | n_ctx = n_ctx 75 | prompt = tokenize(ctx_init) 76 | with torch.no_grad(): 77 | embedding = clip_model.token_embedding(prompt).type(dtype) 78 | ctx_vectors = embedding[0, 1 : 1 + n_ctx, :] 79 | prompt_prefix = ctx_init 80 | else: 81 | # random initialization 82 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) 83 | nn.init.normal_(ctx_vectors, std=0.02) 84 | prompt_prefix = " ".join(["X"] * n_ctx) 85 | print("MaPLe design: Multi-modal Prompt Learning") 86 | print(f'Initial context: "{prompt_prefix}"') 87 | print(f"Number of MaPLe context words (tokens): {n_ctx}") 88 | # These below, related to the shallow prompts 89 | # Linear layer so that the tokens will project to 512 and will be initialized from 768 90 | self.proj = nn.Linear(ctx_dim, 768).type(dtype) 91 | # self.proj.half() 92 | self.ctx = nn.Parameter(ctx_vectors) 93 | # These below parameters related to the shared prompts 94 | # Define the compound prompts for the deeper layers 95 | 96 | # Minimum can be 1, which defaults to shallow MaPLe 97 | # compound prompts 98 | self.compound_prompts_text = nn.ParameterList( 99 | [ 100 | nn.Parameter(torch.empty(n_ctx, 512)) 101 | for _ in range(self.compound_prompts_depth - 1) 102 | ] 103 | ) 104 | for single_para in self.compound_prompts_text: 105 | nn.init.normal_(single_para, std=0.02) 106 | # Also make corresponding projection layers, for each prompt 107 | single_layer = nn.Linear(ctx_dim, 768) 108 | self.compound_prompt_projections = _get_clones( 109 | single_layer, self.compound_prompts_depth - 1 110 | ) 111 | 112 | classnames = [name.replace("_", " ") for name in classnames] 113 | name_lens = [len(_tokenizer.encode(name)) for name in classnames] 114 | prompts = [prompt_prefix + " " + name + "." for name in classnames] 115 | 116 | tokenized_prompts = torch.cat([tokenize(p) for p in prompts]) # (n_cls, n_tkn) 117 | with torch.no_grad(): 118 | embedding = clip_model.token_embedding(tokenized_prompts).type(dtype) 119 | 120 | # These token vectors will be saved when in save_model(), 121 | # but they should be ignored in load_model() as we want to use 122 | # those computed using the current class names 123 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS 124 | self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS 125 | 126 | self.n_cls = n_cls 127 | self.n_ctx = n_ctx 128 | self.tokenized_prompts = tokenized_prompts # torch.Tensor 129 | self.name_lens = name_lens 130 | 131 | def construct_prompts(self, ctx, prefix, suffix, label=None): 132 | # dim0 is either batch_size (during training) or n_cls (during testing) 133 | # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim) 134 | # prefix: the sos token, with shape of (n_cls, 1, ctx_dim) 135 | # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim) 136 | 137 | if label is not None: 138 | prefix = prefix[label] 139 | suffix = suffix[label] 140 | 141 | prompts = torch.cat( 142 | [ 143 | prefix, # (dim0, 1, dim) 144 | ctx, # (dim0, n_ctx, dim) 145 | suffix, # (dim0, *, dim) 146 | ], 147 | dim=1, 148 | ) 149 | 150 | return prompts 151 | 152 | def forward(self): 153 | ctx = self.ctx 154 | 155 | if ctx.dim() == 2: 156 | ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) 157 | 158 | prefix = self.token_prefix 159 | suffix = self.token_suffix 160 | prompts = self.construct_prompts(ctx, prefix, suffix) 161 | 162 | # Before returning, need to transform 163 | # prompts to 768 for the visual side 164 | visual_deep_prompts = [] 165 | for index, layer in enumerate(self.compound_prompt_projections): 166 | visual_deep_prompts.append(layer(self.compound_prompts_text[index])) 167 | # Now the other way around 168 | # We will project the textual prompts from 512 to 768 169 | return ( 170 | prompts, 171 | self.proj(self.ctx), 172 | self.compound_prompts_text, 173 | visual_deep_prompts, 174 | ) # pass here original, as for visual 768 is required 175 | 176 | 177 | class CustomCLIP(nn.Module): 178 | def __init__(self, cfg, classnames, clip_model): 179 | super().__init__() 180 | self.prompt_learner = MultiModalPromptLearner(cfg, classnames, clip_model) 181 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts 182 | self.image_encoder = clip_model.visual 183 | self.text_encoder = TextEncoder(clip_model) 184 | self.logit_scale = clip_model.logit_scale 185 | self.dtype = clip_model.dtype 186 | 187 | def get_text_features(self): 188 | prompts, _, deep_compound_prompts_text, _ = self.prompt_learner() 189 | tokenized_prompts = self.tokenized_prompts 190 | text_features = self.text_encoder( 191 | prompts, tokenized_prompts, deep_compound_prompts_text 192 | ) 193 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 194 | return text_features 195 | 196 | def forward(self, input_ids, attention_mask, pixel_values): 197 | # alhtough input_ids and attention_mask are not used, but they are required for the forward pass **kwargs 198 | _, shared_ctx, _, deep_compound_prompts_vision = self.prompt_learner() 199 | image_features = self.image_encoder( 200 | pixel_values.type(self.dtype), shared_ctx, deep_compound_prompts_vision 201 | ) 202 | 203 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 204 | return image_features 205 | 206 | 207 | def _get_clones(module, N): 208 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 209 | -------------------------------------------------------------------------------- /src/models/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * # noqa: F403 2 | -------------------------------------------------------------------------------- /src/models/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamical-inference/patchsae/6b468b36fe4003724e1c5445180313c79b7a2d0c/src/models/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /src/models/clip/clip.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Portions of this file are based on code from the “openai/CLIP” repository (MIT-licensed): 3 | # https://github.com/openai/CLIP/blob/main/clip/clip.py 4 | """ 5 | 6 | import hashlib 7 | import os 8 | import urllib 9 | import warnings 10 | from typing import List, Union 11 | 12 | import torch 13 | from PIL import Image 14 | from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor 15 | from tqdm import tqdm 16 | 17 | from .model import build_model 18 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 19 | 20 | try: 21 | from torchvision.transforms import InterpolationMode 22 | 23 | BICUBIC = InterpolationMode.BICUBIC 24 | except ImportError: 25 | BICUBIC = Image.BICUBIC 26 | 27 | 28 | if torch.__version__.split(".") < ["1", "7", "1"]: 29 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 30 | 31 | 32 | __all__ = ["available_models", "load", "tokenize"] 33 | _tokenizer = _Tokenizer() 34 | 35 | _MODELS = { 36 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 37 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 38 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 39 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 40 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 41 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 42 | } 43 | 44 | 45 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 46 | os.makedirs(root, exist_ok=True) 47 | filename = os.path.basename(url) 48 | 49 | expected_sha256 = url.split("/")[-2] 50 | download_target = os.path.join(root, filename) 51 | 52 | if os.path.exists(download_target) and not os.path.isfile(download_target): 53 | raise RuntimeError(f"{download_target} exists and is not a regular file") 54 | 55 | if os.path.isfile(download_target): 56 | if ( 57 | hashlib.sha256(open(download_target, "rb").read()).hexdigest() 58 | == expected_sha256 59 | ): 60 | return download_target 61 | else: 62 | warnings.warn( 63 | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" 64 | ) 65 | 66 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 67 | with tqdm( 68 | total=int(source.info().get("Content-Length")), 69 | ncols=80, 70 | unit="iB", 71 | unit_scale=True, 72 | ) as loop: 73 | while True: 74 | buffer = source.read(8192) 75 | if not buffer: 76 | break 77 | 78 | output.write(buffer) 79 | loop.update(len(buffer)) 80 | 81 | if ( 82 | hashlib.sha256(open(download_target, "rb").read()).hexdigest() 83 | != expected_sha256 84 | ): 85 | raise RuntimeError( 86 | "Model has been downloaded but the SHA256 checksum does not not match" 87 | ) 88 | 89 | return download_target 90 | 91 | 92 | def _transform(n_px): 93 | return Compose( 94 | [ 95 | Resize(n_px, interpolation=BICUBIC), 96 | CenterCrop(n_px), 97 | lambda image: image.convert("RGB"), 98 | ToTensor(), 99 | Normalize( 100 | (0.48145466, 0.4578275, 0.40821073), 101 | (0.26862954, 0.26130258, 0.27577711), 102 | ), 103 | ] 104 | ) 105 | 106 | 107 | def available_models() -> List[str]: 108 | """Returns the names of available CLIP models""" 109 | return list(_MODELS.keys()) 110 | 111 | 112 | def load( 113 | name: str, 114 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 115 | jit=False, 116 | ): 117 | """Load a CLIP model 118 | 119 | Parameters 120 | ---------- 121 | name : str 122 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 123 | 124 | device : Union[str, torch.device] 125 | The device to put the loaded model 126 | 127 | jit : bool 128 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 129 | 130 | Returns 131 | ------- 132 | model : torch.nn.Module 133 | The CLIP model 134 | 135 | preprocess : Callable[[PIL.Image], torch.Tensor] 136 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 137 | """ 138 | if name in _MODELS: 139 | model_path = _download(_MODELS[name]) 140 | elif os.path.isfile(name): 141 | model_path = name 142 | else: 143 | raise RuntimeError( 144 | f"Model {name} not found; available models = {available_models()}" 145 | ) 146 | 147 | try: 148 | # loading JIT archive 149 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 150 | state_dict = None 151 | except RuntimeError: 152 | # loading saved state dict 153 | if jit: 154 | warnings.warn( 155 | f"File {model_path} is not a JIT archive. Loading as a state dict instead" 156 | ) 157 | jit = False 158 | state_dict = torch.load(model_path, map_location="cpu") 159 | 160 | if not jit: 161 | model = build_model(state_dict or model.state_dict()).to(device) 162 | if str(device) == "cpu": 163 | model.float() 164 | return model, _transform(model.visual.input_resolution) 165 | 166 | # patch the device names 167 | device_holder = torch.jit.trace( 168 | lambda: torch.ones([]).to(torch.device(device)), example_inputs=[] 169 | ) 170 | device_node = [ 171 | n 172 | for n in device_holder.graph.findAllNodes("prim::Constant") 173 | if "Device" in repr(n) 174 | ][-1] 175 | 176 | def patch_device(module): 177 | try: 178 | graphs = [module.graph] if hasattr(module, "graph") else [] 179 | except RuntimeError: 180 | graphs = [] 181 | 182 | if hasattr(module, "forward1"): 183 | graphs.append(module.forward1.graph) 184 | 185 | for graph in graphs: 186 | for node in graph.findAllNodes("prim::Constant"): 187 | if "value" in node.attributeNames() and str(node["value"]).startswith( 188 | "cuda" 189 | ): 190 | node.copyAttributes(device_node) 191 | 192 | model.apply(patch_device) 193 | patch_device(model.encode_image) 194 | patch_device(model.encode_text) 195 | 196 | # patch dtype to float32 on CPU 197 | if str(device) == "cpu": 198 | float_holder = torch.jit.trace( 199 | lambda: torch.ones([]).float(), example_inputs=[] 200 | ) 201 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 202 | float_node = float_input.node() 203 | 204 | def patch_float(module): 205 | try: 206 | graphs = [module.graph] if hasattr(module, "graph") else [] 207 | except RuntimeError: 208 | graphs = [] 209 | 210 | if hasattr(module, "forward1"): 211 | graphs.append(module.forward1.graph) 212 | 213 | for graph in graphs: 214 | for node in graph.findAllNodes("aten::to"): 215 | inputs = list(node.inputs()) 216 | for i in [ 217 | 1, 218 | 2, 219 | ]: # dtype can be the second or third argument to aten::to() 220 | if inputs[i].node()["value"] == 5: 221 | inputs[i].node().copyAttributes(float_node) 222 | 223 | model.apply(patch_float) 224 | patch_float(model.encode_image) 225 | patch_float(model.encode_text) 226 | 227 | model.float() 228 | 229 | return model, _transform(model.input_resolution.item()) 230 | 231 | 232 | def tokenize( 233 | texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False 234 | ) -> torch.LongTensor: 235 | """ 236 | Returns the tokenized representation of given input string(s) 237 | 238 | Parameters 239 | ---------- 240 | texts : Union[str, List[str]] 241 | An input string or a list of input strings to tokenize 242 | 243 | context_length : int 244 | The context length to use; all CLIP models use 77 as the context length 245 | 246 | truncate: bool 247 | Whether to truncate the text in case its encoding is longer than the context length 248 | 249 | Returns 250 | ------- 251 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 252 | """ 253 | if isinstance(texts, str): 254 | texts = [texts] 255 | 256 | sot_token = _tokenizer.encoder["<|startoftext|>"] 257 | eot_token = _tokenizer.encoder["<|endoftext|>"] 258 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 259 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 260 | 261 | for i, tokens in enumerate(all_tokens): 262 | if len(tokens) > context_length: 263 | if truncate: 264 | tokens = tokens[:context_length] 265 | tokens[-1] = eot_token 266 | else: 267 | raise RuntimeError( 268 | f"Input {texts[i]} is too long for context length {context_length}" 269 | ) 270 | result[i, : len(tokens)] = torch.tensor(tokens) 271 | 272 | return result 273 | -------------------------------------------------------------------------------- /src/models/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Portions of this file are based on code from the “openai/CLIP” repository (MIT-licensed): 3 | # https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py 4 | """ 5 | 6 | import gzip 7 | import html 8 | import os 9 | from functools import lru_cache 10 | 11 | import ftfy 12 | import regex as re 13 | 14 | 15 | @lru_cache() 16 | def default_bpe(): 17 | return os.path.join( 18 | os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz" 19 | ) 20 | 21 | 22 | @lru_cache() 23 | def bytes_to_unicode(): 24 | """ 25 | Returns list of utf-8 byte and a corresponding list of unicode strings. 26 | The reversible bpe codes work on unicode strings. 27 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 28 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 29 | This is a signficant percentage of your normal, say, 32K bpe vocab. 30 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 31 | And avoids mapping to whitespace/control characters the bpe code barfs on. 32 | """ 33 | bs = ( 34 | list(range(ord("!"), ord("~") + 1)) 35 | + list(range(ord("¡"), ord("¬") + 1)) 36 | + list(range(ord("®"), ord("ÿ") + 1)) 37 | ) 38 | cs = bs[:] 39 | n = 0 40 | for b in range(2**8): 41 | if b not in bs: 42 | bs.append(b) 43 | cs.append(2**8 + n) 44 | n += 1 45 | cs = [chr(n) for n in cs] 46 | return dict(zip(bs, cs)) 47 | 48 | 49 | def get_pairs(word): 50 | """Return set of symbol pairs in a word. 51 | Word is represented as tuple of symbols (symbols being variable-length strings). 52 | """ 53 | pairs = set() 54 | prev_char = word[0] 55 | for char in word[1:]: 56 | pairs.add((prev_char, char)) 57 | prev_char = char 58 | return pairs 59 | 60 | 61 | def basic_clean(text): 62 | text = ftfy.fix_text(text) 63 | text = html.unescape(html.unescape(text)) 64 | return text.strip() 65 | 66 | 67 | def whitespace_clean(text): 68 | text = re.sub(r"\s+", " ", text) 69 | text = text.strip() 70 | return text 71 | 72 | 73 | class SimpleTokenizer(object): 74 | def __init__(self, bpe_path: str = default_bpe()): 75 | self.byte_encoder = bytes_to_unicode() 76 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 77 | merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") 78 | merges = merges[1 : 49152 - 256 - 2 + 1] 79 | merges = [tuple(merge.split()) for merge in merges] 80 | vocab = list(bytes_to_unicode().values()) 81 | vocab = vocab + [v + "" for v in vocab] 82 | for merge in merges: 83 | vocab.append("".join(merge)) 84 | vocab.extend(["<|startoftext|>", "<|endoftext|>"]) 85 | self.encoder = dict(zip(vocab, range(len(vocab)))) 86 | self.decoder = {v: k for k, v in self.encoder.items()} 87 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 88 | self.cache = { 89 | "<|startoftext|>": "<|startoftext|>", 90 | "<|endoftext|>": "<|endoftext|>", 91 | } 92 | self.pat = re.compile( 93 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", 94 | re.IGNORECASE, 95 | ) 96 | 97 | def bpe(self, token): 98 | if token in self.cache: 99 | return self.cache[token] 100 | word = tuple(token[:-1]) + (token[-1] + "",) 101 | pairs = get_pairs(word) 102 | 103 | if not pairs: 104 | return token + "" 105 | 106 | while True: 107 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) 108 | if bigram not in self.bpe_ranks: 109 | break 110 | first, second = bigram 111 | new_word = [] 112 | i = 0 113 | while i < len(word): 114 | try: 115 | j = word.index(first, i) 116 | new_word.extend(word[i:j]) 117 | i = j 118 | except: # noqa: E722 119 | new_word.extend(word[i:]) 120 | break 121 | 122 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 123 | new_word.append(first + second) 124 | i += 2 125 | else: 126 | new_word.append(word[i]) 127 | i += 1 128 | new_word = tuple(new_word) 129 | word = new_word 130 | if len(word) == 1: 131 | break 132 | else: 133 | pairs = get_pairs(word) 134 | word = " ".join(word) 135 | self.cache[token] = word 136 | return word 137 | 138 | def encode(self, text): 139 | bpe_tokens = [] 140 | text = whitespace_clean(basic_clean(text)).lower() 141 | for token in re.findall(self.pat, text): 142 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) 143 | bpe_tokens.extend( 144 | self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") 145 | ) 146 | return bpe_tokens 147 | 148 | def decode(self, tokens): 149 | text = "".join([self.decoder[token] for token in tokens]) 150 | text = ( 151 | bytearray([self.byte_decoder[c] for c in text]) 152 | .decode("utf-8", errors="replace") 153 | .replace("", " ") 154 | ) 155 | return text 156 | -------------------------------------------------------------------------------- /src/models/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .maple import get_maple_config # noqa: F401 2 | -------------------------------------------------------------------------------- /src/models/config/default_config.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode 2 | 3 | 4 | def get_default_config(): 5 | """Initialize default configuration. 6 | 7 | Returns: 8 | CfgNode: Default config object containing all settings 9 | """ 10 | cfg = CfgNode() 11 | 12 | # Basic settings 13 | cfg.VERSION = 1 14 | cfg.OUTPUT_DIR = "./output" # Directory for output files 15 | cfg.RESUME = "" # Path to previous output directory 16 | cfg.SEED = -1 # Negative for random, positive for fixed seed 17 | cfg.USE_CUDA = True 18 | cfg.VERBOSE = True # Print detailed info 19 | 20 | # Input settings 21 | cfg.INPUT = CfgNode() 22 | cfg.INPUT.SIZE = (224, 224) 23 | cfg.INPUT.INTERPOLATION = "bilinear" 24 | cfg.INPUT.TRANSFORMS = () 25 | cfg.INPUT.NO_TRANSFORM = False 26 | cfg.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406] # ImageNet mean 27 | cfg.INPUT.PIXEL_STD = [0.229, 0.224, 0.225] # ImageNet std 28 | cfg.INPUT.CROP_PADDING = 4 29 | cfg.INPUT.RRCROP_SCALE = (0.08, 1.0) 30 | cfg.INPUT.CUTOUT_N = 1 31 | cfg.INPUT.CUTOUT_LEN = 16 32 | cfg.INPUT.GN_MEAN = 0.0 33 | cfg.INPUT.GN_STD = 0.15 34 | cfg.INPUT.RANDAUGMENT_N = 2 35 | cfg.INPUT.RANDAUGMENT_M = 10 36 | cfg.INPUT.COLORJITTER_B = 0.4 37 | cfg.INPUT.COLORJITTER_C = 0.4 38 | cfg.INPUT.COLORJITTER_S = 0.4 39 | cfg.INPUT.COLORJITTER_H = 0.1 40 | cfg.INPUT.RGS_P = 0.2 41 | cfg.INPUT.GB_P = 0.5 42 | cfg.INPUT.GB_K = 21 43 | 44 | # Dataset settings 45 | cfg.DATASET = CfgNode() 46 | cfg.DATASET.ROOT = "" 47 | cfg.DATASET.NAME = "" 48 | cfg.DATASET.SOURCE_DOMAINS = () 49 | cfg.DATASET.TARGET_DOMAINS = () 50 | cfg.DATASET.NUM_LABELED = -1 51 | cfg.DATASET.NUM_SHOTS = -1 52 | cfg.DATASET.VAL_PERCENT = 0.1 53 | cfg.DATASET.STL10_FOLD = -1 54 | cfg.DATASET.CIFAR_C_TYPE = "" 55 | cfg.DATASET.CIFAR_C_LEVEL = 1 56 | cfg.DATASET.ALL_AS_UNLABELED = False 57 | 58 | # Dataloader settings 59 | cfg.DATALOADER = CfgNode() 60 | cfg.DATALOADER.NUM_WORKERS = 4 61 | cfg.DATALOADER.K_TRANSFORMS = 1 62 | cfg.DATALOADER.RETURN_IMG0 = False 63 | 64 | cfg.DATALOADER.TRAIN_X = CfgNode() 65 | cfg.DATALOADER.TRAIN_X.SAMPLER = "RandomSampler" 66 | cfg.DATALOADER.TRAIN_X.BATCH_SIZE = 32 67 | cfg.DATALOADER.TRAIN_X.N_DOMAIN = 0 68 | cfg.DATALOADER.TRAIN_X.N_INS = 16 69 | 70 | cfg.DATALOADER.TRAIN_U = CfgNode() 71 | cfg.DATALOADER.TRAIN_U.SAME_AS_X = True 72 | cfg.DATALOADER.TRAIN_U.SAMPLER = "RandomSampler" 73 | cfg.DATALOADER.TRAIN_U.BATCH_SIZE = 32 74 | cfg.DATALOADER.TRAIN_U.N_DOMAIN = 0 75 | cfg.DATALOADER.TRAIN_U.N_INS = 16 76 | 77 | cfg.DATALOADER.TEST = CfgNode() 78 | cfg.DATALOADER.TEST.SAMPLER = "SequentialSampler" 79 | cfg.DATALOADER.TEST.BATCH_SIZE = 32 80 | 81 | # Model settings 82 | cfg.MODEL = CfgNode() 83 | cfg.MODEL.INIT_WEIGHTS = "" 84 | cfg.MODEL.BACKBONE = CfgNode() 85 | cfg.MODEL.BACKBONE.NAME = "" 86 | cfg.MODEL.BACKBONE.PRETRAINED = True 87 | 88 | cfg.MODEL.HEAD = CfgNode() 89 | cfg.MODEL.HEAD.NAME = "" 90 | cfg.MODEL.HEAD.HIDDEN_LAYERS = () 91 | cfg.MODEL.HEAD.ACTIVATION = "relu" 92 | cfg.MODEL.HEAD.BN = True 93 | cfg.MODEL.HEAD.DROPOUT = 0.0 94 | 95 | # Optimization settings 96 | cfg.OPTIM = CfgNode() 97 | cfg.OPTIM.NAME = "adam" 98 | cfg.OPTIM.LR = 0.0003 99 | cfg.OPTIM.WEIGHT_DECAY = 5e-4 100 | cfg.OPTIM.MOMENTUM = 0.9 101 | cfg.OPTIM.SGD_DAMPNING = 0 102 | cfg.OPTIM.SGD_NESTEROV = False 103 | cfg.OPTIM.RMSPROP_ALPHA = 0.99 104 | cfg.OPTIM.ADAM_BETA1 = 0.9 105 | cfg.OPTIM.ADAM_BETA2 = 0.999 106 | cfg.OPTIM.STAGED_LR = False 107 | cfg.OPTIM.NEW_LAYERS = () 108 | cfg.OPTIM.BASE_LR_MULT = 0.1 109 | cfg.OPTIM.LR_SCHEDULER = "single_step" 110 | cfg.OPTIM.STEPSIZE = (-1,) 111 | cfg.OPTIM.GAMMA = 0.1 112 | cfg.OPTIM.MAX_EPOCH = 10 113 | cfg.OPTIM.WARMUP_EPOCH = -1 114 | cfg.OPTIM.WARMUP_TYPE = "linear" 115 | cfg.OPTIM.WARMUP_CONS_LR = 1e-5 116 | cfg.OPTIM.WARMUP_MIN_LR = 1e-5 117 | cfg.OPTIM.WARMUP_RECOUNT = True 118 | 119 | # Training settings 120 | cfg.TRAIN = CfgNode() 121 | cfg.TRAIN.CHECKPOINT_FREQ = 0 122 | cfg.TRAIN.PRINT_FREQ = 10 123 | cfg.TRAIN.COUNT_ITER = "train_x" 124 | 125 | # Testing settings 126 | cfg.TEST = CfgNode() 127 | cfg.TEST.EVALUATOR = "Classification" 128 | cfg.TEST.PER_CLASS_RESULT = False 129 | cfg.TEST.COMPUTE_CMAT = False 130 | cfg.TEST.NO_TEST = False 131 | cfg.TEST.SPLIT = "test" 132 | cfg.TEST.FINAL_MODEL = "last_step" 133 | 134 | # Trainer settings 135 | cfg.TRAINER = CfgNode() 136 | cfg.TRAINER.NAME = "" 137 | 138 | # Domain adaptation settings 139 | cfg.TRAINER.MCD = CfgNode() 140 | cfg.TRAINER.MCD.N_STEP_F = 4 141 | 142 | cfg.TRAINER.MME = CfgNode() 143 | cfg.TRAINER.MME.LMDA = 0.1 144 | 145 | cfg.TRAINER.CDAC = CfgNode() 146 | 147 | return cfg 148 | -------------------------------------------------------------------------------- /src/models/config/maple.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode 2 | 3 | from src.models.config.default_config import get_default_config 4 | 5 | 6 | def get_maple_config(custom_clip_cfg=None): 7 | """Get configuration for MaPLe model. 8 | 9 | Args: 10 | custom_clip_cfg: Optional custom CLIP config 11 | 12 | Returns: 13 | Config object with MaPLe settings 14 | """ 15 | cfg = get_default_config() 16 | cfg.TRAINER = CfgNode() 17 | cfg.TRAINER.MAPLE = CfgNode() 18 | cfg.TRAINER.MAPLE.N_CTX = 2 # number of context vectors at the vision branch 19 | cfg.TRAINER.MAPLE.CTX_INIT = ( 20 | "a photo of a" # initialization words (only for language prompts) 21 | ) 22 | cfg.TRAINER.MAPLE.PREC = "fp16" # fp16, fp32, amp 23 | # If both variables below are set to 0, 0, will the config will degenerate to COOP model 24 | cfg.TRAINER.MAPLE.PROMPT_DEPTH = ( 25 | 9 # Max 12, minimum 0, for 0 it will act as shallow IVLP prompting (J=1) 26 | ) 27 | cfg.DATASET = CfgNode() 28 | cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new 29 | 30 | cfg.TRAINER.NAME = "MaPLe" 31 | cfg.MODEL = CfgNode() 32 | cfg.MODEL.BACKBONE = CfgNode() 33 | cfg.MODEL.BACKBONE.NAME = "ViT-B/16" 34 | 35 | cfg.merge_from_file(custom_clip_cfg) 36 | 37 | return cfg 38 | -------------------------------------------------------------------------------- /src/models/templates/openai_imagenet_templates.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on "openai/CLIP" repository (MIT-licensed): 3 | https://github.com/openai/CLIP/blob/main/data/prompts.md 4 | """ 5 | 6 | openai_imagenet_template = [ 7 | lambda c: f"a bad photo of a {c}.", 8 | lambda c: f"a photo of many {c}.", 9 | lambda c: f"a sculpture of a {c}.", 10 | lambda c: f"a photo of the hard to see {c}.", 11 | lambda c: f"a low resolution photo of the {c}.", 12 | lambda c: f"a rendering of a {c}.", 13 | lambda c: f"graffiti of a {c}.", 14 | lambda c: f"a bad photo of the {c}.", 15 | lambda c: f"a cropped photo of the {c}.", 16 | lambda c: f"a tattoo of a {c}.", 17 | lambda c: f"the embroidered {c}.", 18 | lambda c: f"a photo of a hard to see {c}.", 19 | lambda c: f"a bright photo of a {c}.", 20 | lambda c: f"a photo of a clean {c}.", 21 | lambda c: f"a photo of a dirty {c}.", 22 | lambda c: f"a dark photo of the {c}.", 23 | lambda c: f"a drawing of a {c}.", 24 | lambda c: f"a photo of my {c}.", 25 | lambda c: f"the plastic {c}.", 26 | lambda c: f"a photo of the cool {c}.", 27 | lambda c: f"a close-up photo of a {c}.", 28 | lambda c: f"a black and white photo of the {c}.", 29 | lambda c: f"a painting of the {c}.", 30 | lambda c: f"a painting of a {c}.", 31 | lambda c: f"a pixelated photo of the {c}.", 32 | lambda c: f"a sculpture of the {c}.", 33 | lambda c: f"a bright photo of the {c}.", 34 | lambda c: f"a cropped photo of a {c}.", 35 | lambda c: f"a plastic {c}.", 36 | lambda c: f"a photo of the dirty {c}.", 37 | lambda c: f"a jpeg corrupted photo of a {c}.", 38 | lambda c: f"a blurry photo of the {c}.", 39 | lambda c: f"a photo of the {c}.", 40 | lambda c: f"a good photo of the {c}.", 41 | lambda c: f"a rendering of the {c}.", 42 | lambda c: f"a {c} in a video game.", 43 | lambda c: f"a photo of one {c}.", 44 | lambda c: f"a doodle of a {c}.", 45 | lambda c: f"a close-up photo of the {c}.", 46 | lambda c: f"a photo of a {c}.", 47 | lambda c: f"the origami {c}.", 48 | lambda c: f"the {c} in a video game.", 49 | lambda c: f"a sketch of a {c}.", 50 | lambda c: f"a doodle of the {c}.", 51 | lambda c: f"a origami {c}.", 52 | lambda c: f"a low resolution photo of a {c}.", 53 | lambda c: f"the toy {c}.", 54 | lambda c: f"a rendition of the {c}.", 55 | lambda c: f"a photo of the clean {c}.", 56 | lambda c: f"a photo of a large {c}.", 57 | lambda c: f"a rendition of a {c}.", 58 | lambda c: f"a photo of a nice {c}.", 59 | lambda c: f"a photo of a weird {c}.", 60 | lambda c: f"a blurry photo of a {c}.", 61 | lambda c: f"a cartoon {c}.", 62 | lambda c: f"art of a {c}.", 63 | lambda c: f"a sketch of the {c}.", 64 | lambda c: f"a embroidered {c}.", 65 | lambda c: f"a pixelated photo of a {c}.", 66 | lambda c: f"itap of the {c}.", 67 | lambda c: f"a jpeg corrupted photo of the {c}.", 68 | lambda c: f"a good photo of a {c}.", 69 | lambda c: f"a plushie {c}.", 70 | lambda c: f"a photo of the nice {c}.", 71 | lambda c: f"a photo of the small {c}.", 72 | lambda c: f"a photo of the weird {c}.", 73 | lambda c: f"the cartoon {c}.", 74 | lambda c: f"art of the {c}.", 75 | lambda c: f"a drawing of the {c}.", 76 | lambda c: f"a photo of the large {c}.", 77 | lambda c: f"a black and white photo of a {c}.", 78 | lambda c: f"the plushie {c}.", 79 | lambda c: f"a dark photo of a {c}.", 80 | lambda c: f"itap of a {c}.", 81 | lambda c: f"graffiti of the {c}.", 82 | lambda c: f"a toy {c}.", 83 | lambda c: f"itap of my {c}.", 84 | lambda c: f"a photo of a cool {c}.", 85 | lambda c: f"a photo of a small {c}.", 86 | lambda c: f"a tattoo of the {c}.", 87 | ] 88 | -------------------------------------------------------------------------------- /src/models/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from functools import partial 4 | 5 | import torch 6 | from transformers import CLIPModel, CLIPProcessor 7 | 8 | from src.models.clip import clip 9 | from src.models.config.maple import get_maple_config 10 | 11 | 12 | def load_clip_model(cfg, model_type: str): 13 | """Load and configure a CLIP model.""" 14 | backbone_name = cfg.MODEL.BACKBONE.NAME 15 | url = clip._MODELS[backbone_name] 16 | model_path = clip._download(url) 17 | try: 18 | # loading JIT archive 19 | model = torch.jit.load(model_path, map_location="cpu").eval() 20 | state_dict = None 21 | except RuntimeError: 22 | state_dict = torch.load(model_path, map_location="cpu") 23 | 24 | design_details = { 25 | "trainer": model_type, 26 | "vision_depth": 0, 27 | "language_depth": 0, 28 | "vision_ctx": 0, 29 | "language_ctx": 0, 30 | } 31 | 32 | if model_type == "maple": 33 | design_details["trainer"] = "MaPLe" 34 | design_details["maple_length"] = cfg.TRAINER.MAPLE.N_CTX 35 | 36 | model = clip.build_model(state_dict or model.state_dict(), design_details) 37 | return model 38 | 39 | 40 | def load_checkpoint(fpath: str) -> dict: 41 | """Load a model checkpoint file. 42 | 43 | Handles loading both Python 3 and Python 2 saved checkpoints by catching UnicodeDecodeError. 44 | 45 | Args: 46 | fpath: Path to the checkpoint file 47 | 48 | Returns: 49 | The loaded checkpoint dictionary 50 | 51 | Raises: 52 | ValueError: If fpath is None 53 | FileNotFoundError: If checkpoint file does not exist 54 | Exception: If checkpoint cannot be loaded 55 | 56 | Examples: 57 | >>> checkpoint = load_checkpoint('models/checkpoint.pth') 58 | """ 59 | if fpath is None: 60 | raise ValueError("Checkpoint path cannot be None") 61 | 62 | if not os.path.exists(fpath): 63 | raise FileNotFoundError(f"No checkpoint file found at {fpath}") 64 | 65 | device = None if torch.cuda.is_available() else "cpu" 66 | 67 | try: 68 | return torch.load(fpath, map_location=device) 69 | 70 | except UnicodeDecodeError: 71 | # Handle Python 2 checkpoints 72 | pickle.load = partial(pickle.load, encoding="latin1") 73 | pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1") 74 | return torch.load(fpath, pickle_module=pickle, map_location=device) 75 | 76 | except Exception as e: 77 | raise Exception(f"Failed to load checkpoint from {fpath}: {str(e)}") 78 | 79 | 80 | def _remove_prompt_learner_tokens(state_dict: dict) -> dict: 81 | """Remove prompt learner token vectors from state dict.""" 82 | token_keys = ["prompt_learner.token_prefix", "prompt_learner.token_suffix"] 83 | for key in token_keys: 84 | if key in state_dict: 85 | del state_dict[key] 86 | return state_dict 87 | 88 | 89 | def load_state_dict_without_prompt_learner(ckpt_path: str) -> dict: 90 | """Load checkpoint and remove prompt learner token vectors.""" 91 | checkpoint = load_checkpoint(ckpt_path) 92 | state_dict = checkpoint["state_dict"] 93 | return _remove_prompt_learner_tokens(state_dict) 94 | 95 | 96 | def get_base_clip(backbone: str) -> tuple[CLIPModel, CLIPProcessor]: 97 | """Load base CLIP model and processor.""" 98 | model = CLIPModel.from_pretrained(backbone) 99 | processor = CLIPProcessor.from_pretrained(backbone) 100 | return model, processor 101 | 102 | 103 | def get_adapted_clip( 104 | cfg, 105 | model_type: str, 106 | model_path: str, 107 | config_path: str, 108 | backbone: str, 109 | classnames: list[str], 110 | ) -> tuple[CLIPModel, CLIPProcessor]: 111 | """Load and configure adapted CLIP model with custom prompt learning. 112 | 113 | Args: 114 | cfg: Model configuration 115 | model_type: Type of prompt learning ('maple') 116 | model_path: Path to model checkpoint 117 | classnames: Optional list of class names 118 | 119 | Returns: 120 | Tuple of (model, processor) 121 | """ 122 | if model_type == "maple": 123 | cfg = get_maple_config(custom_clip_cfg=config_path) 124 | 125 | clip_model = load_clip_model(cfg, model_type) 126 | model_statedict = load_state_dict_without_prompt_learner(model_path) 127 | 128 | model_types = { 129 | "maple": "src.models.architecture.maple", 130 | } 131 | 132 | if model_type not in model_types: 133 | raise ValueError(f"Unsupported model type: {model_type}") 134 | 135 | module = __import__(model_types[model_type], fromlist=["CustomCLIP"]) 136 | model = module.CustomCLIP(cfg, classnames, clip_model) 137 | 138 | model.load_state_dict(model_statedict, strict=False) 139 | 140 | processor = CLIPProcessor.from_pretrained(backbone) 141 | return model, processor 142 | -------------------------------------------------------------------------------- /src/sae_training/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Portions of this file are based on code from the “jbloomAus/SAELens” and "HugoFry/mats_sae_training_for_ViTs" repositories (MIT-licensed): 3 | https://github.com/jbloomAus/SAELens/blob/main/sae_lens/config.py 4 | https://github.com/HugoFry/mats_sae_training_for_ViTs/blob/main/sae_training/config.py 5 | """ 6 | 7 | from dataclasses import dataclass 8 | from typing import Optional 9 | 10 | import torch 11 | import wandb 12 | 13 | 14 | class Config: 15 | def __init__(self, config_dict): 16 | if not isinstance(config_dict, dict): 17 | config_dict = config_dict.__dict__ 18 | 19 | for key, value in config_dict.items(): 20 | if isinstance(value, dict): 21 | # Recursively convert nested dictionaries 22 | value = Config(value) 23 | setattr(self, key, value) 24 | 25 | 26 | @dataclass 27 | class ViTSAERunnerConfig: 28 | """ 29 | Configuration for training a sparse autoencoder on a vision transformer. 30 | """ 31 | 32 | # Data Generating Function (Model + Training Distibuion) 33 | custom_clip_ckpt_path: str = None 34 | class_token: bool = True 35 | image_width: int = 224 36 | image_height: int = 224 37 | model_name: str = "openai/clip-vit-base-patch32" 38 | module_name: str = "resid" 39 | block_layer: int = 10 40 | dataset_path: str = "evanarlian/imagenet_1k_resized_256" 41 | image_key: str = "image" 42 | label_key: str = "label" 43 | use_cached_activations: bool = False 44 | cached_activations_path: Optional[str] = ( 45 | None # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_point_head_index}" 46 | ) 47 | 48 | # SAE Parameters 49 | d_in: int = 768 50 | 51 | # Activation Store Parameters 52 | total_training_tokens: int = 2_000_000 53 | n_batches_in_store: int = 32 54 | store_size: Optional[int] = None 55 | max_batch_size_for_vit_forward_pass: int = 1024 56 | create_dataloader: bool = True 57 | 58 | # Misc 59 | device: str = "cpu" 60 | seed: int = 42 61 | dtype: torch.dtype = torch.float32 62 | 63 | # SAE Parameters 64 | b_dec_init_method: str = "geometric_median" 65 | expansion_factor: int = 4 66 | from_pretrained_path: Optional[str] = None 67 | gated_sae: bool = False 68 | 69 | # Training Parameters 70 | l1_coefficient: float = 1e-3 71 | lr: float = 3e-4 72 | lr_scheduler_name: str = "constant" # constant, constantwithwarmup, linearwarmupdecay, cosineannealing, cosineannealingwarmup 73 | lr_warm_up_steps: int = 500 74 | batch_size: int = 4096 75 | mse_cls_coefficient: float = 1.0 76 | 77 | # Resampling protocol args 78 | use_ghost_grads: bool = True 79 | feature_sampling_window: int = ( 80 | 2000 # May need to change this since by default I will use ghost grads 81 | ) 82 | feature_sampling_method: str = "anthropic" # None or Anthropic 83 | resample_batches: int = 32 84 | feature_reinit_scale: float = 0.2 85 | dead_feature_window: int = 1000 # unless this window is larger feature sampling, 86 | dead_feature_estimation_method: str = "no_fire" 87 | dead_feature_threshold: float = 1e-8 88 | 89 | # WANDB 90 | log_to_wandb: bool = True 91 | wandb_project: str = "mats-hugo" 92 | wandb_entity: str = None 93 | wandb_log_frequency: int = 10 94 | 95 | # Misc 96 | n_checkpoints: int = 0 97 | checkpoint_path: str = "checkpoints" 98 | 99 | image_key = "image" 100 | label_key = "label" 101 | 102 | def __post_init__(self): 103 | self.store_size = self.n_batches_in_store * self.batch_size 104 | 105 | # Autofill cached_activations_path unless the user overrode it 106 | if self.cached_activations_path is None: 107 | self.cached_activations_path = f"activations/{self.dataset_path.replace('/', '_')}/{self.model_name.replace('/', '_')}/{self.block_layer}_{self.module_name}" 108 | 109 | self.d_sae = self.d_in * self.expansion_factor 110 | 111 | self.run_name = f"{self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}" 112 | 113 | if self.feature_sampling_method not in [None, "l2", "anthropic"]: 114 | raise ValueError( 115 | f"feature_sampling_method must be None, l2, or anthropic. Got {self.feature_sampling_method}" 116 | ) 117 | 118 | if self.b_dec_init_method not in ["geometric_median", "mean", "zeros"]: 119 | raise ValueError( 120 | f"b_dec_init_method must be geometric_median, mean, or zeros. Got {self.b_dec_init_method}" 121 | ) 122 | if self.b_dec_init_method == "zeros": 123 | print( 124 | "Warning: We are initializing b_dec to zeros. This is probably not what you want." 125 | ) 126 | 127 | self.device = torch.device(self.device) 128 | 129 | unique_id = wandb.util.generate_id() 130 | self.checkpoint_path = f"{self.checkpoint_path}/{unique_id}" 131 | 132 | print( 133 | f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}" 134 | ) 135 | # Print out some useful info: 136 | 137 | total_training_steps = self.total_training_tokens // self.batch_size 138 | print(f"Total training steps: {total_training_steps}") 139 | 140 | total_wandb_updates = total_training_steps // self.wandb_log_frequency 141 | print(f"Total wandb updates: {total_wandb_updates}") 142 | 143 | # how many times will we sample dead neurons? 144 | # assert self.dead_feature_window <= self.feature_sampling_window, "dead_feature_window must be smaller than feature_sampling_window" 145 | n_dead_feature_samples = total_training_steps // self.dead_feature_window 146 | n_feature_window_samples = total_training_steps // self.feature_sampling_window 147 | print( 148 | f"n_tokens_per_feature_sampling_window (millions): {(self.feature_sampling_window * self.batch_size) / 10**6}" 149 | ) 150 | print( 151 | f"n_tokens_per_dead_feature_window (millions): {(self.dead_feature_window * self.batch_size) / 10**6}" 152 | ) 153 | if self.feature_sampling_method is not None: 154 | print(f"We will reset neurons {n_dead_feature_samples} times.") 155 | 156 | if self.use_ghost_grads: 157 | print("Using Ghost Grads.") 158 | 159 | print( 160 | f"We will reset the sparsity calculation {n_feature_window_samples} times." 161 | ) 162 | print( 163 | f"Number of tokens when resampling: {self.resample_batches * self.batch_size}" 164 | ) 165 | # print("Number tokens in dead feature calculation window: ", self.dead_feature_window * self.train_batch_size) 166 | print( 167 | f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.batch_size:.2e}" 168 | ) 169 | -------------------------------------------------------------------------------- /src/sae_training/hooked_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Portions of this file are based on code from the "HugoFry/mats_sae_training_for_ViTs" repository (MIT-licensed): 3 | https://github.com/HugoFry/mats_sae_training_for_ViTs/blob/main/sae_training/hooked_vit.py 4 | """ 5 | 6 | from contextlib import contextmanager 7 | from functools import partial 8 | from typing import Callable, List, Tuple 9 | 10 | import torch 11 | from jaxtyping import Float 12 | from torch import Tensor 13 | from torch.nn import functional as F 14 | from transformers import CLIPModel 15 | 16 | 17 | # The Hook class does not currently only supports hooking on the following locations: 18 | # 1 - residual stream post transformer block. 19 | # 2 - mlp activations. 20 | # More hooks can be added at a later date, but only post-module. 21 | class Hook: 22 | def __init__( 23 | self, 24 | block_layer: int, 25 | module_name: str, 26 | hook_fn: Callable, 27 | is_custom: bool = None, 28 | return_module_output=True, 29 | ): 30 | self.path_dict = { 31 | "resid": "", 32 | } 33 | assert module_name in self.path_dict.keys(), ( 34 | f"Module name '{module_name}' not recognised." 35 | ) 36 | self.return_module_output = return_module_output 37 | self.function = self.get_full_hook_fn(hook_fn) 38 | self.attr_path = self.get_attr_path(block_layer, module_name, is_custom) 39 | 40 | def get_full_hook_fn(self, hook_fn: Callable): 41 | def full_hook_fn(module, module_input, module_output): 42 | hook_fn_output = hook_fn(module_output[0]) 43 | if self.return_module_output: 44 | return module_output 45 | else: 46 | return hook_fn_output # Inexplicably, the module output is not a tensor of activaitons but a tuple (tensor,)...?? 47 | 48 | return full_hook_fn 49 | 50 | def get_attr_path( 51 | self, block_layer: int, module_name: str, is_custom: bool = None 52 | ) -> str: 53 | if is_custom: 54 | attr_path = f"image_encoder.transformer.resblocks[{block_layer}]" 55 | else: 56 | attr_path = f"vision_model.encoder.layers[{block_layer}]" 57 | attr_path += self.path_dict[module_name] 58 | return attr_path 59 | 60 | def get_module(self, model): 61 | return self.get_nested_attr(model, self.attr_path) 62 | 63 | def get_nested_attr(self, model, attr_path): 64 | """ 65 | Gets a nested attribute from an object using a dot-separated path. 66 | """ 67 | module = model 68 | attributes = attr_path.split(".") 69 | for attr in attributes: 70 | if "[" in attr: 71 | # Split at '[' and remove the trailing ']' from the index 72 | attr_name, index = attr[:-1].split("[") 73 | module = getattr(module, attr_name)[int(index)] 74 | else: 75 | module = getattr(module, attr) 76 | return module 77 | 78 | 79 | class HookedVisionTransformer: 80 | def __init__(self, model, processor, device="cuda"): 81 | self.model = model.to(device) 82 | self.processor = processor 83 | 84 | def run_with_cache( 85 | self, 86 | list_of_hook_locations: List[Tuple[int, str]], 87 | *args, 88 | return_type="output", 89 | **kwargs, 90 | ): 91 | cache_dict, list_of_hooks = self.get_caching_hooks(list_of_hook_locations) 92 | with self.hooks(list_of_hooks) as hooked_model: 93 | with torch.no_grad(): 94 | output = hooked_model(*args, **kwargs) 95 | 96 | if return_type == "output": 97 | return output, cache_dict 98 | if return_type == "loss": 99 | return ( 100 | self.contrastive_loss(output.logits_per_image, output.logits_per_text), 101 | cache_dict, 102 | ) 103 | else: 104 | raise Exception( 105 | f"Unrecognised keyword argument return_type='{return_type}'. Must be either 'output' or 'loss'." 106 | ) 107 | 108 | def get_caching_hooks(self, list_of_hook_locations: List[Tuple[int, str]]): 109 | """ 110 | Note that the cache dictionary is index by the tuple (block_layer, module_name). 111 | """ 112 | cache_dict = {} 113 | list_of_hooks = [] 114 | 115 | def save_activations(name, activations): 116 | cache_dict[name] = activations.detach() 117 | 118 | for block_layer, module_name in list_of_hook_locations: 119 | hook_fn = partial(save_activations, (block_layer, module_name)) 120 | if isinstance(self.model, CLIPModel): 121 | is_custom = False 122 | else: 123 | is_custom = True 124 | hook = Hook(block_layer, module_name, hook_fn, is_custom=is_custom) 125 | list_of_hooks.append(hook) 126 | return cache_dict, list_of_hooks 127 | 128 | @torch.no_grad() 129 | def run_with_hooks( 130 | self, list_of_hooks: List[Hook], *args, return_type="output", **kwargs 131 | ): 132 | with self.hooks(list_of_hooks) as hooked_model: 133 | with torch.no_grad(): 134 | output = hooked_model(*args, **kwargs) 135 | if return_type == "output": 136 | return output 137 | if return_type == "loss": 138 | return self.contrastive_loss( 139 | output.logits_per_image, output.logits_per_text 140 | ) 141 | else: 142 | raise Exception( 143 | f"Unrecognised keyword argument return_type='{return_type}'. Must be either 'output' or 'loss'." 144 | ) 145 | 146 | def train_with_hooks( 147 | self, list_of_hooks: List[Hook], *args, return_type="output", **kwargs 148 | ): 149 | with self.hooks(list_of_hooks) as hooked_model: 150 | output = hooked_model(*args, **kwargs) 151 | if return_type == "output": 152 | return output 153 | if return_type == "loss": 154 | return self.contrastive_loss( 155 | output.logits_per_image, output.logits_per_text 156 | ) 157 | else: 158 | raise Exception( 159 | f"Unrecognised keyword argument return_type='{return_type}'. Must be either 'output' or 'loss'." 160 | ) 161 | 162 | def contrastive_loss( 163 | self, 164 | logits_per_image: Float[Tensor, "n_images n_prompts"], # noqa: F722 165 | logits_per_text: Float[Tensor, "n_prompts n_images"], # noqa: F722 166 | ): # Assumes square matrices 167 | assert logits_per_image.size()[0] == logits_per_image.size()[1], ( 168 | "The number of prompts does not match the number of images." 169 | ) 170 | batch_size = logits_per_image.size()[0] 171 | labels = torch.arange(batch_size).long().to(logits_per_image.device) 172 | image_loss = F.cross_entropy(logits_per_image, labels) 173 | text_loss = F.cross_entropy(logits_per_text, labels) 174 | total_loss = (image_loss + text_loss) / 2 175 | return total_loss 176 | 177 | @contextmanager 178 | def hooks(self, hooks: List[Hook]): 179 | """ 180 | 181 | This is a context manager for running a model with hooks. The funciton adds 182 | forward hooks to the model, and then returns the hooked model to be run with 183 | a foward pass. The funciton then cleans up by removing any hooks. 184 | 185 | Args: 186 | 187 | model VisionTransformer: The ViT that you want to run with the forward hook 188 | 189 | hooks List[Tuple[str, Callable]]: A list of forward hooks to add to the model. 190 | Each hook is a tuple of the module name, and the hook funciton. 191 | 192 | """ 193 | hook_handles = [] 194 | try: 195 | for hook in hooks: 196 | # Create a full hook funciton, with all the argumnets needed to run nn.module.register_forward_hook(). 197 | # The hook functions are added to the output of the module. 198 | module = hook.get_module(self.model) 199 | handle = module.register_forward_hook(hook.function) 200 | hook_handles.append(handle) 201 | yield self.model 202 | finally: 203 | for handle in hook_handles: 204 | handle.remove() 205 | 206 | def to(self, device): 207 | self.model = self.model.to(device) 208 | 209 | def __call__(self, *args, return_type="output", **kwargs): 210 | return self.forward(*args, return_type=return_type, **kwargs) 211 | 212 | def forward(self, *args, return_type="output", **kwargs): 213 | if return_type == "output": 214 | return self.model(*args, **kwargs) 215 | elif return_type == "loss": 216 | output = self.model(*args, **kwargs) 217 | return self.contrastive_loss( 218 | output.logits_per_image, output.logits_per_text 219 | ) 220 | else: 221 | raise Exception( 222 | f"Unrecognised keyword argument return_type='{return_type}'. Must be either 'output' or 'loss'." 223 | ) 224 | 225 | def eval(self): 226 | self.model.eval() 227 | 228 | def train(self): 229 | self.model.train() 230 | -------------------------------------------------------------------------------- /src/sae_training/sae_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Portions of this file are based on code from the “jbloomAus/SAELens” and "HugoFry/mats_sae_training_for_ViTs" repositories (MIT-licensed): 3 | https://github.com/jbloomAus/SAELens/blob/main/sae_lens/training/sae_trainer.py 4 | https://github.com/HugoFry/mats_sae_training_for_ViTs/blob/main/sae_training/config.py 5 | """ 6 | 7 | from typing import Any 8 | 9 | import torch 10 | import wandb 11 | from tqdm import tqdm 12 | 13 | from src.sae_training.config import Config 14 | from src.sae_training.hooked_vit import Hook, HookedVisionTransformer 15 | from src.sae_training.sparse_autoencoder import SparseAutoencoder 16 | from src.sae_training.vit_activations_store import ViTActivationsStore 17 | 18 | 19 | class SAETrainer: 20 | def __init__( 21 | self, 22 | sae: SparseAutoencoder, 23 | model: HookedVisionTransformer, 24 | activation_store: ViTActivationsStore, 25 | cfg: Config, 26 | optimizer: torch.optim.Optimizer, 27 | scheduler: torch.optim.lr_scheduler._LRScheduler, 28 | device: torch.device, 29 | ): 30 | self.sae = sae 31 | self.model = model 32 | self.activation_store = activation_store 33 | self.cfg = cfg 34 | self.optimizer = optimizer 35 | self.scheduler = scheduler 36 | self.device = device 37 | 38 | self.act_freq_scores = torch.zeros(sae.cfg.d_sae, device=device) 39 | self.n_forward_passes_since_fired = torch.zeros(sae.cfg.d_sae, device=device) 40 | self.n_frac_active_tokens = 0 41 | self.n_training_tokens = 0 42 | self.ghost_grad_neuron_mask = None 43 | self.n_training_steps = 0 44 | 45 | self.checkpoint_thresholds = list( 46 | range( 47 | 0, 48 | cfg.total_training_tokens, 49 | cfg.total_training_tokens // self.cfg.n_checkpoints, 50 | ) 51 | )[1:] 52 | 53 | def _build_sparsity_log_dict(self) -> dict[str, Any]: 54 | feature_freq = self.act_freq_scores / self.n_frac_active_tokens 55 | log_feature_freq = torch.log10(feature_freq + 1e-10).detach().cpu() 56 | 57 | return { 58 | "plots/feature_density_line_chart": wandb.Histogram( 59 | log_feature_freq.numpy() 60 | ), 61 | "metrics/mean_log10_feature_sparsity": log_feature_freq.mean().item(), 62 | } 63 | 64 | @torch.no_grad() 65 | def _reset_running_sparsity_stats(self) -> None: 66 | self.act_freq_scores = torch.zeros(self.cfg.d_sae, device=self.device) 67 | self.n_frac_active_tokens = 0 68 | 69 | def _train_step( 70 | self, 71 | sae_in: torch.Tensor, 72 | ): 73 | self.optimizer.zero_grad() 74 | 75 | self.sae.train() 76 | self.sae.set_decoder_norm_to_unit_norm() 77 | 78 | # log and then reset the feature sparsity every feature_sampling_window steps 79 | if (self.n_training_steps + 1) % self.cfg.feature_sampling_window == 0: 80 | if self.cfg.log_to_wandb: 81 | sparsity_log_dict = self._build_sparsity_log_dict() 82 | wandb.log(sparsity_log_dict, step=self.n_training_steps) 83 | self._reset_running_sparsity_stats() 84 | 85 | ghost_grad_neuron_mask = ( 86 | self.n_forward_passes_since_fired > self.cfg.dead_feature_window 87 | ).bool() 88 | sae_out, feature_acts, loss_dict = self.sae(sae_in, ghost_grad_neuron_mask) 89 | 90 | with torch.no_grad(): 91 | if self.cfg.class_token: 92 | did_fire = (feature_acts > 0).float().sum(-2) > 0 93 | self.act_freq_scores += (feature_acts.abs() > 0).float().sum(0) 94 | 95 | else: 96 | # default for PatchSAE 97 | did_fire = (((feature_acts > 0).float().sum(-2) > 0).sum(-2)) > 0 98 | self.act_freq_scores += (feature_acts.abs() > 0).float().sum(0).sum(0) 99 | 100 | self.n_forward_passes_since_fired += 1 101 | self.n_forward_passes_since_fired[did_fire] = 0 102 | self.n_frac_active_tokens += sae_out.size(0) 103 | 104 | self.ghost_grad_neuron_mask = ghost_grad_neuron_mask 105 | 106 | loss_dict["loss"].backward() 107 | self.sae.remove_gradient_parallel_to_decoder_directions() 108 | 109 | self.optimizer.step() 110 | self.scheduler.step() 111 | 112 | return sae_out, feature_acts, loss_dict 113 | 114 | def _calculate_sparsity_metrics(self) -> dict: 115 | """Calculate sparsity-related metrics.""" 116 | feature_freq = self.act_freq_scores / self.n_frac_active_tokens 117 | 118 | return { 119 | "sparsity/mean_passes_since_fired": self.n_forward_passes_since_fired.mean().item(), 120 | "sparsity/n_passes_since_fired_over_threshold": self.ghost_grad_neuron_mask.sum().item(), 121 | "sparsity/below_1e-5": (feature_freq < 1e-5).float().mean().item(), 122 | "sparsity/below_1e-6": (feature_freq < 1e-6).float().mean().item(), 123 | "sparsity/dead_features": (feature_freq < self.cfg.dead_feature_threshold) 124 | .float() 125 | .mean() 126 | .item(), 127 | } 128 | 129 | @torch.no_grad() 130 | def _log_train_step( 131 | self, 132 | feature_acts: torch.Tensor, 133 | loss_dict: dict[str, torch.Tensor], 134 | sae_out: torch.Tensor, 135 | sae_in: torch.Tensor, 136 | ): 137 | """Log training metrics to wandb.""" 138 | metrics = self._calculate_metrics(feature_acts, sae_out, sae_in) 139 | sparsity_metrics = self._calculate_sparsity_metrics() 140 | 141 | log_dict = { 142 | "losses/overall_loss": loss_dict["loss"].item(), 143 | "losses/mse_loss": loss_dict["mse_loss"].item(), 144 | "losses/l1_loss": loss_dict["l1_loss"].item(), 145 | "losses/ghost_grad_loss": loss_dict["mse_loss_ghost_resid"].item(), 146 | **metrics, 147 | **sparsity_metrics, 148 | "details/n_training_tokens": self.n_training_tokens, 149 | "details/current_learning_rate": self.optimizer.param_groups[0]["lr"], 150 | } 151 | 152 | wandb.log(log_dict, step=self.n_training_steps) 153 | 154 | @torch.no_grad() 155 | def _calculate_metrics( 156 | self, feature_acts: torch.Tensor, sae_out: torch.Tensor, sae_in: torch.Tensor 157 | ) -> dict: 158 | """Calculate model performance metrics.""" 159 | if self.cfg.class_token: 160 | l0 = (feature_acts > 0).float().sum(-1).mean() 161 | else: 162 | l0 = (feature_acts > 0).float().sum(-1).mean(-1).mean() 163 | per_token_l2_loss = (sae_out - sae_in).pow(2).sum(dim=-1).mean().squeeze() 164 | total_variance = sae_in.pow(2).sum(-1).mean() 165 | explained_variance = 1 - per_token_l2_loss / total_variance 166 | 167 | return { 168 | "metrics/explained_variance": explained_variance.mean().item(), 169 | "metrics/explained_variance_std": explained_variance.std().item(), 170 | "metrics/l0": l0.item(), 171 | } 172 | 173 | @torch.no_grad() 174 | def _update_pbar(self, loss_dict, pbar, batch_size): 175 | pbar.set_description( 176 | f"{self.n_training_steps}| MSE Loss {loss_dict['mse_loss'].item():.3f} | L1 {loss_dict['l1_loss'].item():.3f}" 177 | ) 178 | pbar.update(batch_size) 179 | 180 | @torch.no_grad() 181 | def _checkpoint_if_needed(self): 182 | if ( 183 | self.checkpoint_thresholds 184 | and self.n_training_tokens > self.checkpoint_thresholds[0] 185 | ): 186 | self.save_checkpoint() 187 | self.run_evals() # TODO: Implement this 188 | self.checkpoint_thresholds.pop(0) 189 | 190 | def save_checkpoint(self, is_final=False): 191 | if is_final: 192 | path = f"{self.cfg.checkpoint_path}/final_{self.sae.get_name()}.pt" 193 | else: 194 | path = f"{self.cfg.checkpoint_path}/{self.n_training_tokens}_{self.sae.get_name()}.pt" 195 | self.sae.save_model(path) 196 | 197 | def fit(self) -> SparseAutoencoder: 198 | pbar = tqdm(total=self.cfg.total_training_tokens, desc="Training SAE") 199 | 200 | try: 201 | # Train loop 202 | while self.n_training_tokens < self.cfg.total_training_tokens: 203 | # Do a training step. 204 | sae_acts = self.activation_store.get_batch_activations() 205 | self.n_training_tokens += sae_acts.size(0) 206 | 207 | sae_out, feature_acts, loss_dict = self._train_step(sae_in=sae_acts) 208 | 209 | if ( 210 | self.cfg.log_to_wandb 211 | and (self.n_training_steps + 1) % self.cfg.wandb_log_frequency == 0 212 | ): 213 | self._log_train_step( 214 | feature_acts=feature_acts, 215 | loss_dict=loss_dict, 216 | sae_out=sae_out, 217 | sae_in=sae_acts, 218 | ) 219 | 220 | self._checkpoint_if_needed() 221 | self.n_training_steps += 1 222 | self._update_pbar(loss_dict, pbar, sae_out.size(0)) 223 | finally: 224 | print("Saving final checkpoint") 225 | self.save_checkpoint(is_final=True) 226 | self.run_evals() 227 | 228 | pbar.close() 229 | return self.sae 230 | 231 | @torch.no_grad() 232 | def run_evals(self): 233 | self.sae.eval() 234 | 235 | def _create_hook(hook_fn): 236 | return Hook( 237 | self.sae.cfg.block_layer, 238 | self.sae.cfg.module_name, 239 | hook_fn, 240 | return_module_output=False, 241 | ) 242 | 243 | def _zero_ablation_hook(activations): 244 | activations[:, 0, :] = torch.zeros_like(activations[:, 0, :]).to( 245 | activations.device 246 | ) 247 | return (activations,) 248 | 249 | def _sae_reconstruction_hook(activations): 250 | activations[:, 0, :] = self.sae(activations[:, 0, :])[0] 251 | return (activations,) 252 | 253 | # Get model inputs and compute baseline loss 254 | # model_inputs = self.activation_store.get_batch_of_images_and_labels() 255 | model_inputs = self.activation_store.get_batch_model_inputs(process_labels=True) 256 | original_loss = self.model(return_type="loss", **model_inputs).item() 257 | 258 | # Compute loss with SAE reconstruction 259 | sae_hooks = [_create_hook(_sae_reconstruction_hook)] 260 | reconstruction_loss = self.model.run_with_hooks( 261 | sae_hooks, return_type="loss", **model_inputs 262 | ).item() 263 | 264 | # Compute loss with zeroed activations 265 | zero_hooks = [_create_hook(_zero_ablation_hook)] 266 | zero_ablation_loss = self.model.run_with_hooks( 267 | zero_hooks, return_type="loss", **model_inputs 268 | ).item() 269 | 270 | # Calculate reconstruction score 271 | reconstruction_score = (reconstruction_loss - original_loss) / ( 272 | zero_ablation_loss - original_loss 273 | ) 274 | 275 | # Log metrics if configured 276 | if self.cfg.log_to_wandb: 277 | wandb.log( 278 | { 279 | "metrics/contrastive_loss_score": reconstruction_score, 280 | "metrics/original_contrastive_loss": original_loss, 281 | "metrics/contrastive_loss_with_sae": reconstruction_loss, 282 | "metrics/contrastive_loss_with_ablation": zero_ablation_loss, 283 | }, 284 | step=self.n_training_steps, 285 | ) 286 | 287 | del model_inputs 288 | torch.cuda.empty_cache() 289 | 290 | self.sae.train() 291 | -------------------------------------------------------------------------------- /src/sae_training/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Dict, Optional 3 | 4 | import torch 5 | import torch.optim as optim 6 | import torch.optim.lr_scheduler as lr_scheduler 7 | 8 | from src.sae_training.hooked_vit import HookedVisionTransformer 9 | 10 | SAE_DIM = 49152 11 | 12 | 13 | def process_model_inputs( 14 | batch: Dict, vit: HookedVisionTransformer, device: str, process_labels: bool = False 15 | ) -> torch.Tensor: 16 | """Process input images through the ViT processor.""" 17 | if process_labels: 18 | labels = [f"A photo of a {label}" for label in batch["label"]] 19 | return vit.processor( 20 | images=batch["image"], text=labels, return_tensors="pt", padding=True 21 | ).to(device) 22 | 23 | return vit.processor( 24 | images=batch["image"], text="", return_tensors="pt", padding=True 25 | ).to(device) 26 | 27 | 28 | def get_model_activations( 29 | model: HookedVisionTransformer, inputs: dict, block_layer, module_name, class_token 30 | ) -> torch.Tensor: 31 | """Extract activations from a specific layer of the vision transformer model.""" 32 | hook_location = (block_layer, module_name) 33 | 34 | # Run model forward pass and extract activations from cache 35 | _, cache = model.run_with_cache([hook_location], **inputs) 36 | activations = cache[hook_location] 37 | 38 | batch_size = inputs["pixel_values"].shape[0] 39 | if activations.shape[0] != batch_size: 40 | activations = activations.transpose(0, 1) 41 | 42 | # Extract class token if specified 43 | if class_token: 44 | activations = activations[0, :, :] 45 | 46 | return activations 47 | 48 | 49 | def get_scheduler(scheduler_name: Optional[str], optimizer: optim.Optimizer, **kwargs): 50 | def get_warmup_lambda(warm_up_steps, training_steps): 51 | def lr_lambda(steps): 52 | if steps < warm_up_steps: 53 | return (steps + 1) / warm_up_steps 54 | else: 55 | return (training_steps - steps) / (training_steps - warm_up_steps) 56 | 57 | return lr_lambda 58 | 59 | # heavily derived from hugging face although copilot helped. 60 | def get_warmup_cosine_lambda(warm_up_steps, training_steps, lr_end): 61 | def lr_lambda(steps): 62 | if steps < warm_up_steps: 63 | return (steps + 1) / warm_up_steps 64 | else: 65 | progress = (steps - warm_up_steps) / (training_steps - warm_up_steps) 66 | return lr_end + 0.5 * (1 - lr_end) * (1 + math.cos(math.pi * progress)) 67 | 68 | return lr_lambda 69 | 70 | if scheduler_name is None or scheduler_name.lower() == "constant": 71 | return lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda steps: 1.0) 72 | elif scheduler_name.lower() == "constantwithwarmup": 73 | warm_up_steps = kwargs.get("warm_up_steps", 500) 74 | return lr_scheduler.LambdaLR( 75 | optimizer, 76 | lr_lambda=lambda steps: min(1.0, (steps + 1) / warm_up_steps), 77 | ) 78 | elif scheduler_name.lower() == "linearwarmupdecay": 79 | warm_up_steps = kwargs.get("warm_up_steps", 0) 80 | training_steps = kwargs.get("training_steps") 81 | lr_lambda = get_warmup_lambda(warm_up_steps, training_steps) 82 | return lr_scheduler.LambdaLR(optimizer, lr_lambda) 83 | elif scheduler_name.lower() == "cosineannealing": 84 | training_steps = kwargs.get("training_steps") 85 | eta_min = kwargs.get("lr_end", 0) 86 | return lr_scheduler.CosineAnnealingLR( 87 | optimizer, T_max=training_steps, eta_min=eta_min 88 | ) 89 | elif scheduler_name.lower() == "cosineannealingwarmup": 90 | warm_up_steps = kwargs.get("warm_up_steps", 0) 91 | training_steps = kwargs.get("training_steps") 92 | eta_min = kwargs.get("lr_end", 0) 93 | lr_lambda = get_warmup_cosine_lambda(warm_up_steps, training_steps, eta_min) 94 | return lr_scheduler.LambdaLR(optimizer, lr_lambda) 95 | elif scheduler_name.lower() == "cosineannealingwarmrestarts": 96 | training_steps = kwargs.get("training_steps") 97 | eta_min = kwargs.get("lr_end", 0) 98 | num_cycles = kwargs.get("num_cycles", 1) 99 | T_0 = training_steps // num_cycles 100 | return lr_scheduler.CosineAnnealingWarmRestarts( 101 | optimizer, T_0=T_0, eta_min=eta_min 102 | ) 103 | else: 104 | raise ValueError(f"Unsupported scheduler: {scheduler_name}") 105 | -------------------------------------------------------------------------------- /src/sae_training/vit_activations_store.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Portions of this file are based on code from the “jbloomAus/SAELens” and "HugoFry/mats_sae_training_for_ViTs" repositories (MIT-licensed): 3 | https://github.com/jbloomAus/SAELens/blob/main/sae_lens/training/activations_store.py 4 | https://github.com/HugoFry/mats_sae_training_for_ViTs/blob/main/sae_training/vit_activations_store.py 5 | """ 6 | 7 | from torch.utils.data import DataLoader, TensorDataset 8 | 9 | from src.sae_training.hooked_vit import HookedVisionTransformer 10 | from src.sae_training.utils import get_model_activations, process_model_inputs 11 | 12 | 13 | class ViTActivationsStore: 14 | """ 15 | Class for streaming tokens and generating and storing activations 16 | while training SAEs. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | dataset, 22 | batch_size: int, 23 | device: str, 24 | seed: int, 25 | model: HookedVisionTransformer, 26 | block_layer: int, 27 | module_name: str, 28 | class_token: bool, 29 | ): 30 | self.device = device 31 | self.model = model 32 | self.batch_size = batch_size 33 | self.block_layer = block_layer 34 | self.module_name = module_name 35 | self.class_token = class_token 36 | 37 | self.dataset = dataset.shuffle(seed=seed) 38 | self.dataset_iter = iter(self.dataset) 39 | 40 | def get_batch_model_inputs(self, process_labels=False): 41 | """Get model activations for a batch of data""" 42 | batch_dict = {"image": [], "label": []} 43 | 44 | def _add_data(current_item, batch_dict): 45 | for key, value in current_item.items(): 46 | batch_dict[key].append(value) 47 | 48 | for _ in range(self.batch_size): 49 | try: 50 | current_item = next(self.dataset_iter) 51 | except StopIteration: 52 | self.dataset_iter = iter(self.dataset) 53 | current_item = next(self.dataset_iter) 54 | _add_data(current_item, batch_dict) 55 | 56 | inputs = process_model_inputs( 57 | batch_dict, self.model, self.device, process_labels=process_labels 58 | ) 59 | return inputs 60 | 61 | def get_batch_activations(self): 62 | # """Get model activations for a batch of data""" 63 | inputs = self.get_batch_model_inputs() 64 | return get_model_activations( 65 | self.model, inputs, self.block_layer, self.module_name, self.class_token 66 | ) 67 | 68 | def _create_new_dataloader(self) -> DataLoader: 69 | """Create a new dataloader with fresh activations""" 70 | activations = self._get_batch_activations() 71 | dataset = TensorDataset(activations) 72 | return iter(DataLoader(dataset, batch_size=self.batch_size, shuffle=True)) 73 | 74 | def get_next_batch(self): 75 | """Get next batch, creating new dataloader if current one is exhausted""" 76 | try: 77 | return self._get_batch_activations() 78 | except StopIteration: 79 | self.dataloader = self._create_new_dataloader() 80 | return next(self.dataloader) 81 | -------------------------------------------------------------------------------- /tasks/README.md: -------------------------------------------------------------------------------- 1 | # PatchSAE: Training & Usage Guide 2 | 3 | ## 📋 Table of Contents 4 | - [Train SAE](#-train-sae) 5 | - [Extract SAE Latent Data](#-extract-sae-latent-data) 6 | - [Compute Class-Level SAE Latents](#-compute-class-level-sae-latents) 7 | - [Steer Classification](#-steer-classification) 8 | 9 | ## 🔧 Train SAE 10 | 11 | Train a sparse autoencoder on CLIP features. 12 | 13 | ### Prerequisites 14 | - CLIP checkpoint 15 | - Training dataset (images only) 16 | 17 | ### Training Command 18 | ```bash 19 | PYTHONPATH=./ python tasks/train_sae_vit.py 20 | ``` 21 | > 📝 Configuration files will be added soon 22 | 23 | ### Outputs 24 | - SAE checkpoint (`.pt` file) 25 | 26 | ### Monitoring 27 | - View our [training logs on W&B](https://api.wandb.ai/links/hyesulim-hs/7dx90sq0) 28 | 29 | --- 30 | 31 | ## 📊 Extract SAE Latent Data 32 | 33 | Extract and save SAE latent activations for downstream analysis. 34 | 35 | ### Prerequisites 36 | - CLIP checkpoint 37 | - SAE checkpoint (from training step) 38 | - Dataset (can differ from training dataset) 39 | 40 | ### Run with Original CLIP 41 | ```bash 42 | PYTHONPATH=./ python tasks/compute_sae_feature_data.py \ 43 | --root_dir ./ \ 44 | --dataset_name imagenet \ 45 | --sae_path /PATH/TO/SAE_CKPT.pt \ 46 | --vit_type base 47 | ``` 48 | 49 | ### Run with Adapted CLIP (e.g., MaPLe) 50 | 1. Download MaPLe from the [official repo](https://github.com/muzairkhattak/multimodal-prompt-learning/tree/main?tab=readme-ov-file#model-zoo) or [Google Drive](https://drive.google.com/drive/folders/1EvuvgR8566bL0T7ucvAL3LFVwuUPMRas) 51 | 52 | 2. Run extraction: 53 | ```bash 54 | PYTHONPATH=./ python tasks/compute_sae_feature_data.py \ 55 | --root_dir ./ \ 56 | --dataset_name imagenet \ 57 | --sae_path /PATH/TO/SAE_CKPT.pt \ 58 | --vit_type maple \ 59 | --model_path /PATH/TO/MAPLE_CKPT \ # e.g., .../model.pth.tar-5 60 | --config_path /PATH/TO/MAPLE_CFG \ # e.g., .../configs/models/maple/vit_b16_c2_ep5_batch4_2ctx.yaml 61 | ``` 62 | 63 | ### Output Files 64 | All files will be saved to: `{root_dir}/out/feature_data/{vit_type}/{dataset_name}/` 65 | 66 | - `max_activating_image_indices.pt` 67 | - `max_activating_image_label_indices.pt` 68 | - `max_activating_image_values.pt` 69 | - `sae_mean_acts.pt` 70 | - `sae_sparsity.pt` 71 | 72 | ### Analysis 73 | Explore the extracted features with our [patchsae/analysis/analysis.ipynb](https://github.com/hyesulim/patchsae/blob/9f28fdc6ffb7beccb5c2b8ee629b6752b904aa23/analysis/analysis.ipynb) 74 | 75 | --- 76 | 77 | ## 🧩 Compute Class-Level SAE Latents 78 | 79 | Compute class-level SAE activation patterns. 80 | 81 | ### Prerequisites 82 | - CLIP checkpoint 83 | - SAE checkpoint 84 | - SAE feature data (from previous step) 85 | - Dataset (must be the SAME dataset used in the extraction step) 86 | 87 | ### Run with Original CLIP 88 | ```bash 89 | PYTHONPATH=./ python tasks/compute_class_wise_sae_activation.py \ 90 | --root_dir ./ \ 91 | --dataset_name imagenet \ 92 | --threshold 0.2 \ 93 | --sae_path /PATH/TO/SAE_CKPT.pt \ 94 | --vit_type base 95 | ``` 96 | 97 | ### Run with Adapted CLIP (e.g., MaPLe) 98 | ```bash 99 | PYTHONPATH=./ python tasks/compute_class_wise_sae_activation.py \ 100 | --root_dir ./ \ 101 | --dataset_name imagenet \ 102 | --threshold 0.2 \ 103 | --sae_path /PATH/TO/SAE_CKPT.pt \ 104 | --vit_type maple \ 105 | --model_path /PATH/TO/MAPLE_CKPT \ # e.g., .../model.pth.tar-5 106 | --config_path /PATH/TO/MAPLE_CFG \ # e.g., .../configs/models/maple/vit_b16_c2_ep5_batch4_2ctx.yaml 107 | ``` 108 | 109 | ### Output File 110 | - `cls_sae_cnt.npy` - Matrix of shape `(num_sae_latents, num_classes)` 111 | 112 | --- 113 | 114 | ## 🎯 Steer Classification 115 | 116 | Evaluate classification using feature steering with SAE latents. 117 | 118 | ### Prerequisites 119 | - CLIP checkpoint 120 | - SAE checkpoint 121 | - Class-level activation data (`cls_sae_cnt.npy` from previous step) 122 | - Dataset (must be the SAME dataset used for class-level activations, though can be a different split) 123 | 124 | ### Run with Original CLIP 125 | ```bash 126 | PYTHONPATH=./ python tasks/classification_with_top_k_masking.py \ 127 | --root_dir ./ \ 128 | --dataset_name imagenet \ 129 | --sae_path /PATH/TO/SAE_CKPT.pt \ 130 | --cls_wise_sae_activation_path /PATH/TO/cls_sae_cnt.npy 131 | ``` 132 | 133 | ### Run with Adapted CLIP (e.g., MaPLe) 134 | ```bash 135 | PYTHONPATH=./ python tasks/classification_with_top_k_masking.py \ 136 | --root_dir ./ \ 137 | --dataset_name imagenet \ 138 | --sae_path /PATH/TO/SAE_CKPT.pt \ 139 | --cls_wise_sae_activation_path /PATH/TO/cls_sae_cnt.npy \ 140 | --vit_type maple \ 141 | --model_path /PATH/TO/MAPLE_CKPT \ # e.g., .../model.pth.tar-5 142 | --config_path /PATH/TO/MAPLE_CFG \ # e.g., .../configs/models/maple/vit_b16_c2_ep5_batch4_2ctx.yaml 143 | ``` 144 | 145 | ### Output File 146 | Output will be saved to `eval_outputs/`: 147 | - `metrics.csv` - Contains class-wise True Positive Rate (TPR = TP/(TP+FP+TN+FN)) for each masking configuration 148 | - Results for both "on" and "off" conditions 149 | - For k values in [1, 2, 5, 10, 50, 100, 500, 1000, 2000, SAE_DIM] 150 | -------------------------------------------------------------------------------- /tasks/classification_with_top_k_masking.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from torch.nn.utils.rnn import pad_sequence 8 | from tqdm import tqdm 9 | 10 | from src.models.templates.openai_imagenet_templates import openai_imagenet_template 11 | from src.sae_training.config import Config 12 | from src.sae_training.hooked_vit import Hook, HookedVisionTransformer 13 | from src.sae_training.sparse_autoencoder import SparseAutoencoder 14 | from tasks.utils import ( 15 | SAE_DIM, 16 | get_sae_and_vit, 17 | load_and_organize_dataset, 18 | process_batch, 19 | setup_save_directory, 20 | ) 21 | 22 | TOPK_LIST = [1, 2, 5, 10, 50, 100, 500, 1000, 2000, SAE_DIM] 23 | SAE_BIAS = -0.105131256516992 24 | 25 | 26 | def calculate_text_features(model, device, classnames): 27 | """Calculate mean text features across templates for each class.""" 28 | mean_text_features = 0 29 | 30 | for template_fn in openai_imagenet_template: 31 | # Generate prompts and convert to token IDs 32 | prompts = [template_fn(c) for c in classnames] 33 | prompt_ids = [ 34 | model.processor( 35 | text=p, return_tensors="pt", padding=False, truncation=True 36 | ).input_ids[0] 37 | for p in prompts 38 | ] 39 | 40 | # Process batch 41 | padded_prompts = pad_sequence(prompt_ids, batch_first=True).to(device) 42 | 43 | # Get text features 44 | with torch.no_grad(): 45 | text_features = model.model.get_text_features(padded_prompts) 46 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 47 | mean_text_features += text_features 48 | 49 | return mean_text_features / len(openai_imagenet_template) 50 | 51 | 52 | def create_sae_hooks(vit_type, cfg, cls_features, sae, device, hook_type="on"): 53 | """Create SAE hooks based on model type and hook type.""" 54 | # Setup clamping parameters 55 | clamp_feat_dim = torch.ones(SAE_DIM).bool() 56 | clamp_value = torch.zeros(SAE_DIM) if hook_type == "on" else torch.ones(SAE_DIM) 57 | clamp_value = clamp_value.to(device) 58 | clamp_value[cls_features] = 1.0 if hook_type == "on" else 0.0 59 | 60 | def process_activations(activations, is_maple=False): 61 | """Helper function to process activations with SAE""" 62 | act = activations.transpose(0, 1) if is_maple else activations 63 | processed = ( 64 | sae.forward_clamp( 65 | act[:, :, :], clamp_feat_dim=clamp_feat_dim, clamp_value=clamp_value 66 | )[0] 67 | - SAE_BIAS 68 | ) 69 | return processed 70 | 71 | def hook_fn_default(activations): 72 | activations[:, :, :] = process_activations(activations) 73 | return (activations,) 74 | 75 | def hook_fn_maple(activations): 76 | activations = process_activations(activations, is_maple=True) 77 | return activations.transpose(0, 1) 78 | 79 | # Create appropriate hook based on model type 80 | hook_fn = hook_fn_maple if vit_type == "maple" else hook_fn_default 81 | is_custom = vit_type == "maple" 82 | 83 | return [ 84 | Hook( 85 | cfg.block_layer, 86 | cfg.module_name, 87 | hook_fn, 88 | return_module_output=False, 89 | is_custom=is_custom, 90 | ) 91 | ] 92 | 93 | 94 | def get_predictions(vit, inputs, text_features, vit_type, hooks=None): 95 | """Get model predictions with optional hooks.""" 96 | with torch.no_grad(): 97 | if hooks: 98 | vit_out = vit.run_with_hooks(hooks, return_type="output", **inputs) 99 | else: 100 | vit_out = vit(return_type="output", **inputs) 101 | 102 | image_features = vit_out.image_embeds if vit_type == "base" else vit_out 103 | logit_scale = vit.model.logit_scale.exp() 104 | logits = logit_scale * image_features @ text_features.t() 105 | preds = logits.argmax(dim=-1) 106 | 107 | return preds.cpu().numpy().tolist() 108 | 109 | 110 | def classify_with_top_k_masking( 111 | class_data: list, 112 | cls_idx: int, 113 | sae: SparseAutoencoder, 114 | vit: HookedVisionTransformer, 115 | cls_sae_cnt: torch.Tensor, 116 | text_features: torch.Tensor, 117 | batch_size: int, 118 | device: str, 119 | vit_type: str, 120 | cfg: Config, 121 | ): 122 | """Classify images with top-k feature masking.""" 123 | num_batches = (len(class_data) + batch_size - 1) // batch_size 124 | 125 | preds_dict = defaultdict(list) 126 | 127 | for batch_idx in range(num_batches): 128 | batch_start = batch_idx * batch_size 129 | batch_end = min((batch_idx + 1) * batch_size, len(class_data)) 130 | batch_data = class_data[batch_start:batch_end] 131 | 132 | batch_inputs = process_batch(vit, batch_data, device) 133 | 134 | # Get predictions without SAE 135 | preds_dict["no_sae"].extend( 136 | get_predictions(vit, batch_inputs, text_features, vit_type) 137 | ) 138 | torch.cuda.empty_cache() 139 | 140 | # Get top features for current class 141 | loaded_cls_sae_idx = cls_sae_cnt[cls_idx].argsort()[::-1] 142 | 143 | for topk in TOPK_LIST: 144 | cls_features = loaded_cls_sae_idx[:topk].tolist() 145 | 146 | # Get predictions with features ON 147 | hooks_on = create_sae_hooks(vit_type, cfg, cls_features, sae, device, "on") 148 | preds_dict[f"on_{topk}"].extend( 149 | get_predictions(vit, batch_inputs, text_features, vit_type, hooks_on) 150 | ) 151 | torch.cuda.empty_cache() 152 | 153 | # Get predictions with features OFF 154 | hooks_off = create_sae_hooks( 155 | vit_type, cfg, cls_features, sae, device, "off" 156 | ) 157 | preds_dict[f"off_{topk}"].extend( 158 | get_predictions(vit, batch_inputs, text_features, vit_type, hooks_off) 159 | ) 160 | torch.cuda.empty_cache() 161 | 162 | return preds_dict 163 | 164 | 165 | def main( 166 | sae_path: str, 167 | vit_type: str, 168 | device: str, 169 | dataset_name: str, 170 | root_dir: str, 171 | save_name: str, 172 | backbone: str = "openai/clip-vit-base-patch16", 173 | batch_size: int = 8, 174 | model_path: str = None, 175 | config_path: str = None, 176 | cls_wise_sae_activation_path: str = None, 177 | ): 178 | class_feature_type = cls_wise_sae_activation_path.split("/")[-3] 179 | save_directory = setup_save_directory( 180 | root_dir, save_name, sae_path, f"{class_feature_type}_{vit_type}", dataset_name 181 | ) 182 | 183 | classnames, data_by_class = load_and_organize_dataset(dataset_name) 184 | 185 | sae, vit, cfg = get_sae_and_vit( 186 | sae_path, 187 | vit_type, 188 | device, 189 | backbone, 190 | model_path=model_path, 191 | config_path=config_path, 192 | classnames=classnames, 193 | ) 194 | 195 | cls_sae_cnt = np.load(cls_wise_sae_activation_path) 196 | 197 | if vit_type == "base": 198 | text_features = calculate_text_features(vit, device, classnames) 199 | else: 200 | text_features = vit.model.get_text_features() 201 | 202 | metrics_dict = {} 203 | for class_idx, classname in enumerate(tqdm(classnames)): 204 | preds_dict = classify_with_top_k_masking( 205 | data_by_class[classname], 206 | class_idx, 207 | sae, 208 | vit, 209 | cls_sae_cnt, 210 | text_features, 211 | batch_size, 212 | device, 213 | vit_type, 214 | cfg, 215 | ) 216 | 217 | metrics_dict[class_idx] = {} 218 | for k, v in preds_dict.items(): 219 | metrics_dict[class_idx][k] = v.count(class_idx) / len(v) * 100 220 | 221 | metrics_df = pd.DataFrame(metrics_dict) 222 | metrics_df.to_csv(f"{save_directory}/metrics.csv", index=False) 223 | print(f"metrics.csv saved at {save_directory}") 224 | 225 | 226 | if __name__ == "__main__": 227 | parser = argparse.ArgumentParser( 228 | description="Perform classification with top-k masking" 229 | ) 230 | parser.add_argument("--root_dir", type=str, default=".", help="Root directory") 231 | parser.add_argument("--dataset_name", type=str, default="imagenet") 232 | parser.add_argument( 233 | "--sae_path", type=str, required=True, help="SAE ckpt path (ends with xxx.pt)" 234 | ) 235 | parser.add_argument("--batch_size", type=int, default=128, help="batch size") 236 | parser.add_argument( 237 | "--cls_wise_sae_activation_path", 238 | type=str, 239 | help="path for cls_sae_cnt.npy", 240 | ) 241 | parser.add_argument( 242 | "--vit_type", type=str, default="base", help="choose between [base, maple]" 243 | ) 244 | parser.add_argument( 245 | "--model_path", 246 | type=str, 247 | help="CLIP model path in the case of not using the default", 248 | ) 249 | parser.add_argument( 250 | "--config_path", 251 | type=str, 252 | help="CLIP config path in the case of using maple", 253 | ) 254 | parser.add_argument("--device", type=str, default="cuda") 255 | 256 | args = parser.parse_args() 257 | 258 | main( 259 | sae_path=args.sae_path, 260 | vit_type=args.vit_type, 261 | device=args.device, 262 | dataset_name=args.dataset_name, 263 | root_dir=args.root_dir, 264 | save_name="out/feature_data", 265 | batch_size=args.batch_size, 266 | model_path=args.model_path, 267 | config_path=args.config_path, 268 | cls_wise_sae_activation_path=args.cls_wise_sae_activation_path, 269 | ) 270 | -------------------------------------------------------------------------------- /tasks/compute_class_wise_sae_activation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from typing import Dict 4 | 5 | import numpy as np 6 | import torch 7 | from tqdm import tqdm 8 | 9 | from src.sae_training.config import Config 10 | from src.sae_training.hooked_vit import HookedVisionTransformer 11 | from src.sae_training.sparse_autoencoder import SparseAutoencoder 12 | from src.sae_training.utils import get_model_activations 13 | from tasks.utils import ( 14 | SAE_DIM, 15 | get_sae_and_vit, 16 | load_and_organize_dataset, 17 | process_batch, 18 | setup_save_directory, 19 | ) 20 | 21 | 22 | def get_sae_activations( 23 | model_activations: torch.Tensor, sae: SparseAutoencoder, threshold: float 24 | ) -> torch.Tensor: 25 | """Get binary SAE activations above threshold. 26 | 27 | Args: 28 | model_activations: Input activations from vision transformer 29 | sae: The sparse autoencoder model 30 | threshold: Activation threshold 31 | 32 | Returns: 33 | Binary tensor indicating which features were active 34 | """ 35 | _, cache = sae.run_with_cache(model_activations) 36 | activations = cache["hook_hidden_post"] > threshold 37 | return activations.sum(dim=0).sum(dim=0) 38 | 39 | 40 | def process_class_batch( 41 | batch_data: list, 42 | sae: SparseAutoencoder, 43 | vit: HookedVisionTransformer, 44 | cfg: Config, 45 | threshold: float, 46 | device: str, 47 | ) -> np.ndarray: 48 | """Process a single batch of class data and get SAE feature activations.""" 49 | batch_inputs = process_batch(vit, batch_data, device) 50 | transformer_activations = get_model_activations( 51 | vit, batch_inputs, cfg.block_layer, cfg.module_name, cfg.class_token 52 | ) 53 | active_features = get_sae_activations(transformer_activations, sae, threshold) 54 | return active_features.cpu().numpy() 55 | 56 | 57 | def compute_class_feature_counts( 58 | class_data: list, 59 | sae: SparseAutoencoder, 60 | vit: HookedVisionTransformer, 61 | cfg: Config, 62 | batch_size: int, 63 | threshold: float, 64 | device: str, 65 | ) -> np.ndarray: 66 | """Compute SAE feature activation counts for a single class.""" 67 | feature_counts = np.zeros(SAE_DIM) 68 | num_batches = (len(class_data) + batch_size - 1) // batch_size 69 | 70 | for batch_idx in range(num_batches): 71 | batch_start = batch_idx * batch_size 72 | batch_end = min((batch_idx + 1) * batch_size, len(class_data)) 73 | batch_data = class_data[batch_start:batch_end] 74 | 75 | batch_counts = process_class_batch(batch_data, sae, vit, cfg, threshold, device) 76 | feature_counts += batch_counts 77 | torch.cuda.empty_cache() 78 | 79 | return feature_counts 80 | 81 | 82 | def compute_all_class_activations( 83 | classnames: list, 84 | data_by_class: Dict, 85 | sae: SparseAutoencoder, 86 | vit: HookedVisionTransformer, 87 | cfg: Config, 88 | batch_size: int, 89 | threshold: float, 90 | device: str, 91 | ) -> np.ndarray: 92 | """Compute SAE activation counts across all classes.""" 93 | class_activation_counts = np.zeros((len(classnames), SAE_DIM)) 94 | 95 | for class_idx, classname in enumerate(tqdm(classnames)): 96 | class_data = data_by_class[classname] 97 | class_counts = compute_class_feature_counts( 98 | class_data, sae, vit, cfg, batch_size, threshold, device 99 | ) 100 | class_activation_counts[class_idx] = class_counts 101 | 102 | return class_activation_counts 103 | 104 | 105 | def main( 106 | sae_path: str, 107 | vit_type: str, 108 | device: str, 109 | dataset_name: str, 110 | root_dir: str, 111 | save_name: str, 112 | backbone: str = "openai/clip-vit-base-patch16", 113 | batch_size: int = 8, 114 | model_path: str = None, 115 | config_path: str = None, 116 | threshold: float = 0.2, 117 | ): 118 | """Main function to compute and save class-wise SAE activation counts.""" 119 | 120 | save_directory = setup_save_directory( 121 | root_dir, save_name, sae_path, vit_type, dataset_name 122 | ) 123 | 124 | classnames, data_by_class = load_and_organize_dataset(dataset_name) 125 | 126 | sae, vit, cfg = get_sae_and_vit( 127 | sae_path, 128 | vit_type, 129 | device, 130 | backbone, 131 | model_path=model_path, 132 | config_path=config_path, 133 | classnames=classnames, 134 | ) 135 | 136 | class_activation_counts = compute_all_class_activations( 137 | classnames, data_by_class, sae, vit, cfg, batch_size, threshold, device 138 | ) 139 | 140 | # Save results 141 | save_path = os.path.join(save_directory, "cls_sae_cnt.npy") 142 | np.save(save_path, class_activation_counts) 143 | print(f"Class activation counts saved at {save_directory}") 144 | 145 | 146 | if __name__ == "__main__": 147 | parser = argparse.ArgumentParser(description="Save class-wise SAE activation count") 148 | parser.add_argument("--root_dir", type=str, default=".", help="Root directory") 149 | parser.add_argument("--dataset_name", type=str, default="imagenet") 150 | parser.add_argument( 151 | "--sae_path", type=str, required=True, help="SAE ckpt path (ends with xxx.pt)" 152 | ) 153 | parser.add_argument( 154 | "--vit_type", type=str, default="base", help="choose between [base, maple]" 155 | ) 156 | parser.add_argument( 157 | "--threshold", type=float, default=0.2, help="threshold for SAE activation" 158 | ) 159 | parser.add_argument("--batch_size", type=int, default=128) 160 | parser.add_argument( 161 | "--model_path", 162 | type=str, 163 | help="CLIP model path in the case of not using the default", 164 | ) 165 | parser.add_argument( 166 | "--config_path", 167 | type=str, 168 | help="CLIP config path in the case of using maple", 169 | ) 170 | parser.add_argument("--device", type=str, default="cuda") 171 | 172 | args = parser.parse_args() 173 | 174 | main( 175 | sae_path=args.sae_path, 176 | vit_type=args.vit_type, 177 | device=args.device, 178 | dataset_name=args.dataset_name, 179 | root_dir=args.root_dir, 180 | save_name="out/feature_data", 181 | batch_size=args.batch_size, 182 | model_path=args.model_path, 183 | config_path=args.config_path, 184 | threshold=args.threshold, 185 | ) 186 | -------------------------------------------------------------------------------- /tasks/compute_sae_feature_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Dict, Optional, Tuple 3 | 4 | import torch 5 | from datasets import Dataset, load_dataset 6 | from tqdm import tqdm 7 | 8 | from src.sae_training.config import Config 9 | from src.sae_training.hooked_vit import HookedVisionTransformer 10 | from src.sae_training.sparse_autoencoder import SparseAutoencoder 11 | from src.sae_training.utils import get_model_activations, process_model_inputs 12 | from tasks.utils import ( 13 | DATASET_INFO, 14 | get_classnames, 15 | get_sae_activations, 16 | get_sae_and_vit, 17 | setup_save_directory, 18 | ) 19 | 20 | 21 | def initialize_storage_tensors( 22 | d_sae: int, num_max: int, device: str 23 | ) -> Dict[str, torch.Tensor]: 24 | """Initialize tensors for storing results.""" 25 | return { 26 | "max_activating_image_values": torch.zeros([d_sae, num_max]).to(device), 27 | "max_activating_image_indices": torch.zeros([d_sae, num_max]).to(device), 28 | "sae_sparsity": torch.zeros([d_sae]).to(device), 29 | "sae_mean_acts": torch.zeros([d_sae]).to(device), 30 | } 31 | 32 | 33 | def get_new_top_k( 34 | first_values: torch.Tensor, 35 | first_indices: torch.Tensor, 36 | second_values: torch.Tensor, 37 | second_indices: torch.Tensor, 38 | k: int, 39 | ) -> Tuple[torch.Tensor, torch.Tensor]: 40 | """Get top k values and indices from two sets of values/indices.""" 41 | total_values = torch.cat([first_values, second_values], dim=1) 42 | total_indices = torch.cat([first_indices, second_indices], dim=1) 43 | new_values, indices_of_indices = torch.topk(total_values, k=k, dim=1) 44 | new_indices = torch.gather(total_indices, 1, indices_of_indices) 45 | return new_values, new_indices 46 | 47 | 48 | def compute_sae_statistics( 49 | sae_activations: torch.Tensor, 50 | ) -> Tuple[torch.Tensor, torch.Tensor]: 51 | """Compute mean activations and sparsity statistics for SAE features.""" 52 | mean_acts = sae_activations.sum(dim=1) 53 | sparsity = (sae_activations > 0).sum(dim=1) 54 | return mean_acts, sparsity 55 | 56 | 57 | def get_top_activations( 58 | sae_activations: torch.Tensor, num_top_images: int, images_processed: int 59 | ) -> Tuple[torch.Tensor, torch.Tensor]: 60 | """Get top activating images and their indices.""" 61 | top_k = min(num_top_images, sae_activations.size(1)) 62 | values, indices = torch.topk(sae_activations, k=top_k, dim=1) 63 | indices += images_processed 64 | return values, indices 65 | 66 | 67 | def process_batch( 68 | batch: Dict, 69 | vit: HookedVisionTransformer, 70 | sae: SparseAutoencoder, 71 | cfg: Config, 72 | device: str, 73 | num_top_images: int, 74 | images_processed: int, 75 | storage: Dict[str, torch.Tensor], 76 | ) -> Tuple[Dict[str, torch.Tensor], int]: 77 | """Process a single batch of images and update feature statistics.""" 78 | # Get model activations 79 | inputs = process_model_inputs(batch, vit, device) 80 | model_acts = get_model_activations( 81 | vit, inputs, cfg.block_layer, cfg.module_name, cfg.class_token 82 | ) 83 | sae_acts = get_sae_activations(model_acts, sae).transpose(0, 1) 84 | 85 | # Update statistics 86 | mean_acts, sparsity = compute_sae_statistics(sae_acts) 87 | storage["sae_mean_acts"] += mean_acts 88 | storage["sae_sparsity"] += sparsity 89 | 90 | # Get top activating images 91 | values, indices = get_top_activations(sae_acts, num_top_images, images_processed) 92 | 93 | top_values, top_indices = get_new_top_k( 94 | storage["max_activating_image_values"], 95 | storage["max_activating_image_indices"], 96 | values, 97 | indices, 98 | num_top_images, 99 | ) 100 | 101 | # Update processed image count 102 | images_processed += model_acts.size(0) 103 | 104 | return { 105 | "max_activating_image_values": top_values, 106 | "max_activating_image_indices": top_indices, 107 | "sae_sparsity": storage["sae_sparsity"], 108 | "sae_mean_acts": storage["sae_mean_acts"], 109 | }, images_processed 110 | 111 | 112 | def save_results( 113 | save_directory: str, 114 | storage: Dict[str, torch.Tensor], 115 | dataset: Dataset, 116 | label_field: Optional[str] = None, 117 | ) -> None: 118 | """Save results to disk.""" 119 | if label_field and label_field in dataset.features: 120 | max_activating_image_label_indices = torch.tensor( 121 | [ 122 | dataset[int(index)][label_field] 123 | for index in tqdm( 124 | storage["max_activating_image_indices"].flatten(), 125 | desc="getting image labels", 126 | ) 127 | ] 128 | ).view(storage["max_activating_image_indices"].shape) 129 | 130 | torch.save( 131 | max_activating_image_label_indices, 132 | f"{save_directory}/max_activating_image_label_indices.pt", 133 | ) 134 | 135 | torch.save( 136 | storage["max_activating_image_indices"], 137 | f"{save_directory}/max_activating_image_indices.pt", 138 | ) 139 | torch.save( 140 | storage["max_activating_image_values"], 141 | f"{save_directory}/max_activating_image_values.pt", 142 | ) 143 | torch.save(storage["sae_sparsity"], f"{save_directory}/sae_sparsity.pt") 144 | torch.save(storage["sae_mean_acts"], f"{save_directory}/sae_mean_acts.pt") 145 | 146 | print(f"Results saved to {save_directory}") 147 | 148 | 149 | @torch.inference_mode() 150 | def main( 151 | sae_path: str, 152 | vit_type: str, 153 | device: str, 154 | dataset_name: str, 155 | root_dir: str, 156 | save_name: str, 157 | backbone: str = "openai/clip-vit-base-patch16", 158 | number_of_max_activating_images: int = 10, 159 | seed: int = 1, 160 | batch_size: int = 8, 161 | model_path: str = None, 162 | config_path: str = None, 163 | ): 164 | """Main function to extract and save feature data.""" 165 | torch.set_float32_matmul_precision("high") 166 | torch.cuda.empty_cache() 167 | 168 | dataset = load_dataset(**DATASET_INFO[dataset_name]) 169 | dataset = dataset.shuffle(seed=seed) 170 | classnames = get_classnames(dataset_name, dataset) 171 | 172 | sae, vit, cfg = get_sae_and_vit( 173 | sae_path, vit_type, device, backbone, model_path, config_path, classnames 174 | ) 175 | 176 | storage = initialize_storage_tensors( 177 | sae.cfg.d_sae, number_of_max_activating_images, device 178 | ) 179 | 180 | # Process batches 181 | total_iterations = (len(dataset) + batch_size - 1) // batch_size 182 | num_processed = 0 183 | 184 | for iteration in tqdm(range(total_iterations)): 185 | batch_start = iteration * batch_size 186 | batch_end = (iteration + 1) * batch_size 187 | current_batch = dataset[batch_start:batch_end] 188 | 189 | storage, num_processed = process_batch( 190 | current_batch, 191 | vit, 192 | sae, 193 | cfg, 194 | device, 195 | number_of_max_activating_images, 196 | num_processed, 197 | storage, 198 | ) 199 | 200 | # Finalize statistics 201 | storage["sae_mean_acts"] /= storage["sae_sparsity"] 202 | storage["sae_sparsity"] /= num_processed 203 | 204 | # Save results 205 | save_directory = setup_save_directory( 206 | root_dir, save_name, sae_path, vit_type, dataset_name 207 | ) 208 | save_results( 209 | save_directory, 210 | storage, 211 | dataset, 212 | label_field="label" if "label" in dataset.features else None, 213 | ) 214 | 215 | 216 | if __name__ == "__main__": 217 | parser = argparse.ArgumentParser( 218 | description="Process ViT SAE images and save feature data" 219 | ) 220 | parser.add_argument("--root_dir", type=str, default=".", help="Root directory") 221 | parser.add_argument("--dataset_name", type=str, default="imagenet") 222 | parser.add_argument( 223 | "--sae_path", type=str, required=True, help="SAE ckpt path (ends with xxx.pt)" 224 | ) 225 | parser.add_argument( 226 | "--batch_size", 227 | type=int, 228 | default=32, 229 | help="Batch size to compute model activations and sae features", 230 | ) 231 | parser.add_argument( 232 | "--model_path", 233 | type=str, 234 | help="CLIP model path in the case of not using the default", 235 | ) 236 | parser.add_argument( 237 | "--config_path", 238 | type=str, 239 | help="CLIP config path in the case of using maple", 240 | ) 241 | parser.add_argument("--device", type=str, default="cuda") 242 | parser.add_argument( 243 | "--vit_type", type=str, default="base", help="choose between [base, maple]" 244 | ) 245 | args = parser.parse_args() 246 | 247 | main( 248 | sae_path=args.sae_path, 249 | vit_type=args.vit_type, 250 | device=args.device, 251 | dataset_name=args.dataset_name, 252 | root_dir=args.root_dir, 253 | save_name="out/feature_data", 254 | batch_size=args.batch_size, 255 | model_path=args.model_path, 256 | config_path=args.config_path, 257 | ) 258 | -------------------------------------------------------------------------------- /tasks/train_sae_vit.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import wandb 5 | from datasets import load_dataset 6 | 7 | from src.sae_training.config import ViTSAERunnerConfig 8 | from src.sae_training.sae_trainer import SAETrainer 9 | from src.sae_training.sparse_autoencoder import SparseAutoencoder 10 | from src.sae_training.utils import get_scheduler 11 | from src.sae_training.vit_activations_store import ViTActivationsStore 12 | from tasks.utils import ( 13 | DATASET_INFO, 14 | get_classnames, 15 | load_hooked_vit, 16 | ) 17 | 18 | if __name__ == "__main__": 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--class_token", action="store_true", default=None) 21 | parser.add_argument("--image_width", type=int, default=224) 22 | parser.add_argument("--image_height", type=int, default=224) 23 | parser.add_argument( 24 | "--model_name", type=str, default="openai/clip-vit-base-patch16" 25 | ) 26 | parser.add_argument("--module_name", type=str, default="resid") 27 | parser.add_argument("--block_layer", type=int, default=-2) 28 | parser.add_argument("--clip_dim", type=int, default=768) 29 | 30 | parser.add_argument("--dataset", type=str, default="imagenet") 31 | parser.add_argument("--use_cached_activations", action="store_true", default=None) 32 | parser.add_argument("--cached_activations_path", type=str) 33 | parser.add_argument("--expansion_factor", type=int, default=64) 34 | parser.add_argument("--b_dec_init_method", type=str, default="geometric_median") 35 | parser.add_argument("--gated_sae", action="store_true", default=None) 36 | # Training Parameters 37 | parser.add_argument("--lr", type=float, default=0.0004) 38 | parser.add_argument("--l1_coefficient", type=float, default=0.00008) 39 | parser.add_argument("--lr_scheduler_name", type=str, default="constantwithwarmup") 40 | parser.add_argument("--batch_size", type=int, default=16) 41 | parser.add_argument("--lr_warm_up_steps", type=int, default=500) 42 | parser.add_argument("--total_training_tokens", type=int, default=2_621_440) 43 | parser.add_argument("--n_batches_in_store", type=int, default=15) 44 | parser.add_argument("--mse_cls_coefficient", type=float, default=1.0) 45 | # Dead Neurons and Sparsity 46 | parser.add_argument("--use_ghost_grads", action="store_true", default=None) 47 | parser.add_argument("--feature_sampling_method") 48 | parser.add_argument("--feature_sampling_window", type=int, default=64) 49 | parser.add_argument("--dead_feature_window", type=int, default=64) 50 | parser.add_argument("--dead_feature_threshold", type=float, default=1e-6) 51 | # WANDB 52 | parser.add_argument("--log_to_wandb", action="store_true", default=None) 53 | parser.add_argument("--wandb_project", type=str, default="patch_sae") 54 | parser.add_argument("--wandb_entity", type=str, default="test") 55 | parser.add_argument("--wandb_log_frequency", type=int, default=20) 56 | # Misc 57 | parser.add_argument("--seed", type=int, default=42) 58 | parser.add_argument("--n_checkpoints", type=int, default=1) 59 | parser.add_argument("--checkpoint_path", type=str, default="out/checkpoints") 60 | parser.add_argument("--device", type=str, default="cuda") 61 | # resume 62 | parser.add_argument("--root_dir", type=str, default="") 63 | parser.add_argument("--resume", action="store_true", default=None) 64 | parser.add_argument("--run_name", type=str, default="train") 65 | parser.add_argument("--start_training_steps", type=int, default=0) 66 | parser.add_argument("--pt_name", type=str) 67 | 68 | parser.add_argument( 69 | "--vit_type", type=str, default="base", help="choose between [base, maple]" 70 | ) 71 | parser.add_argument( 72 | "--model_path", 73 | type=str, 74 | help="CLIP model path in the case of not using the default", 75 | ) 76 | parser.add_argument( 77 | "--config_path", 78 | type=str, 79 | help="CLIP config path in the case of using maple", 80 | ) 81 | args = parser.parse_args() 82 | 83 | cfg = ViTSAERunnerConfig( 84 | class_token=args.class_token, 85 | image_width=args.image_width, 86 | image_height=args.image_height, 87 | model_name=f"openai/{args.model_name}", 88 | module_name=args.module_name, 89 | block_layer=args.block_layer, 90 | dataset_path=DATASET_INFO[args.dataset]["path"], 91 | image_key="image", 92 | label_key="label", 93 | use_cached_activations=args.use_cached_activations, 94 | cached_activations_path=args.cached_activations_path, 95 | d_in=args.clip_dim, 96 | expansion_factor=args.expansion_factor, 97 | b_dec_init_method=args.b_dec_init_method, 98 | gated_sae=args.gated_sae, 99 | lr=args.lr, 100 | l1_coefficient=args.l1_coefficient, 101 | lr_scheduler_name=args.lr_scheduler_name, 102 | batch_size=args.batch_size, 103 | lr_warm_up_steps=args.lr_warm_up_steps, 104 | total_training_tokens=args.total_training_tokens, 105 | n_batches_in_store=args.n_batches_in_store, 106 | mse_cls_coefficient=args.mse_cls_coefficient, 107 | use_ghost_grads=args.use_ghost_grads, 108 | feature_sampling_method=args.feature_sampling_method, 109 | feature_sampling_window=args.feature_sampling_window, 110 | dead_feature_window=args.dead_feature_window, 111 | dead_feature_threshold=args.dead_feature_threshold, 112 | log_to_wandb=args.log_to_wandb, 113 | wandb_project=args.wandb_project, 114 | wandb_entity=args.wandb_entity, 115 | wandb_log_frequency=args.wandb_log_frequency, 116 | device=args.device, 117 | seed=args.seed, 118 | n_checkpoints=args.n_checkpoints, 119 | checkpoint_path=args.checkpoint_path, 120 | dtype=torch.float32, 121 | ) 122 | 123 | print("Loading dataset") 124 | classnames = get_classnames(args.dataset) 125 | dataset = load_dataset(**DATASET_INFO[args.dataset]) 126 | 127 | print("Loading SAE and ViT models") 128 | sae = SparseAutoencoder(cfg, args.device) 129 | 130 | vit = load_hooked_vit( 131 | cfg, 132 | args.vit_type, 133 | args.model_name, 134 | args.device, 135 | args.model_path, 136 | args.config_path, 137 | classnames, 138 | ) 139 | 140 | print("Initializing ViTActivationsStore") 141 | activation_store = ViTActivationsStore( 142 | dataset, 143 | args.batch_size, 144 | args.device, 145 | args.seed, 146 | vit, 147 | args.block_layer, 148 | cfg.module_name, 149 | args.class_token, 150 | ) 151 | 152 | optimizer = torch.optim.Adam(sae.parameters(), lr=sae.cfg.lr) 153 | scheduler = get_scheduler(args.lr_scheduler_name, optimizer=optimizer) 154 | 155 | print("Initializing SAE b_dec using activation_store") 156 | sae.initialize_b_dec(activation_store) 157 | sae.train() 158 | 159 | if cfg.log_to_wandb: 160 | wandb.init(project=cfg.wandb_project, config=cfg, name=cfg.run_name) 161 | 162 | sae_trainer = SAETrainer( 163 | sae, vit, activation_store, cfg, optimizer, scheduler, args.device 164 | ) 165 | sae_trainer.fit() 166 | -------------------------------------------------------------------------------- /tasks/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import defaultdict 4 | from typing import Dict, Tuple 5 | 6 | import torch 7 | from datasets import Dataset, load_dataset 8 | from tqdm import tqdm 9 | 10 | from src.models.utils import get_adapted_clip, get_base_clip 11 | from src.sae_training.config import Config 12 | from src.sae_training.hooked_vit import HookedVisionTransformer 13 | from src.sae_training.sparse_autoencoder import SparseAutoencoder 14 | 15 | # Dataset configurations 16 | DATASET_INFO = { 17 | "imagenet": { 18 | "path": "evanarlian/imagenet_1k_resized_256", 19 | "split": "train", 20 | }, 21 | "imagenet-sketch": { 22 | "path": "clip-benchmark/wds_imagenet_sketch", 23 | "split": "train", 24 | }, 25 | "oxford_flowers": { 26 | "path": "nelorth/oxford-flowers", 27 | "split": "train", 28 | }, 29 | "caltech101": { 30 | "path": "HuggingFaceM4/Caltech-101", 31 | "split": "train", 32 | "name": "with_background_category", 33 | }, 34 | } 35 | 36 | SAE_DIM = 49152 37 | 38 | 39 | def load_sae(sae_path: str, device: str) -> tuple[SparseAutoencoder, Config]: 40 | """Load a sparse autoencoder model from a checkpoint file.""" 41 | checkpoint = torch.load(sae_path, map_location="cpu") 42 | 43 | if "cfg" in checkpoint: 44 | cfg = Config(checkpoint["cfg"]) 45 | else: 46 | cfg = Config(checkpoint["config"]) 47 | sae = SparseAutoencoder(cfg, device) 48 | sae.load_state_dict(checkpoint["state_dict"]) 49 | sae.eval().to(device) 50 | 51 | return sae, cfg 52 | 53 | 54 | def load_hooked_vit( 55 | cfg: Config, 56 | vit_type: str, 57 | backbone: str, 58 | device: str, 59 | model_path: str = None, 60 | config_path: str = None, 61 | classnames: list[str] = None, 62 | ) -> HookedVisionTransformer: 63 | """Load a vision transformer model with hooks.""" 64 | if vit_type == "base": 65 | model, processor = get_base_clip(backbone) 66 | else: 67 | model, processor = get_adapted_clip( 68 | cfg, vit_type, model_path, config_path, backbone, classnames 69 | ) 70 | 71 | return HookedVisionTransformer(model, processor, device=device) 72 | 73 | 74 | def get_sae_and_vit( 75 | sae_path: str, 76 | vit_type: str, 77 | device: str, 78 | backbone: str, 79 | model_path: str = None, 80 | config_path: str = None, 81 | classnames: list[str] = None, 82 | ) -> tuple[SparseAutoencoder, HookedVisionTransformer, Config]: 83 | """Load both SAE and ViT models.""" 84 | sae, cfg = load_sae(sae_path, device) 85 | vit = load_hooked_vit( 86 | cfg, vit_type, backbone, device, model_path, config_path, classnames 87 | ) 88 | return sae, vit, cfg 89 | 90 | 91 | def load_and_organize_dataset(dataset_name: str) -> Tuple[list, Dict]: 92 | # TODO: ERR for imagenet (gets killed after 75%) 93 | """ 94 | Load dataset and organize data by class. 95 | Return classnames and data by class. 96 | Requried for classification_with_top_k_masking.py and compute_class_wise_sae_activation.py 97 | """ 98 | dataset = load_dataset(**DATASET_INFO[dataset_name]) 99 | classnames = get_classnames(dataset_name, dataset) 100 | 101 | data_by_class = defaultdict(list) 102 | for data_item in tqdm(dataset): 103 | classname = classnames[data_item["label"]] 104 | data_by_class[classname].append(data_item) 105 | 106 | return classnames, data_by_class 107 | 108 | 109 | def get_classnames( 110 | dataset_name: str, dataset: Dataset = None, data_root: str = "./configs/classnames" 111 | ) -> list[str]: 112 | """Get class names for a dataset.""" 113 | 114 | filename = f"{data_root}/{dataset_name}_classnames" 115 | txt_filename = filename + ".txt" 116 | json_filename = filename + ".json" 117 | 118 | if not os.path.exists(txt_filename) and not os.path.exists(json_filename): 119 | raise ValueError(f"Dataset {dataset_name} not supported") 120 | 121 | filename = json_filename if os.path.exists(json_filename) else txt_filename 122 | 123 | with open(filename, "r") as file: 124 | if dataset_name == "caltech101": 125 | class_names = [line.strip() for line in file.readlines()] 126 | elif dataset_name == "imagenet" or dataset_name == "imagenet-sketch": 127 | class_names = [ 128 | " ".join(line.strip().split(" ")[1:]) for line in file.readlines() 129 | ] 130 | elif dataset_name == "oxford_flowers": 131 | assert dataset is not None, "Dataset must be provided for Oxford Flowers" 132 | new_class_dict = {} 133 | class_names = json.load(file) 134 | classnames_from_hf = dataset.features["label"].names 135 | for i, class_name in enumerate(classnames_from_hf): 136 | new_class_dict[i] = class_names[class_name] 137 | class_names = list(new_class_dict.values()) 138 | 139 | else: 140 | raise ValueError(f"Dataset {dataset_name} not supported") 141 | 142 | return class_names 143 | 144 | 145 | def setup_save_directory( 146 | root_dir: str, save_name: str, sae_path: str, vit_type: str, dataset_name: str 147 | ) -> str: 148 | """Set and create the save directory path.""" 149 | sae_run_name = sae_path.split("/")[-2] 150 | save_directory = ( 151 | f"{root_dir}/{save_name}/sae_{sae_run_name}/{vit_type}/{dataset_name}" 152 | ) 153 | os.makedirs(save_directory, exist_ok=True) 154 | return save_directory 155 | 156 | 157 | def get_sae_activations( 158 | model_activations: torch.Tensor, sae: SparseAutoencoder 159 | ) -> torch.Tensor: 160 | """Extract and process activations from the sparse autoencoder.""" 161 | hook_name = "hook_hidden_post" 162 | 163 | # Run SAE forward pass and get activations from cache 164 | _, cache = sae.run_with_cache(model_activations) 165 | sae_activations = cache[hook_name] 166 | 167 | # Average across sequence length dimension if needed 168 | if len(sae_activations.size()) > 2: 169 | sae_activations = sae_activations.mean(dim=1) 170 | 171 | return sae_activations 172 | 173 | 174 | def process_batch(vit, batch_data, device): 175 | """Process a single batch of images.""" 176 | images = [data["image"] for data in batch_data] 177 | 178 | inputs = vit.processor( 179 | images=images, text="", return_tensors="pt", padding=True 180 | ).to(device) 181 | return inputs 182 | 183 | 184 | def get_max_acts_and_images( 185 | datasets: dict, feat_data_root: str, sae_runname: str, vit_name: str 186 | ) -> tuple[dict, dict]: 187 | """Load and return maximum activations and mean activations for each dataset.""" 188 | max_act_imgs = {} 189 | mean_acts = {} 190 | 191 | for dataset_name in datasets: 192 | # Load max activating image indices 193 | max_act_path = os.path.join( 194 | feat_data_root, 195 | f"{sae_runname}/{vit_name}/{dataset_name}", 196 | "max_activating_image_indices.pt", 197 | ) 198 | max_act_imgs[dataset_name] = torch.load(max_act_path, map_location="cpu").to( 199 | torch.int32 200 | ) 201 | 202 | # Load mean activations 203 | mean_acts_path = os.path.join( 204 | feat_data_root, 205 | f"{sae_runname}/{vit_name}/{dataset_name}", 206 | "sae_mean_acts.pt", 207 | ) 208 | mean_acts[dataset_name] = torch.load(mean_acts_path, map_location="cpu").numpy() 209 | 210 | return max_act_imgs, mean_acts 211 | 212 | 213 | def load_datasets(include_imagenet: bool = False, seed: int = 1): 214 | """Load multiple datasets from HuggingFace.""" 215 | if include_imagenet: 216 | return { 217 | "imagenet": load_dataset( 218 | "evanarlian/imagenet_1k_resized_256", split="train" 219 | ).shuffle(seed=seed), 220 | "imagenet-sketch": load_dataset( 221 | "clip-benchmark/wds_imagenet_sketch", split="test" 222 | ).shuffle(seed=seed), 223 | "caltech101": load_dataset( 224 | "HuggingFaceM4/Caltech-101", 225 | "with_background_category", 226 | split="train", 227 | ).shuffle(seed=seed), 228 | } 229 | else: 230 | return { 231 | "imagenet-sketch": load_dataset( 232 | "clip-benchmark/wds_imagenet_sketch", split="test" 233 | ).shuffle(seed=seed), 234 | "caltech101": load_dataset( 235 | "HuggingFaceM4/Caltech-101", 236 | "with_background_category", 237 | split="train", 238 | ).shuffle(seed=seed), 239 | } 240 | 241 | 242 | def get_all_classnames(datasets, data_root): 243 | """Get class names for all datasets.""" 244 | class_names = {} 245 | for dataset_name, dataset in datasets.items(): 246 | class_names[dataset_name] = get_classnames(dataset_name, dataset, data_root) 247 | 248 | # imagenet classnames are required to classnames for maple 249 | if "imagenet" not in class_names: 250 | filename = f"{data_root}/imagenet_classnames" 251 | txt_filename = filename + ".txt" 252 | json_filename = filename + ".json" 253 | 254 | if not os.path.exists(txt_filename) and not os.path.exists(json_filename): 255 | raise ValueError(f"Dataset {dataset_name} not supported") 256 | 257 | filename = json_filename if os.path.exists(json_filename) else txt_filename 258 | 259 | with open(filename, "r") as file: 260 | class_names["imagenet"] = [ 261 | " ".join(line.strip().split(" ")[1:]) for line in file.readlines() 262 | ] 263 | 264 | return class_names 265 | --------------------------------------------------------------------------------