├── src ├── __init__.py ├── models │ ├── __init__.py │ ├── load.py │ ├── deit_ensemble.py │ └── vit_edited.py ├── plots │ ├── __init__.py │ ├── explainability.py │ ├── linear_probing.py │ ├── class_ident.py │ └── mech_interp.py ├── datasets │ ├── __init__.py │ ├── cifar.py │ └── imagenet.py ├── perturbation │ ├── __init__.py │ ├── attn_perturbation.py │ ├── tokens_perturbation.py │ └── explain_perturbation.py ├── linear_probing │ ├── __init__.py │ └── prober.py ├── accuracy.py ├── vis.py ├── gradients.py ├── identifiability.py ├── extractor.py └── memories.py ├── framework.png ├── notebooks └── images │ ├── sample_0.JPEG │ ├── sample_1.JPEG │ ├── framework_1.png │ └── framework_2.png ├── setup.py ├── requirements.txt ├── .gitignore └── README.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/plots/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/perturbation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/linear_probing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinagvilas/vit-cls_emb/HEAD/framework.png -------------------------------------------------------------------------------- /notebooks/images/sample_0.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinagvilas/vit-cls_emb/HEAD/notebooks/images/sample_0.JPEG -------------------------------------------------------------------------------- /notebooks/images/sample_1.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinagvilas/vit-cls_emb/HEAD/notebooks/images/sample_1.JPEG -------------------------------------------------------------------------------- /notebooks/images/framework_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinagvilas/vit-cls_emb/HEAD/notebooks/images/framework_1.png -------------------------------------------------------------------------------- /notebooks/images/framework_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinagvilas/vit-cls_emb/HEAD/notebooks/images/framework_2.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = "vit_cls_emb", 5 | version = "0.0.1", 6 | packages=find_packages(), 7 | ) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | numpy 3 | opencv-python 4 | jupyterlab 5 | pandas == 1.5.2 6 | seaborn 7 | scipy 8 | 9 | einops 10 | scikit-learn 11 | #timm == 0.9.2 12 | timm == 0.6.12 13 | tqdm == 4.64.1 14 | torch == 1.13.1 15 | transformers == 4.26.0 16 | -------------------------------------------------------------------------------- /src/datasets/cifar.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from PIL import Image 3 | 4 | from torchvision.datasets import CIFAR100 5 | 6 | 7 | class MyCIFAR100(CIFAR100): 8 | def __init__(self, imgs_path, n=5): 9 | self.imgs_path = imgs_path 10 | super().__init__(imgs_path, train=False, download=False, transform=None) 11 | self.n_imgs = n 12 | self.stim_info = self._get_stimuli_info() 13 | return 14 | 15 | def _get_stimuli_info(self): 16 | stim_info = defaultdict(list) 17 | for img, label in zip(self.data, self.targets): 18 | label = str(label) 19 | if len(stim_info[label]) < self.n_imgs: 20 | stim_info[label].append(Image.fromarray(img)) 21 | return stim_info 22 | -------------------------------------------------------------------------------- /src/accuracy.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | 5 | from src.datasets.imagenet import ImagenetDatasetS 6 | 7 | 8 | def compute_accuracy( 9 | proj_path, dataset_path, model, device='cpu', by_concept=False, 10 | concept=None, percentage=True 11 | ): 12 | # Load accuracy 13 | acc_file = Path(proj_path) / 'results/class_embed' / model / 'acc.pt' 14 | acc = torch.load(acc_file, map_location=device) 15 | 16 | # Compute accuracy by concept 17 | if concept: 18 | stim_info = ImagenetDatasetS(Path(dataset_path)).stim_info 19 | c_idx = torch.Tensor( 20 | stim_info.loc[stim_info['imagenet_id'] == concept].index 21 | ).long() 22 | acc = acc[c_idx] 23 | if percentage: 24 | return torch.sum(acc) / len(acc) * 100 25 | else: 26 | return acc 27 | elif by_concept == True: 28 | return acc 29 | else: 30 | return torch.sum(acc) / len(acc) * 100 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Experiment generated 2 | stim_info.csv 3 | figures/ 4 | output/ 5 | manuscript/ 6 | data/ 7 | dataset/ 8 | results/ 9 | model_cktp/ 10 | weights/ 11 | wandb/ 12 | .tmp/ 13 | *.pdf 14 | *.zip 15 | 16 | # OS generated files 17 | .DS_Store 18 | 19 | # Byte-compiled / optimized / DLL files 20 | **/__pycache__/ 21 | *.py[cod] 22 | *$py.class 23 | 24 | # C extensions 25 | *.so 26 | 27 | # Distribution / packaging 28 | .Python 29 | build/ 30 | develop-eggs/ 31 | dist/ 32 | downloads/ 33 | eggs/ 34 | .eggs/ 35 | lib/ 36 | lib64/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | wheels/ 41 | pip-wheel-metadata/ 42 | share/python-wheels/ 43 | *.egg-info/ 44 | .installed.cfg 45 | *.egg 46 | MANIFEST 47 | 48 | # PyInstaller 49 | # Usually these files are written by a python script from a template 50 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 51 | *.manifest 52 | *.spec 53 | 54 | # Installer logs 55 | pip-log.txt 56 | pip-delete-this-directory.txt 57 | 58 | # Unit test / coverage reports 59 | htmlcov/ 60 | .tox/ 61 | .nox/ 62 | .coverage 63 | .coverage.* 64 | .cache 65 | nosetests.xml 66 | coverage.xml 67 | *.cover 68 | *.py,cover 69 | .hypothesis/ 70 | .pytest_cache/ 71 | 72 | # Translations 73 | *.mo 74 | *.pot 75 | 76 | # Django stuff: 77 | *.log 78 | local_settings.py 79 | db.sqlite3 80 | db.sqlite3-journal 81 | 82 | # Flask stuff: 83 | instance/ 84 | .webassets-cache 85 | 86 | # Scrapy stuff: 87 | .scrapy 88 | 89 | # Sphinx documentation 90 | docs/_build/ 91 | 92 | # PyBuilder 93 | target/ 94 | 95 | # Jupyter Notebook 96 | .ipynb_checkpoints 97 | 98 | # IPython 99 | profile_default/ 100 | ipython_config.py 101 | 102 | # pyenv 103 | .python-version 104 | 105 | # pipenv 106 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 107 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 108 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 109 | # install all needed dependencies. 110 | #Pipfile.lock 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Rope project settings 132 | .ropeproject 133 | 134 | # mkdocs documentation 135 | /site 136 | 137 | # mypy 138 | .mypy_cache/ 139 | .dmypy.json 140 | dmypy.json 141 | 142 | # Pyre type checker 143 | .pyre/ 144 | 145 | # VSCODE 146 | .vscode/ 147 | 148 | # Local History for Visual Studio Code 149 | .history/ 150 | 151 | # Built Visual Studio Code Extensions 152 | *.vsix -------------------------------------------------------------------------------- /src/datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from PIL import Image 3 | 4 | import pandas as pd 5 | from torch.utils.data import Dataset 6 | 7 | 8 | # Exclude categories of ImageNet that are not present in ImageNetS 9 | CATS_EXCLUDE = [ 10 | 'n04356056', 'n04355933', 'n04493381', 'n02808440', 'n03642806', 11 | 'n03832673', 'n04008634', 'n03773504', 'n03887697', 'n15075141' 12 | ] 13 | 14 | 15 | class ImagenetDatasetS(Dataset): 16 | def __init__(self, imagenet_path, partition='validation', n=5): 17 | self.path = Path(imagenet_path) / 'ImageNetS919' 18 | self.imgs_path = self.path / partition 19 | self.partition = partition 20 | self.n_imgs = n 21 | self.stim_info = self.get_stimuli_info() 22 | 23 | def get_stimuli_info(self): 24 | cats = self.get_category_info() 25 | file = self.path / f'{self.partition}_stim_info.csv' 26 | if file.exists(): 27 | stim_info = pd.read_csv(file) 28 | else: 29 | stim_info = [] 30 | cat_dirs = [d for d in self.imgs_path.iterdir() if d.is_dir()] 31 | for c in cat_dirs: 32 | imagenet_id = c.name 33 | if imagenet_id in CATS_EXCLUDE: 34 | continue 35 | else: 36 | idx = cats.loc[imagenet_id]['index'] 37 | cat = cats.loc[imagenet_id]['cat'] 38 | imgs_paths = [i for i in c.iterdir() if i.suffix == '.JPEG'] 39 | if len(imgs_paths) < self.n_imgs: 40 | continue 41 | else: 42 | for i_idx, i in enumerate(imgs_paths[:self.n_imgs]): 43 | i_name = i.name 44 | stim_info.append([imagenet_id, idx, cat, i_name, i_idx]) 45 | stim_info = pd.DataFrame( 46 | stim_info, 47 | columns=['imagenet_id', 'index', 'cat', 'img_name', 'img_index'] 48 | ) 49 | stim_info.to_csv(file, index=None) 50 | return stim_info 51 | 52 | def get_category_info(self): 53 | cats_info = pd.read_csv( 54 | self.path / 'LOC_synset_mapping.txt', sep='\t', header=None 55 | ) 56 | cats_info[['imagenet_id', 'cat']] = cats_info[0].str.split(' ', n=1, expand=True) 57 | cats_info = cats_info.drop(columns=0).reset_index(drop=False) 58 | cats_info = cats_info.set_index('imagenet_id') 59 | return cats_info 60 | 61 | def __len__(self): 62 | return len(self.stim_info) 63 | 64 | def __getitem__(self, idx): 65 | item_info = self.stim_info.iloc[idx] 66 | item = {} 67 | item['imagenet_id'] = item_info['imagenet_id'] 68 | item['index'] = item_info['index'] 69 | item['cat'] = item_info['cat'] 70 | item['img_index'] = item_info['img_index'] 71 | img_path = self.imgs_path / item_info['imagenet_id'] / item_info['img_name'] 72 | item['img'] = Image.open(img_path).convert('RGB') 73 | return item 74 | 75 | def get_item_by_concept(self, concept, img_idx): 76 | idx = ( 77 | self.stim_info.loc[self.stim_info['imagenet_id'] == concept] 78 | .iloc[img_idx].name 79 | ) 80 | return self.__getitem__(idx) 81 | -------------------------------------------------------------------------------- /src/models/load.py: -------------------------------------------------------------------------------- 1 | import timm 2 | from timm import create_model 3 | import torch 4 | import torchvision.transforms as T 5 | from transformers import AutoImageProcessor 6 | 7 | from src.models.deit_ensemble import base_patch16_224_hierarchical 8 | 9 | 10 | def load_vit( 11 | model_name, device, proj_path=None, return_transform=False, 12 | pretrained=True 13 | ): 14 | """_summary_ 15 | 16 | Parameters 17 | ---------- 18 | model_name : str 19 | Can be one of the following options: vit_b_16, vit_b_32, vit_large_16, 20 | vit_miil_16, vit_cifar_16, deit_ensemble_16, vit_gap_16. 21 | device : str 22 | 'cpu' or 'cuda' 23 | proj_path : pathlib Path, optional 24 | Path to the folder containing the source code, by default None 25 | return_transform : bool, optional 26 | Return image transform, by default False 27 | 28 | Returns 29 | ------- 30 | list 31 | Containing model, number of tokens, dimension of hidden states 32 | and image transform. 33 | """ 34 | # Get model source and info 35 | if model_name == 'vit_b_32': 36 | msource = 'vit_base_patch32_224' 37 | psource = 'google/vit-base-patch32-224-in21k' 38 | n_tokens = 50 39 | hs_dim = 768 40 | elif model_name == 'vit_b_16': 41 | msource = 'vit_base_patch16_224' 42 | psource = 'google/vit-base-patch16-224-in21k' 43 | n_tokens = 197 44 | hs_dim = 768 45 | elif model_name == 'vit_large_16': 46 | msource = 'vit_large_patch16_224' 47 | psource = 'google/vit-large-patch16-224-in21k' 48 | n_tokens = 197 49 | hs_dim = 1024 50 | elif model_name == 'vit_miil_16': 51 | msource = 'vit_base_patch16_224_miil' 52 | n_tokens = 197 53 | hs_dim = 768 54 | elif model_name == 'vit_cifar_16': 55 | n_tokens = 197 56 | hs_dim = 768 57 | elif model_name == 'deit_ensemble_16': 58 | psource = 'facebook/deit-base-distilled-patch16-224' 59 | n_tokens = 197 60 | hs_dim = 768 61 | elif model_name == 'vit_gap_16': 62 | msource = 'vit_base_patch16_rpn_224' 63 | psource = 'google/vit-base-patch16-224-in21k' 64 | n_tokens = 196 65 | hs_dim = 768 66 | 67 | # Load model 68 | if model_name.startswith('deit'): 69 | model = base_patch16_224_hierarchical(pretrained=pretrained).to(device) 70 | elif 'cifar' in model_name: 71 | model = create_model('vit_base_patch16_224', num_classes=100).to(device) 72 | if pretrained == True: 73 | state_dict = torch.load(proj_path / 'model_cktp' / 'vit_base_patch16_224-CIFAR100.pt') 74 | state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()} 75 | model.load_state_dict(state_dict) 76 | else: 77 | model = create_model(msource, pretrained=pretrained).to(device) 78 | 79 | # Get image transform 80 | if return_transform == True: 81 | if 'miil' in model_name: 82 | data_config = timm.data.resolve_data_config({}, model=model) 83 | img_transform = timm.data.create_transform(**data_config, is_training=False) 84 | elif 'cifar' in model_name: 85 | img_transform = T.Compose([ 86 | T.Resize((224, 224),), 87 | T.CenterCrop(224), 88 | T.ToTensor(), 89 | T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 90 | ]) 91 | else: 92 | img_transform = AutoImageProcessor.from_pretrained(psource) 93 | else: 94 | img_transform = None 95 | 96 | return model, n_tokens, hs_dim, img_transform -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Analyzing Vision Transformers in Class Embedding Space (NeurIPS '23) 2 | _by [Martina G. Vilas](https://martinagvilas.github.io/), Timothy Schaumlöffel and [Gemma Roig](http://www.cvai.cs.uni-frankfurt.de/team.html)_ 3 | 4 | __*Links*__: [Paper](https://arxiv.org/abs/2310.18969) | [Video presentation]() _(coming soon)_ | [Poster]() _(coming soon)_ 5 | 6 | > __Abstract__: Despite the growing use of transformer models in computer vision, a mechanistic 7 | understanding of these networks is still needed. This work introduces a method to 8 | reverse-engineer Vision Transformers trained to solve image classification tasks. 9 | Inspired by previous research in NLP, we demonstrate how the inner representations 10 | at any level of the hierarchy can be projected onto the learned class embedding 11 | space to uncover how these networks build categorical representations for their pre- 12 | dictions. We use our framework to show how image tokens develop class-specific 13 | representations that depend on attention mechanisms and contextual information, 14 | and give insights on how self-attention and MLP layers differentially contribute to 15 | this categorical composition. We additionally demonstrate that this method (1) can 16 | be used to determine the parts of an image that would be important for detecting 17 | the class of interest, and (2) exhibits significant advantages over traditional linear 18 | probing approaches. Taken together, our results position our proposed framework 19 | as a powerful tool for mechanistic interpretability and explainability research. 20 | 21 | ![framework](framework.png) 22 | 23 |
24 | Schematic of our framework 25 |
26 | 27 | ## :paperclip: Contents 28 | 29 | - [Tutorial](#tutorial) 30 | - [Running the experiments](#running-the-experiments) 31 | - [Citing our work](#citing-our-work) 32 | - [Acknowledgements](#acknowledgements) 33 | 34 | ## Tutorial 35 | 36 | You can access a tutorial of our method here: 37 | Open In Colab 38 | 39 | 40 | ## Running the experiments 41 | 42 | #### Step 1: Get a local working copy of this code 43 | __1.1.__ Clone this repository in your local machine. 44 | 45 | __1.2.__ Install the required software using conda, by running: 46 | ``` 47 | conda create --name vit-cls python=3.9 48 | conda activate vit-cls 49 | pip install -r requirements.txt 50 | pip install . 51 | ``` 52 | 53 | #### Step 2: Download the dataset and model checkpoints 54 | __2.1.__ Download the ImageNet-S dataset from [here](https://github.com/LUSSeg/ImageNet-S). 55 | 56 | __2.2.__ Download the stimuli info file from [here](https://drive.google.com/drive/folders/1bkJeOGMxU2Ta0CrtKeY9JBLArwmQM9mu?usp=sharing), and place it inside the `ImageNet-S/ImageNetS919` 57 | folder downloaded in the previous step. 58 | 59 | __2.3.__ Download the model checkpoint folder from [here](https://drive.google.com/drive/folders/1bkJeOGMxU2Ta0CrtKeY9JBLArwmQM9mu?usp=sharing), and place it inside the project folder. 60 | 61 | #### Step 3: Run experiments for extracting code 62 | __3.1.__ Project hidden states to class embedding space and save key coefficients, by running: 63 | ``` 64 | python extractor.py -pp {PATH TO SOURCE CODE} -dp {PATH TO DATASET} -m {MODEL} -pretrained 65 | ``` 66 | - The model can be one of `vit_b_32`, `vit_b_16`, `vit_large_16`, `vit_cifar_16`, `vit_miil_16`, `deit_ensemble_16` (_Refinement_ model) and `vit_gap_16`. 67 | - You can reproduce the results of the random model by removing the `-pretrained` flag. 68 | 69 | 70 | __3.2.__ Run attention perturbation studies, by: 71 | ``` 72 | python perturbation/attn_perturbation.py -pp {PATH TO SOURCE CODE} -dp {PATH TO DATASET} -m vit_b_32 -pt {PERTURBATION TYPE} 73 | ``` 74 | - Perturbation type can be one of `self_only` or `no_cls`. 75 | 76 | __3.3.__ Run context perturbation studies, by: 77 | ``` 78 | python perturbation/tokens_perturbation.py -pp {PATH TO SOURCE CODE} -dp {PATH TO DATASET} -m vit_b_32 -mt {MASK TYPE} 79 | ``` 80 | - Mask type can be one of `context` or `class label`. 81 | 82 | __3.4.__ Run memory extractor, by: 83 | ``` 84 | python memories.py -pp {PATH TO SOURCE CODE} -dp {PATH TO DATASET} -m {MODEL} -lt {LAYER TYPE} 85 | ``` 86 | - Layer type can be one of `attn` or `mlp`. 87 | 88 | __3.5.__ Run comparison with a linear probing approach, by: 89 | ``` 90 | python linear_probing/prober.py -pp {PATH TO SOURCE CODE} -dp {PATH TO DATASET} -l {LAYER INDEX} 91 | ``` 92 | 93 | #### Step 4: Reproduce the results 94 | After running the above code, 95 | head to the [notebooks](https://github.com/martinagvilas/vit-cls_emb/tree/main/notebooks) section to reproduce and visualize the reported results. 96 | 97 | ## Citing our work 98 | Please cite this work as: 99 | ``` 100 | @inproceedings{vilas2023analyzing_vit, 101 | title = {Analyzing Vision Transformers for Image Classification in Class Embedding Space}, 102 | author = {Vilas, Martina G. and Schauml\"{o}ffel, Timothy and Roig, Gemma}, 103 | booktitle = {Advances in Neural Information Processing Systems}, 104 | pages = {40030--40041}, 105 | volume = {36}, 106 | year = {2023} 107 | url = {https://proceedings.neurips.cc/paper_files/paper/2023/file/7dd309df03d37643b96f5048b44da798-Paper-Conference.pdf}, 108 | } 109 | 110 | ``` 111 | 112 | 113 | ## Acknowledgements 114 | - The pre-trained models are extracted from the [timm](https://github.com/huggingface/pytorch-image-models/tree/main) library. 115 | - Our readme is inspired by [IPViT](https://github.com/Muzammal-Naseer/IPViT). 116 | -------------------------------------------------------------------------------- /src/perturbation/attn_perturbation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict 3 | from pathlib import Path 4 | 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from tqdm import tqdm 8 | from transformers import ViTImageProcessor 9 | 10 | from src.datasets.imagenet import ImagenetDatasetS 11 | from src.models.vit_edited import vit_base_patch32_224, vit_base_patch16_224 12 | 13 | 14 | class ViTPerturbAttn(): 15 | def __init__( 16 | self, version, project_path, imgs_path, device='cpu', 17 | perturb_type='no_cls' 18 | ): 19 | self.version = version 20 | self.model_name = f'vit_b_{version}' 21 | self.device = device 22 | self.perturb_type = perturb_type 23 | 24 | self.project_path = project_path 25 | self.imgs_path = imgs_path 26 | self.perturb_path = project_path / 'results' / 'perturbation' / model 27 | self.perturb_path.mkdir(parents=True, exist_ok=True) 28 | 29 | self._get_model_layers() 30 | self._load_model() 31 | 32 | def _get_model_layers(self): 33 | self.layers = ['blocks.11'] 34 | return 35 | 36 | def _load_model(self): 37 | if self.version == '32': 38 | self.model = vit_base_patch32_224( 39 | pretrained=True, pretrained_cfg=True, 40 | perturb_type=self.perturb_type, block_perturb='all' 41 | ).to(self.device) 42 | source = 'google/vit-base-patch32-224-in21k' 43 | self.n_tokens = 49 44 | elif self.version == '16': 45 | self.model = vit_base_patch16_224( 46 | pretrained=True, pretrained_cfg=True, 47 | perturb_type=self.perturb_type, block_perturb='all' 48 | ).to(self.device) 49 | source = 'google/vit-base-patch16-224-in21k' 50 | self.n_tokens = 196 51 | self.model.eval() 52 | 53 | # Add feature extractor 54 | self._add_feature_extractor() 55 | 56 | # Get image transform 57 | self.img_transform = ViTImageProcessor.from_pretrained(source) 58 | return 59 | 60 | def _add_feature_extractor(self): 61 | def get_activation(layer_name): 62 | def hook(_, input, output): 63 | try: 64 | self.model._features[layer_name] = output.detach() 65 | except AttributeError: 66 | # attention layer can output a tuple 67 | self.model._features[layer_name] = output[0].detach() 68 | return hook 69 | 70 | for layer_name, layer in self.model.named_modules(): 71 | if layer_name in self.layers: 72 | layer.register_forward_hook(get_activation(layer_name)) 73 | return 74 | 75 | def _get_dataset(self): 76 | dataset = ImagenetDatasetS(self.imgs_path) 77 | self.dataloader = DataLoader(dataset, batch_size=1, collate_fn=lambda x: x[0]) 78 | return 79 | 80 | def compute(self): 81 | self._get_dataset() 82 | acc = [] 83 | dec = defaultdict(list) 84 | for id, data in tqdm( 85 | enumerate(self.dataloader), total=len(self.dataloader) 86 | ): 87 | # Get image features 88 | img_ft = self.img_transform(data['img'], return_tensors="pt") 89 | img_ft = img_ft['pixel_values'].to(self.device) 90 | 91 | # Compute hidden states 92 | self.model._features = {} 93 | with torch.no_grad(): 94 | out = self.model(img_ft) 95 | 96 | # Turn hidden states into decodability scores 97 | for l_name, l_repr in self.model._features.items(): 98 | l_split = l_name.split('.') 99 | try: 100 | l_name = f'hs-{l_split[2]}_{l_split[1]}' 101 | except: 102 | l_name = f'hs_{l_split[1]}' 103 | with torch.no_grad(): 104 | preds = self.model.head(self.model.norm(l_repr)) 105 | ordered_idx = torch.argsort(preds, dim=2, descending=True) 106 | label_idx = (ordered_idx == data['index']).nonzero() 107 | dec[l_name].append(label_idx[:, 2]) 108 | 109 | # Compute accuracy 110 | pred = out.topk(1)[1] 111 | cat_acc = torch.squeeze((pred == data['index']).long()) 112 | acc.append(cat_acc) 113 | 114 | # Save accuracy and hidden states 115 | acc = torch.hstack(acc).to(self.device) 116 | torch.save(acc, self.perturb_path / f'attn-{self.perturb_type}_acc.pt') 117 | 118 | dec = torch.stack(dec['hs_11']).to(self.device) 119 | torch.save(dec, self.perturb_path / f'attn-{self.perturb_type}_dec.pt') 120 | 121 | return 122 | 123 | 124 | if __name__ == '__main__': 125 | parser = argparse.ArgumentParser() 126 | parser.add_argument( 127 | '-pp', action='store', required=True, 128 | help='Path to the folder containing the source code.' 129 | ) 130 | parser.add_argument( 131 | '-dp', action='store', required=True, 132 | help='Path to the folder containing the dataset.' 133 | ) 134 | parser.add_argument( 135 | '-m', action='store', required=True, 136 | help='Select which model to run. Can be one of the following options: \ 137 | vit_16, vit_32.' 138 | ) 139 | parser.add_argument( 140 | '-pt', action='store', required=True, 141 | help='Perturbation type. Can be one of [no_cls], [self_only].' 142 | ) 143 | args = parser.parse_args() 144 | 145 | project_path = Path(args.pp) 146 | data_path = Path(args.dp) 147 | 148 | model = args.m 149 | version = model.split('_')[-1] 150 | device = "cuda" if torch.cuda.is_available() else "cpu" 151 | perturb_type = args.pt 152 | 153 | perturb = ViTPerturbAttn( 154 | version, project_path, data_path, device, perturb_type=perturb_type 155 | ) 156 | perturb.compute() -------------------------------------------------------------------------------- /src/linear_probing/prober.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from PIL import Image 4 | 5 | import numpy as np 6 | from sklearn.linear_model import Ridge 7 | from sklearn.metrics import accuracy_score 8 | from timm import create_model 9 | from tqdm import tqdm 10 | import torch 11 | from torchvision.models.feature_extraction import create_feature_extractor 12 | from transformers import AutoImageProcessor 13 | 14 | from src.datasets.imagenet import ImagenetDatasetS 15 | 16 | LAYERS = [f'{layer}-{i}' for layer in ['hs', 'attn', 'mlp'] for i in range(12)] 17 | 18 | CATS_EXCLUDE = [ 19 | 'n04356056', 'n04355933', 'n04493381', 'n02808440', 'n03642806', 20 | 'n03832673', 'n04008634', 'n03773504', 'n03887697', 'n15075141' 21 | ] 22 | 23 | 24 | class LinearProber: 25 | def __init__(self, model_name, layer, project_path, imgs_path, device): 26 | self.model_name = model_name 27 | self.layer = layer 28 | self.device = device 29 | self._load_model() 30 | 31 | self.project_path = Path(project_path) 32 | self.imgs_path = Path(imgs_path) 33 | 34 | return 35 | 36 | def _load_model(self): 37 | # Load model 38 | self.model = create_model( 39 | 'vit_base_patch32_224', pretrained=True 40 | ).to(self.device) 41 | 42 | # Add feature extractor 43 | block = self.layer.split('-')[-1] 44 | if 'attn' in self.layer: 45 | self.node = f'blocks.{block}.attn.proj' 46 | elif 'mlp' in self.layer: 47 | self.node = f'blocks.{block}.mlp.fc2' 48 | elif 'hs' in self.layer: 49 | self.node = f'blocks.{block}.add_1' 50 | self.extractor = create_feature_extractor(self.model, [self.node]).to(self.device) 51 | 52 | # Get image processor 53 | self.img_transform = AutoImageProcessor.from_pretrained( 54 | 'google/vit-base-patch32-224-in21k' 55 | ) 56 | 57 | return 58 | 59 | def _get_training_data(self, token): 60 | dataset = ImagenetDatasetS(self.imgs_path, partition='test', n=10) 61 | 62 | fts = [] 63 | targets = [] 64 | for _, row in tqdm( 65 | dataset.stim_info.iterrows(), total=len(dataset.stim_info), 66 | desc=f'{self.node}/{token}' 67 | ): 68 | # Get image features 69 | img_path = self.imgs_path / 'ImageNetS919/test' / row['imagenet_id'] / row['img_name'] 70 | img = Image.open(img_path).convert('RGB') 71 | 72 | img_ft = self.img_transform(img, return_tensors="pt") 73 | img_ft = img_ft['pixel_values'].to(self.device) 74 | 75 | # Run model 76 | with torch.no_grad(): 77 | out = self.extractor(img_ft) 78 | 79 | # Save features 80 | out = out[self.node][0, token] 81 | fts.append(out) 82 | 83 | # Get one hot encoder 84 | target = torch.zeros(1000).to(self.device) - 1 85 | target[row['index']] = 1 86 | targets.append(target) 87 | 88 | fts = torch.stack(fts) 89 | targets = torch.stack(targets) 90 | 91 | return fts, targets 92 | 93 | def _get_testing_data(self, token): 94 | dataset = ImagenetDatasetS(self.imgs_path) 95 | 96 | fts = [] 97 | targets = [] 98 | for _, row in tqdm(dataset.stim_info.iterrows(), total=len(dataset.stim_info)): 99 | # Get image features 100 | img_path = self.imgs_path / 'ImageNetS919/validation' / row['imagenet_id'] / row['img_name'] 101 | img = Image.open(img_path).convert('RGB') 102 | 103 | img_ft = self.img_transform(img, return_tensors="pt") 104 | img_ft = img_ft['pixel_values'].to(self.device) 105 | 106 | # Run model 107 | with torch.no_grad(): 108 | out = self.extractor(img_ft) 109 | 110 | # Save features 111 | out = out[self.node][0, token] 112 | fts.append(out) 113 | 114 | # Get index 115 | targets.append(row['index']) 116 | 117 | fts = torch.stack(fts) 118 | 119 | return fts, targets 120 | 121 | def compute(self): 122 | path = self.project_path / 'results' / 'linear_probing' / self.layer 123 | path.mkdir(parents=True, exist_ok=True) 124 | 125 | accs = [] 126 | for token in range(50): 127 | # Get training data 128 | fts, targets = self._get_training_data(token) 129 | 130 | # Train model 131 | lm = Ridge(alpha=1.0, solver='lsqr') 132 | lm.fit(fts.cpu().numpy(), targets.cpu().numpy()) 133 | 134 | # Get testing data 135 | fts, targets = self._get_testing_data(token) 136 | 137 | # Test accuracy 138 | preds = lm.predict(fts.cpu().numpy()) 139 | topk = np.argmax(preds, axis=1) 140 | acc = accuracy_score(targets, topk) 141 | accs.append(acc) 142 | 143 | # Save position of correct prediction 144 | pos = 999 - np.where(np.argsort(preds) == np.expand_dims(targets, axis=1))[1] 145 | np.save(path / f'pos_t{token}.npy', pos) 146 | 147 | # Save accuracy for token and layer 148 | np.save(path / f'acc.npy', accs) 149 | 150 | return 151 | 152 | 153 | if __name__ == '__main__': 154 | parser = argparse.ArgumentParser() 155 | parser.add_argument( 156 | '-pp', action='store', required=True, 157 | help='Path to the folder containing the source code.' 158 | ) 159 | parser.add_argument( 160 | '-dp', action='store', required=True, 161 | help='Path to the folder containing the dataset.' 162 | ) 163 | parser.add_argument( 164 | '-l', action='store', required=True, 165 | help='Layer index.' 166 | ) 167 | 168 | args = parser.parse_args() 169 | project_path = Path(args.pp) 170 | data_path = Path(args.dp) 171 | 172 | device = "cuda" if torch.cuda.is_available() else "cpu" 173 | 174 | layer = LAYERS[int(args.l)] 175 | LinearProber('vit_b_32', layer, project_path, data_path, device).compute() 176 | -------------------------------------------------------------------------------- /src/models/deit_ensemble.py: -------------------------------------------------------------------------------- 1 | ### 2 | # From https://github.com/Muzammal-Naseer/ATViT/blob/c7f947831decc43cc2b73884b16e6e79b02ad547/vit_models/deit_ensemble.py#L111 3 | ### 4 | 5 | from functools import partial 6 | 7 | import torch 8 | import torch.nn as nn 9 | import math 10 | from einops import reduce, rearrange 11 | from timm.models.registry import register_model 12 | from timm.models.vision_transformer import VisionTransformer, _cfg 13 | 14 | import torch.nn.functional as F 15 | 16 | __all__ = [ 17 | "tiny_patch16_224_hierarchical", "small_patch16_224_hierarchical", "base_patch16_224_hierarchical" 18 | ] 19 | 20 | 21 | class TransformerHead(nn.Module): 22 | expansion = 1 23 | 24 | def __init__(self, token_dim, num_patches=196, num_classes=1000, stride=1): 25 | super(TransformerHead, self).__init__() 26 | 27 | self.token_dim = token_dim 28 | self.num_patches = num_patches 29 | self.num_classes = num_classes 30 | 31 | # To process patches 32 | self.conv = nn.Conv2d(self.token_dim, self.token_dim, kernel_size=3, stride=stride, padding=1, bias=False) 33 | self.bn = nn.BatchNorm2d(self.token_dim) 34 | self.conv = nn.Conv2d(self.token_dim, self.token_dim, kernel_size=3, stride=1, padding=1, bias=False) 35 | self.bn = nn.BatchNorm2d(self.token_dim) 36 | 37 | self.shortcut = nn.Sequential() 38 | if stride != 1 or self.token_dim != self.expansion * self.token_dim: 39 | self.shortcut = nn.Sequential( 40 | nn.Conv2d(self.token_dim, self.expansion * self.token_dim, kernel_size=1, stride=stride, bias=False), 41 | nn.BatchNorm2d(self.expansion * self.token_dim) 42 | ) 43 | 44 | self.token_fc = nn.Linear(self.token_dim, self.token_dim) 45 | 46 | def forward(self, x): 47 | """ 48 | x : (B, num_patches + 1, D) -> (B, C=num_classes) 49 | """ 50 | cls_token, patch_tokens = x[:, 0], x[:, 1:] 51 | size = int(math.sqrt(x.shape[1])) 52 | 53 | patch_tokens = rearrange(patch_tokens, 'b (h w) d -> b d h w', h=size, w=size) # B, D, H, W 54 | features = F.relu(self.bn(self.conv(patch_tokens))) 55 | features = self.bn(self.conv(features)) 56 | features += self.shortcut(patch_tokens) 57 | features = F.relu(features) 58 | patch_tokens = F.avg_pool2d(features, 14).view(-1, self.token_dim) 59 | cls_token = self.token_fc(cls_token) 60 | 61 | out = patch_tokens + cls_token 62 | 63 | return out 64 | 65 | 66 | class VisionTransformer_hierarchical(VisionTransformer): 67 | def __init__(self, *args, **kwargs): 68 | super().__init__(*args, **kwargs) 69 | 70 | # # Transformer heads 71 | # self.transformerheads = nn.Sequential(*[ 72 | # TransformerHead(self.embed_dim) 73 | # for i in range(11)]) 74 | 75 | def forward_features(self, x): 76 | B, nc, w, h = x.shape 77 | x = self.patch_embed(x) 78 | 79 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 80 | x = torch.cat((cls_tokens, x), dim=1) 81 | x = x + self.pos_embed 82 | x = self.pos_drop(x) 83 | 84 | # Store transformer outputs 85 | #transformerheads_outputs = [] 86 | 87 | for idx, blk in enumerate(self.blocks): 88 | x = blk(x) 89 | # if idx <= 10: 90 | # out = self.norm(x) 91 | # out = self.transformerheads[idx](out) 92 | # transformerheads_outputs.append(out) 93 | 94 | x = self.norm(x) 95 | return x 96 | #return x, transformerheads_outputs 97 | 98 | def forward(self, x): 99 | x = self.forward_features(x) 100 | #x, transformerheads_outputs = self.forward_features(x) 101 | output = self.head(x[:, 0]) 102 | # output = [] 103 | # for y in transformerheads_outputs: 104 | # output.append(self.head(y)) 105 | # output.append(self.head(x[:, 0])) 106 | return output 107 | 108 | 109 | @register_model 110 | def tiny_patch16_224_hierarchical(pretrained=False, **kwargs): 111 | model = VisionTransformer_hierarchical( 112 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 113 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 114 | model.default_cfg = _cfg() 115 | if pretrained: 116 | checkpoint = torch.hub.load_state_dict_from_url( 117 | url="https://github.com/Muzammal-Naseer/Improving-Adversarial-Transferability-of-Vision-Transformers" 118 | "/releases/download/v0/deit_tiny_trm.pth", 119 | map_location="cpu", check_hash=True 120 | ) 121 | model.load_state_dict(checkpoint["state_dict"]) 122 | return model 123 | 124 | 125 | @register_model 126 | def small_patch16_224_hierarchical(pretrained=False, **kwargs): 127 | model = VisionTransformer_hierarchical( 128 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 129 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 130 | model.default_cfg = _cfg() 131 | 132 | if pretrained: 133 | checkpoint = torch.hub.load_state_dict_from_url( 134 | url="https://github.com/Muzammal-Naseer/Improving-Adversarial-Transferability-of-Vision-Transformers" 135 | "/releases/download/v0/deit_small_trm.pth", 136 | map_location="cpu", check_hash=True 137 | ) 138 | model.load_state_dict(checkpoint["state_dict"]) 139 | return model 140 | 141 | 142 | @register_model 143 | def base_patch16_224_hierarchical(pretrained=False, **kwargs): 144 | model = VisionTransformer_hierarchical( 145 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 146 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 147 | model.default_cfg = _cfg() 148 | 149 | if pretrained: 150 | checkpoint = torch.hub.load_state_dict_from_url( 151 | url="https://github.com/Muzammal-Naseer/Improving-Adversarial-Transferability-of-Vision-Transformers" 152 | "/releases/download/v0/deit_base_trm.pth", 153 | map_location="cpu", check_hash=True 154 | ) 155 | model.load_state_dict(checkpoint["state_dict"], strict=False) 156 | return model 157 | -------------------------------------------------------------------------------- /src/perturbation/tokens_perturbation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict 3 | from pathlib import Path 4 | 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from tqdm import tqdm 8 | from transformers import ViTImageProcessor 9 | 10 | from src.datasets.imagenet import ImagenetDatasetS 11 | from src.models.vit_edited import vit_base_patch32_224, vit_base_patch16_224 12 | from src.vis import Vis 13 | 14 | 15 | class ViTPerturbTokens(): 16 | def __init__(self, version, project_path, imgs_path, device='cpu'): 17 | self.version = version 18 | self.model_name = f'vit_b_{version}' 19 | self.device = device 20 | 21 | self.project_path = Path(project_path) 22 | self.imgs_path = Path(imgs_path) 23 | self.perturb_path = project_path / 'results' / 'perturbation' / model 24 | self.perturb_path.mkdir(parents=True, exist_ok=True) 25 | 26 | self._get_model_layers() 27 | self._load_model() 28 | 29 | def _get_model_layers(self): 30 | # Compute identifiability of tokens in last block 31 | self.layers = ['blocks.11'] 32 | return 33 | 34 | def _load_model(self): 35 | # Load model 36 | if self.version == '32': 37 | self.model = vit_base_patch32_224( 38 | pretrained=True, pretrained_cfg=True 39 | ).to(self.device) 40 | self.n_tokens = 49 41 | source = 'google/vit-base-patch32-224-in21k' 42 | elif self.version == '16': 43 | self.model = vit_base_patch16_224( 44 | pretrained=True, pretrained_cfg=True 45 | ).to(self.device) 46 | source = 'google/vit-base-patch16-224-in21k' 47 | self.n_tokens = 196 48 | self.model.eval() 49 | 50 | # Add feature extractor 51 | self._add_feature_extractor() 52 | 53 | # Get image transform 54 | self.img_transform = ViTImageProcessor.from_pretrained(source) 55 | return 56 | 57 | def _add_feature_extractor(self): 58 | def get_activation(layer_name): 59 | def hook(_, input, output): 60 | try: 61 | self.model._features[layer_name] = output.detach() 62 | except AttributeError: 63 | # attention layer can output a tuple 64 | self.model._features[layer_name] = output[0].detach() 65 | return hook 66 | 67 | for layer_name, layer in self.model.named_modules(): 68 | if layer_name in self.layers: 69 | layer.register_forward_hook(get_activation(layer_name)) 70 | return 71 | 72 | 73 | def _get_dataset(self): 74 | self.dataset = ImagenetDatasetS(self.imgs_path) 75 | self.dataloader = DataLoader( 76 | self.dataset, batch_size=1, collate_fn=lambda x: x[0] 77 | ) 78 | return 79 | 80 | def compute(self, mask_type='context'): 81 | self._get_dataset() 82 | 83 | vis = Vis(self.project_path, self.imgs_path, self.model_name, self.device) 84 | 85 | acc = [] 86 | dec = defaultdict(list) 87 | for id, data in tqdm( 88 | enumerate(self.dataloader), total=len(self.dataloader) 89 | ): 90 | # Get image features 91 | img_ft = self.img_transform(data['img'], return_tensors="pt") 92 | img_ft = img_ft['pixel_values'].to(self.device) 93 | 94 | # Get segmentations of discarded tokens 95 | mask = vis.get_segmentation(data['imagenet_id'], idx=data['img_index']) 96 | mask = mask.flatten() 97 | if mask_type == 'context': 98 | mask = torch.tensor((mask == 1).nonzero()[0] + 1).to(self.device) 99 | elif mask_type == 'class_label': 100 | mask = torch.tensor((mask == 0).nonzero()[0] + 1).to(self.device) 101 | tokens = torch.hstack((torch.tensor([0]), mask)) 102 | 103 | # Compute hidden states 104 | self.model._features = {} 105 | with torch.no_grad(): 106 | out = self.model(img_ft, tokens) 107 | 108 | # Turn hidden states into decodability scores 109 | for l_name, l_repr in self.model._features.items(): 110 | l_split = l_name.split('.') 111 | try: 112 | l_name = f'hs-{l_split[2]}_{l_split[1]}' 113 | except: 114 | l_name = f'hs_{l_split[1]}' 115 | with torch.no_grad(): 116 | preds = self.model.head(self.model.norm(l_repr)) 117 | ordered_idx = torch.argsort(preds, dim=2, descending=True) 118 | label_idx = (ordered_idx == data['index']).nonzero() 119 | dec[l_name].append(label_idx[:, 2]) 120 | 121 | # Compute accuracy 122 | pred = out.topk(1)[1] 123 | cat_acc = torch.squeeze((pred == data['index']).long()) 124 | acc.append(cat_acc) 125 | 126 | # Save accuracy and hidden states 127 | acc = torch.hstack(acc).to(self.device) 128 | torch.save(acc, self.perturb_path / f'no_{mask_type}_tokens_acc.pt') 129 | torch.save(dec, self.perturb_path / f'no_{mask_type}_tokens_dec.pt') 130 | 131 | return 132 | 133 | 134 | if __name__ == '__main__': 135 | parser = argparse.ArgumentParser() 136 | parser.add_argument( 137 | '-pp', action='store', required=True, 138 | help='Path to the folder containing the source code.' 139 | ) 140 | parser.add_argument( 141 | '-dp', action='store', required=True, 142 | help='Path to the folder containing the dataset.' 143 | ) 144 | parser.add_argument( 145 | '-m', action='store', required=True, 146 | help='Select which model to run. Can be one of the following options: \ 147 | vit_16, vit_32' 148 | ) 149 | parser.add_argument( 150 | '-mt', action='store', required=True, 151 | help='Mask type. Can be one of [context], [class_label].' 152 | ) 153 | args = parser.parse_args() 154 | 155 | project_path = Path(args.pp) 156 | data_path = Path(args.dp) 157 | 158 | model = args.m 159 | version = model.split('_')[-1] 160 | device = "cuda" if torch.cuda.is_available() else "cpu" 161 | mask_type = args.mt 162 | 163 | perturb = ViTPerturbTokens(version, project_path, data_path, device) 164 | if mask_type: 165 | perturb.compute(mask_type) 166 | else: 167 | perturb.compute() -------------------------------------------------------------------------------- /src/plots/explainability.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | import torch 5 | 6 | from src.gradients import ViTGradients 7 | from src.vis import Vis 8 | 9 | 10 | def plot_specific_ft_importance( 11 | model_name, block, head, imgs_info, stim_info, proj_path, dataset_path 12 | ): 13 | """ 14 | Plot block- and head-specific feature importance results. 15 | """ 16 | # Get gradients 17 | vis = Vis(proj_path, dataset_path, model_name, 'cpu') 18 | g = ViTGradients(model_name, proj_path, dataset_path) 19 | 20 | # Plot 21 | fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(5, 3)) 22 | for i, ax in zip(imgs_info, axes.flat): 23 | # Compute gradients 24 | grads, _ = g.compute(i[0], i[1], i[2]) 25 | mask = - grads[block][head, 0, 1:] 26 | mask_vis = vis.mask(i[0], i[1], mask) 27 | 28 | ax.imshow(mask_vis) 29 | ax.set_xticks([]) 30 | ax.set_yticks([]) 31 | cat = stim_info[stim_info['index'] == i[2]]['cat'].unique()[0].split(',')[0].lower() 32 | ax.set_title(cat) 33 | 34 | f = proj_path / 'results/figures' / f'explainability_{model_name}_head_1.png' 35 | plt.tight_layout() 36 | plt.savefig(f, dpi=300) 37 | plt.show() 38 | return 39 | 40 | 41 | def plot_sum_ft_importance(model_name, imgs_info, stim_info, proj_path, dataset_path): 42 | """ 43 | Plot sum of gradients feature importance results. 44 | """ 45 | # Visualize feature importance 46 | if len(imgs_info) <= 6: 47 | fig, axes = plt.subplots(nrows=1, ncols=len(imgs_info), figsize=(10, 3)) 48 | else: 49 | fig, axes = plt.subplots(nrows=2, ncols=int(len(imgs_info)/2), figsize=(12, 5)) 50 | 51 | for i, ax in zip(imgs_info, axes.flat): 52 | 53 | # Compute importance by gradients 54 | grads, _ = ViTGradients(model_name, proj_path, dataset_path).compute(i[0], i[1], i[2]) 55 | importance = [] 56 | for b in range(12): 57 | importance.append(torch.sum(grads[b], dim=0)[0]) 58 | mask = torch.sum(torch.stack(importance), dim=0)[1:] 59 | mask = - mask 60 | 61 | # Plot heatmap over image 62 | vis = Vis(proj_path, dataset_path, model_name, 'cpu') 63 | mask_vis = vis.mask(i[0], i[1], mask) 64 | ax.imshow(mask_vis) 65 | ax.set_xticks([]) 66 | ax.set_yticks([]) 67 | if i[2] == 582: 68 | cat = 'grocery store' 69 | else: 70 | cat = stim_info[stim_info['index'] == i[2]]['cat'].unique()[0].split(',')[0].lower() 71 | ax.set_title(cat) 72 | 73 | f = proj_path / 'results/figures' / f'explainability_{model_name}.png' 74 | plt.tight_layout() 75 | plt.savefig(f, dpi=300) 76 | plt.show() 77 | 78 | return 79 | 80 | 81 | def compare_explainability(proj_path, dataset_path, stim_info): 82 | model_name = 'vit_b_32' 83 | 84 | # Select different images, classes, blocks and head 85 | img_info = [ 86 | ['n02422699', 0, 352, 7, 0], ['n02422699', 0, 350, 7, 0], 87 | ['n04404412', 1, 851, 11, 6], ['n04404412', 1, 831, 11, 6], 88 | ] # [imagenet_id, image_id, class_id, block, attention head] 89 | 90 | # Compute gradients 91 | vis = Vis(proj_path, dataset_path, model_name, 'cpu') 92 | g = ViTGradients(model_name, proj_path, dataset_path) 93 | 94 | # Plot 95 | fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(10, 6)) 96 | #fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(10, 5)) 97 | for i, ax in zip(img_info, axes.flat[:4]): 98 | b = i[3] 99 | h = i[4] 100 | grads, _ = g.compute(i[0], i[1], i[2]) 101 | mask = - grads[b][h, 0, 1:] 102 | mask_vis = vis.mask(i[0], i[1], mask) 103 | ax.imshow(mask_vis) 104 | ax.set_xticks([]) 105 | ax.set_yticks([]) 106 | cat = stim_info[stim_info['index'] == i[2]]['cat'].unique()[0].split(',')[0].lower() 107 | ax.set_title(f'{cat} - bl. {b}, h. {h}', fontsize=10) 108 | 109 | for i, ax in zip(img_info, axes.flat[4:]): 110 | # Compute importance by gradients 111 | grads, _ = g.compute(i[0], i[1], i[2]) 112 | importance = [] 113 | for b in range(12): 114 | importance.append(torch.sum(grads[b], dim=0)[0]) 115 | mask = torch.sum(torch.stack(importance), dim=0)[1:] 116 | mask = - mask 117 | 118 | # Plot heatmap over image 119 | mask_vis = vis.mask(i[0], i[1], mask) 120 | ax.imshow(mask_vis) 121 | ax.set_xticks([]) 122 | ax.set_yticks([]) 123 | if i[2] == 582: 124 | cat = 'grocery store' 125 | else: 126 | cat = stim_info[stim_info['index'] == i[2]]['cat'].unique()[0].split(',')[0].lower() 127 | ax.set_title(cat, fontsize=10) 128 | 129 | fig.text(x=0.1, y=0.92, s='(a) Block- and head-specific visualization', weight='bold', fontsize=12) 130 | fig.text(x=0.1, y=0.5, s='(b) Sum of gradients visualization', weight='bold', fontsize=12) 131 | fig.text(x=0.11, y=0.86, s='(1)', weight='bold', fontsize=11) 132 | fig.text(x=0.51, y=0.86, s='(2)', weight='bold', fontsize=11) 133 | fig.text(x=0.11, y=0.44, s='(1)', weight='bold', fontsize=11) 134 | fig.text(x=0.51, y=0.44, s='(2)', weight='bold', fontsize=11) 135 | 136 | f = proj_path / 'results/figures' / f'explainability_{model_name}_v2.png' 137 | plt.savefig(f, dpi=200) 138 | plt.show() 139 | return 140 | 141 | 142 | def plot_perturbation(model_name, perturb_type, res_path, random=False): 143 | """ 144 | Plot perturbation experiment results. 145 | """ 146 | labels = torch.arange(48) / 49 * 100 147 | 148 | fig, ax = plt.subplots(figsize=(6,4)) 149 | 150 | f = res_path / 'perturbation' / model_name / f'{perturb_type}_grads.pt' 151 | emb_perturb = torch.load(f, map_location='cpu') 152 | emb_perturb = torch.flip(emb_perturb, dims=(0,)).detach().numpy() 153 | sns.lineplot(x=labels, y=emb_perturb, ax=ax, label='cls-emb removal') 154 | print(f'AUC emb: {np.sum(emb_perturb) / (np.max(emb_perturb) * 49)}') 155 | 156 | f = res_path / 'perturbation' / model_name / f'{perturb_type}_grads_chefer.pt' 157 | chefer_perturb = torch.load(f, map_location='cpu') 158 | chefer_perturb = torch.flip(chefer_perturb, dims=(0,)).detach().numpy() 159 | sns.lineplot(x=labels, y=chefer_perturb, ax=ax, label='chefer removal') 160 | print(f'AUC chefer: {np.sum(chefer_perturb) / (np.max(chefer_perturb) * 49)}') 161 | 162 | if random == True: 163 | f = res_path / 'perturbation' / model_name / 'perturb_random.pt' 164 | rand_perturb = torch.mean(torch.load(f, map_location='cpu'), dim=0) 165 | rand_perturb = torch.flip(rand_perturb, dims=(0,)).detach().numpy() 166 | sns.lineplot(x=labels, y=rand_perturb, ax=ax, label='random removal') 167 | print(f'AUC random: {np.sum(rand_perturb) / (np.max(rand_perturb) * 49)}') 168 | 169 | ax.hlines( 170 | xmin=-0.5, xmax=100.5, y=emb_perturb[0], colors='dimgray', linestyles='--', lw=2, 171 | label='baseline accuracy' 172 | ) 173 | ax.set_xticks(np.arange(0, 100, 10)) 174 | ax.set_xlim(-0.5, 95) 175 | ax.set_ylim(0,0.9) 176 | ax.set_xlabel('percentage of tokens removed') 177 | ax.set_ylabel('accuracy') 178 | ax.legend() 179 | 180 | plt.tight_layout() 181 | f = res_path / 'figures' / f'neg_perturb_{model_name}.png' 182 | plt.savefig(f, dpi=300) 183 | plt.show() 184 | 185 | return -------------------------------------------------------------------------------- /src/vis.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from PIL import Image 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | 9 | from src.datasets.imagenet import ImagenetDatasetS 10 | from src.models.load import load_vit 11 | 12 | 13 | class Vis(): 14 | def __init__(self, project_path, dataset_path, model, device): 15 | # Get paths 16 | self.project_path = Path(project_path) 17 | self.res_path = self.project_path / 'results' 18 | 19 | # Get model image processor 20 | self.model = model 21 | self.device = device 22 | self._get_img_processor() 23 | 24 | # Get stimuli information 25 | self.dataset_path = Path(dataset_path) 26 | dataset = ImagenetDatasetS(self.dataset_path) 27 | self.imgs_path = dataset.imgs_path 28 | self.segmentations_path = dataset.path / 'validation-segmentation' 29 | self.stim_info = dataset.stim_info 30 | self.concepts = self.stim_info['imagenet_id'].unique().tolist() 31 | 32 | # Get segmentation mapping 33 | f = dataset_path / 'data/categories/ImageNetS_categories_im919.txt' 34 | seg_map = pd.read_csv(f, header=None) 35 | self.seg_map = {row[0]:idx+1 for idx, row in seg_map.iterrows()} 36 | 37 | return 38 | 39 | def _get_img_processor(self): 40 | """Load image processor and get image dimensions. 41 | """ 42 | 43 | # Load model 44 | _, _, _, self.processor = load_vit( 45 | self.model, self.device, self.project_path, return_transform=True 46 | ) 47 | 48 | # Get image dimension 49 | if '32' in self.model: 50 | self.img_dim = 7 51 | elif '16' in self.model: 52 | self.img_dim = 14 53 | 54 | return 55 | 56 | def get_array(self, concept, idx): 57 | """Get image as array. 58 | 59 | Parameters 60 | ---------- 61 | concept : str 62 | Imagenet id. 63 | idx : int 64 | Image id of Imagenet-S. 65 | 66 | Returns 67 | ------- 68 | np.array 69 | Image of concept and idx. 70 | """ 71 | file = self.imgs_path / concept / ( 72 | self.stim_info.loc[self.stim_info['imagenet_id'] == concept] 73 | .iloc[idx]['img_name'] 74 | ) 75 | img = Image.open(file).convert('RGB') 76 | img = self.processor(img) 77 | try: 78 | img = np.transpose(img['pixel_values'][0], (1, 2, 0)) 79 | except: 80 | img = img 81 | img = (img - img.min()) / (img.max() - img.min()) 82 | return img 83 | 84 | def get_pil(self, concept, idx): 85 | """Get image as PIL Image. 86 | 87 | Parameters 88 | ---------- 89 | concept : str 90 | Imagenet id. 91 | idx : int 92 | Image id of Imagenet-S. 93 | 94 | Returns 95 | ------- 96 | PIL Image 97 | Image of concept and idx. 98 | """ 99 | img = self.get_array(concept, idx) 100 | img = Image.fromarray((img * 255).astype(np.uint8)) 101 | return img 102 | 103 | def get_token(self, concept, idx, token_idx, as_pil=True): 104 | """Return token from image. 105 | 106 | Parameters 107 | ---------- 108 | concept : str 109 | Imagenet id. 110 | idx : int 111 | Image id of Imagenet-S. 112 | token_idx : int 113 | Index of token in image. 114 | as_pil : bool, optional 115 | Whether to return the token as image, by default True. 116 | 117 | Returns 118 | ------- 119 | np.array or PIL Image 120 | Token of image. 121 | """ 122 | # Get token position 123 | row_idx = token_idx // self.img_dim 124 | col_idx = token_idx % self.img_dim 125 | 126 | # Split image 127 | img = self.get_array(concept, idx) 128 | token = np.split(img, self.img_dim, axis=0)[row_idx] 129 | token = np.split(token, self.img_dim, axis=1)[col_idx] 130 | 131 | # Return token 132 | if as_pil: 133 | return Image.fromarray((token * 255).astype(np.uint8)).resize((100,100)) 134 | else: 135 | return token 136 | 137 | def get_segmentation(self, concept, idx): 138 | """Get class segmentation. 139 | 140 | Parameters 141 | ---------- 142 | concept : str 143 | Imagenet id. 144 | idx : int 145 | Image id of Imagenet-S. 146 | 147 | Returns 148 | ------- 149 | np.array 150 | Mask of class in the image. 151 | """ 152 | 153 | # Get segmentation information 154 | f = self.stim_info.loc[self.stim_info['imagenet_id'] == concept] 155 | f = f"{Path(f.iloc[idx]['img_name']).stem}.png" 156 | sgt = Image.open(self.segmentations_path / concept / f) 157 | 158 | # Get mask class 159 | mask = np.array(sgt) 160 | mask = mask[:, :, 1] * 256 + mask[:, :, 0] 161 | mask = (mask == self.seg_map[concept]).astype(int) 162 | 163 | # Resize mask 164 | mask = cv2.resize( 165 | mask, dsize=(self.img_dim, self.img_dim), 166 | interpolation=cv2.INTER_NEAREST 167 | ) 168 | return mask 169 | 170 | def mask(self, concept, idx, weights, prepro='normalize_minmax', invert=False): 171 | """Generate heatmap over image from weights. 172 | Code adapted from https://github.com/hila-chefer/Transformer-MM-Explainability 173 | 174 | Parameters 175 | ---------- 176 | concept : str 177 | Imagenet id. 178 | idx : int 179 | Image id of Imagenet-S. 180 | weights : torch Tensor 181 | Weights to plot the heatmap. Dimension should be same as the number 182 | of image tokens. 183 | prepro : str, optional 184 | Wheter to preprocess the weights, by default 'normalize_minmax'. 185 | Can be one of the following: 'normalize_minmax' to scale the 186 | weights in the range of 0-1, or 'normalize_decoding' to scale the 187 | weights with respect to the number of classes in Imagenet. 188 | invert : bool, optional 189 | Whether to invert the weights, by default False. 190 | 191 | Returns 192 | ------- 193 | np.array 194 | Image with heatmap. 195 | """ 196 | 197 | # Preprocess weights if needed 198 | if prepro == 'normalize_minmax': 199 | weights = (weights - weights.min()) / (weights.max() - weights.min()) 200 | elif prepro == 'normalize_decoding': 201 | weights = weights / 1000 202 | elif prepro == None: 203 | pass 204 | 205 | if invert == True: 206 | weights = - weights 207 | print('inverted') 208 | 209 | # Transform weights into same size as image 210 | weights = weights.reshape(1, 1, self.img_dim, self.img_dim).float() 211 | weights = torch.nn.functional.interpolate(weights, size=224, mode='bilinear') 212 | weights = torch.squeeze(weights).detach().numpy() 213 | 214 | # Create heatmap 215 | heatmap = cv2.applyColorMap(np.uint8(255 * weights), cv2.COLORMAP_JET) 216 | heatmap = np.float32(heatmap) / 255 217 | 218 | # Get image 219 | img = self.get_array(concept, idx) 220 | 221 | # Apply heatmap to image 222 | vis = heatmap + np.float32(img) 223 | vis = vis / np.max(vis) 224 | vis = np.uint8(255 * vis) 225 | vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR) 226 | 227 | return vis 228 | -------------------------------------------------------------------------------- /src/gradients.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | from torch.nn import CrossEntropyLoss 5 | from torch.nn.functional import softmax 6 | from transformers import ViTImageProcessor 7 | 8 | from src.datasets.imagenet import ImagenetDatasetS 9 | from src.models.vit_edited import vit_base_patch32_224, vit_base_patch16_224 10 | 11 | 12 | class ViTGradients(): 13 | """Compute feature importance of image tokens. 14 | """ 15 | def __init__(self, model_name, project_path, imgs_path, device='cpu'): 16 | self.model_name = model_name 17 | self.device = device 18 | 19 | self.project_path = Path(project_path) 20 | self.imgs_path = Path(imgs_path) 21 | self.res_path = project_path / 'results' 22 | self.grad_path = self.res_path / 'gradients' / self.model_name 23 | self.grad_path.mkdir(parents=True, exist_ok=True) 24 | 25 | self._load_model() 26 | 27 | def _load_model(self): 28 | if self.model_name == 'vit_b_32': 29 | self.model = vit_base_patch32_224( 30 | pretrained=True, pretrained_cfg=True 31 | ).to(self.device) 32 | self.n_tokens = 50 33 | source = 'google/vit-base-patch32-224-in21k' 34 | elif self.model_name == 'vit_b_16': 35 | self.model = vit_base_patch16_224( 36 | pretrained=True, pretrained_cfg=True 37 | ).to(self.device) 38 | source = 'google/vit-base-patch16-224-in21k' 39 | self.n_tokens = 197 40 | self.model.eval() 41 | self.img_transform = ViTImageProcessor.from_pretrained(source) 42 | return 43 | 44 | def _get_dataset(self): 45 | self.dataset = ImagenetDatasetS(self.imgs_path) 46 | return 47 | 48 | def compute( 49 | self, concept, img_idx, cat_idx=None, grad_type='cross_entropy', 50 | input_type='attn_probs' 51 | ): 52 | # Get data 53 | self._get_dataset() 54 | data = self.dataset.get_item_by_concept(concept, img_idx) 55 | self.cat_idx = cat_idx 56 | 57 | # Get image features 58 | img_ft = self.img_transform(data['img'], return_tensors="pt") 59 | img_ft = img_ft['pixel_values'].to(self.device) 60 | 61 | # Compute hidden states 62 | self.model.zero_grad() 63 | for b in range(12): 64 | self.model.blocks[b].attn.attn_probs = None 65 | self.model.blocks[b].attn.key_proj_vals = None 66 | self.model.blocks[b].attn_cls = None 67 | _ = self.model(img_ft) 68 | 69 | # Get gradients 70 | self.input_type = input_type 71 | if grad_type == 'cross_entropy': 72 | grads, attns = self._compute_cross_entropy(data) 73 | elif grad_type == 'cat_prob': 74 | grads, attns = self._compute_cat_prob(data) 75 | 76 | return grads, attns 77 | 78 | def _compute_cross_entropy(self, data): 79 | loss = CrossEntropyLoss(reduction='none') 80 | 81 | grads = {} 82 | attns = {} 83 | for b in range(12): 84 | if self.input_type == 'attn_probs': 85 | inp = self.model.blocks[b].attn.attn_probs 86 | elif self.input_type == 'key_proj_vals': 87 | inp = self.model.blocks[b].attn.key_proj_vals 88 | out = self.model.blocks[b].attn_cls 89 | 90 | # Collect grads of all tokens 91 | out = torch.squeeze(out) 92 | out = softmax(out, dim=1) 93 | target = torch.zeros(1, 1000).to(self.device) 94 | if self.cat_idx != None: 95 | target[0, self.cat_idx] = 1 96 | else: 97 | target[0, data['index']] = 1 98 | target = target.repeat(self.n_tokens, 1) 99 | l = loss(out, target) 100 | 101 | b_grads = [] 102 | for t in range(self.n_tokens): 103 | grad = torch.autograd.grad(l[t], inp, retain_graph=True) 104 | if self.input_type == 'attn_probs': 105 | b_grads.append(grad[0][0, :, t, :].detach()) 106 | elif self.input_type == 'key_proj_vals': 107 | b_grads.append(grad[0][0, t, :].detach()) 108 | 109 | grads[b] = torch.stack(b_grads).transpose(0,1) 110 | 111 | if self.input_type == 'attn_probs': 112 | attns[b] = torch.squeeze(inp) 113 | else: 114 | continue 115 | 116 | return grads, attns 117 | 118 | def _compute_cat_prob(self, data): 119 | grads = {} 120 | attns = {} 121 | for b in range(12): 122 | if self.input_type == 'attn_probs': 123 | inp = self.model.blocks[b].attn.attn_probs 124 | elif self.input_type == 'key_proj_vals': 125 | inp = self.model.blocks[b].attn.key_proj_vals 126 | out = self.model.blocks[b].attn_cls 127 | 128 | # Collect grads of all tokens 129 | b_grads = [] 130 | for t in range(self.n_tokens): 131 | cat_prob = torch.zeros(1, self.n_tokens, 1000).to(self.device) 132 | cat_prob[0, t, data['index']] = 1 133 | cat_prob.requires_grad_(True) 134 | cat_prob = torch.sum(cat_prob * out).to(self.device) 135 | 136 | # Compute gradients 137 | grad = torch.autograd.grad(cat_prob, inp, retain_graph=True) 138 | 139 | if self.input_type == 'attn_probs': 140 | b_grads.append(grad[0][0, :, t, :].detach()) 141 | elif self.input_type == 'key_proj_vals': 142 | b_grads.append(grad[0][0, t, :].detach()) 143 | 144 | grads[b] = torch.stack(b_grads).transpose(0,1) 145 | 146 | if self.input_type == 'attn_probs': 147 | attns[b] = torch.squeeze(inp) 148 | else: 149 | continue 150 | 151 | return grads, attns 152 | 153 | 154 | class CheferGradients(ViTGradients): 155 | """Compute feature importance of image tokens using method from 156 | https://github.com/hila-chefer/Transformer-MM-Explainability 157 | """ 158 | def __init__(self, model_name, project_path, imgs_path, device='cpu'): 159 | super().__init__(model_name, project_path, imgs_path, device) 160 | return 161 | 162 | def compute(self, concept, img_idx, cat_idx=None): 163 | 164 | # Get data 165 | self._get_dataset() 166 | data = self.dataset.get_item_by_concept(concept, img_idx) 167 | self.cat_idx = cat_idx 168 | 169 | # Get image features 170 | img_ft = self.img_transform(data['img'], return_tensors="pt") 171 | img_ft = img_ft['pixel_values'].to(self.device) 172 | 173 | # Compute hidden states 174 | self.model.zero_grad() 175 | for b in range(12): 176 | self.model.blocks[b].attn.attn_probs = None 177 | out = self.model(img_ft) 178 | del img_ft 179 | 180 | # Get one hot vector solution 181 | cat_prob = torch.zeros(1, 1000).to(self.device) 182 | if self.cat_idx != None: 183 | cat_prob[0, self.cat_idx] = 1 184 | else: 185 | cat_prob[0, data['index']] = 1 186 | cat_prob.requires_grad_(True) 187 | cat_prob = torch.sum(cat_prob * out).to(self.device) 188 | 189 | # Get gradients 190 | R = torch.eye(self.n_tokens, self.n_tokens).to(self.device) 191 | for b in range(12): 192 | cam = self.model.blocks[b].attn.attn_probs 193 | grad = torch.autograd.grad(cat_prob, cam, retain_graph=True)[0] 194 | 195 | # Average over heads 196 | cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]) 197 | grad = grad.reshape(-1, grad.shape[-2], grad.shape[-1]) 198 | cam = grad * cam 199 | cam = cam.clamp(min=0).mean(dim=0).detach() 200 | 201 | # Apply self-attention rules 202 | r_add = torch.matmul(cam, R) 203 | R += r_add 204 | 205 | R = R[0, 1:] 206 | 207 | return R 208 | -------------------------------------------------------------------------------- /src/plots/linear_probing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import pandas as pd 4 | import seaborn as sns 5 | import torch 6 | 7 | from src.identifiability import get_class_embed 8 | from src.linear_probing.prober import LAYERS 9 | 10 | 11 | def plot_top1_acc(res_path, dataset_path): 12 | """ 13 | Plot top-1 accuracies of linear probing and cls projection. 14 | """ 15 | plt.rcParams.update({'font.size': 12}) 16 | fig, axes = plt.subplots(ncols=2, figsize=(13, 5)) 17 | 18 | # Get linear probing accuracies 19 | accs = [] 20 | for layer in LAYERS: 21 | file = res_path / 'linear_probing' / layer / 'acc.npy' 22 | acc = np.load(file)[1:] 23 | df = pd.DataFrame(acc, columns=['top-1 acc']) 24 | df['layer'] = layer.split('-')[0] 25 | df['block'] = int(layer.split('-')[1]) + 1 26 | accs.append(df) 27 | accs = pd.concat(accs) 28 | ## Plot 29 | sns.lineplot(accs, y='top-1 acc', x='block', hue='layer', ax=axes[0]) 30 | axes[0].set_title('linear probing') 31 | axes[0].set_xticks(np.arange(1, 13)) 32 | 33 | # Get cls projection accuracies 34 | accs = [] 35 | for layer in LAYERS: 36 | block = layer.split('-')[1] 37 | layer_type = layer.split('-')[0] 38 | if 'attn' in layer: 39 | layer = f'hs-attn_{block}' 40 | elif 'mlp' in layer: 41 | layer = f'hs-mlp_{block}' 42 | else: 43 | layer = f'hs_{block}' 44 | dec = get_class_embed( 45 | res_path, dataset_path, 'vit_b_32', layer, 'pos', normalize=False 46 | ) 47 | dec = dec[:, 1:] 48 | acc = torch.sum((dec == 0), axis=0) / dec.shape[0] 49 | df = pd.DataFrame(acc.detach().numpy(), columns=['top-1 acc']) 50 | df['layer'] = layer_type 51 | df['block'] = int(block) + 1 52 | accs.append(df) 53 | accs = pd.concat(accs) 54 | ## Plot 55 | sns.lineplot(accs, y='top-1 acc', x='block', hue='layer', ax=axes[1]) 56 | axes[1].set_title('cls projection') 57 | axes[1].set_xticks(np.arange(1, 13)) 58 | 59 | plt.tight_layout() 60 | plt.show() 61 | return 62 | 63 | 64 | def plot_perturbation(res_path, model_name='vit_b_32'): 65 | """ 66 | Plot and compare perturbation experiments. 67 | """ 68 | plt.rcParams.update({'font.size': 12}) 69 | 70 | labels = torch.arange(48) / 49 * 100 71 | 72 | fig, ax = plt.subplots(figsize=(6,4)) 73 | 74 | # Plot our method 75 | f = res_path / 'perturbation' / model_name / 'negative_grads.pt' 76 | neg_perturb = torch.load(f, map_location='cpu') 77 | neg_perturb = torch.flip(neg_perturb, dims=(0,)).detach().numpy() 78 | sns.lineplot(x=labels, y=neg_perturb, ax=ax, label='NEG cls-based removal') 79 | print(f'negative AUC emb: {np.sum(neg_perturb) / (np.max(neg_perturb) * 49)}') 80 | 81 | # Plot linear probing 82 | f = res_path / 'perturbation' / model_name / 'negative_linear-probe.pt' 83 | linear_perturb = torch.load(f, map_location='cpu') 84 | linear_perturb = torch.flip(linear_perturb, dims=(0,)).detach().numpy() 85 | sns.lineplot(x=labels, y=linear_perturb, ax=ax, label='NEG probe removal') 86 | print(f'negative AUC linear: {np.sum(linear_perturb) / (np.max(linear_perturb) * 49)}') 87 | 88 | # Plot our method 89 | f = res_path / 'perturbation' / model_name / 'positive_grads.pt' 90 | neg_perturb = torch.load(f, map_location='cpu') 91 | neg_perturb = torch.flip(neg_perturb, dims=(0,)).detach().numpy() 92 | sns.lineplot(x=labels, y=neg_perturb, ax=ax, label='POS cls-based removal') 93 | print(f'positive AUC emb: {np.sum(neg_perturb) / (np.max(neg_perturb) * 49)}') 94 | 95 | # Plot linear probing 96 | f = res_path / 'perturbation' / model_name / 'positive_linear_probe.pt' 97 | linear_perturb = torch.load(f, map_location='cpu') 98 | linear_perturb = torch.flip(linear_perturb, dims=(0,)).detach().numpy() 99 | sns.lineplot(x=labels, y=linear_perturb, ax=ax, label='POS probe removal') 100 | print(f'positive AUC linear: {np.sum(linear_perturb) / (np.max(linear_perturb) * 49)}') 101 | 102 | # Plot random perturbation 103 | f = res_path / 'perturbation' / model_name / 'perturb_random.pt' 104 | rand_perturb = torch.mean(torch.load(f, map_location='cpu'), dim=0) 105 | rand_perturb = torch.flip(rand_perturb, dims=(0,)).detach().numpy() 106 | sns.lineplot(x=labels, y=rand_perturb, ax=ax, label='random removal') 107 | print(f'random AUC: {np.sum(rand_perturb) / (np.max(rand_perturb) * 49)}') 108 | 109 | 110 | ax.hlines( 111 | xmin=-0.5, xmax=100.5, y=neg_perturb[0], colors='dimgray', linestyles='--', lw=2, 112 | label='baseline accuracy' 113 | ) 114 | ax.set_xticks(np.arange(0, 100, 10)) 115 | ax.set_xlim(-0.5, 95) 116 | ax.set_ylim(0,0.9) 117 | ax.set_xlabel('percentage of tokens removed') 118 | ax.set_ylabel('accuracy') 119 | ax.legend(fontsize=12) 120 | 121 | plt.tight_layout() 122 | plt.show() 123 | return 124 | 125 | 126 | def plot_all_linear_probing(res_path, dataset_path): 127 | plt.rcParams.update({'font.size': 15}) 128 | fig, axes = plt.subplots(ncols=3, figsize=(13, 4)) 129 | 130 | ## ACCURACY 131 | accs = [] 132 | for layer in LAYERS: 133 | file = res_path / 'linear_probing' / layer / 'acc.npy' 134 | acc = np.load(file)[1:] 135 | df = pd.DataFrame(acc, columns=['top-1 acc']) 136 | df['layer'] = layer.split('-')[0] 137 | df['block'] = int(layer.split('-')[1]) + 1 138 | accs.append(df) 139 | accs = pd.concat(accs) 140 | sns.lineplot(accs, y='top-1 acc', x='block', hue='layer', ax=axes[0]) 141 | axes[0].set_title('linear probing') 142 | axes[0].set_xticks(np.arange(1, 13)) 143 | 144 | accs = [] 145 | for layer in LAYERS: 146 | block = layer.split('-')[1] 147 | layer_type = layer.split('-')[0] 148 | if 'attn' in layer: 149 | layer = f'hs-attn_{block}' 150 | elif 'mlp' in layer: 151 | layer = f'hs-mlp_{block}' 152 | else: 153 | layer = f'hs_{block}' 154 | dec = get_class_embed( 155 | res_path, dataset_path, 'vit_b_32', layer, 'pos', normalize=False 156 | ) 157 | dec = dec[:, 1:] 158 | acc = torch.sum((dec == 0), axis=0) / dec.shape[0] 159 | df = pd.DataFrame(acc.detach().numpy(), columns=['top-1 acc']) 160 | df['layer'] = layer_type 161 | df['block'] = int(block) + 1 162 | accs.append(df) 163 | accs = pd.concat(accs) 164 | sns.lineplot(accs, y='top-1 acc', x='block', hue='layer', ax=axes[1]) 165 | axes[1].set_title('cls projection') 166 | axes[1].set_xticks(np.arange(1, 13)) 167 | 168 | ## PERTURBATION 169 | model_name = 'vit_b_32' 170 | 171 | labels = torch.arange(48) / 49 * 100 172 | 173 | # Plot our method 174 | f = res_path / 'perturbation' / model_name / 'negative_grads.pt' 175 | neg_perturb = torch.load(f, map_location='cpu') 176 | neg_perturb = torch.flip(neg_perturb, dims=(0,)).detach().numpy() 177 | sns.lineplot(x=labels, y=neg_perturb, ax=axes[2], label='NEG cls-based removal') 178 | 179 | # Plot linear probing 180 | f = res_path / 'perturbation' / model_name / 'negative_linear-probe.pt' 181 | linear_perturb = torch.load(f, map_location='cpu') 182 | linear_perturb = torch.flip(linear_perturb, dims=(0,)).detach().numpy() 183 | sns.lineplot(x=labels, y=linear_perturb, ax=axes[2], label='NEG probe removal') 184 | 185 | # Plot our method 186 | f = res_path / 'perturbation' / model_name / 'positive_grads.pt' 187 | neg_perturb = torch.load(f, map_location='cpu') 188 | neg_perturb = torch.flip(neg_perturb, dims=(0,)).detach().numpy() 189 | sns.lineplot(x=labels, y=neg_perturb, ax=axes[2], label='POS cls-based removal') 190 | 191 | # Plot linear probing 192 | f = res_path / 'perturbation' / model_name / 'positive_linear_probe.pt' 193 | linear_perturb = torch.load(f, map_location='cpu') 194 | linear_perturb = torch.flip(linear_perturb, dims=(0,)).detach().numpy() 195 | sns.lineplot(x=labels, y=linear_perturb, ax=axes[2], label='POS probe removal') 196 | 197 | # Plot random perturbation 198 | f = res_path / 'perturbation' / model_name / 'perturb_random.pt' 199 | rand_perturb = torch.mean(torch.load(f, map_location='cpu'), dim=0) 200 | rand_perturb = torch.flip(rand_perturb, dims=(0,)).detach().numpy() 201 | sns.lineplot(x=labels, y=rand_perturb, ax=axes[2], label='random removal') 202 | 203 | ax = axes[2] 204 | ax.hlines( 205 | xmin=-0.5, xmax=100.5, y=neg_perturb[0], colors='dimgray', linestyles='--', lw=2, 206 | label='baseline accuracy' 207 | ) 208 | ax.set_xticks(np.arange(0, 100, 10)) 209 | ax.set_xlim(-0.5, 95) 210 | ax.set_ylim(0,0.9) 211 | ax.set_xlabel('percentage of tokens removed') 212 | ax.set_ylabel('accuracy') 213 | ax.legend(prop={'size': 10}) 214 | 215 | plt.tight_layout() 216 | f = res_path / 'figures' / f'compare_linear.png' 217 | plt.savefig(f, dpi=300) 218 | plt.show() 219 | return -------------------------------------------------------------------------------- /src/perturbation/explain_perturbation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from tqdm import tqdm 8 | from transformers import ViTImageProcessor 9 | 10 | from src.datasets.imagenet import ImagenetDatasetS 11 | from src.gradients import CheferGradients, ViTGradients 12 | from src.models.vit_edited import vit_base_patch32_224, vit_base_patch16_224 13 | 14 | 15 | class ViTPerturb(): 16 | def __init__(self, model_name, project_path, imgs_path, device='cpu'): 17 | self.model_name = model_name 18 | self.device = device 19 | 20 | self.project_path = project_path 21 | self.imgs_path = imgs_path 22 | self.perturb_path = project_path / 'results' / 'perturbation' / model 23 | self.perturb_path.mkdir(parents=True, exist_ok=True) 24 | (self.perturb_path / 'masks').mkdir(parents=True, exist_ok=True) 25 | 26 | self._load_model() 27 | 28 | def _load_model(self): 29 | # Get model 30 | if self.model_name == 'vit_b_32': 31 | self.model = vit_base_patch32_224( 32 | pretrained=True, pretrained_cfg=True 33 | ).to(self.device) 34 | source = 'google/vit-base-patch32-224-in21k' 35 | self.n_tokens = 49 36 | elif self.model_name == 'vit_b_16': 37 | self.model = vit_base_patch16_224( 38 | pretrained=True, pretrained_cfg=True 39 | ).to(self.device) 40 | source = 'google/vit-base-patch16-224-in21k' 41 | self.n_tokens = 196 42 | self.model.eval() 43 | 44 | # Get image transform 45 | self.img_transform = ViTImageProcessor.from_pretrained(source) 46 | return 47 | 48 | def _get_dataset(self): 49 | self.dataset = ImagenetDatasetS(self.imgs_path) 50 | self.dataloader = DataLoader(self.dataset, batch_size=1, collate_fn=lambda x: x[0]) 51 | return 52 | 53 | def compute(self, perturb_type='negative', source_type='grads'): 54 | self.perturb_type = perturb_type 55 | self.source_type = source_type 56 | 57 | # Get mask 58 | self._get_dataset() 59 | if self.source_type == 'grads': 60 | tokens_mask = self.get_accumulated_grads() 61 | elif self.source_type == 'grads_chefer': 62 | tokens_mask = self.get_grads_chefer() 63 | elif self.source_type == 'linear_probe': 64 | tokens_mask = self.get_accumulated_linear_probe() 65 | 66 | # Compute accuracy 67 | accs = [] 68 | for n in range(1, self.n_tokens): 69 | accs.append(self._compute(n, tokens_mask)) 70 | accs = torch.stack(accs).to(self.device) 71 | 72 | # Save accuracy 73 | f = self.perturb_path / f'{self.perturb_type}_{self.source_type}.pt' 74 | torch.save(accs, f) 75 | 76 | return 77 | 78 | def _compute(self, n_tokens, mask): 79 | self._get_dataset() 80 | acc = [] 81 | for id, data in tqdm( 82 | enumerate(self.dataloader), total=len(self.dataloader) 83 | ): 84 | # Get image features 85 | img_ft = self.img_transform(data['img'], return_tensors="pt") 86 | img_ft = img_ft['pixel_values'].to(self.device) 87 | 88 | # Get decoded tokens and add CLS 89 | tokens_mask = mask[id][:n_tokens].to(self.device) 90 | tokens_mask = torch.cat((torch.tensor([0,]).to(self.device), tokens_mask)) 91 | 92 | # Compute hidden states 93 | with torch.no_grad(): 94 | out = self.model(img_ft, tokens_mask) 95 | 96 | # Compute accuracy 97 | pred = out.topk(1)[1] 98 | cat_acc = torch.squeeze((pred == data['index']).long()) 99 | acc.append(cat_acc) 100 | 101 | # Save accuracy 102 | acc = torch.hstack(acc).to(self.device) 103 | acc = torch.sum(acc) / 4550 104 | print(f'acc {n_tokens}: {acc}', flush=True) 105 | 106 | return acc 107 | 108 | def get_accumulated_grads(self): 109 | f = self.perturb_path / 'masks' / f'{self.source_type}.pt' 110 | if f.is_file(): 111 | tokens_mask = torch.load(f).to(self.device) 112 | else: 113 | tokens_mask = [] 114 | for data in tqdm(self.dataloader, total=len(self.dataloader)): 115 | concept=data['imagenet_id'] 116 | img_idx=data['img_index'] 117 | cat_idx=data['index'] 118 | 119 | # Compute grads 120 | g = ViTGradients(self.model_name, self.project_path, self.imgs_path) 121 | grads, _ = g.compute(concept, img_idx, cat_idx) 122 | 123 | # Accumulate over blocks 124 | importance = [] 125 | for b in range(12): 126 | importance.append(torch.sum(grads[b], dim=0)[0]) 127 | mask = torch.sum(torch.stack(importance), dim=0)[1:] 128 | mask = - mask 129 | 130 | # Orden in terms of importance 131 | mask = mask.topk(self.n_tokens, dim=0)[1] + 1 132 | tokens_mask.append(mask) 133 | 134 | tokens_mask = torch.stack(tokens_mask).to(self.device) 135 | torch.save(tokens_mask, f) 136 | 137 | # Flip order if positive perturbation 138 | if self.perturb_type == 'positive': 139 | tokens_mask = torch.flip(tokens_mask, dims=[1]) 140 | 141 | return tokens_mask 142 | 143 | def get_grads_chefer(self): 144 | f = self.perturb_path / 'masks' / f'{self.source_type}.pt' 145 | if f.is_file(): 146 | tokens_mask = torch.load(f).to(self.device) 147 | else: 148 | tokens_mask = [] 149 | for data in tqdm(self.dataloader, total=len(self.dataloader)): 150 | concept=data['imagenet_id'] 151 | img_idx=data['img_index'] 152 | cat_idx=data['index'] 153 | 154 | # Compute grads 155 | g = CheferGradients(self.model_name, self.project_path, self.imgs_path) 156 | R = g.compute(concept, img_idx, cat_idx) 157 | mask = R.topk(self.n_tokens, dim=0)[1] + 1 158 | tokens_mask.append(mask) 159 | 160 | tokens_mask = torch.stack(tokens_mask).to(self.device) 161 | torch.save(tokens_mask, f) 162 | 163 | # Flip order if positive perturbation 164 | if self.perturb_type == 'positive': 165 | tokens_mask = torch.flip(tokens_mask, dims=[1]) 166 | 167 | return tokens_mask 168 | 169 | def get_accumulated_linear_probe(self): 170 | # Accumulate linear probing results over blocks 171 | decs = [] 172 | for b in range(12): 173 | path = project_path / 'results/linear_probing' / f'hs-{b}' 174 | b_decs = [] 175 | for t in range(1, 50): 176 | f = path / f'pos_t{t}.npy' 177 | dec = np.load(f) 178 | b_decs.append(dec) 179 | decs.append(torch.Tensor(np.vstack(b_decs)).to(self.device).T) 180 | decs = torch.sum(torch.stack(decs), dim=0) 181 | 182 | tokens_mask = decs.topk(k=decs.shape[1], dim=1, largest=False)[1] 183 | tokens_mask = tokens_mask + 1 184 | 185 | # Flip order if positive perturbation 186 | if self.perturb_type == 'positive': 187 | tokens_mask = torch.flip(tokens_mask, dims=[1]) 188 | 189 | return tokens_mask 190 | 191 | def compute_random(self, perms=10): 192 | random_accs = [] 193 | for p in range(perms): 194 | # Get random mask 195 | random_mask = torch.randperm(49, generator=torch.manual_seed(p)) + 1 196 | 197 | # Compute accuracy 198 | p_accs = [] 199 | for n in range(1, self.n_tokens): 200 | mask = random_mask[:n] 201 | mask = torch.cat((torch.tensor([0,]).to(device), mask.to(device))) 202 | p_accs.append(self._compute_random(mask)) 203 | random_accs.append(torch.stack(p_accs).to(self.device)) 204 | 205 | random_accs = torch.stack(random_accs).to(self.device) 206 | f = self.perturb_path / f'perturb_random.pt' 207 | torch.save(random_accs, f) 208 | 209 | return 210 | 211 | def _compute_random(self, tokens_mask): 212 | acc = [] 213 | for _, data in tqdm( 214 | enumerate(self.dataloader), total=len(self.dataloader) 215 | ): 216 | # Get image features 217 | img_ft = self.img_transform(data['img'], return_tensors="pt") 218 | img_ft = img_ft['pixel_values'].to(self.device) 219 | 220 | # Compute hidden states 221 | with torch.no_grad(): 222 | out = self.model(img_ft, tokens_mask) 223 | 224 | # Get sample accuracy 225 | pred = out.topk(1)[1] 226 | cat_acc = torch.squeeze((pred == data['index']).long()) 227 | acc.append(cat_acc) 228 | 229 | # Compute accuracy 230 | acc = torch.hstack(acc).to(self.device) 231 | acc = torch.sum(acc) / 4550 232 | print(f'acc {len(tokens_mask)}: {acc}', flush=True) 233 | 234 | return acc 235 | 236 | 237 | if __name__ == '__main__': 238 | parser = argparse.ArgumentParser() 239 | parser.add_argument( 240 | '-pp', action='store', required=True, 241 | help='Path to the folder containing the source code.' 242 | ) 243 | parser.add_argument( 244 | '-dp', action='store', required=True, 245 | help='Path to the folder containing the dataset.' 246 | ) 247 | parser.add_argument( 248 | '-m', action='store', required=True, 249 | help='Select which model to run. Can be one of the following options: \ 250 | vit_16, vit_32.' 251 | ) 252 | parser.add_argument( 253 | '-random', action='store_true',help='Compute random mask' 254 | ) 255 | parser.add_argument( 256 | '-pt', action='store', required=True, 257 | help='Perturbation type: positive or negative' 258 | ) 259 | parser.add_argument( 260 | '-st', action='store', required=True, 261 | help='Source for the importance weights: grads, grads_chefer or linear_probe' 262 | ) 263 | 264 | args = parser.parse_args() 265 | model = args.m 266 | device = "cuda" if torch.cuda.is_available() else "cpu" 267 | random = args.random 268 | perturb_type = args.pt 269 | source_type = args.st 270 | 271 | project_path = Path(args.pp) 272 | data_path = Path(args.dp) 273 | 274 | if random == True: 275 | p = ViTPerturb(model, project_path, data_path, device) 276 | p.compute_random(perturb_type=perturb_type, source_type=source_type) 277 | else: 278 | p = ViTPerturb(model, project_path, data_path, device) 279 | p.compute(perturb_type=perturb_type, source_type=source_type) -------------------------------------------------------------------------------- /src/plots/class_ident.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import pandas as pd 4 | from scipy.stats import wilcoxon 5 | import seaborn as sns 6 | import torch 7 | 8 | from src.identifiability import MODEL_MAP 9 | from src.identifiability import ( 10 | get_class_embed, get_ident_change_rate, get_ident_mean_evolution, 11 | get_ident_segmented, compute_context_diff 12 | ) 13 | from src.vis import Vis 14 | 15 | 16 | def print_identifiability(model_name, res_path, dataset_path, layer=11): 17 | """ 18 | Compute and print class identifiability measures. 19 | """ 20 | # Get class identifiability 21 | dec = get_class_embed(res_path, dataset_path, model_name, f'hs_{layer}', 'pos')[:, 1:] 22 | 23 | # Compute class identifiability rate 24 | avg_percentage = (torch.sum(dec == 0, dim=1) / dec.shape[1]).float().mean() * 100 25 | print(f'- Average percentage of class identifiable image tokens per image: {avg_percentage}') 26 | 27 | # Compute percentage of images with at least one identifiable token 28 | image_percentage = torch.sum((torch.sum((dec == 0), dim=1) > 0)) / dec.shape[0] * 100 29 | print( 30 | f'- Percentage of images with at least one class identifiable token: {image_percentage}' 31 | ) 32 | 33 | return 34 | 35 | 36 | def label_token(row): 37 | """ 38 | Assign to token its correct label. 39 | """ 40 | if row['token'] == 0: 41 | return '[CLS]' 42 | else: 43 | return 'image' 44 | 45 | 46 | def plot_identifiability_evolution(res_path, dataset_path): 47 | """ 48 | Plot identifiability evolution over blocks. 49 | """ 50 | plt.rcParams.update({'font.size': 16}) 51 | fig, axes = plt.subplots(nrows=7, figsize=(13, 30)) 52 | for model_name, ax in zip(MODEL_MAP.keys(), axes.flat): 53 | if 'large' in model_name: 54 | n_layers = 24 55 | else: 56 | n_layers = 12 57 | 58 | dfs = [] 59 | pvals = [] 60 | for b in range(n_layers): 61 | # Get class identifiability of block 62 | dec = get_class_embed( 63 | res_path, dataset_path, model_name, f'hs_{b}', 'pos', normalize=True 64 | ) 65 | 66 | # Compare to random model 67 | if 'miil' in model_name: 68 | random_model = f'vit_b_16_random' 69 | else: 70 | random_model = f'{model_name}_random' 71 | rand_dec = get_class_embed( 72 | res_path, dataset_path, random_model, f'hs_{b}', 'pos', normalize=True 73 | ) 74 | wx = wilcoxon(dec.flatten(), rand_dec.flatten(), alternative='greater') 75 | pvals.append(np.round(wx.pvalue, 3)) 76 | 77 | # Add to df 78 | df = ( 79 | pd.DataFrame(pd.DataFrame(dec).stack()) 80 | .reset_index(names=['image', 'token']) 81 | .rename(columns={0:'class identifiability'}) 82 | ) 83 | if 'gap' not in model_name: 84 | df['token label'] = df.apply(lambda row: label_token(row), axis=1) 85 | df['block'] = b + 1 86 | dfs.append(df) 87 | dfs = pd.concat(dfs) 88 | 89 | # Plot 90 | if 'gap' not in model_name: 91 | sns.boxplot( 92 | dfs, x='block', y='class identifiability', hue='token label', 93 | fliersize=0, palette=sns.color_palette("Set2"), ax=ax 94 | ) 95 | sns.move_legend(ax, 'lower right') 96 | else: 97 | sns.boxplot( 98 | dfs, x='block', y='class identifiability', fliersize=0, 99 | palette=sns.color_palette(['#fc8d62']), ax=ax 100 | ) 101 | ax.hlines(xmin=-1, xmax=n_layers, y=0.5, colors='dimgray', linestyles='--', lw=2) 102 | 103 | # Add significance stars 104 | sig = np.where(np.array(pvals) < 0.05)[0] 105 | ax.scatter(sig, [1.1] * len(sig), marker='*', c='grey', s=45) 106 | 107 | ax.set_xlim(-0.5, int(n_layers)) 108 | ax.set_title(f'{MODEL_MAP[model_name]}') 109 | 110 | plt.tight_layout() 111 | f = res_path / 'figures' / f'class_identifiability_evolution.png' 112 | plt.savefig(f, dpi=300) 113 | plt.show() 114 | 115 | return 116 | 117 | 118 | def compare_identifiability_evolution(res_path, dataset_path, cifar_path): 119 | # Get evolution 120 | df = [] 121 | for model in MODEL_MAP.keys(): 122 | if 'cifar' in model: 123 | ident = get_ident_mean_evolution(model, cifar_path, res_path) 124 | else: 125 | ident = get_ident_mean_evolution(model, dataset_path, res_path) 126 | df.append(ident) 127 | df = pd.concat(df) 128 | df['class identifiability'] = df['class identifiability'].astype('float') 129 | 130 | # Plot 131 | plt.rcParams.update({'font.size': 17}) 132 | fig, axes = plt.subplots(ncols=2, figsize=(13, 3.5)) 133 | 134 | cls_df = df.loc[df['token type'] == '[CLS]'] 135 | sns.lineplot( 136 | cls_df, x='block', y='class identifiability', hue='model', 137 | ax=axes[0], marker='*', markersize=10 138 | ) 139 | axes[0].set_title('[CLS] token') 140 | 141 | img_df = df.loc[df['token type'] == 'image'] 142 | sns.lineplot( 143 | img_df, x='block', y='class identifiability', hue='model', 144 | ax=axes[1], marker='*', markersize=10 145 | ) 146 | axes[1].set_title('image tokens', fontsize=20) 147 | 148 | for ax in axes.flat: 149 | ax.set_ylim(0.4, 1.05) 150 | ax.set_ylabel('class identifiability', fontsize=20) 151 | ax.set_xticks(np.arange(1, 13)) 152 | ax.set_xticklabels(np.arange(1, 13)) 153 | ax.set_xticklabels([]) 154 | ax.set_xlim((0.8, 12.2)) 155 | ax.set_xlabel('normalized block', fontsize=20) 156 | ax.hlines( 157 | xmin=0.8, xmax=12.2, y=0.5, colors='dimgray', linestyles='--', lw=2, 158 | label='Chance Level' 159 | ) 160 | 161 | # Unify legend 162 | lines, labels = axes[1].get_legend_handles_labels() 163 | lgd = fig.legend( 164 | lines, labels, loc='center right', #nrow=7, 165 | bbox_to_anchor=(1.15, 0.5), 166 | ) 167 | for ax in axes.flat: 168 | ax.get_legend().remove() 169 | 170 | f = res_path / 'figures' / f'mean_evolution.png' 171 | plt.savefig(f, dpi=300, bbox_inches='tight') 172 | plt.show() 173 | return 174 | 175 | 176 | def plot_logits_increment(res_path, dataset_path): 177 | """ 178 | Plot percentage of image tokens that increment the logits of the correct class. 179 | """ 180 | 181 | fig, axes = plt.subplots(nrows=4, ncols=2, figsize=(13, 13)) 182 | for model_name, ax in zip(MODEL_MAP.keys(), axes.flat): 183 | 184 | if 'large' in model_name: 185 | n_layers = 24 186 | else: 187 | n_layers = 12 188 | 189 | # Get change rate 190 | pers = get_ident_change_rate(model_name, dataset_path, res_path, n_layers) 191 | 192 | # Plot 193 | sns.lineplot(pers, x='block', y='change rate', marker='o', markersize=9, ax=ax) 194 | ax.set_xlim(1.8, int(n_layers) + 0.2) 195 | ax.set_xticks(np.arange(2, int(n_layers) + 1)) 196 | ax.xaxis.set_tick_params(labelsize=10) 197 | ax.set_ylim(0.5, 1) 198 | ax.set_yticks(np.arange(0.6, 1, 0.1)) 199 | ax.yaxis.set_tick_params(labelsize=12) 200 | ax.set_title(MODEL_MAP[model_name]) 201 | 202 | fig.delaxes(axes[3,1]) 203 | 204 | plt.tight_layout() 205 | f = res_path / 'figures' / f'prob_increment.png' 206 | plt.savefig(f, dpi=300) 207 | plt.show() 208 | 209 | return 210 | 211 | 212 | def print_attn_perturb(model_name, perturb_type, res_path, dataset_path): 213 | """ 214 | Print class identifiability of attention perturbation studies. 215 | """ 216 | # Get perturbed identifiability 217 | f = res_path / 'perturbation' / model_name / f'attn-{perturb_type}_dec.pt' 218 | attn_dec = torch.load(f, map_location='cpu') 219 | attn_dec = 1 - (attn_dec / 1000) 220 | attn_dec_acc = torch.sum(attn_dec == 1) / attn_dec.flatten().shape[0] * 100 221 | print(f'Class identifiability in {perturb_type}: {attn_dec_acc}') 222 | 223 | # Get unperturbed identifiability 224 | dec = get_class_embed( 225 | res_path, dataset_path, model_name, 'hs_11', decoding_type='pos', normalize=True 226 | )[:, 1:] 227 | dec_acc = torch.sum(dec == 1) / dec.flatten().shape[0] * 100 228 | print(f'Class identifiability in unperturbed model: {dec_acc}') 229 | 230 | return 231 | 232 | 233 | def plot_context_diff(model_name, proj_path, dataset_path): 234 | """ 235 | Plot identifiability evolution separately for class- and context-labeled tokens. 236 | """ 237 | dfs = get_ident_segmented(model_name, proj_path, dataset_path) 238 | 239 | if 'large' in model_name: 240 | nrows = 4 241 | n_layers = 24 242 | height = 10 243 | else: 244 | nrows = 2 245 | n_layers = 12 246 | height = 5 247 | 248 | plt.rcParams.update({'font.size': 12}) 249 | fig, axes = plt.subplots(nrows=nrows, ncols=6, figsize=(13, height)) 250 | for ax, b in zip(axes.flat, np.arange(n_layers)): 251 | df = dfs.loc[dfs['block'] == b] 252 | _, pval = compute_context_diff(df) 253 | sns.boxplot( 254 | df, y='class identifiability', x='token location', 255 | fliersize=0, ax=ax 256 | ) 257 | if pval < 0.05: 258 | ax.set_title(f'block {b+1} *') 259 | else: 260 | ax.set_title(f'block {b+1}') 261 | ax.set_xlabel('') 262 | for ax in axes.flat[1:]: 263 | ax.set_ylabel('') 264 | fig.suptitle(MODEL_MAP[model_name]) 265 | 266 | plt.tight_layout() 267 | f = proj_path / 'results/figures' / f'context_{model_name}.png' 268 | plt.savefig(f, dpi=300) 269 | plt.show() 270 | 271 | return 272 | 273 | 274 | def print_context_perturb(model_name, mask_type, proj_path, dataset_path): 275 | """ 276 | Print class identifiability of context perturbation studies. 277 | """ 278 | # Get identifiability of perturbed model 279 | res_path = proj_path / 'results' 280 | f = res_path / 'perturbation' / model_name / f'no_{mask_type}_tokens_dec.pt' 281 | context_dec = torch.load(f, map_location='cpu')['hs_11'] 282 | context_dec_acc = [] 283 | for i in context_dec: 284 | i = 1 - (i / 1000) 285 | context_dec_acc.append(torch.sum(i == 1) / i.shape[0]) 286 | context_dec_acc = torch.mean(torch.stack(context_dec_acc)) * 100 287 | print(f'Class identifiability with no {mask_type} tokens: {context_dec_acc}') 288 | 289 | # Get identifiability of unperturbed model 290 | im = Vis(proj_path, dataset_path, model_name, device='cpu') 291 | stim_info = im.stim_info 292 | concepts = stim_info['imagenet_id'].unique().tolist() 293 | sgts = [] 294 | for c in concepts: 295 | for i in range(5): 296 | gt = im.get_segmentation(c, i).flatten() 297 | sgts.append(gt) 298 | sgts = np.hstack(sgts) 299 | 300 | dec = get_class_embed( 301 | res_path, dataset_path, model_name, 'hs_11', decoding_type='pos', normalize=True 302 | )[:, 1:] 303 | dec = dec.flatten() 304 | if mask_type == 'context': 305 | dec = dec[(sgts == 1).nonzero()] 306 | elif mask_type == 'class_label': 307 | dec = dec[(sgts == 0).nonzero()] 308 | dec_acc = torch.sum(dec == 1) / dec.shape[0] * 100 309 | print(f'Class identifiability in unperturbed model: {dec_acc}') 310 | 311 | return -------------------------------------------------------------------------------- /src/identifiability.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from src.datasets.cifar import MyCIFAR100 7 | from src.datasets.imagenet import ImagenetDatasetS 8 | from src.vis import Vis 9 | 10 | 11 | MODEL_MAP = { 12 | 'vit_b_32': 'ViT-B/32', 13 | 'vit_b_16': 'ViT-B/16', 14 | 'vit_large_16': 'ViT-L/16', 15 | 'vit_miil_16': 'ViT-B/16-MIIL', 16 | 'vit_cifar_16': 'ViT-B/16-CIFAR', 17 | 'deit_ensemble_16': 'ViT-B/16-Refinement', 18 | 'vit_gap_16': 'ViT-B/16-GAP', 19 | } 20 | 21 | 22 | def get_class_embed( 23 | res_path, img_path, model, layer, decoding_type='pos', normalize=False 24 | ): 25 | """Get class projection data. 26 | 27 | Parameters 28 | ---------- 29 | res_path : pathlib.Path 30 | Path to results. 31 | img_path : pathlib.Path 32 | Path to dataset. 33 | model : str 34 | Model name. Can be one of the following options: vit_b_16, vit_b_32, 35 | vit_large_16, vit_miil_16, vit_cifar_16, deit_ensemble_16, vit_gap_16. 36 | layer : str 37 | Layer index. 38 | decoding_type : str, optional 39 | Type of projection, by default 'pos'. Can be one of the following: 40 | 'pos', 'probs'. 41 | normalize : bool, optional 42 | Whether to normalize over number of classes, by default False. 43 | 44 | Returns 45 | ------- 46 | torch.Tensor 47 | Class projection data. 48 | """ 49 | net_path = res_path / 'class_embed' 50 | 51 | # Get dataset info 52 | if 'cifar' in model: 53 | dataset = MyCIFAR100(img_path) 54 | stim_info = dataset.stim_info 55 | concepts = list(stim_info.keys()) 56 | else: 57 | dataset = ImagenetDatasetS(img_path) 58 | stim_info = dataset.stim_info 59 | concepts = stim_info['imagenet_id'].unique() 60 | 61 | # Stack decoding info across categories 62 | dec = [] 63 | for c in concepts: 64 | f = net_path / model / c / f'{decoding_type}_{layer}.pt' 65 | dec.append(torch.load(f, map_location='cpu')) 66 | dec = torch.vstack(dec) 67 | if normalize == True: 68 | if 'cifar' in model: 69 | dec = 1 - (dec / 100) 70 | else: 71 | dec = 1 - (dec / 1000) 72 | return dec 73 | else: 74 | return dec 75 | 76 | 77 | def get_ident_mean_evolution(model_name, dataset_path, res_path): 78 | """Get mean identifiability evolution. 79 | 80 | Parameters 81 | ---------- 82 | model_name : str 83 | Model name. Can be one of the following options: vit_b_16, vit_b_32, 84 | vit_large_16, vit_miil_16, vit_cifar_16, deit_ensemble_16, vit_gap_16. 85 | dataset_path : pathlib.Path 86 | Path to dataset. 87 | res_path : pathlib.Path 88 | Path to results. 89 | 90 | Returns 91 | ------- 92 | pandas.DataFrame 93 | Evolution of mean identifiability across layers. 94 | """ 95 | if 'large' in model_name: 96 | n_layers = 24 97 | else: 98 | n_layers = 12 99 | 100 | df = [] 101 | for b in range(n_layers): 102 | dec = get_class_embed( 103 | res_path, dataset_path, model_name, f'hs_{b}', 'pos', normalize=True 104 | ) 105 | if 'gap' in model_name: 106 | dec = dec.mean() 107 | df.append([b+1, 'image', dec]) 108 | else: 109 | dec = dec.mean(dim=0) 110 | cls_dec = dec[0].numpy() 111 | df.append([b+1, '[CLS]', cls_dec]) 112 | token_dec = torch.mean(dec[1:]).numpy() 113 | df.append([b+1, 'image', token_dec]) 114 | 115 | df = pd.DataFrame(df, columns=['block', 'token type', 'class identifiability']) 116 | df['model'] = MODEL_MAP[model_name] 117 | 118 | if model_name == 'vit_large_16': 119 | new_df = [] 120 | for idx, b in enumerate(np.arange(1, 25, 2)): 121 | for tt in ['[CLS]', 'image']: 122 | t_df = df.loc[df['token type'] == tt] 123 | mean = t_df.loc[t_df['block'].isin([b, b+1])]['class identifiability'].mean() 124 | new_df.append([idx+1, tt, mean, MODEL_MAP['vit_large_16']]) 125 | new_df = pd.DataFrame( 126 | new_df, columns=['block', 'token type', 'class identifiability', 'model'] 127 | ) 128 | df = new_df 129 | 130 | return df 131 | 132 | 133 | def get_ident_change_rate(model_name, dataset_path, res_path, n_layers=12): 134 | """Get change rate of class identifiability across layers. 135 | 136 | Parameters 137 | ---------- 138 | model_name : str 139 | Model name. Can be one of the following options: vit_b_16, vit_b_32, 140 | vit_large_16, vit_miil_16, vit_cifar_16, deit_ensemble_16, vit_gap_16. 141 | dataset_path : pathlib.Path 142 | Path to dataset. 143 | res_path : pathlib.Path 144 | Path to results. 145 | n_layers : int, optional 146 | Number of layers in model, by default 12. 147 | 148 | Returns 149 | ------- 150 | pandas.DataFrame 151 | Containing the change rate across layers. 152 | """ 153 | pers = [] 154 | for b in range(1, n_layers): 155 | i_dec = get_class_embed(res_path, dataset_path, model_name, f'hs_{b-1}', 'probs')[:, 1:] 156 | j_dec = get_class_embed(res_path, dataset_path, model_name, f'hs_{b}', 'probs')[:, 1:] 157 | per = (torch.sum((j_dec - i_dec) > 0) / j_dec.flatten().shape[0]).detach().numpy() 158 | pers.append([b+1, per]) 159 | pers = pd.DataFrame(pers, columns=['block', 'change rate']) 160 | pers['change rate'] = pers['change rate'].astype('float') 161 | return pers 162 | 163 | 164 | def get_ident_segmented(model_name, proj_path, dataset_path): 165 | """Get class identifiability separately for class- and context-labeled tokens. 166 | 167 | Parameters 168 | ---------- 169 | model_name : str 170 | Model name. Can be one of the following options: vit_b_16, vit_b_32, 171 | vit_large_16, vit_miil_16, vit_cifar_16, deit_ensemble_16, vit_gap_16. 172 | proj_path : pathlib.Path 173 | Path to source code. 174 | dataset_path : pathlib.Path 175 | Path to dataset. 176 | 177 | Returns 178 | ------- 179 | pandas DataFrame 180 | Containing class identifiability of class- and context-labeled tokens. 181 | """ 182 | # Get segmentation annotations 183 | im = Vis(proj_path, dataset_path, model_name, device='cpu') 184 | stim_info = im.stim_info 185 | concepts = stim_info['imagenet_id'].unique().tolist() 186 | sgts = [] 187 | for c in concepts: 188 | for i in range(5): 189 | gt = im.get_segmentation(c, i).flatten() 190 | sgts.append(gt) 191 | sgts = np.hstack(sgts) 192 | 193 | # Get identifiability 194 | if 'large' in model_name: 195 | n_layers = 24 196 | else: 197 | n_layers = 12 198 | 199 | # Save in dataframe 200 | dfs = [] 201 | for b in range(n_layers): 202 | dec = get_class_embed( 203 | proj_path / 'results', dataset_path, model_name, f'hs_{b}', 204 | decoding_type='pos', normalize=True 205 | ) 206 | if 'gap' not in model_name: 207 | dec = dec[:, 1:] # remove cls token 208 | df = pd.DataFrame( 209 | {'class identifiability': dec.flatten().detach().numpy(), 'token location': sgts} 210 | ) 211 | df['block'] = b 212 | dfs.append(df) 213 | dfs = pd.concat(dfs) 214 | dfs['token location'] = dfs['token location'].replace({0: 'context', 1: 'class'}) 215 | 216 | return dfs 217 | 218 | 219 | def compute_context_diff(df): 220 | """Compute significant difference in identifiability between class- and context-labeled tokens 221 | 222 | Parameters 223 | ---------- 224 | df : pandas DataFrame 225 | Containing class identifiability of class- and context-labeled tokens. 226 | 227 | Returns 228 | ------- 229 | tuple 230 | Containing difference value and pvalue, 231 | """ 232 | context = df.loc[df['token location'] == 'context']['class identifiability'].to_numpy() 233 | category = df.loc[df['token location'] == 'class']['class identifiability'].to_numpy() 234 | 235 | # Compute true difference 236 | true_diff = np.mean(category) - np.mean(context) 237 | 238 | # Compute difference with shuffled labeles 239 | context_len = context.shape[0] 240 | all_data = np.concatenate((context, category)) 241 | random_diffs = [] 242 | n_perm = 300 243 | for _ in range(300): 244 | np.random.shuffle(all_data) 245 | context, category = all_data[:context_len], all_data[context_len:] 246 | random_diffs.append(np.mean(category) - np.mean(context)) 247 | pval = np.sum(true_diff < random_diffs) / n_perm 248 | 249 | return true_diff, pval 250 | 251 | 252 | def compute_class_similarity_change(model_name, block, layer_type, dataset_path, res_path): 253 | """Compute class similarity change rate of layer. 254 | 255 | Parameters 256 | ---------- 257 | model_name : str 258 | Model name. Can be one of the following options: vit_b_16, vit_b_32, 259 | vit_large_16, vit_miil_16, vit_cifar_16, deit_ensemble_16, vit_gap_16. 260 | block : int 261 | Index of block. 262 | layer_type : str 263 | Layer type. Can be one of 'attn' or 'mlp'. 264 | dataset_path : pathlib.Path 265 | Path to dataset. 266 | res_path : pathlib.Path 267 | Path to results. 268 | 269 | Returns 270 | ------- 271 | torch.Tensor 272 | Class similarity change rate. 273 | """ 274 | dec_layer = get_class_embed( 275 | res_path, dataset_path, model_name, f'hs-{layer_type}_{block}', 'probs' 276 | ) 277 | dec = get_class_embed( 278 | res_path, dataset_path, model_name, f'hs_{block-1}', 'probs' 279 | ) 280 | return (dec_layer - dec) 281 | 282 | 283 | def compute_residual_match(model_name, dataset_path, res_path, token_type='all'): 284 | """Compute match with the predictions of the residual. 285 | 286 | Parameters 287 | ---------- 288 | model_name : str 289 | Model name. Can be one of the following options: vit_b_16, vit_b_32, 290 | vit_large_16, vit_miil_16, vit_cifar_16, deit_ensemble_16, vit_gap_16. 291 | dataset_path : pathlib.Path 292 | Path to dataset. 293 | res_path : pathlib.Path 294 | Path to results. 295 | token_type : str, optional 296 | Compute match with all tokens, or with cls only, by default 'all'. 297 | 298 | Returns 299 | ------- 300 | pd.DataFrame 301 | Match with residual 302 | """ 303 | 304 | if 'large' in model_name: 305 | n_layers = 24 306 | else: 307 | n_layers = 12 308 | 309 | if 'gap' in model_name: 310 | idxs = torch.arange(196) 311 | elif token_type == 'cls': 312 | idxs = 0 313 | elif (token_type == 'all') & ('32' in model_name): 314 | idxs = torch.arange(50) 315 | elif (token_type == 'all') & ('16' in model_name): 316 | idxs = torch.arange(197) 317 | 318 | data = [] 319 | for block in range(1, n_layers): 320 | # Get residual stream prediction 321 | topk_b = get_class_embed( 322 | res_path, dataset_path, model_name, f'hs_{block}', decoding_type='topk' 323 | )[:, idxs, 0] 324 | 325 | # Get attention layer prediction and compute match 326 | topk_attn = get_class_embed( 327 | res_path, dataset_path, model_name, f'hs-attn_{block}', decoding_type='topk' 328 | )[:, idxs, 0] 329 | attn_match = torch.sum(topk_b == topk_attn) / topk_b.flatten().shape[0] * 100 330 | data.append(['attn', block+1, attn_match.detach().numpy()]) 331 | 332 | # Get MLP layer prediction and compute match 333 | topk_mlp = get_class_embed( 334 | res_path, dataset_path, model_name, f'hs-mlp_{block}', decoding_type='topk' 335 | )[:, idxs, 0] 336 | mlp_match = torch.sum(topk_b == topk_mlp) / topk_b.flatten().shape[0] * 100 337 | data.append(['mlp', block+1, mlp_match.detach().numpy()]) 338 | 339 | # Get previous block residual stream prediction and compute match 340 | topk_prev = get_class_embed( 341 | res_path, dataset_path, model_name, f'hs_{block-1}', decoding_type='topk' 342 | )[:, idxs, 0] 343 | prev_match = torch.sum(topk_b == topk_prev) / topk_b.flatten().shape[0] * 100 344 | data.append(['prev', block+1, prev_match.detach().numpy()]) 345 | 346 | # Compute rate of tokens that do not match any prediction of the above 347 | comp_tokens = (topk_b != topk_prev) & (topk_b != topk_attn) & (topk_b != topk_attn) 348 | comp = torch.sum(comp_tokens) / topk_b.flatten().shape[0] * 100 349 | data.append(['comp', block+1, comp.detach().numpy()]) 350 | 351 | data = pd.DataFrame(data, columns=['Source', 'block', 'match']) 352 | data['match'] = data['match'].astype('float') 353 | return data -------------------------------------------------------------------------------- /src/extractor.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from itertools import product 3 | from pathlib import Path 4 | import re 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import DataLoader 9 | from torchvision.datasets import CIFAR100 10 | from torchvision.models.feature_extraction import create_feature_extractor 11 | from tqdm import tqdm 12 | 13 | from src.datasets.imagenet import ImagenetDatasetS 14 | from src.datasets.cifar import MyCIFAR100 15 | from src.models.load import load_vit 16 | 17 | 18 | class ViTExtractor(): 19 | """Project hidden states to class embedding space and save key coefficients. 20 | """ 21 | def __init__( 22 | self, model_name, project_path, imgs_path, device='cpu', 23 | pretrained=True 24 | ): 25 | self.pretrained = pretrained 26 | if self.pretrained == True: 27 | self.model_name = model_name 28 | elif self.pretrained == False: 29 | self.model_name = f'{model_name}_random' 30 | self.device = device 31 | 32 | self.project_path = project_path 33 | self.imgs_path = imgs_path 34 | self.hs_path = project_path / 'results' / 'class_embed' / model_name 35 | self.hs_path.mkdir(parents=True, exist_ok=True) 36 | 37 | # super().__init__(self.model_name, project_path, imgs_path, device) 38 | 39 | self._get_model_layers() 40 | self._load_model() 41 | 42 | def _get_model_layers(self): 43 | prefix = 'blocks.' 44 | join = '.' 45 | blocks_layer_types = [ 46 | 'add_1', 'attn.getitem_4', 'attn.getitem_5', 'attn.softmax', 47 | 'attn.proj', 'mlp.fc1', 'mlp.fc2' 48 | ] 49 | 50 | if '_large_' in self.model_name: 51 | n_blocks = np.arange(24) 52 | else: 53 | n_blocks = np.arange(12) 54 | 55 | layers = ['head'] 56 | for b, l in product(n_blocks, blocks_layer_types): 57 | if l == None: 58 | layers.append(f'{prefix}{b}') 59 | else: 60 | layers.append(f'{prefix}{b}{join}{l}') 61 | 62 | self.layers = layers 63 | return 64 | 65 | def _get_layer_name(self, layer_model): 66 | """Create layer name from layer id.""" 67 | if layer_model == 'head': 68 | layer_name = 'cls-out' 69 | else: 70 | block = re.findall(r'\d+', layer_model)[0] 71 | if any(w in layer_model for w in ['.k_', 'getitem_4']): 72 | suffix = f'key' 73 | elif any(w in layer_model for w in ['.q_', 'getitem_3']): 74 | suffix = f'query' 75 | elif any(w in layer_model for w in ['.v_', 'getitem_5']): 76 | suffix = f'value' 77 | elif any(w in layer_model for w in ['attn.softmax']): 78 | suffix = f'attn-w' 79 | elif any(w in layer_model for w in ['attn.proj']): 80 | suffix = f'hs-attn' 81 | elif any(w in layer_model for w in ['mlp.fc1']): 82 | suffix = 'key-mlp' 83 | elif any(w in layer_model for w in ['.mlp', 'mlp.fc2']): 84 | suffix = f'hs-mlp' 85 | else: 86 | suffix = 'hs' 87 | layer_name = f'{suffix}_{block}' 88 | return layer_name 89 | 90 | def _load_model(self): 91 | # Get model 92 | self.model, self.n_tokens, _, self.img_transform = load_vit( 93 | self.model_name, self.device, self.project_path, return_transform=True 94 | ) 95 | self.model.eval() 96 | 97 | # Add feature extractor 98 | layers_map = {l: self._get_layer_name(l) for l in self.layers} 99 | self.extractor = create_feature_extractor(self.model, layers_map).to(self.device) 100 | 101 | # Get normalization 102 | if self.model_name == 'vit_gap_16': 103 | self.ln = self.model.fc_norm 104 | else: 105 | self.ln = self.model.norm 106 | 107 | # Get projection matrix 108 | self.cls_proj = self.model.head 109 | 110 | return 111 | 112 | def _get_dataset(self): 113 | self.dataset = ImagenetDatasetS(self.imgs_path) 114 | self.dataloader = DataLoader( 115 | self.dataset, batch_size=5, collate_fn=self._imagenet_collate_batch 116 | ) 117 | return 118 | 119 | def _imagenet_collate_batch(self, batch): 120 | assert all(i['imagenet_id']==batch[0]['imagenet_id'] for i in batch) 121 | data = {} 122 | data['imagenet_id'] = batch[0]['imagenet_id'] 123 | data['index'] = batch[0]['index'] 124 | data['cat'] = batch[0]['cat'] 125 | data['imgs'] = [i['img'] for i in batch] 126 | return data 127 | 128 | def extract_hidden_states(self): 129 | self._get_dataset() 130 | acc = [] 131 | for _, data in tqdm( 132 | enumerate(self.dataloader), total=len(self.dataloader) 133 | ): 134 | # Prepare concept path 135 | cls_emb_path = self.hs_path / data['imagenet_id'] 136 | cls_emb_path.mkdir(parents=True, exist_ok=True) 137 | if self.pretrained == True: 138 | net_ft_path = self.project_path / 'results' / 'net_ft' / self.model_name / data['imagenet_id'] 139 | net_ft_path.mkdir(parents=True, exist_ok=True) 140 | 141 | # Get image features 142 | try: 143 | img_ft = self.img_transform(data['imgs'], return_tensors="pt") 144 | img_ft = img_ft['pixel_values'].to(self.device) 145 | except: 146 | img_ft = [self.img_transform(i) for i in data['imgs']] 147 | img_ft = torch.stack(img_ft).to(self.device) 148 | 149 | # Compute hidden states 150 | with torch.no_grad(): 151 | out = self.extractor(img_ft) 152 | 153 | # Compute and save projections 154 | for l_name, l_repr in out.items(): 155 | if l_name == 'cls-out': 156 | pred = l_repr.topk(1)[1] 157 | cat_acc = torch.squeeze((pred == data['index']).long()) 158 | acc.append(cat_acc) 159 | 160 | elif 'hs' in l_name: 161 | # Project to class embedding space 162 | if 'gap' in self.model_name: # add normalization 163 | block = int(l_name.split('_')[-1]) 164 | with torch.no_grad(): 165 | if 'attn' in l_name: 166 | l_repr = self.model.blocks[block].norm1(l_repr) 167 | elif 'mlp' in l_name: 168 | l_repr = self.model.blocks[block].norm2(l_repr) 169 | preds = self.cls_proj(self.ln(l_repr)) 170 | else: 171 | with torch.no_grad(): 172 | preds = self.cls_proj(self.ln(l_repr)) 173 | 174 | # Get top-5 predictions 175 | top_k = preds.topk(5, dim=-1)[1] 176 | torch.save(top_k, (cls_emb_path / f'topk_{l_name}.pt')) 177 | 178 | # Get correct label position 179 | ordered_idx = torch.argsort(preds, dim=2, descending=True) 180 | label_idx = (ordered_idx == data['index']).nonzero() 181 | pos = label_idx[:, 2].reshape(self.dataset.n_imgs, (self.n_tokens)) 182 | torch.save(pos, (cls_emb_path / f'pos_{l_name}.pt')) 183 | 184 | # Get correct label probability 185 | probs = preds[:, :, data['index']].clone() 186 | torch.save(probs, (cls_emb_path / f'probs_{l_name}.pt')) 187 | 188 | # Save key coefficients 189 | elif self.pretrained == True: 190 | l_repr = l_repr.clone() 191 | torch.save(l_repr, (net_ft_path / f'{l_name}.pt')) 192 | 193 | else: 194 | continue 195 | 196 | # Save accuracy 197 | acc = torch.hstack(acc) 198 | file = self.hs_path / 'acc.pt' 199 | torch.save(acc, file) 200 | 201 | return 202 | 203 | 204 | class ExtractorCIFAR100(ViTExtractor): 205 | """Project hidden states to class embedding space and save key coefficients 206 | of the CIFAR model. 207 | """ 208 | def __init__(self, model_name, project_path, imgs_path, device, pretrained=True): 209 | self.model_name = model_name 210 | self.pretrained = pretrained 211 | super().__init__(self.model_name, project_path, imgs_path, device, pretrained=pretrained) 212 | self._load_model() 213 | return 214 | 215 | def _get_dataset(self): 216 | self.dataset = CIFAR100( 217 | self.imgs_path, train=False, download=True, transform=self.img_transform 218 | ) 219 | self.dataloader = DataLoader(self.dataset, batch_size=1, shuffle=False) 220 | return 221 | 222 | def extract_hidden_states(self): 223 | dataset = MyCIFAR100(self.imgs_path) 224 | 225 | acc = [] 226 | for label, data in tqdm(dataset.stim_info.items(), total=len(dataset)): 227 | # Prepare concept path 228 | cls_emb_path = self.hs_path / label 229 | cls_emb_path.mkdir(parents=True, exist_ok=True) 230 | 231 | if self.pretrained == True: 232 | net_ft_path = self.project_path / 'results' / 'net_ft' / self.model_name / label 233 | net_ft_path.mkdir(parents=True, exist_ok=True) 234 | 235 | # Get image features 236 | img_ft = torch.stack([self.img_transform(img) for img in data]).to(self.device) 237 | 238 | # Compute hidden states 239 | with torch.no_grad(): 240 | out = self.extractor(img_ft) 241 | 242 | # Save hidden states 243 | for l_name, l_repr in out.items(): 244 | if l_name == 'cls-out': 245 | pred = l_repr.topk(1)[1] 246 | cat_acc = torch.squeeze((pred == int(label)).long()) 247 | acc.append(cat_acc) 248 | 249 | elif 'hs' in l_name: 250 | with torch.no_grad(): 251 | preds = self.cls_proj(self.ln(l_repr)) 252 | 253 | top_k = preds.topk(5, dim=-1)[1] 254 | torch.save(top_k, (cls_emb_path / f'topk_{l_name}.pt')) 255 | 256 | ordered_idx = torch.argsort(preds, dim=2, descending=True) 257 | label_idx = (ordered_idx == int(label)).nonzero() 258 | pos = label_idx[:, 2].reshape(dataset.n_imgs, (self.n_tokens)) 259 | torch.save(pos, (cls_emb_path / f'pos_{l_name}.pt')) 260 | 261 | probs = preds[:, :, int(label)].clone() 262 | torch.save(probs, (cls_emb_path / f'probs_{l_name}.pt')) 263 | 264 | elif self.pretrained == True: 265 | l_repr = l_repr.clone() 266 | torch.save(l_repr, (net_ft_path / f'{l_name}.pt')) 267 | 268 | else: 269 | continue 270 | 271 | # Save accuracy 272 | acc = torch.hstack(acc) 273 | file = self.hs_path / 'acc.pt' 274 | torch.save(acc, file) 275 | return 276 | 277 | 278 | if __name__ == '__main__': 279 | parser = argparse.ArgumentParser() 280 | parser.add_argument( 281 | '-pp', action='store', required=True, 282 | help='Path to the folder containing the source code.' 283 | ) 284 | parser.add_argument( 285 | '-dp', action='store', required=True, 286 | help='Path to the folder containing the dataset.' 287 | ) 288 | parser.add_argument( 289 | '-m', action='store', required=True, 290 | help='Select which model to run. Can be one of the following options: \ 291 | vit_b_16, vit_b_32, vit_large_16, vit_miil_16, vit_cifar_16, \ 292 | deit_ensemble_16, vit_gap_16.' 293 | ) 294 | parser.add_argument( 295 | '-pretrained', action='store_true', help='Use pretrained model.' 296 | ) 297 | 298 | args = parser.parse_args() 299 | project_path = Path(args.pp) 300 | data_path = Path(args.dp) 301 | 302 | model = args.m 303 | pretrained = args.pretrained 304 | 305 | device = "cuda" if torch.cuda.is_available() else "cpu" 306 | 307 | if model == 'vit_cifar_16': 308 | extractor = ExtractorCIFAR100( 309 | model, project_path, data_path, device=device, pretrained=pretrained 310 | ) 311 | extractor.extract_hidden_states() 312 | else: 313 | extractor = ViTExtractor( 314 | model, project_path, data_path, device=device, pretrained=pretrained 315 | ) 316 | extractor.extract_hidden_states() -------------------------------------------------------------------------------- /src/memories.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from scipy.stats import wilcoxon 7 | import torch 8 | 9 | from src.datasets.cifar import MyCIFAR100 10 | from src.datasets.imagenet import ImagenetDatasetS 11 | from src.models.load import load_vit 12 | 13 | 14 | class Memory: 15 | def __init__(self, proj_path, dataset_path, model, device='cpu'): 16 | self.model_name = model 17 | self.device = device 18 | 19 | self.proj_path = Path(proj_path) 20 | self.dataset_path = Path(dataset_path) 21 | self.net_ft_path = proj_path / 'results' / 'net_ft' / model 22 | self.dec_path = proj_path / 'results' / 'class_embed' / model 23 | 24 | if 'cifar' in model: 25 | self.stim_info = MyCIFAR100(self.dataset_path).stim_info 26 | else: 27 | self.stim_info = ImagenetDatasetS(self.dataset_path).stim_info 28 | 29 | self._load_model() 30 | 31 | def _load_model(self): 32 | # Load model 33 | self.model, self.n_tokens, self.hs_dim, _ = load_vit( 34 | self.model_name, self.device, self.proj_path 35 | ) 36 | self.model.eval() 37 | 38 | # Load random model 39 | self.random_model, _, _, _ = load_vit( 40 | self.model_name, self.device, self.proj_path, pretrained=False 41 | ) 42 | self.random_model.eval() 43 | 44 | # Define number of layers 45 | if 'large' in self.model_name: 46 | self.n_layers = 24 47 | else: 48 | self.n_layers = 12 49 | 50 | # Define number of classes 51 | if 'cifar' in self.model_name: 52 | self.n_classes = 100 53 | else: 54 | self.n_classes = 1000 55 | 56 | return 57 | 58 | def compute_class_value_agr(self, layer_type): 59 | """ 60 | Compute class-value agreement scores. 61 | """ 62 | # Normalize class embedding 63 | with torch.no_grad(): 64 | self.model.head.weight /= self.model.head.weight.norm(dim=-1, keepdim=True) 65 | self.random_model.head.weight /= self.random_model.head.weight.norm(dim=-1, keepdim=True) 66 | 67 | # Get top-1 logit per class 68 | max_logits = [] 69 | pvals = [] 70 | random_k = [] 71 | for b in range(self.n_layers): 72 | if layer_type == 'mlp': 73 | with torch.no_grad(): 74 | self.model.blocks[b].mlp.fc2.weight /= self.model.blocks[b].mlp.fc2.weight.norm(dim=0, keepdim=True) 75 | val_proj = self.model.head.weight @ self.model.blocks[b].mlp.fc2.weight 76 | else: 77 | with torch.no_grad(): 78 | self.model.blocks[b].attn.proj.weight /= self.model.blocks[b].attn.proj.weight.norm(dim=0, keepdim=True) 79 | val_proj = self.model.head.weight @ self.model.blocks[b].attn.proj.weight 80 | 81 | logits_top_k = torch.squeeze(val_proj.topk(1, dim=1)[0]) 82 | max_logits.append(logits_top_k) 83 | 84 | # Get top-1 logit per class in random model 85 | if layer_type == 'mlp': 86 | with torch.no_grad(): 87 | self.random_model.blocks[b].mlp.fc2.weight /= self.random_model.blocks[b].mlp.fc2.weight.norm(dim=0, keepdim=True) 88 | val_proj = self.random_model.head.weight @ self.random_model.blocks[b].mlp.fc2.weight 89 | else: 90 | with torch.no_grad(): 91 | self.random_model.blocks[b].attn.proj.weight /= self.random_model.blocks[b].attn.proj.weight.norm(dim=0, keepdim=True) 92 | val_proj = self.random_model.head.weight @ self.random_model.blocks[b].attn.proj.weight 93 | 94 | random_top_k = torch.squeeze(val_proj.topk(1, dim=1)[0]).detach().numpy() 95 | random_k.append(np.mean(random_top_k)) 96 | 97 | # Compute statistical difference 98 | wx = wilcoxon(logits_top_k.detach().numpy(), random_top_k, alternative='greater') 99 | pvals.append(wx.pvalue) 100 | 101 | # Save results in dataframe 102 | max_logits = torch.stack(max_logits).flatten().detach().numpy() 103 | blocks = torch.arange(1, self.n_layers+1).repeat_interleave(self.n_classes) 104 | df = pd.DataFrame({'top-1 logits': max_logits, 'block': blocks}) 105 | 106 | return df, pvals, random_k 107 | 108 | def compute_key_value_agreement(self, block, layer_type): 109 | """Compute key value agreement rates. 110 | """ 111 | res_path = self.proj_path / 'results' / 'memories' / self.model_name 112 | res_path.mkdir(parents=True, exist_ok=True) 113 | 114 | if layer_type == 'mlp': 115 | with torch.no_grad(): 116 | val_proj = self.model.head.weight @ self.model.blocks[block].mlp.fc2.weight 117 | elif layer_type == 'attn': 118 | with torch.no_grad(): 119 | val_proj = self.model.head.weight @ self.model.blocks[block].attn.proj.weight 120 | key_topk = val_proj.topk(5, dim=1)[1] 121 | 122 | if 'cifar' in self.model_name: 123 | concepts = list(self.stim_info.keys()) 124 | indexes = concepts 125 | n_classes = 100 126 | else: 127 | concepts = self.stim_info['imagenet_id'].unique() 128 | indexes = self.stim_info['index'].unique() 129 | n_classes = 1000 130 | 131 | agreement = [] 132 | agreement_random = [] 133 | logits = [] 134 | logits_random = [] 135 | for c, c_idx in zip(concepts, indexes): 136 | c_idx = int(c_idx) 137 | 138 | # Get top-5 keys of hidden state 139 | if layer_type == 'mlp': 140 | file = self.net_ft_path / c / f'key-{layer_type}_{block}.pt' 141 | key_val = torch.load(file, map_location=self.device) 142 | 143 | elif layer_type == 'attn': 144 | attn_file = self.net_ft_path / c / f'attn-w_{block}.pt' 145 | attn_data = torch.load(attn_file, map_location=self.device) 146 | val_file = self.net_ft_path / c / f'value_{block}.pt' 147 | val_data = torch.load(val_file, map_location=self.device) 148 | key_val = attn_data @ val_data 149 | key_val = key_val.transpose(1,2).reshape(5, self.n_tokens, self.hs_dim) 150 | 151 | hs_key_topk = key_val.topk(5, dim=-1)[1] 152 | 153 | # Compute top-k value agreement 154 | agreement.append(self._compute_agreement(hs_key_topk, key_topk[c_idx])) 155 | agreement_random.append(self._compute_agreement( 156 | hs_key_topk, key_topk[torch.randint(0, n_classes, (1,))]) 157 | ) 158 | 159 | # Compute average logits 160 | logits.append(self._get_logits(hs_key_topk, val_proj, c_idx)) 161 | logits_random.append( 162 | self._get_logits(hs_key_topk, val_proj, torch.randint(0, n_classes, (1,))) 163 | ) 164 | 165 | f_agreement = res_path / f'{layer_type}_{block}_value-class_agreement.pt' 166 | torch.save(torch.stack(agreement), f_agreement) 167 | f_agreement_random = res_path / f'{layer_type}_{block}_value-class_agreement-random.pt' 168 | torch.save(torch.stack(agreement_random), f_agreement_random) 169 | 170 | f_logits = res_path / f'{layer_type}_{block}_value-class_logits.pt' 171 | torch.save(torch.stack(logits), f_logits) 172 | f_logits_random = res_path / f'{layer_type}_{block}_value-class_logits-random.pt' 173 | torch.save(torch.stack(logits_random), f_logits_random) 174 | 175 | return 176 | 177 | def _compute_agreement(self, hs_key_topk, c_key_topk): 178 | agr = torch.isin(hs_key_topk, c_key_topk) 179 | agr = torch.any(agr, dim=-1) 180 | return agr 181 | 182 | def _get_logits(self, hs_key_topk, val_proj, c_idx): 183 | c_logits = [] 184 | for img in range(hs_key_topk.shape[0]): 185 | for t in range(hs_key_topk.shape[1]): 186 | ks = hs_key_topk[img, t] 187 | c_logits.append(val_proj[c_idx, ks]) 188 | c_logits = torch.stack(c_logits).view(hs_key_topk.shape[0], hs_key_topk.shape[1], -1) 189 | return c_logits 190 | 191 | def compute_composition(self, block, layer_type): 192 | # Create paths 193 | res_path = self.proj_path / 'results' / 'memories' / self.model_name 194 | res_path.mkdir(parents=True, exist_ok=True) 195 | 196 | # Get concepts 197 | if 'cifar' in self.model_name: 198 | concepts = list(self.stim_info.keys()) 199 | else: 200 | concepts = self.stim_info['imagenet_id'].unique() 201 | 202 | # Get most activating class per memory value 203 | if layer_type == 'mlp': 204 | with torch.no_grad(): 205 | val_proj = self.model.head.weight @ self.model.blocks[block].mlp.fc2.weight 206 | elif layer_type == 'attn': 207 | with torch.no_grad(): 208 | val_proj = self.model.head.weight @ self.model.blocks[block].attn.proj.weight 209 | key_topk = torch.squeeze(val_proj.topk(1, dim=0)[1]) 210 | 211 | # Get top-5 most activated memories per image and token 212 | match = [] 213 | for c in concepts: 214 | if layer_type == 'mlp': 215 | file = self.net_ft_path / c / f'key-{layer_type}_{block}.pt' 216 | key_val = torch.load(file, map_location=self.device) 217 | 218 | elif layer_type == 'attn': 219 | attn_file = self.net_ft_path / c / f'attn-w_{block}.pt' 220 | attn_data = torch.load(attn_file, map_location=self.device) 221 | val_file = self.net_ft_path / c / f'value_{block}.pt' 222 | val_data = torch.load(val_file, map_location=self.device) 223 | key_val = attn_data @ val_data 224 | key_val = key_val.transpose(1,2).reshape(5, self.n_tokens, self.hs_dim) 225 | 226 | hs_key_topk = key_val.topk(5, dim=-1)[1] 227 | 228 | # Get most activating classes for the top-5 activated memories 229 | for img in range(key_val.shape[0]): 230 | for t in range(key_val.shape[1]): 231 | for k in range(5): 232 | k_top = key_topk[hs_key_topk[img, t, k]] 233 | hs_key_topk[img, t, k] = k_top 234 | 235 | # Compute agreement with predictions at output 236 | file = self.dec_path / c / f'topk_hs-{layer_type}_{block}.pt' 237 | preds = torch.load(file, map_location=self.device)[:, :, 0] 238 | img_match = [] 239 | for img in range(key_val.shape[0]): 240 | for t in range(key_val.shape[1]): 241 | layer_pred = preds[img, t] 242 | img_match.append(torch.isin(hs_key_topk[img, t], layer_pred)) 243 | img_match = torch.stack(img_match) 244 | img_match = img_match.reshape(key_val.shape[0], key_val.shape[1], -1) 245 | match.append(img_match) 246 | 247 | f = res_path / f'{layer_type}_{block}_top5_pred-match.pt' 248 | torch.save(torch.stack(match), f) 249 | 250 | return 251 | 252 | 253 | def compute_key_value_agr_rate(model_name, layer_type, res_path, tokens='all'): 254 | """Compute key-value agreement rate. 255 | 256 | Parameters 257 | ---------- 258 | model_name : str 259 | Model name. Can be one of the following options: vit_b_16, vit_b_32, 260 | vit_large_16, vit_miil_16, vit_cifar_16, deit_ensemble_16, vit_gap_16. 261 | layer_type : str 262 | Type of layer. Can be one of 'attn' or 'mlp'. 263 | res_path : pathlib.Path 264 | Path to results. 265 | tokens : str, optional 266 | Use all tokens or cls only, by default 'all'. 267 | 268 | Returns 269 | ------- 270 | pd.DataFrame 271 | Key-value agreement rate. 272 | """ 273 | if 'large' in model_name: 274 | n_layers = 24 275 | else: 276 | n_layers = 12 277 | 278 | # Get key-value agreement scores 279 | agreement_rate = [] 280 | for b in range(n_layers): 281 | f = res_path / 'memories/' / model_name / f'{layer_type}_{b}_value-class_agreement.pt' 282 | agr = torch.load(f, map_location='cpu') 283 | if tokens == 'all': 284 | agr = agr.flatten(start_dim=1) 285 | elif tokens == 'cls': 286 | agr = agr[:,:, 0] 287 | # Compute agreement rate 288 | tokens_agr = torch.sum(agr, dim=1) / agr.shape[1] * 100 289 | agreement_rate.append(tokens_agr) 290 | 291 | # Save results in dataframe 292 | agreement_rate = torch.stack(agreement_rate).flatten().detach().numpy() 293 | agreement_rate = pd.DataFrame( 294 | {'block': torch.arange(1,n_layers+1).repeat_interleave(agr.shape[0]), 'agreement rate': agreement_rate} 295 | ) 296 | agreement_rate['agreement rate'] = agreement_rate['agreement rate'].astype('float') 297 | 298 | return agreement_rate 299 | 300 | 301 | if __name__ == '__main__': 302 | parser = argparse.ArgumentParser() 303 | parser.add_argument( 304 | '-pp', action='store', required=True, 305 | help='Path to the folder containing the source code.' 306 | ) 307 | parser.add_argument( 308 | '-dp', action='store', required=True, 309 | help='Path to the folder containing the dataset.' 310 | ) 311 | parser.add_argument( 312 | '-m', action='store', required=True, 313 | help='Select which model to run. Can be one of the following options: \ 314 | vit_16, vit_32' 315 | ) 316 | parser.add_argument( 317 | '-lt', action='store', required=True, 318 | help='Hidden state type' 319 | ) 320 | 321 | args = parser.parse_args() 322 | 323 | project_path = Path(args.pp) 324 | dataset_path = Path(args.dp) 325 | 326 | model = args.m 327 | layer_type = args.lt 328 | device = "cuda" if torch.cuda.is_available() else "cpu" 329 | 330 | if 'large' in model: 331 | n_layers = 24 332 | else: 333 | n_layers = 12 334 | 335 | for b in range(n_layers): 336 | memory = Memory(project_path, dataset_path, model, device=device) 337 | memory.compute_key_value_agreement(b, layer_type) 338 | memory.compute_composition(b, layer_type) -------------------------------------------------------------------------------- /src/plots/mech_interp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | import pandas as pd 5 | import torch 6 | 7 | from src.identifiability import MODEL_MAP 8 | from src.identifiability import ( 9 | compute_class_similarity_change, compute_residual_match 10 | ) 11 | from src.memories import Memory, compute_key_value_agr_rate 12 | from src.vis import Vis 13 | 14 | 15 | def plot_class_building(model_name, res_path, dataset_path): 16 | """ 17 | Plot categorical building over blocks and layers. 18 | """ 19 | plt.rcParams.update({'font.size': 16}) 20 | 21 | if 'large' in model_name: 22 | n_layers = 24 23 | else: 24 | n_layers = 12 25 | 26 | fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(13,5)) 27 | 28 | # Plot class similarity change rate 29 | pers = [] 30 | for layer_type in ['attn', 'mlp']: 31 | for b in range(1, n_layers): 32 | 33 | diff = compute_class_similarity_change( 34 | model_name, b, layer_type, dataset_path, res_path 35 | ) 36 | 37 | if 'gap' not in model_name: 38 | diff_cls = diff[:, 0] 39 | cls_change_rate = torch.sum(diff_cls > 0) / diff_cls.flatten().shape[0] 40 | pers.append([f'{layer_type} - [CLS] ', b+1, cls_change_rate.detach().numpy()]) 41 | 42 | diff_img = diff[:, 1:] 43 | diff_change_rate = torch.sum(diff_img > 0) / diff_img.flatten().shape[0] 44 | pers.append([f'{layer_type} - Image', b+1, diff_change_rate.detach().numpy()]) 45 | 46 | else: 47 | diff_img = diff 48 | diff_change_rate = torch.sum(diff_img > 0) / diff_img.flatten().shape[0] 49 | pers.append([f'{layer_type} - Image', b+1, diff_change_rate.detach().numpy()]) 50 | 51 | pers = pd.DataFrame(pers, columns=['Layer Type', 'block', 'change rate']) 52 | pers['change rate'] = pers['change rate'].astype('float') 53 | 54 | if model_name == 'vit_gap_16': 55 | colors = ["#1f78b4", "#33a02c"] 56 | sns.lineplot( 57 | pers, x='block', y='change rate', hue='Layer Type', marker='*', ax=axes[0], 58 | palette=sns.color_palette(colors) 59 | ) 60 | else: 61 | sns.lineplot( 62 | pers, x='block', y='change rate', hue='Layer Type', marker='*', ax=axes[0], 63 | palette=sns.color_palette("Paired", 4) 64 | ) 65 | axes[0].hlines(xmin=-1, xmax=n_layers, y=0.5, colors='dimgray', linestyles='--', lw=2) 66 | 67 | axes[0].set_xlim(1.9, int(n_layers) + 0.1) 68 | axes[0].set_xticks(np.arange(2, n_layers + 1)) 69 | axes[0].xaxis.set_tick_params(labelsize=10) 70 | axes[0].set_ylim(0, 1) 71 | 72 | axes[0].legend(loc='lower left').set_title('') 73 | axes[0].set_title('class similarity change rate') 74 | 75 | # Plot residual composition 76 | match = compute_residual_match(model_name, dataset_path, res_path) 77 | match = match.pivot(index='block', columns=['Source']) 78 | match.columns = match.columns.droplevel() 79 | match = match[['prev', 'attn', 'mlp', 'comp']] 80 | match.plot(kind='bar', stacked=True, rot=0, ax=axes[1], cmap="Set2") 81 | 82 | axes[1].set_ylabel('% of tokens') 83 | axes[1].xaxis.set_tick_params(labelsize=10) 84 | axes[1].legend( 85 | ['previous', 'attn','mlp', 'composition'], loc='upper left', ncol=2 86 | ) 87 | axes[1].set_title('residual stream composition') 88 | 89 | fig.suptitle(MODEL_MAP[model_name]) 90 | plt.tight_layout() 91 | f = res_path / 'figures' / f'compositionality_{model_name}.png' 92 | plt.savefig(f, dpi=300) 93 | plt.show() 94 | 95 | return 96 | 97 | 98 | def get_segmented_increases(model_name, layer_type, proj_path, dataset_path): 99 | """ 100 | Compute class similarity change separately for class- and context-labeled tokens. 101 | """ 102 | # Get segmentations 103 | im = Vis(proj_path, dataset_path, model_name, device='cpu') 104 | stim_info = im.stim_info 105 | concepts = stim_info['imagenet_id'].unique().tolist() 106 | sgts = [] 107 | for c in concepts: 108 | for i in range(5): 109 | gt = im.get_segmentation(c, i).flatten() 110 | sgts.append(gt) 111 | sgts = np.hstack(sgts) 112 | 113 | # Compute increases 114 | if 'large' in model_name: 115 | n_layers = 24 116 | else: 117 | n_layers = 12 118 | 119 | dfs = [] 120 | for b in range(1, n_layers): 121 | res_path = proj_path / 'results' 122 | dec = compute_class_similarity_change(model_name, b, layer_type, dataset_path, res_path) 123 | if 'gap' not in model_name: 124 | dec = dec[:, 1:] 125 | df = pd.DataFrame({'decoding': dec.flatten().detach().numpy(), 'token location': sgts}) 126 | df['block'] = b 127 | dfs.append(df) 128 | dfs = pd.concat(dfs) 129 | dfs['token location'] = dfs['token location'].replace({0: 'context', 1: 'category'}) 130 | return dfs 131 | 132 | 133 | def compute_context_diff(model_name, layer_type, proj_path, dataset_path): 134 | """ 135 | Compare class similarity change between class- and context-labeled tokens. 136 | """ 137 | if 'large' in model_name: 138 | n_layers = 24 139 | else: 140 | n_layers = 12 141 | 142 | dfs = get_segmented_increases(model_name, layer_type, proj_path, dataset_path) 143 | diffs = [] 144 | pvals = [] 145 | for b in range(1, n_layers): 146 | df = dfs.loc[dfs['block'] == b] 147 | context = df.loc[df['token location'] == 'context']['decoding'].to_numpy() 148 | category = df.loc[df['token location'] == 'category']['decoding'].to_numpy() 149 | 150 | # Compute true difference 151 | true_diff = np.mean(category) - np.mean(context) 152 | 153 | # Compare to random model 154 | context_len = context.shape[0] 155 | all_data = np.concatenate((context, category)) 156 | random_diffs = [] 157 | n_perm = 300 158 | for p in range(300): 159 | np.random.shuffle(all_data) 160 | context, category = all_data[:context_len], all_data[context_len:] 161 | random_diffs.append(np.mean(category) - np.mean(context)) 162 | pval = np.sum(true_diff < random_diffs) / n_perm 163 | 164 | diffs.append(true_diff) 165 | pvals.append(pval) 166 | 167 | df = pd.DataFrame({'diffs': diffs, 'pvals': pvals}) 168 | df['model'] = model_name 169 | df['layer'] = layer_type 170 | df['block'] = np.arange(len(diffs)) 171 | 172 | return df 173 | 174 | 175 | def plot_categorical_updates(proj_path, dataset_path): 176 | """ 177 | Plot categorical updates for all models. 178 | """ 179 | plt.rcParams.update({'font.size': 14}) 180 | fig, axes = plt.subplots(nrows=7, ncols=2, figsize=(10, 20)) 181 | for m_idx, model in enumerate(MODEL_MAP.keys()): 182 | mem = Memory(proj_path, dataset_path, model) 183 | 184 | for l_idx, layer_type in enumerate(['attn', 'mlp']): 185 | df, pvals_l, _ = mem.compute_class_value_agr(layer_type) 186 | 187 | # Plot pvalues 188 | sig = np.where(np.array(pvals_l) < 0.05)[0] 189 | axes[m_idx, l_idx].scatter(sig, [0.8] * len(sig), marker='*', c='grey', s=50) 190 | 191 | # Plot agreement values 192 | sns.boxplot(df, x='block', y='top-1 logits', ax=axes[m_idx, l_idx]) 193 | axes[m_idx, l_idx].set_ylim((0,0.85)) 194 | axes[m_idx, l_idx].xaxis.set_tick_params(labelsize=9) 195 | axes[m_idx, l_idx].yaxis.set_tick_params(labelsize=12) 196 | 197 | axes[m_idx, l_idx].set_title(f'{MODEL_MAP[model]} - {layer_type}') 198 | 199 | plt.tight_layout() 200 | f = proj_path / 'results/figures' / f'match_score.png' 201 | plt.savefig(f, dpi=300) 202 | plt.show() 203 | return 204 | 205 | 206 | def plot_key_value_agreement_rate(res_path): 207 | """ 208 | Plot key-value agreement rate. 209 | """ 210 | plt.rcParams.update({'font.size': 14}) 211 | fig, axes = plt.subplots(nrows=7, ncols=2, figsize=(10, 20)) 212 | for m_idx, model_name in enumerate(MODEL_MAP.keys()): 213 | for l_idx, layer_type in enumerate(['attn', 'mlp']): 214 | df = compute_key_value_agr_rate(model_name, layer_type, res_path) 215 | sns.stripplot( 216 | df, x='block', y='agreement rate', alpha=0.3, zorder=1, 217 | ax=axes[m_idx, l_idx] 218 | ) 219 | sns.pointplot( 220 | df, x='block', y='agreement rate', color='red', 221 | markers="d", scale=.75, errorbar=None, ax=axes[m_idx, l_idx] 222 | ) 223 | axes[m_idx, l_idx].set_title(f'{MODEL_MAP[model_name]} - {layer_type}') 224 | axes[m_idx, l_idx].set_ylim((0,100)) 225 | 226 | axes[m_idx, l_idx].xaxis.set_tick_params(labelsize=9) 227 | axes[m_idx, l_idx].yaxis.set_tick_params(labelsize=12) 228 | 229 | plt.tight_layout() 230 | f = res_path / 'figures' / f'key_val_agr.png' 231 | plt.savefig(f, dpi=300) 232 | plt.show() 233 | return 234 | 235 | 236 | def compare_memory_pairs(proj_path, dataset_path): 237 | """ 238 | Compare memory results across ViT variants. 239 | """ 240 | plt.rcParams.update({'font.size': 15}) 241 | fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 3)) 242 | 243 | # Plot class-value agreement 244 | for layer_type, ax in zip(['attn', 'mlp'], axes.flat[:2]): 245 | dfs = [] 246 | random_k = [] 247 | for model_name in MODEL_MAP.keys(): 248 | mem = Memory(proj_path, dataset_path, model_name) 249 | df, _, rk = mem.compute_class_value_agr(layer_type) 250 | random_k.append(rk) 251 | 252 | if model_name == 'vit_large_16': 253 | new_df = [] 254 | for idx, b in enumerate(np.arange(1, 25, 2)): 255 | mean = df.loc[df['block'].isin([b, b+1])]['top-1 logits'].mean() 256 | new_df.append([mean, idx+1, MODEL_MAP['vit_large_16']]) 257 | new_df = pd.DataFrame( 258 | new_df, columns=['top-1 logits', 'block', 'model'] 259 | ) 260 | df = new_df 261 | else: 262 | df['model'] = MODEL_MAP[model_name] 263 | 264 | dfs.append(df) 265 | 266 | dfs = pd.concat(dfs) 267 | sns.lineplot( 268 | dfs, x='block', y='top-1 logits', hue='model', ax=ax, 269 | marker='*', markersize=10 270 | ) 271 | ax.set_title(layer_type) 272 | ax.set_ylim((0.08, 0.45)) 273 | ax.set_xticks(np.arange(1, 13)) 274 | ax.set_xticklabels([]) 275 | ax.set_xlim((0.8, 12.2)) 276 | ax.set_xlabel('normalized layer') 277 | 278 | axes[0].set_ylabel('class-value agreement') 279 | axes[1].set_ylabel('') 280 | 281 | # Plot key-value agreement rate 282 | for layer_type, ax in zip(['attn', 'mlp'], axes[2:].flat): 283 | dfs = [] 284 | for model_name in MODEL_MAP.keys(): 285 | res_path = proj_path / 'results' 286 | df = compute_key_value_agr_rate(model_name, layer_type, res_path) 287 | 288 | if model_name == 'vit_large_16': 289 | new_df = [] 290 | for idx, b in enumerate(np.arange(1, 25, 2)): 291 | mean = df.loc[df['block'].isin([b, b+1])]['agreement rate'].mean() 292 | new_df.append([mean, idx+1, MODEL_MAP['vit_large_16']]) 293 | new_df = pd.DataFrame( 294 | new_df, columns=['agreement rate', 'block', 'model'] 295 | ) 296 | df = new_df 297 | else: 298 | df['model'] = MODEL_MAP[model_name] 299 | 300 | dfs.append(df) 301 | 302 | dfs = pd.concat(dfs) 303 | sns.lineplot( 304 | dfs, x='block', y='agreement rate', hue='model', ax=ax, 305 | marker='*', markersize=10 306 | ) 307 | ax.set_title(layer_type) 308 | ax.set_ylim((0,85)) 309 | ax.set_xticks(np.arange(1, 13)) 310 | ax.set_xticklabels([]) 311 | ax.set_xlim((0.8, 12.2)) 312 | ax.set_xlabel('normalized layer') 313 | 314 | axes[2].set_ylabel('key-value agreement rate') 315 | axes[3].set_ylabel('') 316 | 317 | lines, labels = axes[1].get_legend_handles_labels() 318 | lgd = fig.legend( 319 | lines, labels, loc='upper center', ncol=4, 320 | bbox_to_anchor=(0.5, 1.25), 321 | ) 322 | 323 | for ax in axes.flat: 324 | ax.get_legend().remove() 325 | 326 | plt.tight_layout() 327 | f = proj_path / 'results/figures' / f'compare_all.png' 328 | plt.savefig(f, dpi=300, bbox_inches='tight') 329 | plt.show() 330 | return 331 | 332 | 333 | def plot_agr_rate_diff(res_path, device='cpu'): 334 | """ 335 | Plot difference in key-value agreement rate between accurate and non-accurate samples. 336 | """ 337 | plt.rcParams.update({'font.size': 14}) 338 | fig, axes = plt.subplots(nrows=7, ncols=2, figsize=(10,20)) 339 | for m_idx, model_name in enumerate(MODEL_MAP.keys()): 340 | for l_idx, layer_type in enumerate(['attn', 'mlp']): 341 | if 'large' in model_name: 342 | n_layers = 24 343 | else: 344 | n_layers = 12 345 | 346 | # Get accuracy 347 | acc_file = res_path / 'class_embed'/ model_name / 'acc.pt' 348 | acc = torch.load(acc_file, map_location=device) 349 | 350 | diffs = [] 351 | pvals = [] 352 | for b in range(n_layers): 353 | # Compute difference in accuracy 354 | f = res_path / 'memories/' / model_name / f'{layer_type}_{b}_value-class_agreement.pt' 355 | agr = torch.load(f, map_location='cpu')[:,:, 0].flatten() 356 | 357 | acc_1 = agr[acc==1].float() 358 | agr_acc = torch.sum(acc_1) / acc_1.shape[0] * 100 359 | acc_2 = agr[acc==0].float() 360 | agr_inacc = torch.sum(acc_2) / acc_2.shape[0] * 100 361 | 362 | true_diff = agr_acc - agr_inacc 363 | diffs.append(true_diff.detach().numpy()) 364 | 365 | # Compare to random model 366 | random_diffs = [] 367 | for p in range(300): 368 | rand_idxs = torch.randperm(len(acc)) 369 | r_acc_1 = agr[rand_idxs[:len(acc_1)]].float() 370 | r_agr_acc = torch.sum(r_acc_1) / r_acc_1.shape[0] * 100 371 | r_acc_2 = agr[rand_idxs[len(acc_1):]].float() 372 | r_agr_inacc = torch.sum(r_acc_2) / r_acc_2.shape[0] * 100 373 | 374 | random_diffs.append(r_agr_acc - r_agr_inacc) 375 | 376 | pval = torch.sum(true_diff < torch.stack(random_diffs)) / 300 377 | pvals.append(pval.detach().numpy()) 378 | 379 | df = pd.DataFrame({'diff': diffs, 'block': np.arange(1, n_layers+1)}) 380 | df['model'] = model_name 381 | df['layer'] = layer_type 382 | df['diff'] = df['diff'].astype('float') 383 | 384 | # Plot pvalues 385 | sig = np.where(np.array(pvals) < 0.05)[0] + 1 386 | axes[m_idx, l_idx].scatter(sig, [69] * len(sig), marker='*', c='grey', s=50) 387 | 388 | sns.lineplot(df, x='block', y='diff', marker='o', ax=axes[m_idx, l_idx]) 389 | axes[m_idx, l_idx].set_ylim(-5, 72) 390 | axes[m_idx, l_idx].set_xticks(np.arange(1, n_layers+1)) 391 | axes[m_idx, l_idx].set_xlim((0.8, n_layers + 0.2)) 392 | 393 | axes[m_idx, l_idx].hlines( 394 | xmin=1, xmax=n_layers+1, y=0, colors='dimgray', linestyles='--', lw=2 395 | ) 396 | 397 | axes[m_idx, l_idx].set_ylabel('difference (%)') 398 | axes[m_idx, l_idx].xaxis.set_tick_params(labelsize=9) 399 | axes[m_idx, l_idx].yaxis.set_tick_params(labelsize=12) 400 | 401 | axes[m_idx, l_idx].set_title(f'{MODEL_MAP[model_name]} - {layer_type}') 402 | 403 | plt.tight_layout() 404 | f = res_path / 'figures' / f'acc_agr_rate.png' 405 | plt.savefig(f, dpi=300) 406 | plt.show() 407 | 408 | return 409 | 410 | 411 | def plot_memory_compositionality(res_path): 412 | """ 413 | Plot compositionality of memory pairs. 414 | """ 415 | plt.rcParams.update({'font.size': 14}) 416 | fig, axes = plt.subplots(nrows=4, ncols=2, figsize=(13, 15)) 417 | for model_name, ax in zip(MODEL_MAP.keys(), axes.flat): 418 | if 'large' in model_name: 419 | n_layers = 24 420 | else: 421 | n_layers = 12 422 | 423 | # Get compositionality 424 | comps = [] 425 | for layer_type in ['attn', 'mlp']: 426 | for b in range(n_layers): 427 | f = res_path / 'memories' / model_name / f'{layer_type}_{b}_top5_pred-match.pt' 428 | comp = torch.load(f, map_location='cpu')#[:, :, 0] 429 | preds = torch.any(comp, dim=-1) 430 | per = torch.sum(torch.any(comp, dim=-1)) / preds.flatten().shape[0] * 100 431 | comps.append([(b+1), layer_type, per.detach().numpy()]) 432 | 433 | comps = pd.DataFrame(comps, columns=['block', 'layer type', 'percentage']) 434 | comps['percentage'] = comps['percentage'].astype('float') 435 | 436 | # Plot 437 | sns.lineplot(comps, x='block', y='percentage', hue='layer type', marker='o', ax=ax) 438 | ax.set_xticks(np.arange(1, n_layers+1)) 439 | ax.xaxis.set_tick_params(labelsize=10) 440 | ax.set_xlim(1, n_layers) 441 | ax.set_yticks(np.arange(0, 75, 10)) 442 | ax.yaxis.set_tick_params(labelsize=12) 443 | ax.set_ylim((0, 75)) 444 | ax.set_ylabel('%') 445 | 446 | ax.legend(loc='upper left') 447 | ax.set_title(MODEL_MAP[model_name]) 448 | 449 | fig.delaxes(axes[3,1]) 450 | 451 | plt.tight_layout() 452 | f = res_path / 'figures' / f'composition.png' 453 | plt.savefig(f, dpi=300) 454 | plt.show() 455 | return -------------------------------------------------------------------------------- /src/models/vit_edited.py: -------------------------------------------------------------------------------- 1 | """ Vision Transformer (ViT) in PyTorch 2 | 3 | A PyTorch implement of Vision Transformers as described in: 4 | 5 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' 6 | - https://arxiv.org/abs/2010.11929 7 | 8 | `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` 9 | - https://arxiv.org/abs/2106.10270 10 | 11 | The official jax code is released and available at https://github.com/google-research/vision_transformer 12 | 13 | Acknowledgments: 14 | * The paper authors for releasing code and weights, thanks! 15 | * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out 16 | for some einops/einsum fun 17 | * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT 18 | * Bert reference code checks against Huggingface Transformers and Tensorflow Bert 19 | 20 | Hacked together by / Copyright 2020, Ross Wightman 21 | """ 22 | import math 23 | import logging 24 | from functools import partial 25 | from collections import OrderedDict 26 | from typing import Optional 27 | 28 | import torch 29 | import torch.nn as nn 30 | import torch.nn.functional as F 31 | import torch.utils.checkpoint 32 | 33 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD,\ 34 | OPENAI_CLIP_MEAN, OPENAI_CLIP_STD 35 | from timm.models.helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv, checkpoint_seq 36 | from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ 37 | from timm.models.registry import register_model 38 | 39 | _logger = logging.getLogger(__name__) 40 | 41 | 42 | def _cfg(url='', **kwargs): 43 | return { 44 | 'url': url, 45 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 46 | 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 47 | 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 48 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 49 | **kwargs 50 | } 51 | 52 | 53 | default_cfgs = { 54 | # patch models (weights from official Google JAX impl) 55 | 'vit_tiny_patch16_224': _cfg( 56 | url='https://storage.googleapis.com/vit_models/augreg/' 57 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 58 | 'vit_tiny_patch16_384': _cfg( 59 | url='https://storage.googleapis.com/vit_models/augreg/' 60 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 61 | input_size=(3, 384, 384), crop_pct=1.0), 62 | 'vit_small_patch32_224': _cfg( 63 | url='https://storage.googleapis.com/vit_models/augreg/' 64 | 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 65 | 'vit_small_patch32_384': _cfg( 66 | url='https://storage.googleapis.com/vit_models/augreg/' 67 | 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 68 | input_size=(3, 384, 384), crop_pct=1.0), 69 | 'vit_small_patch16_224': _cfg( 70 | url='https://storage.googleapis.com/vit_models/augreg/' 71 | 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 72 | 'vit_small_patch16_384': _cfg( 73 | url='https://storage.googleapis.com/vit_models/augreg/' 74 | 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 75 | input_size=(3, 384, 384), crop_pct=1.0), 76 | 'vit_base_patch32_224': _cfg( 77 | url='https://storage.googleapis.com/vit_models/augreg/' 78 | 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 79 | 'vit_base_patch32_384': _cfg( 80 | url='https://storage.googleapis.com/vit_models/augreg/' 81 | 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 82 | input_size=(3, 384, 384), crop_pct=1.0), 83 | 'vit_base_patch16_224': _cfg( 84 | url='https://storage.googleapis.com/vit_models/augreg/' 85 | 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), 86 | 'vit_base_patch16_384': _cfg( 87 | url='https://storage.googleapis.com/vit_models/augreg/' 88 | 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', 89 | input_size=(3, 384, 384), crop_pct=1.0), 90 | 'vit_base_patch8_224': _cfg( 91 | url='https://storage.googleapis.com/vit_models/augreg/' 92 | 'B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), 93 | 'vit_large_patch32_224': _cfg( 94 | url='', # no official model weights for this combo, only for in21k 95 | ), 96 | 'vit_large_patch32_384': _cfg( 97 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', 98 | input_size=(3, 384, 384), crop_pct=1.0), 99 | 'vit_large_patch16_224': _cfg( 100 | url='https://storage.googleapis.com/vit_models/augreg/' 101 | 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'), 102 | 'vit_large_patch16_384': _cfg( 103 | url='https://storage.googleapis.com/vit_models/augreg/' 104 | 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', 105 | input_size=(3, 384, 384), crop_pct=1.0), 106 | 107 | 'vit_large_patch14_224': _cfg(url=''), 108 | 'vit_huge_patch14_224': _cfg(url=''), 109 | 'vit_giant_patch14_224': _cfg(url=''), 110 | 'vit_gigantic_patch14_224': _cfg(url=''), 111 | 112 | 113 | # patch models, imagenet21k (weights from official Google JAX impl) 114 | 'vit_tiny_patch16_224_in21k': _cfg( 115 | url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', 116 | num_classes=21843), 117 | 'vit_small_patch32_224_in21k': _cfg( 118 | url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', 119 | num_classes=21843), 120 | 'vit_small_patch16_224_in21k': _cfg( 121 | url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', 122 | num_classes=21843), 123 | 'vit_base_patch32_224_in21k': _cfg( 124 | url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz', 125 | num_classes=21843), 126 | 'vit_base_patch16_224_in21k': _cfg( 127 | url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', 128 | num_classes=21843), 129 | 'vit_base_patch8_224_in21k': _cfg( 130 | url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', 131 | num_classes=21843), 132 | 'vit_large_patch32_224_in21k': _cfg( 133 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', 134 | num_classes=21843), 135 | 'vit_large_patch16_224_in21k': _cfg( 136 | url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', 137 | num_classes=21843), 138 | 'vit_huge_patch14_224_in21k': _cfg( 139 | url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', 140 | hf_hub_id='timm/vit_huge_patch14_224_in21k', 141 | num_classes=21843), 142 | 143 | # SAM trained models (https://arxiv.org/abs/2106.01548) 144 | 'vit_base_patch32_224_sam': _cfg( 145 | url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'), 146 | 'vit_base_patch16_224_sam': _cfg( 147 | url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'), 148 | 149 | # DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only) 150 | 'vit_small_patch16_224_dino': _cfg( 151 | url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth', 152 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), 153 | 'vit_small_patch8_224_dino': _cfg( 154 | url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth', 155 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), 156 | 'vit_base_patch16_224_dino': _cfg( 157 | url='https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth', 158 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), 159 | 'vit_base_patch8_224_dino': _cfg( 160 | url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth', 161 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), 162 | 163 | 164 | # ViT ImageNet-21K-P pretraining by MILL 165 | 'vit_base_patch16_224_miil_in21k': _cfg( 166 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth', 167 | mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221), 168 | 'vit_base_patch16_224_miil': _cfg( 169 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_1k_miil_84_4-2deb18e3.pth', 170 | mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'), 171 | 172 | 'vit_base_patch16_rpn_224': _cfg( 173 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth'), 174 | 175 | # experimental (may be removed) 176 | 'vit_base_patch32_plus_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95), 177 | 'vit_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95), 178 | 'vit_small_patch16_36x1_224': _cfg(url=''), 179 | 'vit_small_patch16_18x2_224': _cfg(url=''), 180 | 'vit_base_patch16_18x2_224': _cfg(url=''), 181 | 182 | 'vit_base_patch32_224_clip_laion2b': _cfg( 183 | hf_hub_id='laion/CLIP-ViT-B-32-laion2B-s34B-b79K', 184 | hf_hub_filename='open_clip_pytorch_model.bin', 185 | mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), 186 | 'vit_large_patch14_224_clip_laion2b': _cfg( 187 | hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K', 188 | hf_hub_filename='open_clip_pytorch_model.bin', 189 | mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=768), 190 | 'vit_huge_patch14_224_clip_laion2b': _cfg( 191 | hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K', 192 | hf_hub_filename='open_clip_pytorch_model.bin', 193 | mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=1024), 194 | 'vit_giant_patch14_224_clip_laion2b': _cfg( 195 | hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K', 196 | hf_hub_filename='open_clip_pytorch_model.bin', 197 | mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=1024), 198 | 199 | } 200 | 201 | 202 | class Attention(nn.Module): 203 | def __init__( 204 | self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., 205 | perturb=False 206 | 207 | ): 208 | super().__init__() 209 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 210 | self.num_heads = num_heads 211 | head_dim = dim // num_heads 212 | self.scale = head_dim ** -0.5 213 | 214 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 215 | self.attn_drop = nn.Dropout(attn_drop) 216 | self.proj = nn.Linear(dim, dim) 217 | self.proj_drop = nn.Dropout(proj_drop) 218 | 219 | self.perturb = perturb 220 | 221 | def save_attn_gradients(self, attn_gradients): 222 | self.attn_gradients = attn_gradients 223 | 224 | def forward(self, x): 225 | B, N, C = x.shape 226 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 227 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 228 | 229 | attn = (q @ k.transpose(-2, -1)) * self.scale 230 | attn = attn.softmax(dim=-1) 231 | attn = self.attn_drop(attn) 232 | 233 | self.attn_probs = attn 234 | 235 | #attn.register_hook(self.save_attn_gradients) 236 | 237 | if self.perturb == 'self_only': 238 | for t in range(1, attn.shape[-1]): 239 | attn[:, :, t, :] = 0 240 | attn[:, :, t, t] = 1 241 | elif self.perturb == 'no_cls': 242 | for t in range(1, attn.shape[-1]): 243 | attn[:, :, t, 0] = 0 244 | 245 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 246 | self.key_proj_vals = x 247 | 248 | x = self.proj(x) 249 | x = self.proj_drop(x) 250 | return x 251 | 252 | 253 | class LayerScale(nn.Module): 254 | def __init__(self, dim, init_values=1e-5, inplace=False): 255 | super().__init__() 256 | self.inplace = inplace 257 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 258 | 259 | def forward(self, x): 260 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 261 | 262 | 263 | class Block(nn.Module): 264 | def __init__( 265 | self, 266 | dim, 267 | num_heads, 268 | mlp_ratio=4., 269 | qkv_bias=False, 270 | drop=0., 271 | attn_drop=0., 272 | init_values=None, 273 | drop_path=0., 274 | act_layer=nn.GELU, 275 | norm_layer=nn.LayerNorm, 276 | perturb_type=None, 277 | norm_cls=None, 278 | func_cls=None 279 | ): 280 | super().__init__() 281 | self.norm1 = norm_layer(dim) 282 | self.attn = Attention( 283 | dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, 284 | proj_drop=drop, perturb=perturb_type 285 | ) 286 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 287 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 288 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 289 | 290 | self.norm2 = norm_layer(dim) 291 | self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) 292 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 293 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 294 | 295 | self.norm_cls = norm_cls 296 | self.func_cls = func_cls 297 | 298 | def forward(self, x): 299 | attn_x = self.ls1(self.attn(self.norm1(x))) 300 | self.attn_cls = self.func_cls(self.norm_cls(attn_x)) 301 | #self.attn_cls = self.func_cls(self.norm_cls(self.attn(self.norm1(x)))) 302 | #x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) 303 | x = x + self.drop_path1(attn_x) 304 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) 305 | return x 306 | 307 | 308 | class ResPostBlock(nn.Module): 309 | 310 | def __init__( 311 | self, 312 | dim, 313 | num_heads, 314 | mlp_ratio=4., 315 | qkv_bias=False, 316 | drop=0., 317 | attn_drop=0., 318 | init_values=None, 319 | drop_path=0., 320 | act_layer=nn.GELU, 321 | norm_layer=nn.LayerNorm 322 | ): 323 | super().__init__() 324 | self.init_values = init_values 325 | 326 | self.attn = Attention( 327 | dim, num_heads=num_heads, qkv_bias=qkv_bias, 328 | attn_drop=attn_drop, proj_drop=drop 329 | ) 330 | self.norm1 = norm_layer(dim) 331 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 332 | 333 | self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) 334 | self.norm2 = norm_layer(dim) 335 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 336 | 337 | self.init_weights() 338 | 339 | def init_weights(self): 340 | # NOTE this init overrides that base model init with specific changes for the block type 341 | if self.init_values is not None: 342 | nn.init.constant_(self.norm1.weight, self.init_values) 343 | nn.init.constant_(self.norm2.weight, self.init_values) 344 | 345 | def forward(self, x): 346 | x = x + self.drop_path1(self.norm1(self.attn(x))) 347 | x = x + self.drop_path2(self.norm2(self.mlp(x))) 348 | return x 349 | 350 | 351 | class ParallelBlock(nn.Module): 352 | 353 | def __init__( 354 | self, 355 | dim, 356 | num_heads, 357 | num_parallel=2, 358 | mlp_ratio=4., 359 | qkv_bias=False, 360 | init_values=None, 361 | drop=0., 362 | attn_drop=0., 363 | drop_path=0., 364 | act_layer=nn.GELU, 365 | norm_layer=nn.LayerNorm 366 | ): 367 | super().__init__() 368 | self.num_parallel = num_parallel 369 | self.attns = nn.ModuleList() 370 | self.ffns = nn.ModuleList() 371 | for _ in range(num_parallel): 372 | self.attns.append(nn.Sequential(OrderedDict([ 373 | ('norm', norm_layer(dim)), 374 | ('attn', Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)), 375 | ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), 376 | ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) 377 | ]))) 378 | self.ffns.append(nn.Sequential(OrderedDict([ 379 | ('norm', norm_layer(dim)), 380 | ('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)), 381 | ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), 382 | ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) 383 | ]))) 384 | 385 | def _forward_jit(self, x): 386 | x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0) 387 | x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0) 388 | return x 389 | 390 | @torch.jit.ignore 391 | def _forward(self, x): 392 | x = x + sum(attn(x) for attn in self.attns) 393 | x = x + sum(ffn(x) for ffn in self.ffns) 394 | return x 395 | 396 | def forward(self, x): 397 | if torch.jit.is_scripting() or torch.jit.is_tracing(): 398 | return self._forward_jit(x) 399 | else: 400 | return self._forward(x) 401 | 402 | 403 | class VisionTransformer(nn.Module): 404 | """ Vision Transformer 405 | 406 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 407 | - https://arxiv.org/abs/2010.11929 408 | """ 409 | 410 | def __init__( 411 | self, 412 | img_size=224, 413 | patch_size=16, 414 | in_chans=3, 415 | num_classes=1000, 416 | global_pool='token', 417 | embed_dim=768, 418 | depth=12, 419 | num_heads=12, 420 | mlp_ratio=4., 421 | qkv_bias=True, 422 | init_values=None, 423 | class_token=True, 424 | no_embed_class=False, 425 | pre_norm=False, 426 | fc_norm=None, 427 | drop_rate=0., 428 | attn_drop_rate=0., 429 | drop_path_rate=0., 430 | weight_init='', 431 | embed_layer=PatchEmbed, 432 | norm_layer=None, 433 | act_layer=None, 434 | block_fn=Block, 435 | perturb_type=None, 436 | block_perturb=None 437 | ): 438 | """ 439 | Args: 440 | img_size (int, tuple): input image size 441 | patch_size (int, tuple): patch size 442 | in_chans (int): number of input channels 443 | num_classes (int): number of classes for classification head 444 | global_pool (str): type of global pooling for final sequence (default: 'token') 445 | embed_dim (int): embedding dimension 446 | depth (int): depth of transformer 447 | num_heads (int): number of attention heads 448 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 449 | qkv_bias (bool): enable bias for qkv if True 450 | init_values: (float): layer-scale init values 451 | class_token (bool): use class token 452 | fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) 453 | drop_rate (float): dropout rate 454 | attn_drop_rate (float): attention dropout rate 455 | drop_path_rate (float): stochastic depth rate 456 | weight_init (str): weight init scheme 457 | embed_layer (nn.Module): patch embedding layer 458 | norm_layer: (nn.Module): normalization layer 459 | act_layer: (nn.Module): MLP activation layer 460 | """ 461 | super().__init__() 462 | assert global_pool in ('', 'avg', 'token') 463 | assert class_token or global_pool != 'token' 464 | use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm 465 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 466 | act_layer = act_layer or nn.GELU 467 | 468 | self.num_classes = num_classes 469 | self.global_pool = global_pool 470 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 471 | self.num_prefix_tokens = 1 if class_token else 0 472 | self.no_embed_class = no_embed_class 473 | self.grad_checkpointing = False 474 | 475 | self.patch_embed = embed_layer( 476 | img_size=img_size, 477 | patch_size=patch_size, 478 | in_chans=in_chans, 479 | embed_dim=embed_dim, 480 | bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) 481 | ) 482 | num_patches = self.patch_embed.num_patches 483 | 484 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None 485 | embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens 486 | self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) 487 | self.pos_drop = nn.Dropout(p=drop_rate) 488 | self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() 489 | 490 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 491 | 492 | self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() 493 | 494 | # Classifier Head 495 | self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() 496 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 497 | 498 | if block_perturb == 'all': 499 | self.blocks = nn.Sequential(*[ 500 | block_fn( 501 | dim=embed_dim, 502 | num_heads=num_heads, 503 | mlp_ratio=mlp_ratio, 504 | qkv_bias=qkv_bias, 505 | init_values=init_values, 506 | drop=drop_rate, 507 | attn_drop=attn_drop_rate, 508 | drop_path=dpr[i], 509 | norm_layer=norm_layer, 510 | act_layer=act_layer, 511 | perturb_type=perturb_type, 512 | norm_cls=self.norm, 513 | func_cls =self.head, 514 | ) 515 | for i in range(depth)]) 516 | else: 517 | self.blocks = nn.Sequential(*[ 518 | block_fn( 519 | dim=embed_dim, 520 | num_heads=num_heads, 521 | mlp_ratio=mlp_ratio, 522 | qkv_bias=qkv_bias, 523 | init_values=init_values, 524 | drop=drop_rate, 525 | attn_drop=attn_drop_rate, 526 | drop_path=dpr[i], 527 | norm_layer=norm_layer, 528 | act_layer=act_layer, 529 | perturb_type=perturb_type, 530 | norm_cls=self.norm, 531 | func_cls =self.head 532 | ) if i == block_perturb 533 | else block_fn( 534 | dim=embed_dim, 535 | num_heads=num_heads, 536 | mlp_ratio=mlp_ratio, 537 | qkv_bias=qkv_bias, 538 | init_values=init_values, 539 | drop=drop_rate, 540 | attn_drop=attn_drop_rate, 541 | drop_path=dpr[i], 542 | norm_layer=norm_layer, 543 | act_layer=act_layer, 544 | norm_cls=self.norm, 545 | func_cls =self.head 546 | ) 547 | for i in range(depth)]) 548 | 549 | if weight_init != 'skip': 550 | self.init_weights(weight_init) 551 | 552 | def init_weights(self, mode=''): 553 | assert mode in ('jax', 'jax_nlhb', 'moco', '') 554 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. 555 | trunc_normal_(self.pos_embed, std=.02) 556 | if self.cls_token is not None: 557 | nn.init.normal_(self.cls_token, std=1e-6) 558 | named_apply(get_init_weights_vit(mode, head_bias), self) 559 | 560 | def _init_weights(self, m): 561 | # this fn left here for compat with downstream users 562 | init_weights_vit_timm(m) 563 | 564 | @torch.jit.ignore() 565 | def load_pretrained(self, checkpoint_path, prefix=''): 566 | _load_weights(self, checkpoint_path, prefix) 567 | 568 | @torch.jit.ignore 569 | def no_weight_decay(self): 570 | return {'pos_embed', 'cls_token', 'dist_token'} 571 | 572 | @torch.jit.ignore 573 | def group_matcher(self, coarse=False): 574 | return dict( 575 | stem=r'^cls_token|pos_embed|patch_embed', # stem and embed 576 | blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] 577 | ) 578 | 579 | @torch.jit.ignore 580 | def set_grad_checkpointing(self, enable=True): 581 | self.grad_checkpointing = enable 582 | 583 | @torch.jit.ignore 584 | def get_classifier(self): 585 | return self.head 586 | 587 | def reset_classifier(self, num_classes: int, global_pool=None): 588 | self.num_classes = num_classes 589 | if global_pool is not None: 590 | assert global_pool in ('', 'avg', 'token') 591 | self.global_pool = global_pool 592 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 593 | 594 | def _pos_embed(self, x): 595 | if self.no_embed_class: 596 | # deit-3, updated JAX (big vision) 597 | # position embedding does not overlap with class token, add then concat 598 | x = x + self.pos_embed 599 | if self.cls_token is not None: 600 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 601 | else: 602 | # original timm, JAX, and deit vit impl 603 | # pos_embed has entry for class token, concat then add 604 | if self.cls_token is not None: 605 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 606 | x = x + self.pos_embed 607 | return self.pos_drop(x) 608 | 609 | def forward_features(self, x, tokens_mask=None): 610 | x = self.patch_embed(x) 611 | x = self._pos_embed(x) 612 | 613 | if tokens_mask != None: 614 | x = x[:, tokens_mask, :] 615 | 616 | x = self.norm_pre(x) 617 | if self.grad_checkpointing and not torch.jit.is_scripting(): 618 | x = checkpoint_seq(self.blocks, x) 619 | else: 620 | x = self.blocks(x) 621 | x = self.norm(x) 622 | return x 623 | 624 | def forward_head(self, x, pre_logits: bool = False): 625 | if self.global_pool: 626 | x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] 627 | x = self.fc_norm(x) 628 | return x if pre_logits else self.head(x) 629 | 630 | def forward(self, x, tokens_mask=None): 631 | x = self.forward_features(x, tokens_mask) 632 | x = self.forward_head(x) 633 | return x 634 | 635 | 636 | def init_weights_vit_timm(module: nn.Module, name: str = ''): 637 | """ ViT weight initialization, original timm impl (for reproducibility) """ 638 | if isinstance(module, nn.Linear): 639 | trunc_normal_(module.weight, std=.02) 640 | if module.bias is not None: 641 | nn.init.zeros_(module.bias) 642 | elif hasattr(module, 'init_weights'): 643 | module.init_weights() 644 | 645 | 646 | def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.): 647 | """ ViT weight initialization, matching JAX (Flax) impl """ 648 | if isinstance(module, nn.Linear): 649 | if name.startswith('head'): 650 | nn.init.zeros_(module.weight) 651 | nn.init.constant_(module.bias, head_bias) 652 | else: 653 | nn.init.xavier_uniform_(module.weight) 654 | if module.bias is not None: 655 | nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias) 656 | elif isinstance(module, nn.Conv2d): 657 | lecun_normal_(module.weight) 658 | if module.bias is not None: 659 | nn.init.zeros_(module.bias) 660 | elif hasattr(module, 'init_weights'): 661 | module.init_weights() 662 | 663 | 664 | def init_weights_vit_moco(module: nn.Module, name: str = ''): 665 | """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """ 666 | if isinstance(module, nn.Linear): 667 | if 'qkv' in name: 668 | # treat the weights of Q, K, V separately 669 | val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1])) 670 | nn.init.uniform_(module.weight, -val, val) 671 | else: 672 | nn.init.xavier_uniform_(module.weight) 673 | if module.bias is not None: 674 | nn.init.zeros_(module.bias) 675 | elif hasattr(module, 'init_weights'): 676 | module.init_weights() 677 | 678 | 679 | def get_init_weights_vit(mode='jax', head_bias: float = 0.): 680 | if 'jax' in mode: 681 | return partial(init_weights_vit_jax, head_bias=head_bias) 682 | elif 'moco' in mode: 683 | return init_weights_vit_moco 684 | else: 685 | return init_weights_vit_timm 686 | 687 | 688 | @torch.no_grad() 689 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): 690 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation 691 | """ 692 | import numpy as np 693 | 694 | def _n2p(w, t=True): 695 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 696 | w = w.flatten() 697 | if t: 698 | if w.ndim == 4: 699 | w = w.transpose([3, 2, 0, 1]) 700 | elif w.ndim == 3: 701 | w = w.transpose([2, 0, 1]) 702 | elif w.ndim == 2: 703 | w = w.transpose([1, 0]) 704 | return torch.from_numpy(w) 705 | 706 | w = np.load(checkpoint_path) 707 | if not prefix and 'opt/target/embedding/kernel' in w: 708 | prefix = 'opt/target/' 709 | 710 | if hasattr(model.patch_embed, 'backbone'): 711 | # hybrid 712 | backbone = model.patch_embed.backbone 713 | stem_only = not hasattr(backbone, 'stem') 714 | stem = backbone if stem_only else backbone.stem 715 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) 716 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) 717 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) 718 | if not stem_only: 719 | for i, stage in enumerate(backbone.stages): 720 | for j, block in enumerate(stage.blocks): 721 | bp = f'{prefix}block{i + 1}/unit{j + 1}/' 722 | for r in range(3): 723 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) 724 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) 725 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) 726 | if block.downsample is not None: 727 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) 728 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) 729 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) 730 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 731 | else: 732 | embed_conv_w = adapt_input_conv( 733 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) 734 | model.patch_embed.proj.weight.copy_(embed_conv_w) 735 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 736 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 737 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) 738 | if pos_embed_w.shape != model.pos_embed.shape: 739 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights 740 | pos_embed_w, 741 | model.pos_embed, 742 | getattr(model, 'num_prefix_tokens', 1), 743 | model.patch_embed.grid_size 744 | ) 745 | model.pos_embed.copy_(pos_embed_w) 746 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 747 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 748 | if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: 749 | model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 750 | model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 751 | # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights 752 | # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 753 | # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 754 | # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 755 | for i, block in enumerate(model.blocks.children()): 756 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 757 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' 758 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 759 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 760 | block.attn.qkv.weight.copy_(torch.cat([ 761 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 762 | block.attn.qkv.bias.copy_(torch.cat([ 763 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 764 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 765 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 766 | for r in range(2): 767 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) 768 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) 769 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) 770 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) 771 | 772 | 773 | def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()): 774 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 775 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 776 | _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) 777 | ntok_new = posemb_new.shape[1] 778 | if num_prefix_tokens: 779 | posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:] 780 | ntok_new -= num_prefix_tokens 781 | else: 782 | posemb_prefix, posemb_grid = posemb[:, :0], posemb[0] 783 | gs_old = int(math.sqrt(len(posemb_grid))) 784 | if not len(gs_new): # backwards compatibility 785 | gs_new = [int(math.sqrt(ntok_new))] * 2 786 | assert len(gs_new) >= 2 787 | _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) 788 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 789 | posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False) 790 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) 791 | posemb = torch.cat([posemb_prefix, posemb_grid], dim=1) 792 | return posemb 793 | 794 | 795 | def _convert_openai_clip(state_dict, model): 796 | out_dict = {} 797 | swaps = [ 798 | ('visual.', ''), ('conv1', 'patch_embed.proj'), ('positional_embedding', 'pos_embed'), 799 | ('transformer.resblocks.', 'blocks.'), ('ln_pre', 'norm_pre'), ('ln_post', 'norm'), ('ln_', 'norm'), 800 | ('in_proj_', 'qkv.'), ('out_proj', 'proj'), ('mlp.c_fc', 'mlp.fc1'), ('mlp.c_proj', 'mlp.fc2'), 801 | ] 802 | for k, v in state_dict.items(): 803 | if not k.startswith('visual.'): 804 | continue 805 | for sp in swaps: 806 | k = k.replace(sp[0], sp[1]) 807 | 808 | if k == 'proj': 809 | k = 'head.weight' 810 | v = v.transpose(0, 1) 811 | out_dict['head.bias'] = torch.zeros(v.shape[0]) 812 | elif k == 'class_embedding': 813 | k = 'cls_token' 814 | v = v.unsqueeze(0).unsqueeze(1) 815 | elif k == 'pos_embed': 816 | v = v.unsqueeze(0) 817 | if v.shape[1] != model.pos_embed.shape[1]: 818 | # To resize pos embedding when using model at different size from pretrained weights 819 | v = resize_pos_embed( 820 | v, 821 | model.pos_embed, 822 | 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1), 823 | model.patch_embed.grid_size 824 | ) 825 | out_dict[k] = v 826 | return out_dict 827 | 828 | 829 | def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False): 830 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 831 | import re 832 | out_dict = {} 833 | if 'model' in state_dict: 834 | # For deit models 835 | state_dict = state_dict['model'] 836 | 837 | if 'visual.class_embedding' in state_dict: 838 | return _convert_openai_clip(state_dict, model) 839 | 840 | for k, v in state_dict.items(): 841 | if 'patch_embed.proj.weight' in k and len(v.shape) < 4: 842 | # For old models that I trained prior to conv based patchification 843 | O, I, H, W = model.patch_embed.proj.weight.shape 844 | v = v.reshape(O, -1, H, W) 845 | elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]: 846 | # To resize pos embedding when using model at different size from pretrained weights 847 | v = resize_pos_embed( 848 | v, 849 | model.pos_embed, 850 | 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1), 851 | model.patch_embed.grid_size 852 | ) 853 | elif adapt_layer_scale and 'gamma_' in k: 854 | # remap layer-scale gamma into sub-module (deit3 models) 855 | k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k) 856 | elif 'pre_logits' in k: 857 | # NOTE representation layer removed as not used in latest 21k/1k pretrained weights 858 | continue 859 | out_dict[k] = v 860 | return out_dict 861 | 862 | 863 | def _create_vision_transformer(variant, pretrained=False, **kwargs): 864 | if kwargs.get('features_only', None): 865 | raise RuntimeError('features_only not implemented for Vision Transformer models.') 866 | 867 | pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None)) 868 | model = build_model_with_cfg( 869 | VisionTransformer, variant, pretrained, 870 | pretrained_cfg=pretrained_cfg, 871 | pretrained_filter_fn=checkpoint_filter_fn, 872 | pretrained_custom_load='npz' in pretrained_cfg['url'], 873 | **kwargs) 874 | return model 875 | 876 | 877 | @register_model 878 | def vit_tiny_patch16_224(pretrained=False, **kwargs): 879 | """ ViT-Tiny (Vit-Ti/16) 880 | """ 881 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 882 | model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) 883 | return model 884 | 885 | 886 | @register_model 887 | def vit_tiny_patch16_384(pretrained=False, **kwargs): 888 | """ ViT-Tiny (Vit-Ti/16) @ 384x384. 889 | """ 890 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 891 | model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs) 892 | return model 893 | 894 | 895 | @register_model 896 | def vit_small_patch32_224(pretrained=False, **kwargs): 897 | """ ViT-Small (ViT-S/32) 898 | """ 899 | model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) 900 | model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs) 901 | return model 902 | 903 | 904 | @register_model 905 | def vit_small_patch32_384(pretrained=False, **kwargs): 906 | """ ViT-Small (ViT-S/32) at 384x384. 907 | """ 908 | model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) 909 | model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs) 910 | return model 911 | 912 | 913 | @register_model 914 | def vit_small_patch16_224(pretrained=False, **kwargs): 915 | """ ViT-Small (ViT-S/16) 916 | NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper 917 | """ 918 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 919 | model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) 920 | return model 921 | 922 | 923 | @register_model 924 | def vit_small_patch16_384(pretrained=False, **kwargs): 925 | """ ViT-Small (ViT-S/16) 926 | NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper 927 | """ 928 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 929 | model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs) 930 | return model 931 | 932 | 933 | @register_model 934 | def vit_base_patch32_224(pretrained=False, **kwargs): 935 | """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). 936 | ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer. 937 | """ 938 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) 939 | model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs) 940 | return model 941 | 942 | 943 | @register_model 944 | def vit_base_patch32_384(pretrained=False, **kwargs): 945 | """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). 946 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 947 | """ 948 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) 949 | model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs) 950 | return model 951 | 952 | 953 | @register_model 954 | def vit_base_patch16_224(pretrained=False, **kwargs): 955 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 956 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. 957 | """ 958 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 959 | model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) 960 | return model 961 | 962 | 963 | @register_model 964 | def vit_base_patch16_384(pretrained=False, **kwargs): 965 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 966 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 967 | """ 968 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 969 | model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs) 970 | return model 971 | 972 | 973 | @register_model 974 | def vit_base_patch8_224(pretrained=False, **kwargs): 975 | """ ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929). 976 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. 977 | """ 978 | model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) 979 | model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **model_kwargs) 980 | return model 981 | 982 | 983 | @register_model 984 | def vit_large_patch32_224(pretrained=False, **kwargs): 985 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. 986 | """ 987 | model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) 988 | model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs) 989 | return model 990 | 991 | 992 | @register_model 993 | def vit_large_patch32_384(pretrained=False, **kwargs): 994 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). 995 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 996 | """ 997 | model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) 998 | model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs) 999 | return model 1000 | 1001 | 1002 | @register_model 1003 | def vit_large_patch16_224(pretrained=False, **kwargs): 1004 | """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). 1005 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. 1006 | """ 1007 | model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) 1008 | model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs) 1009 | return model 1010 | 1011 | 1012 | @register_model 1013 | def vit_large_patch16_384(pretrained=False, **kwargs): 1014 | """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). 1015 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 1016 | """ 1017 | model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) 1018 | model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs) 1019 | return model 1020 | 1021 | 1022 | @register_model 1023 | def vit_large_patch14_224(pretrained=False, **kwargs): 1024 | """ ViT-Large model (ViT-L/14) 1025 | """ 1026 | model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, **kwargs) 1027 | model = _create_vision_transformer('vit_large_patch14_224', pretrained=pretrained, **model_kwargs) 1028 | return model 1029 | 1030 | 1031 | @register_model 1032 | def vit_huge_patch14_224(pretrained=False, **kwargs): 1033 | """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). 1034 | """ 1035 | model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs) 1036 | model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **model_kwargs) 1037 | return model 1038 | 1039 | 1040 | @register_model 1041 | def vit_giant_patch14_224(pretrained=False, **kwargs): 1042 | """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 1043 | """ 1044 | model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, **kwargs) 1045 | model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **model_kwargs) 1046 | return model 1047 | 1048 | 1049 | @register_model 1050 | def vit_gigantic_patch14_224(pretrained=False, **kwargs): 1051 | """ ViT-Gigantic (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 1052 | """ 1053 | model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs) 1054 | model = _create_vision_transformer('vit_gigantic_patch14_224', pretrained=pretrained, **model_kwargs) 1055 | return model 1056 | 1057 | 1058 | @register_model 1059 | def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs): 1060 | """ ViT-Tiny (Vit-Ti/16). 1061 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 1062 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 1063 | """ 1064 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 1065 | model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs) 1066 | return model 1067 | 1068 | 1069 | @register_model 1070 | def vit_small_patch32_224_in21k(pretrained=False, **kwargs): 1071 | """ ViT-Small (ViT-S/16) 1072 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 1073 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 1074 | """ 1075 | model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) 1076 | model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs) 1077 | return model 1078 | 1079 | 1080 | @register_model 1081 | def vit_small_patch16_224_in21k(pretrained=False, **kwargs): 1082 | """ ViT-Small (ViT-S/16) 1083 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 1084 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 1085 | """ 1086 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 1087 | model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs) 1088 | return model 1089 | 1090 | 1091 | @register_model 1092 | def vit_base_patch32_224_in21k(pretrained=False, **kwargs): 1093 | """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). 1094 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 1095 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 1096 | """ 1097 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) 1098 | model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs) 1099 | return model 1100 | 1101 | 1102 | @register_model 1103 | def vit_base_patch16_224_in21k(pretrained=False, **kwargs): 1104 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 1105 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 1106 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 1107 | """ 1108 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 1109 | model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) 1110 | return model 1111 | 1112 | 1113 | @register_model 1114 | def vit_base_patch8_224_in21k(pretrained=False, **kwargs): 1115 | """ ViT-Base model (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929). 1116 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 1117 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 1118 | """ 1119 | model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) 1120 | model = _create_vision_transformer('vit_base_patch8_224_in21k', pretrained=pretrained, **model_kwargs) 1121 | return model 1122 | 1123 | 1124 | @register_model 1125 | def vit_large_patch32_224_in21k(pretrained=False, **kwargs): 1126 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). 1127 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 1128 | NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights 1129 | """ 1130 | model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) 1131 | model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs) 1132 | return model 1133 | 1134 | 1135 | @register_model 1136 | def vit_large_patch16_224_in21k(pretrained=False, **kwargs): 1137 | """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). 1138 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 1139 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 1140 | """ 1141 | model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) 1142 | model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) 1143 | return model 1144 | 1145 | 1146 | @register_model 1147 | def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): 1148 | """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). 1149 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 1150 | NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights 1151 | """ 1152 | model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs) 1153 | model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs) 1154 | return model 1155 | 1156 | 1157 | @register_model 1158 | def vit_base_patch16_224_sam(pretrained=False, **kwargs): 1159 | """ ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 1160 | """ 1161 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 1162 | model = _create_vision_transformer('vit_base_patch16_224_sam', pretrained=pretrained, **model_kwargs) 1163 | return model 1164 | 1165 | 1166 | @register_model 1167 | def vit_base_patch32_224_sam(pretrained=False, **kwargs): 1168 | """ ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 1169 | """ 1170 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) 1171 | model = _create_vision_transformer('vit_base_patch32_224_sam', pretrained=pretrained, **model_kwargs) 1172 | return model 1173 | 1174 | 1175 | @register_model 1176 | def vit_small_patch16_224_dino(pretrained=False, **kwargs): 1177 | """ ViT-Small (ViT-S/16) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294 1178 | """ 1179 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 1180 | model = _create_vision_transformer('vit_small_patch16_224_dino', pretrained=pretrained, **model_kwargs) 1181 | return model 1182 | 1183 | 1184 | @register_model 1185 | def vit_small_patch8_224_dino(pretrained=False, **kwargs): 1186 | """ ViT-Small (ViT-S/8) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294 1187 | """ 1188 | model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs) 1189 | model = _create_vision_transformer('vit_small_patch8_224_dino', pretrained=pretrained, **model_kwargs) 1190 | return model 1191 | 1192 | 1193 | @register_model 1194 | def vit_base_patch16_224_dino(pretrained=False, **kwargs): 1195 | """ ViT-Base (ViT-B/16) /w DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294 1196 | """ 1197 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 1198 | model = _create_vision_transformer('vit_base_patch16_224_dino', pretrained=pretrained, **model_kwargs) 1199 | return model 1200 | 1201 | 1202 | @register_model 1203 | def vit_base_patch8_224_dino(pretrained=False, **kwargs): 1204 | """ ViT-Base (ViT-B/8) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294 1205 | """ 1206 | model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) 1207 | model = _create_vision_transformer('vit_base_patch8_224_dino', pretrained=pretrained, **model_kwargs) 1208 | return model 1209 | 1210 | 1211 | @register_model 1212 | def vit_base_patch16_224_miil_in21k(pretrained=False, **kwargs): 1213 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 1214 | Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K 1215 | """ 1216 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) 1217 | model = _create_vision_transformer('vit_base_patch16_224_miil_in21k', pretrained=pretrained, **model_kwargs) 1218 | return model 1219 | 1220 | 1221 | @register_model 1222 | def vit_base_patch16_224_miil(pretrained=False, **kwargs): 1223 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 1224 | Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K 1225 | """ 1226 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) 1227 | model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs) 1228 | return model 1229 | 1230 | 1231 | # Experimental models below 1232 | 1233 | @register_model 1234 | def vit_base_patch32_plus_256(pretrained=False, **kwargs): 1235 | """ ViT-Base (ViT-B/32+) 1236 | """ 1237 | model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs) 1238 | model = _create_vision_transformer('vit_base_patch32_plus_256', pretrained=pretrained, **model_kwargs) 1239 | return model 1240 | 1241 | 1242 | @register_model 1243 | def vit_base_patch16_plus_240(pretrained=False, **kwargs): 1244 | """ ViT-Base (ViT-B/16+) 1245 | """ 1246 | model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs) 1247 | model = _create_vision_transformer('vit_base_patch16_plus_240', pretrained=pretrained, **model_kwargs) 1248 | return model 1249 | 1250 | 1251 | @register_model 1252 | def vit_base_patch16_rpn_224(pretrained=False, **kwargs): 1253 | """ ViT-Base (ViT-B/16) w/ residual post-norm 1254 | """ 1255 | model_kwargs = dict( 1256 | patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, class_token=False, 1257 | block_fn=ResPostBlock, global_pool=kwargs.pop('global_pool', 'avg'), **kwargs) 1258 | model = _create_vision_transformer('vit_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs) 1259 | return model 1260 | 1261 | 1262 | @register_model 1263 | def vit_small_patch16_36x1_224(pretrained=False, **kwargs): 1264 | """ ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove. 1265 | Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 1266 | Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. 1267 | """ 1268 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5, **kwargs) 1269 | model = _create_vision_transformer('vit_small_patch16_36x1_224', pretrained=pretrained, **model_kwargs) 1270 | return model 1271 | 1272 | 1273 | @register_model 1274 | def vit_small_patch16_18x2_224(pretrained=False, **kwargs): 1275 | """ ViT-Small w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. 1276 | Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 1277 | Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. 1278 | """ 1279 | model_kwargs = dict( 1280 | patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock, **kwargs) 1281 | model = _create_vision_transformer('vit_small_patch16_18x2_224', pretrained=pretrained, **model_kwargs) 1282 | return model 1283 | 1284 | 1285 | @register_model 1286 | def vit_base_patch16_18x2_224(pretrained=False, **kwargs): 1287 | """ ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. 1288 | Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 1289 | """ 1290 | model_kwargs = dict( 1291 | patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs) 1292 | model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs) 1293 | return model 1294 | 1295 | 1296 | @register_model 1297 | def vit_base_patch32_224_clip_laion2b(pretrained=False, **kwargs): 1298 | """ ViT-B/32 1299 | Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs. 1300 | """ 1301 | model_kwargs = dict( 1302 | patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) 1303 | model = _create_vision_transformer('vit_base_patch32_224_clip_laion2b', pretrained=pretrained, **model_kwargs) 1304 | return model 1305 | 1306 | 1307 | @register_model 1308 | def vit_large_patch14_224_clip_laion2b(pretrained=False, **kwargs): 1309 | """ ViT-Large model (ViT-L/14) 1310 | Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs. 1311 | """ 1312 | model_kwargs = dict( 1313 | patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) 1314 | model = _create_vision_transformer('vit_large_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs) 1315 | return model 1316 | 1317 | 1318 | @register_model 1319 | def vit_huge_patch14_224_clip_laion2b(pretrained=False, **kwargs): 1320 | """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). 1321 | Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs. 1322 | """ 1323 | model_kwargs = dict( 1324 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) 1325 | model = _create_vision_transformer('vit_huge_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs) 1326 | return model 1327 | 1328 | 1329 | @register_model 1330 | def vit_giant_patch14_224_clip_laion2b(pretrained=False, **kwargs): 1331 | """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 1332 | Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs. 1333 | """ 1334 | model_kwargs = dict( 1335 | patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, 1336 | pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) 1337 | model = _create_vision_transformer('vit_giant_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs) 1338 | return model 1339 | --------------------------------------------------------------------------------