├── src
├── __init__.py
├── models
│ ├── __init__.py
│ ├── load.py
│ ├── deit_ensemble.py
│ └── vit_edited.py
├── plots
│ ├── __init__.py
│ ├── explainability.py
│ ├── linear_probing.py
│ ├── class_ident.py
│ └── mech_interp.py
├── datasets
│ ├── __init__.py
│ ├── cifar.py
│ └── imagenet.py
├── perturbation
│ ├── __init__.py
│ ├── attn_perturbation.py
│ ├── tokens_perturbation.py
│ └── explain_perturbation.py
├── linear_probing
│ ├── __init__.py
│ └── prober.py
├── accuracy.py
├── vis.py
├── gradients.py
├── identifiability.py
├── extractor.py
└── memories.py
├── framework.png
├── notebooks
└── images
│ ├── sample_0.JPEG
│ ├── sample_1.JPEG
│ ├── framework_1.png
│ └── framework_2.png
├── setup.py
├── requirements.txt
├── .gitignore
└── README.md
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/plots/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/perturbation/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/linear_probing/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/martinagvilas/vit-cls_emb/HEAD/framework.png
--------------------------------------------------------------------------------
/notebooks/images/sample_0.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/martinagvilas/vit-cls_emb/HEAD/notebooks/images/sample_0.JPEG
--------------------------------------------------------------------------------
/notebooks/images/sample_1.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/martinagvilas/vit-cls_emb/HEAD/notebooks/images/sample_1.JPEG
--------------------------------------------------------------------------------
/notebooks/images/framework_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/martinagvilas/vit-cls_emb/HEAD/notebooks/images/framework_1.png
--------------------------------------------------------------------------------
/notebooks/images/framework_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/martinagvilas/vit-cls_emb/HEAD/notebooks/images/framework_2.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = "vit_cls_emb",
5 | version = "0.0.1",
6 | packages=find_packages(),
7 | )
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib
2 | numpy
3 | opencv-python
4 | jupyterlab
5 | pandas == 1.5.2
6 | seaborn
7 | scipy
8 |
9 | einops
10 | scikit-learn
11 | #timm == 0.9.2
12 | timm == 0.6.12
13 | tqdm == 4.64.1
14 | torch == 1.13.1
15 | transformers == 4.26.0
16 |
--------------------------------------------------------------------------------
/src/datasets/cifar.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from PIL import Image
3 |
4 | from torchvision.datasets import CIFAR100
5 |
6 |
7 | class MyCIFAR100(CIFAR100):
8 | def __init__(self, imgs_path, n=5):
9 | self.imgs_path = imgs_path
10 | super().__init__(imgs_path, train=False, download=False, transform=None)
11 | self.n_imgs = n
12 | self.stim_info = self._get_stimuli_info()
13 | return
14 |
15 | def _get_stimuli_info(self):
16 | stim_info = defaultdict(list)
17 | for img, label in zip(self.data, self.targets):
18 | label = str(label)
19 | if len(stim_info[label]) < self.n_imgs:
20 | stim_info[label].append(Image.fromarray(img))
21 | return stim_info
22 |
--------------------------------------------------------------------------------
/src/accuracy.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import torch
4 |
5 | from src.datasets.imagenet import ImagenetDatasetS
6 |
7 |
8 | def compute_accuracy(
9 | proj_path, dataset_path, model, device='cpu', by_concept=False,
10 | concept=None, percentage=True
11 | ):
12 | # Load accuracy
13 | acc_file = Path(proj_path) / 'results/class_embed' / model / 'acc.pt'
14 | acc = torch.load(acc_file, map_location=device)
15 |
16 | # Compute accuracy by concept
17 | if concept:
18 | stim_info = ImagenetDatasetS(Path(dataset_path)).stim_info
19 | c_idx = torch.Tensor(
20 | stim_info.loc[stim_info['imagenet_id'] == concept].index
21 | ).long()
22 | acc = acc[c_idx]
23 | if percentage:
24 | return torch.sum(acc) / len(acc) * 100
25 | else:
26 | return acc
27 | elif by_concept == True:
28 | return acc
29 | else:
30 | return torch.sum(acc) / len(acc) * 100
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Experiment generated
2 | stim_info.csv
3 | figures/
4 | output/
5 | manuscript/
6 | data/
7 | dataset/
8 | results/
9 | model_cktp/
10 | weights/
11 | wandb/
12 | .tmp/
13 | *.pdf
14 | *.zip
15 |
16 | # OS generated files
17 | .DS_Store
18 |
19 | # Byte-compiled / optimized / DLL files
20 | **/__pycache__/
21 | *.py[cod]
22 | *$py.class
23 |
24 | # C extensions
25 | *.so
26 |
27 | # Distribution / packaging
28 | .Python
29 | build/
30 | develop-eggs/
31 | dist/
32 | downloads/
33 | eggs/
34 | .eggs/
35 | lib/
36 | lib64/
37 | parts/
38 | sdist/
39 | var/
40 | wheels/
41 | pip-wheel-metadata/
42 | share/python-wheels/
43 | *.egg-info/
44 | .installed.cfg
45 | *.egg
46 | MANIFEST
47 |
48 | # PyInstaller
49 | # Usually these files are written by a python script from a template
50 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
51 | *.manifest
52 | *.spec
53 |
54 | # Installer logs
55 | pip-log.txt
56 | pip-delete-this-directory.txt
57 |
58 | # Unit test / coverage reports
59 | htmlcov/
60 | .tox/
61 | .nox/
62 | .coverage
63 | .coverage.*
64 | .cache
65 | nosetests.xml
66 | coverage.xml
67 | *.cover
68 | *.py,cover
69 | .hypothesis/
70 | .pytest_cache/
71 |
72 | # Translations
73 | *.mo
74 | *.pot
75 |
76 | # Django stuff:
77 | *.log
78 | local_settings.py
79 | db.sqlite3
80 | db.sqlite3-journal
81 |
82 | # Flask stuff:
83 | instance/
84 | .webassets-cache
85 |
86 | # Scrapy stuff:
87 | .scrapy
88 |
89 | # Sphinx documentation
90 | docs/_build/
91 |
92 | # PyBuilder
93 | target/
94 |
95 | # Jupyter Notebook
96 | .ipynb_checkpoints
97 |
98 | # IPython
99 | profile_default/
100 | ipython_config.py
101 |
102 | # pyenv
103 | .python-version
104 |
105 | # pipenv
106 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
107 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
108 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
109 | # install all needed dependencies.
110 | #Pipfile.lock
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Rope project settings
132 | .ropeproject
133 |
134 | # mkdocs documentation
135 | /site
136 |
137 | # mypy
138 | .mypy_cache/
139 | .dmypy.json
140 | dmypy.json
141 |
142 | # Pyre type checker
143 | .pyre/
144 |
145 | # VSCODE
146 | .vscode/
147 |
148 | # Local History for Visual Studio Code
149 | .history/
150 |
151 | # Built Visual Studio Code Extensions
152 | *.vsix
--------------------------------------------------------------------------------
/src/datasets/imagenet.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from PIL import Image
3 |
4 | import pandas as pd
5 | from torch.utils.data import Dataset
6 |
7 |
8 | # Exclude categories of ImageNet that are not present in ImageNetS
9 | CATS_EXCLUDE = [
10 | 'n04356056', 'n04355933', 'n04493381', 'n02808440', 'n03642806',
11 | 'n03832673', 'n04008634', 'n03773504', 'n03887697', 'n15075141'
12 | ]
13 |
14 |
15 | class ImagenetDatasetS(Dataset):
16 | def __init__(self, imagenet_path, partition='validation', n=5):
17 | self.path = Path(imagenet_path) / 'ImageNetS919'
18 | self.imgs_path = self.path / partition
19 | self.partition = partition
20 | self.n_imgs = n
21 | self.stim_info = self.get_stimuli_info()
22 |
23 | def get_stimuli_info(self):
24 | cats = self.get_category_info()
25 | file = self.path / f'{self.partition}_stim_info.csv'
26 | if file.exists():
27 | stim_info = pd.read_csv(file)
28 | else:
29 | stim_info = []
30 | cat_dirs = [d for d in self.imgs_path.iterdir() if d.is_dir()]
31 | for c in cat_dirs:
32 | imagenet_id = c.name
33 | if imagenet_id in CATS_EXCLUDE:
34 | continue
35 | else:
36 | idx = cats.loc[imagenet_id]['index']
37 | cat = cats.loc[imagenet_id]['cat']
38 | imgs_paths = [i for i in c.iterdir() if i.suffix == '.JPEG']
39 | if len(imgs_paths) < self.n_imgs:
40 | continue
41 | else:
42 | for i_idx, i in enumerate(imgs_paths[:self.n_imgs]):
43 | i_name = i.name
44 | stim_info.append([imagenet_id, idx, cat, i_name, i_idx])
45 | stim_info = pd.DataFrame(
46 | stim_info,
47 | columns=['imagenet_id', 'index', 'cat', 'img_name', 'img_index']
48 | )
49 | stim_info.to_csv(file, index=None)
50 | return stim_info
51 |
52 | def get_category_info(self):
53 | cats_info = pd.read_csv(
54 | self.path / 'LOC_synset_mapping.txt', sep='\t', header=None
55 | )
56 | cats_info[['imagenet_id', 'cat']] = cats_info[0].str.split(' ', n=1, expand=True)
57 | cats_info = cats_info.drop(columns=0).reset_index(drop=False)
58 | cats_info = cats_info.set_index('imagenet_id')
59 | return cats_info
60 |
61 | def __len__(self):
62 | return len(self.stim_info)
63 |
64 | def __getitem__(self, idx):
65 | item_info = self.stim_info.iloc[idx]
66 | item = {}
67 | item['imagenet_id'] = item_info['imagenet_id']
68 | item['index'] = item_info['index']
69 | item['cat'] = item_info['cat']
70 | item['img_index'] = item_info['img_index']
71 | img_path = self.imgs_path / item_info['imagenet_id'] / item_info['img_name']
72 | item['img'] = Image.open(img_path).convert('RGB')
73 | return item
74 |
75 | def get_item_by_concept(self, concept, img_idx):
76 | idx = (
77 | self.stim_info.loc[self.stim_info['imagenet_id'] == concept]
78 | .iloc[img_idx].name
79 | )
80 | return self.__getitem__(idx)
81 |
--------------------------------------------------------------------------------
/src/models/load.py:
--------------------------------------------------------------------------------
1 | import timm
2 | from timm import create_model
3 | import torch
4 | import torchvision.transforms as T
5 | from transformers import AutoImageProcessor
6 |
7 | from src.models.deit_ensemble import base_patch16_224_hierarchical
8 |
9 |
10 | def load_vit(
11 | model_name, device, proj_path=None, return_transform=False,
12 | pretrained=True
13 | ):
14 | """_summary_
15 |
16 | Parameters
17 | ----------
18 | model_name : str
19 | Can be one of the following options: vit_b_16, vit_b_32, vit_large_16,
20 | vit_miil_16, vit_cifar_16, deit_ensemble_16, vit_gap_16.
21 | device : str
22 | 'cpu' or 'cuda'
23 | proj_path : pathlib Path, optional
24 | Path to the folder containing the source code, by default None
25 | return_transform : bool, optional
26 | Return image transform, by default False
27 |
28 | Returns
29 | -------
30 | list
31 | Containing model, number of tokens, dimension of hidden states
32 | and image transform.
33 | """
34 | # Get model source and info
35 | if model_name == 'vit_b_32':
36 | msource = 'vit_base_patch32_224'
37 | psource = 'google/vit-base-patch32-224-in21k'
38 | n_tokens = 50
39 | hs_dim = 768
40 | elif model_name == 'vit_b_16':
41 | msource = 'vit_base_patch16_224'
42 | psource = 'google/vit-base-patch16-224-in21k'
43 | n_tokens = 197
44 | hs_dim = 768
45 | elif model_name == 'vit_large_16':
46 | msource = 'vit_large_patch16_224'
47 | psource = 'google/vit-large-patch16-224-in21k'
48 | n_tokens = 197
49 | hs_dim = 1024
50 | elif model_name == 'vit_miil_16':
51 | msource = 'vit_base_patch16_224_miil'
52 | n_tokens = 197
53 | hs_dim = 768
54 | elif model_name == 'vit_cifar_16':
55 | n_tokens = 197
56 | hs_dim = 768
57 | elif model_name == 'deit_ensemble_16':
58 | psource = 'facebook/deit-base-distilled-patch16-224'
59 | n_tokens = 197
60 | hs_dim = 768
61 | elif model_name == 'vit_gap_16':
62 | msource = 'vit_base_patch16_rpn_224'
63 | psource = 'google/vit-base-patch16-224-in21k'
64 | n_tokens = 196
65 | hs_dim = 768
66 |
67 | # Load model
68 | if model_name.startswith('deit'):
69 | model = base_patch16_224_hierarchical(pretrained=pretrained).to(device)
70 | elif 'cifar' in model_name:
71 | model = create_model('vit_base_patch16_224', num_classes=100).to(device)
72 | if pretrained == True:
73 | state_dict = torch.load(proj_path / 'model_cktp' / 'vit_base_patch16_224-CIFAR100.pt')
74 | state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
75 | model.load_state_dict(state_dict)
76 | else:
77 | model = create_model(msource, pretrained=pretrained).to(device)
78 |
79 | # Get image transform
80 | if return_transform == True:
81 | if 'miil' in model_name:
82 | data_config = timm.data.resolve_data_config({}, model=model)
83 | img_transform = timm.data.create_transform(**data_config, is_training=False)
84 | elif 'cifar' in model_name:
85 | img_transform = T.Compose([
86 | T.Resize((224, 224),),
87 | T.CenterCrop(224),
88 | T.ToTensor(),
89 | T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
90 | ])
91 | else:
92 | img_transform = AutoImageProcessor.from_pretrained(psource)
93 | else:
94 | img_transform = None
95 |
96 | return model, n_tokens, hs_dim, img_transform
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Analyzing Vision Transformers in Class Embedding Space (NeurIPS '23)
2 | _by [Martina G. Vilas](https://martinagvilas.github.io/), Timothy Schaumlöffel and [Gemma Roig](http://www.cvai.cs.uni-frankfurt.de/team.html)_
3 |
4 | __*Links*__: [Paper](https://arxiv.org/abs/2310.18969) | [Video presentation]() _(coming soon)_ | [Poster]() _(coming soon)_
5 |
6 | > __Abstract__: Despite the growing use of transformer models in computer vision, a mechanistic
7 | understanding of these networks is still needed. This work introduces a method to
8 | reverse-engineer Vision Transformers trained to solve image classification tasks.
9 | Inspired by previous research in NLP, we demonstrate how the inner representations
10 | at any level of the hierarchy can be projected onto the learned class embedding
11 | space to uncover how these networks build categorical representations for their pre-
12 | dictions. We use our framework to show how image tokens develop class-specific
13 | representations that depend on attention mechanisms and contextual information,
14 | and give insights on how self-attention and MLP layers differentially contribute to
15 | this categorical composition. We additionally demonstrate that this method (1) can
16 | be used to determine the parts of an image that would be important for detecting
17 | the class of interest, and (2) exhibits significant advantages over traditional linear
18 | probing approaches. Taken together, our results position our proposed framework
19 | as a powerful tool for mechanistic interpretability and explainability research.
20 |
21 | 
22 |
23 |
24 | Schematic of our framework
25 |
26 |
27 | ## :paperclip: Contents
28 |
29 | - [Tutorial](#tutorial)
30 | - [Running the experiments](#running-the-experiments)
31 | - [Citing our work](#citing-our-work)
32 | - [Acknowledgements](#acknowledgements)
33 |
34 | ## Tutorial
35 |
36 | You can access a tutorial of our method here:
37 |
38 |
39 |
40 | ## Running the experiments
41 |
42 | #### Step 1: Get a local working copy of this code
43 | __1.1.__ Clone this repository in your local machine.
44 |
45 | __1.2.__ Install the required software using conda, by running:
46 | ```
47 | conda create --name vit-cls python=3.9
48 | conda activate vit-cls
49 | pip install -r requirements.txt
50 | pip install .
51 | ```
52 |
53 | #### Step 2: Download the dataset and model checkpoints
54 | __2.1.__ Download the ImageNet-S dataset from [here](https://github.com/LUSSeg/ImageNet-S).
55 |
56 | __2.2.__ Download the stimuli info file from [here](https://drive.google.com/drive/folders/1bkJeOGMxU2Ta0CrtKeY9JBLArwmQM9mu?usp=sharing), and place it inside the `ImageNet-S/ImageNetS919`
57 | folder downloaded in the previous step.
58 |
59 | __2.3.__ Download the model checkpoint folder from [here](https://drive.google.com/drive/folders/1bkJeOGMxU2Ta0CrtKeY9JBLArwmQM9mu?usp=sharing), and place it inside the project folder.
60 |
61 | #### Step 3: Run experiments for extracting code
62 | __3.1.__ Project hidden states to class embedding space and save key coefficients, by running:
63 | ```
64 | python extractor.py -pp {PATH TO SOURCE CODE} -dp {PATH TO DATASET} -m {MODEL} -pretrained
65 | ```
66 | - The model can be one of `vit_b_32`, `vit_b_16`, `vit_large_16`, `vit_cifar_16`, `vit_miil_16`, `deit_ensemble_16` (_Refinement_ model) and `vit_gap_16`.
67 | - You can reproduce the results of the random model by removing the `-pretrained` flag.
68 |
69 |
70 | __3.2.__ Run attention perturbation studies, by:
71 | ```
72 | python perturbation/attn_perturbation.py -pp {PATH TO SOURCE CODE} -dp {PATH TO DATASET} -m vit_b_32 -pt {PERTURBATION TYPE}
73 | ```
74 | - Perturbation type can be one of `self_only` or `no_cls`.
75 |
76 | __3.3.__ Run context perturbation studies, by:
77 | ```
78 | python perturbation/tokens_perturbation.py -pp {PATH TO SOURCE CODE} -dp {PATH TO DATASET} -m vit_b_32 -mt {MASK TYPE}
79 | ```
80 | - Mask type can be one of `context` or `class label`.
81 |
82 | __3.4.__ Run memory extractor, by:
83 | ```
84 | python memories.py -pp {PATH TO SOURCE CODE} -dp {PATH TO DATASET} -m {MODEL} -lt {LAYER TYPE}
85 | ```
86 | - Layer type can be one of `attn` or `mlp`.
87 |
88 | __3.5.__ Run comparison with a linear probing approach, by:
89 | ```
90 | python linear_probing/prober.py -pp {PATH TO SOURCE CODE} -dp {PATH TO DATASET} -l {LAYER INDEX}
91 | ```
92 |
93 | #### Step 4: Reproduce the results
94 | After running the above code,
95 | head to the [notebooks](https://github.com/martinagvilas/vit-cls_emb/tree/main/notebooks) section to reproduce and visualize the reported results.
96 |
97 | ## Citing our work
98 | Please cite this work as:
99 | ```
100 | @inproceedings{vilas2023analyzing_vit,
101 | title = {Analyzing Vision Transformers for Image Classification in Class Embedding Space},
102 | author = {Vilas, Martina G. and Schauml\"{o}ffel, Timothy and Roig, Gemma},
103 | booktitle = {Advances in Neural Information Processing Systems},
104 | pages = {40030--40041},
105 | volume = {36},
106 | year = {2023}
107 | url = {https://proceedings.neurips.cc/paper_files/paper/2023/file/7dd309df03d37643b96f5048b44da798-Paper-Conference.pdf},
108 | }
109 |
110 | ```
111 |
112 |
113 | ## Acknowledgements
114 | - The pre-trained models are extracted from the [timm](https://github.com/huggingface/pytorch-image-models/tree/main) library.
115 | - Our readme is inspired by [IPViT](https://github.com/Muzammal-Naseer/IPViT).
116 |
--------------------------------------------------------------------------------
/src/perturbation/attn_perturbation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from collections import defaultdict
3 | from pathlib import Path
4 |
5 | import torch
6 | from torch.utils.data import DataLoader
7 | from tqdm import tqdm
8 | from transformers import ViTImageProcessor
9 |
10 | from src.datasets.imagenet import ImagenetDatasetS
11 | from src.models.vit_edited import vit_base_patch32_224, vit_base_patch16_224
12 |
13 |
14 | class ViTPerturbAttn():
15 | def __init__(
16 | self, version, project_path, imgs_path, device='cpu',
17 | perturb_type='no_cls'
18 | ):
19 | self.version = version
20 | self.model_name = f'vit_b_{version}'
21 | self.device = device
22 | self.perturb_type = perturb_type
23 |
24 | self.project_path = project_path
25 | self.imgs_path = imgs_path
26 | self.perturb_path = project_path / 'results' / 'perturbation' / model
27 | self.perturb_path.mkdir(parents=True, exist_ok=True)
28 |
29 | self._get_model_layers()
30 | self._load_model()
31 |
32 | def _get_model_layers(self):
33 | self.layers = ['blocks.11']
34 | return
35 |
36 | def _load_model(self):
37 | if self.version == '32':
38 | self.model = vit_base_patch32_224(
39 | pretrained=True, pretrained_cfg=True,
40 | perturb_type=self.perturb_type, block_perturb='all'
41 | ).to(self.device)
42 | source = 'google/vit-base-patch32-224-in21k'
43 | self.n_tokens = 49
44 | elif self.version == '16':
45 | self.model = vit_base_patch16_224(
46 | pretrained=True, pretrained_cfg=True,
47 | perturb_type=self.perturb_type, block_perturb='all'
48 | ).to(self.device)
49 | source = 'google/vit-base-patch16-224-in21k'
50 | self.n_tokens = 196
51 | self.model.eval()
52 |
53 | # Add feature extractor
54 | self._add_feature_extractor()
55 |
56 | # Get image transform
57 | self.img_transform = ViTImageProcessor.from_pretrained(source)
58 | return
59 |
60 | def _add_feature_extractor(self):
61 | def get_activation(layer_name):
62 | def hook(_, input, output):
63 | try:
64 | self.model._features[layer_name] = output.detach()
65 | except AttributeError:
66 | # attention layer can output a tuple
67 | self.model._features[layer_name] = output[0].detach()
68 | return hook
69 |
70 | for layer_name, layer in self.model.named_modules():
71 | if layer_name in self.layers:
72 | layer.register_forward_hook(get_activation(layer_name))
73 | return
74 |
75 | def _get_dataset(self):
76 | dataset = ImagenetDatasetS(self.imgs_path)
77 | self.dataloader = DataLoader(dataset, batch_size=1, collate_fn=lambda x: x[0])
78 | return
79 |
80 | def compute(self):
81 | self._get_dataset()
82 | acc = []
83 | dec = defaultdict(list)
84 | for id, data in tqdm(
85 | enumerate(self.dataloader), total=len(self.dataloader)
86 | ):
87 | # Get image features
88 | img_ft = self.img_transform(data['img'], return_tensors="pt")
89 | img_ft = img_ft['pixel_values'].to(self.device)
90 |
91 | # Compute hidden states
92 | self.model._features = {}
93 | with torch.no_grad():
94 | out = self.model(img_ft)
95 |
96 | # Turn hidden states into decodability scores
97 | for l_name, l_repr in self.model._features.items():
98 | l_split = l_name.split('.')
99 | try:
100 | l_name = f'hs-{l_split[2]}_{l_split[1]}'
101 | except:
102 | l_name = f'hs_{l_split[1]}'
103 | with torch.no_grad():
104 | preds = self.model.head(self.model.norm(l_repr))
105 | ordered_idx = torch.argsort(preds, dim=2, descending=True)
106 | label_idx = (ordered_idx == data['index']).nonzero()
107 | dec[l_name].append(label_idx[:, 2])
108 |
109 | # Compute accuracy
110 | pred = out.topk(1)[1]
111 | cat_acc = torch.squeeze((pred == data['index']).long())
112 | acc.append(cat_acc)
113 |
114 | # Save accuracy and hidden states
115 | acc = torch.hstack(acc).to(self.device)
116 | torch.save(acc, self.perturb_path / f'attn-{self.perturb_type}_acc.pt')
117 |
118 | dec = torch.stack(dec['hs_11']).to(self.device)
119 | torch.save(dec, self.perturb_path / f'attn-{self.perturb_type}_dec.pt')
120 |
121 | return
122 |
123 |
124 | if __name__ == '__main__':
125 | parser = argparse.ArgumentParser()
126 | parser.add_argument(
127 | '-pp', action='store', required=True,
128 | help='Path to the folder containing the source code.'
129 | )
130 | parser.add_argument(
131 | '-dp', action='store', required=True,
132 | help='Path to the folder containing the dataset.'
133 | )
134 | parser.add_argument(
135 | '-m', action='store', required=True,
136 | help='Select which model to run. Can be one of the following options: \
137 | vit_16, vit_32.'
138 | )
139 | parser.add_argument(
140 | '-pt', action='store', required=True,
141 | help='Perturbation type. Can be one of [no_cls], [self_only].'
142 | )
143 | args = parser.parse_args()
144 |
145 | project_path = Path(args.pp)
146 | data_path = Path(args.dp)
147 |
148 | model = args.m
149 | version = model.split('_')[-1]
150 | device = "cuda" if torch.cuda.is_available() else "cpu"
151 | perturb_type = args.pt
152 |
153 | perturb = ViTPerturbAttn(
154 | version, project_path, data_path, device, perturb_type=perturb_type
155 | )
156 | perturb.compute()
--------------------------------------------------------------------------------
/src/linear_probing/prober.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | from PIL import Image
4 |
5 | import numpy as np
6 | from sklearn.linear_model import Ridge
7 | from sklearn.metrics import accuracy_score
8 | from timm import create_model
9 | from tqdm import tqdm
10 | import torch
11 | from torchvision.models.feature_extraction import create_feature_extractor
12 | from transformers import AutoImageProcessor
13 |
14 | from src.datasets.imagenet import ImagenetDatasetS
15 |
16 | LAYERS = [f'{layer}-{i}' for layer in ['hs', 'attn', 'mlp'] for i in range(12)]
17 |
18 | CATS_EXCLUDE = [
19 | 'n04356056', 'n04355933', 'n04493381', 'n02808440', 'n03642806',
20 | 'n03832673', 'n04008634', 'n03773504', 'n03887697', 'n15075141'
21 | ]
22 |
23 |
24 | class LinearProber:
25 | def __init__(self, model_name, layer, project_path, imgs_path, device):
26 | self.model_name = model_name
27 | self.layer = layer
28 | self.device = device
29 | self._load_model()
30 |
31 | self.project_path = Path(project_path)
32 | self.imgs_path = Path(imgs_path)
33 |
34 | return
35 |
36 | def _load_model(self):
37 | # Load model
38 | self.model = create_model(
39 | 'vit_base_patch32_224', pretrained=True
40 | ).to(self.device)
41 |
42 | # Add feature extractor
43 | block = self.layer.split('-')[-1]
44 | if 'attn' in self.layer:
45 | self.node = f'blocks.{block}.attn.proj'
46 | elif 'mlp' in self.layer:
47 | self.node = f'blocks.{block}.mlp.fc2'
48 | elif 'hs' in self.layer:
49 | self.node = f'blocks.{block}.add_1'
50 | self.extractor = create_feature_extractor(self.model, [self.node]).to(self.device)
51 |
52 | # Get image processor
53 | self.img_transform = AutoImageProcessor.from_pretrained(
54 | 'google/vit-base-patch32-224-in21k'
55 | )
56 |
57 | return
58 |
59 | def _get_training_data(self, token):
60 | dataset = ImagenetDatasetS(self.imgs_path, partition='test', n=10)
61 |
62 | fts = []
63 | targets = []
64 | for _, row in tqdm(
65 | dataset.stim_info.iterrows(), total=len(dataset.stim_info),
66 | desc=f'{self.node}/{token}'
67 | ):
68 | # Get image features
69 | img_path = self.imgs_path / 'ImageNetS919/test' / row['imagenet_id'] / row['img_name']
70 | img = Image.open(img_path).convert('RGB')
71 |
72 | img_ft = self.img_transform(img, return_tensors="pt")
73 | img_ft = img_ft['pixel_values'].to(self.device)
74 |
75 | # Run model
76 | with torch.no_grad():
77 | out = self.extractor(img_ft)
78 |
79 | # Save features
80 | out = out[self.node][0, token]
81 | fts.append(out)
82 |
83 | # Get one hot encoder
84 | target = torch.zeros(1000).to(self.device) - 1
85 | target[row['index']] = 1
86 | targets.append(target)
87 |
88 | fts = torch.stack(fts)
89 | targets = torch.stack(targets)
90 |
91 | return fts, targets
92 |
93 | def _get_testing_data(self, token):
94 | dataset = ImagenetDatasetS(self.imgs_path)
95 |
96 | fts = []
97 | targets = []
98 | for _, row in tqdm(dataset.stim_info.iterrows(), total=len(dataset.stim_info)):
99 | # Get image features
100 | img_path = self.imgs_path / 'ImageNetS919/validation' / row['imagenet_id'] / row['img_name']
101 | img = Image.open(img_path).convert('RGB')
102 |
103 | img_ft = self.img_transform(img, return_tensors="pt")
104 | img_ft = img_ft['pixel_values'].to(self.device)
105 |
106 | # Run model
107 | with torch.no_grad():
108 | out = self.extractor(img_ft)
109 |
110 | # Save features
111 | out = out[self.node][0, token]
112 | fts.append(out)
113 |
114 | # Get index
115 | targets.append(row['index'])
116 |
117 | fts = torch.stack(fts)
118 |
119 | return fts, targets
120 |
121 | def compute(self):
122 | path = self.project_path / 'results' / 'linear_probing' / self.layer
123 | path.mkdir(parents=True, exist_ok=True)
124 |
125 | accs = []
126 | for token in range(50):
127 | # Get training data
128 | fts, targets = self._get_training_data(token)
129 |
130 | # Train model
131 | lm = Ridge(alpha=1.0, solver='lsqr')
132 | lm.fit(fts.cpu().numpy(), targets.cpu().numpy())
133 |
134 | # Get testing data
135 | fts, targets = self._get_testing_data(token)
136 |
137 | # Test accuracy
138 | preds = lm.predict(fts.cpu().numpy())
139 | topk = np.argmax(preds, axis=1)
140 | acc = accuracy_score(targets, topk)
141 | accs.append(acc)
142 |
143 | # Save position of correct prediction
144 | pos = 999 - np.where(np.argsort(preds) == np.expand_dims(targets, axis=1))[1]
145 | np.save(path / f'pos_t{token}.npy', pos)
146 |
147 | # Save accuracy for token and layer
148 | np.save(path / f'acc.npy', accs)
149 |
150 | return
151 |
152 |
153 | if __name__ == '__main__':
154 | parser = argparse.ArgumentParser()
155 | parser.add_argument(
156 | '-pp', action='store', required=True,
157 | help='Path to the folder containing the source code.'
158 | )
159 | parser.add_argument(
160 | '-dp', action='store', required=True,
161 | help='Path to the folder containing the dataset.'
162 | )
163 | parser.add_argument(
164 | '-l', action='store', required=True,
165 | help='Layer index.'
166 | )
167 |
168 | args = parser.parse_args()
169 | project_path = Path(args.pp)
170 | data_path = Path(args.dp)
171 |
172 | device = "cuda" if torch.cuda.is_available() else "cpu"
173 |
174 | layer = LAYERS[int(args.l)]
175 | LinearProber('vit_b_32', layer, project_path, data_path, device).compute()
176 |
--------------------------------------------------------------------------------
/src/models/deit_ensemble.py:
--------------------------------------------------------------------------------
1 | ###
2 | # From https://github.com/Muzammal-Naseer/ATViT/blob/c7f947831decc43cc2b73884b16e6e79b02ad547/vit_models/deit_ensemble.py#L111
3 | ###
4 |
5 | from functools import partial
6 |
7 | import torch
8 | import torch.nn as nn
9 | import math
10 | from einops import reduce, rearrange
11 | from timm.models.registry import register_model
12 | from timm.models.vision_transformer import VisionTransformer, _cfg
13 |
14 | import torch.nn.functional as F
15 |
16 | __all__ = [
17 | "tiny_patch16_224_hierarchical", "small_patch16_224_hierarchical", "base_patch16_224_hierarchical"
18 | ]
19 |
20 |
21 | class TransformerHead(nn.Module):
22 | expansion = 1
23 |
24 | def __init__(self, token_dim, num_patches=196, num_classes=1000, stride=1):
25 | super(TransformerHead, self).__init__()
26 |
27 | self.token_dim = token_dim
28 | self.num_patches = num_patches
29 | self.num_classes = num_classes
30 |
31 | # To process patches
32 | self.conv = nn.Conv2d(self.token_dim, self.token_dim, kernel_size=3, stride=stride, padding=1, bias=False)
33 | self.bn = nn.BatchNorm2d(self.token_dim)
34 | self.conv = nn.Conv2d(self.token_dim, self.token_dim, kernel_size=3, stride=1, padding=1, bias=False)
35 | self.bn = nn.BatchNorm2d(self.token_dim)
36 |
37 | self.shortcut = nn.Sequential()
38 | if stride != 1 or self.token_dim != self.expansion * self.token_dim:
39 | self.shortcut = nn.Sequential(
40 | nn.Conv2d(self.token_dim, self.expansion * self.token_dim, kernel_size=1, stride=stride, bias=False),
41 | nn.BatchNorm2d(self.expansion * self.token_dim)
42 | )
43 |
44 | self.token_fc = nn.Linear(self.token_dim, self.token_dim)
45 |
46 | def forward(self, x):
47 | """
48 | x : (B, num_patches + 1, D) -> (B, C=num_classes)
49 | """
50 | cls_token, patch_tokens = x[:, 0], x[:, 1:]
51 | size = int(math.sqrt(x.shape[1]))
52 |
53 | patch_tokens = rearrange(patch_tokens, 'b (h w) d -> b d h w', h=size, w=size) # B, D, H, W
54 | features = F.relu(self.bn(self.conv(patch_tokens)))
55 | features = self.bn(self.conv(features))
56 | features += self.shortcut(patch_tokens)
57 | features = F.relu(features)
58 | patch_tokens = F.avg_pool2d(features, 14).view(-1, self.token_dim)
59 | cls_token = self.token_fc(cls_token)
60 |
61 | out = patch_tokens + cls_token
62 |
63 | return out
64 |
65 |
66 | class VisionTransformer_hierarchical(VisionTransformer):
67 | def __init__(self, *args, **kwargs):
68 | super().__init__(*args, **kwargs)
69 |
70 | # # Transformer heads
71 | # self.transformerheads = nn.Sequential(*[
72 | # TransformerHead(self.embed_dim)
73 | # for i in range(11)])
74 |
75 | def forward_features(self, x):
76 | B, nc, w, h = x.shape
77 | x = self.patch_embed(x)
78 |
79 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
80 | x = torch.cat((cls_tokens, x), dim=1)
81 | x = x + self.pos_embed
82 | x = self.pos_drop(x)
83 |
84 | # Store transformer outputs
85 | #transformerheads_outputs = []
86 |
87 | for idx, blk in enumerate(self.blocks):
88 | x = blk(x)
89 | # if idx <= 10:
90 | # out = self.norm(x)
91 | # out = self.transformerheads[idx](out)
92 | # transformerheads_outputs.append(out)
93 |
94 | x = self.norm(x)
95 | return x
96 | #return x, transformerheads_outputs
97 |
98 | def forward(self, x):
99 | x = self.forward_features(x)
100 | #x, transformerheads_outputs = self.forward_features(x)
101 | output = self.head(x[:, 0])
102 | # output = []
103 | # for y in transformerheads_outputs:
104 | # output.append(self.head(y))
105 | # output.append(self.head(x[:, 0]))
106 | return output
107 |
108 |
109 | @register_model
110 | def tiny_patch16_224_hierarchical(pretrained=False, **kwargs):
111 | model = VisionTransformer_hierarchical(
112 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
113 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
114 | model.default_cfg = _cfg()
115 | if pretrained:
116 | checkpoint = torch.hub.load_state_dict_from_url(
117 | url="https://github.com/Muzammal-Naseer/Improving-Adversarial-Transferability-of-Vision-Transformers"
118 | "/releases/download/v0/deit_tiny_trm.pth",
119 | map_location="cpu", check_hash=True
120 | )
121 | model.load_state_dict(checkpoint["state_dict"])
122 | return model
123 |
124 |
125 | @register_model
126 | def small_patch16_224_hierarchical(pretrained=False, **kwargs):
127 | model = VisionTransformer_hierarchical(
128 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
129 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
130 | model.default_cfg = _cfg()
131 |
132 | if pretrained:
133 | checkpoint = torch.hub.load_state_dict_from_url(
134 | url="https://github.com/Muzammal-Naseer/Improving-Adversarial-Transferability-of-Vision-Transformers"
135 | "/releases/download/v0/deit_small_trm.pth",
136 | map_location="cpu", check_hash=True
137 | )
138 | model.load_state_dict(checkpoint["state_dict"])
139 | return model
140 |
141 |
142 | @register_model
143 | def base_patch16_224_hierarchical(pretrained=False, **kwargs):
144 | model = VisionTransformer_hierarchical(
145 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
146 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
147 | model.default_cfg = _cfg()
148 |
149 | if pretrained:
150 | checkpoint = torch.hub.load_state_dict_from_url(
151 | url="https://github.com/Muzammal-Naseer/Improving-Adversarial-Transferability-of-Vision-Transformers"
152 | "/releases/download/v0/deit_base_trm.pth",
153 | map_location="cpu", check_hash=True
154 | )
155 | model.load_state_dict(checkpoint["state_dict"], strict=False)
156 | return model
157 |
--------------------------------------------------------------------------------
/src/perturbation/tokens_perturbation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from collections import defaultdict
3 | from pathlib import Path
4 |
5 | import torch
6 | from torch.utils.data import DataLoader
7 | from tqdm import tqdm
8 | from transformers import ViTImageProcessor
9 |
10 | from src.datasets.imagenet import ImagenetDatasetS
11 | from src.models.vit_edited import vit_base_patch32_224, vit_base_patch16_224
12 | from src.vis import Vis
13 |
14 |
15 | class ViTPerturbTokens():
16 | def __init__(self, version, project_path, imgs_path, device='cpu'):
17 | self.version = version
18 | self.model_name = f'vit_b_{version}'
19 | self.device = device
20 |
21 | self.project_path = Path(project_path)
22 | self.imgs_path = Path(imgs_path)
23 | self.perturb_path = project_path / 'results' / 'perturbation' / model
24 | self.perturb_path.mkdir(parents=True, exist_ok=True)
25 |
26 | self._get_model_layers()
27 | self._load_model()
28 |
29 | def _get_model_layers(self):
30 | # Compute identifiability of tokens in last block
31 | self.layers = ['blocks.11']
32 | return
33 |
34 | def _load_model(self):
35 | # Load model
36 | if self.version == '32':
37 | self.model = vit_base_patch32_224(
38 | pretrained=True, pretrained_cfg=True
39 | ).to(self.device)
40 | self.n_tokens = 49
41 | source = 'google/vit-base-patch32-224-in21k'
42 | elif self.version == '16':
43 | self.model = vit_base_patch16_224(
44 | pretrained=True, pretrained_cfg=True
45 | ).to(self.device)
46 | source = 'google/vit-base-patch16-224-in21k'
47 | self.n_tokens = 196
48 | self.model.eval()
49 |
50 | # Add feature extractor
51 | self._add_feature_extractor()
52 |
53 | # Get image transform
54 | self.img_transform = ViTImageProcessor.from_pretrained(source)
55 | return
56 |
57 | def _add_feature_extractor(self):
58 | def get_activation(layer_name):
59 | def hook(_, input, output):
60 | try:
61 | self.model._features[layer_name] = output.detach()
62 | except AttributeError:
63 | # attention layer can output a tuple
64 | self.model._features[layer_name] = output[0].detach()
65 | return hook
66 |
67 | for layer_name, layer in self.model.named_modules():
68 | if layer_name in self.layers:
69 | layer.register_forward_hook(get_activation(layer_name))
70 | return
71 |
72 |
73 | def _get_dataset(self):
74 | self.dataset = ImagenetDatasetS(self.imgs_path)
75 | self.dataloader = DataLoader(
76 | self.dataset, batch_size=1, collate_fn=lambda x: x[0]
77 | )
78 | return
79 |
80 | def compute(self, mask_type='context'):
81 | self._get_dataset()
82 |
83 | vis = Vis(self.project_path, self.imgs_path, self.model_name, self.device)
84 |
85 | acc = []
86 | dec = defaultdict(list)
87 | for id, data in tqdm(
88 | enumerate(self.dataloader), total=len(self.dataloader)
89 | ):
90 | # Get image features
91 | img_ft = self.img_transform(data['img'], return_tensors="pt")
92 | img_ft = img_ft['pixel_values'].to(self.device)
93 |
94 | # Get segmentations of discarded tokens
95 | mask = vis.get_segmentation(data['imagenet_id'], idx=data['img_index'])
96 | mask = mask.flatten()
97 | if mask_type == 'context':
98 | mask = torch.tensor((mask == 1).nonzero()[0] + 1).to(self.device)
99 | elif mask_type == 'class_label':
100 | mask = torch.tensor((mask == 0).nonzero()[0] + 1).to(self.device)
101 | tokens = torch.hstack((torch.tensor([0]), mask))
102 |
103 | # Compute hidden states
104 | self.model._features = {}
105 | with torch.no_grad():
106 | out = self.model(img_ft, tokens)
107 |
108 | # Turn hidden states into decodability scores
109 | for l_name, l_repr in self.model._features.items():
110 | l_split = l_name.split('.')
111 | try:
112 | l_name = f'hs-{l_split[2]}_{l_split[1]}'
113 | except:
114 | l_name = f'hs_{l_split[1]}'
115 | with torch.no_grad():
116 | preds = self.model.head(self.model.norm(l_repr))
117 | ordered_idx = torch.argsort(preds, dim=2, descending=True)
118 | label_idx = (ordered_idx == data['index']).nonzero()
119 | dec[l_name].append(label_idx[:, 2])
120 |
121 | # Compute accuracy
122 | pred = out.topk(1)[1]
123 | cat_acc = torch.squeeze((pred == data['index']).long())
124 | acc.append(cat_acc)
125 |
126 | # Save accuracy and hidden states
127 | acc = torch.hstack(acc).to(self.device)
128 | torch.save(acc, self.perturb_path / f'no_{mask_type}_tokens_acc.pt')
129 | torch.save(dec, self.perturb_path / f'no_{mask_type}_tokens_dec.pt')
130 |
131 | return
132 |
133 |
134 | if __name__ == '__main__':
135 | parser = argparse.ArgumentParser()
136 | parser.add_argument(
137 | '-pp', action='store', required=True,
138 | help='Path to the folder containing the source code.'
139 | )
140 | parser.add_argument(
141 | '-dp', action='store', required=True,
142 | help='Path to the folder containing the dataset.'
143 | )
144 | parser.add_argument(
145 | '-m', action='store', required=True,
146 | help='Select which model to run. Can be one of the following options: \
147 | vit_16, vit_32'
148 | )
149 | parser.add_argument(
150 | '-mt', action='store', required=True,
151 | help='Mask type. Can be one of [context], [class_label].'
152 | )
153 | args = parser.parse_args()
154 |
155 | project_path = Path(args.pp)
156 | data_path = Path(args.dp)
157 |
158 | model = args.m
159 | version = model.split('_')[-1]
160 | device = "cuda" if torch.cuda.is_available() else "cpu"
161 | mask_type = args.mt
162 |
163 | perturb = ViTPerturbTokens(version, project_path, data_path, device)
164 | if mask_type:
165 | perturb.compute(mask_type)
166 | else:
167 | perturb.compute()
--------------------------------------------------------------------------------
/src/plots/explainability.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import seaborn as sns
4 | import torch
5 |
6 | from src.gradients import ViTGradients
7 | from src.vis import Vis
8 |
9 |
10 | def plot_specific_ft_importance(
11 | model_name, block, head, imgs_info, stim_info, proj_path, dataset_path
12 | ):
13 | """
14 | Plot block- and head-specific feature importance results.
15 | """
16 | # Get gradients
17 | vis = Vis(proj_path, dataset_path, model_name, 'cpu')
18 | g = ViTGradients(model_name, proj_path, dataset_path)
19 |
20 | # Plot
21 | fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(5, 3))
22 | for i, ax in zip(imgs_info, axes.flat):
23 | # Compute gradients
24 | grads, _ = g.compute(i[0], i[1], i[2])
25 | mask = - grads[block][head, 0, 1:]
26 | mask_vis = vis.mask(i[0], i[1], mask)
27 |
28 | ax.imshow(mask_vis)
29 | ax.set_xticks([])
30 | ax.set_yticks([])
31 | cat = stim_info[stim_info['index'] == i[2]]['cat'].unique()[0].split(',')[0].lower()
32 | ax.set_title(cat)
33 |
34 | f = proj_path / 'results/figures' / f'explainability_{model_name}_head_1.png'
35 | plt.tight_layout()
36 | plt.savefig(f, dpi=300)
37 | plt.show()
38 | return
39 |
40 |
41 | def plot_sum_ft_importance(model_name, imgs_info, stim_info, proj_path, dataset_path):
42 | """
43 | Plot sum of gradients feature importance results.
44 | """
45 | # Visualize feature importance
46 | if len(imgs_info) <= 6:
47 | fig, axes = plt.subplots(nrows=1, ncols=len(imgs_info), figsize=(10, 3))
48 | else:
49 | fig, axes = plt.subplots(nrows=2, ncols=int(len(imgs_info)/2), figsize=(12, 5))
50 |
51 | for i, ax in zip(imgs_info, axes.flat):
52 |
53 | # Compute importance by gradients
54 | grads, _ = ViTGradients(model_name, proj_path, dataset_path).compute(i[0], i[1], i[2])
55 | importance = []
56 | for b in range(12):
57 | importance.append(torch.sum(grads[b], dim=0)[0])
58 | mask = torch.sum(torch.stack(importance), dim=0)[1:]
59 | mask = - mask
60 |
61 | # Plot heatmap over image
62 | vis = Vis(proj_path, dataset_path, model_name, 'cpu')
63 | mask_vis = vis.mask(i[0], i[1], mask)
64 | ax.imshow(mask_vis)
65 | ax.set_xticks([])
66 | ax.set_yticks([])
67 | if i[2] == 582:
68 | cat = 'grocery store'
69 | else:
70 | cat = stim_info[stim_info['index'] == i[2]]['cat'].unique()[0].split(',')[0].lower()
71 | ax.set_title(cat)
72 |
73 | f = proj_path / 'results/figures' / f'explainability_{model_name}.png'
74 | plt.tight_layout()
75 | plt.savefig(f, dpi=300)
76 | plt.show()
77 |
78 | return
79 |
80 |
81 | def compare_explainability(proj_path, dataset_path, stim_info):
82 | model_name = 'vit_b_32'
83 |
84 | # Select different images, classes, blocks and head
85 | img_info = [
86 | ['n02422699', 0, 352, 7, 0], ['n02422699', 0, 350, 7, 0],
87 | ['n04404412', 1, 851, 11, 6], ['n04404412', 1, 831, 11, 6],
88 | ] # [imagenet_id, image_id, class_id, block, attention head]
89 |
90 | # Compute gradients
91 | vis = Vis(proj_path, dataset_path, model_name, 'cpu')
92 | g = ViTGradients(model_name, proj_path, dataset_path)
93 |
94 | # Plot
95 | fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(10, 6))
96 | #fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(10, 5))
97 | for i, ax in zip(img_info, axes.flat[:4]):
98 | b = i[3]
99 | h = i[4]
100 | grads, _ = g.compute(i[0], i[1], i[2])
101 | mask = - grads[b][h, 0, 1:]
102 | mask_vis = vis.mask(i[0], i[1], mask)
103 | ax.imshow(mask_vis)
104 | ax.set_xticks([])
105 | ax.set_yticks([])
106 | cat = stim_info[stim_info['index'] == i[2]]['cat'].unique()[0].split(',')[0].lower()
107 | ax.set_title(f'{cat} - bl. {b}, h. {h}', fontsize=10)
108 |
109 | for i, ax in zip(img_info, axes.flat[4:]):
110 | # Compute importance by gradients
111 | grads, _ = g.compute(i[0], i[1], i[2])
112 | importance = []
113 | for b in range(12):
114 | importance.append(torch.sum(grads[b], dim=0)[0])
115 | mask = torch.sum(torch.stack(importance), dim=0)[1:]
116 | mask = - mask
117 |
118 | # Plot heatmap over image
119 | mask_vis = vis.mask(i[0], i[1], mask)
120 | ax.imshow(mask_vis)
121 | ax.set_xticks([])
122 | ax.set_yticks([])
123 | if i[2] == 582:
124 | cat = 'grocery store'
125 | else:
126 | cat = stim_info[stim_info['index'] == i[2]]['cat'].unique()[0].split(',')[0].lower()
127 | ax.set_title(cat, fontsize=10)
128 |
129 | fig.text(x=0.1, y=0.92, s='(a) Block- and head-specific visualization', weight='bold', fontsize=12)
130 | fig.text(x=0.1, y=0.5, s='(b) Sum of gradients visualization', weight='bold', fontsize=12)
131 | fig.text(x=0.11, y=0.86, s='(1)', weight='bold', fontsize=11)
132 | fig.text(x=0.51, y=0.86, s='(2)', weight='bold', fontsize=11)
133 | fig.text(x=0.11, y=0.44, s='(1)', weight='bold', fontsize=11)
134 | fig.text(x=0.51, y=0.44, s='(2)', weight='bold', fontsize=11)
135 |
136 | f = proj_path / 'results/figures' / f'explainability_{model_name}_v2.png'
137 | plt.savefig(f, dpi=200)
138 | plt.show()
139 | return
140 |
141 |
142 | def plot_perturbation(model_name, perturb_type, res_path, random=False):
143 | """
144 | Plot perturbation experiment results.
145 | """
146 | labels = torch.arange(48) / 49 * 100
147 |
148 | fig, ax = plt.subplots(figsize=(6,4))
149 |
150 | f = res_path / 'perturbation' / model_name / f'{perturb_type}_grads.pt'
151 | emb_perturb = torch.load(f, map_location='cpu')
152 | emb_perturb = torch.flip(emb_perturb, dims=(0,)).detach().numpy()
153 | sns.lineplot(x=labels, y=emb_perturb, ax=ax, label='cls-emb removal')
154 | print(f'AUC emb: {np.sum(emb_perturb) / (np.max(emb_perturb) * 49)}')
155 |
156 | f = res_path / 'perturbation' / model_name / f'{perturb_type}_grads_chefer.pt'
157 | chefer_perturb = torch.load(f, map_location='cpu')
158 | chefer_perturb = torch.flip(chefer_perturb, dims=(0,)).detach().numpy()
159 | sns.lineplot(x=labels, y=chefer_perturb, ax=ax, label='chefer removal')
160 | print(f'AUC chefer: {np.sum(chefer_perturb) / (np.max(chefer_perturb) * 49)}')
161 |
162 | if random == True:
163 | f = res_path / 'perturbation' / model_name / 'perturb_random.pt'
164 | rand_perturb = torch.mean(torch.load(f, map_location='cpu'), dim=0)
165 | rand_perturb = torch.flip(rand_perturb, dims=(0,)).detach().numpy()
166 | sns.lineplot(x=labels, y=rand_perturb, ax=ax, label='random removal')
167 | print(f'AUC random: {np.sum(rand_perturb) / (np.max(rand_perturb) * 49)}')
168 |
169 | ax.hlines(
170 | xmin=-0.5, xmax=100.5, y=emb_perturb[0], colors='dimgray', linestyles='--', lw=2,
171 | label='baseline accuracy'
172 | )
173 | ax.set_xticks(np.arange(0, 100, 10))
174 | ax.set_xlim(-0.5, 95)
175 | ax.set_ylim(0,0.9)
176 | ax.set_xlabel('percentage of tokens removed')
177 | ax.set_ylabel('accuracy')
178 | ax.legend()
179 |
180 | plt.tight_layout()
181 | f = res_path / 'figures' / f'neg_perturb_{model_name}.png'
182 | plt.savefig(f, dpi=300)
183 | plt.show()
184 |
185 | return
--------------------------------------------------------------------------------
/src/vis.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | from PIL import Image
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import torch
8 |
9 | from src.datasets.imagenet import ImagenetDatasetS
10 | from src.models.load import load_vit
11 |
12 |
13 | class Vis():
14 | def __init__(self, project_path, dataset_path, model, device):
15 | # Get paths
16 | self.project_path = Path(project_path)
17 | self.res_path = self.project_path / 'results'
18 |
19 | # Get model image processor
20 | self.model = model
21 | self.device = device
22 | self._get_img_processor()
23 |
24 | # Get stimuli information
25 | self.dataset_path = Path(dataset_path)
26 | dataset = ImagenetDatasetS(self.dataset_path)
27 | self.imgs_path = dataset.imgs_path
28 | self.segmentations_path = dataset.path / 'validation-segmentation'
29 | self.stim_info = dataset.stim_info
30 | self.concepts = self.stim_info['imagenet_id'].unique().tolist()
31 |
32 | # Get segmentation mapping
33 | f = dataset_path / 'data/categories/ImageNetS_categories_im919.txt'
34 | seg_map = pd.read_csv(f, header=None)
35 | self.seg_map = {row[0]:idx+1 for idx, row in seg_map.iterrows()}
36 |
37 | return
38 |
39 | def _get_img_processor(self):
40 | """Load image processor and get image dimensions.
41 | """
42 |
43 | # Load model
44 | _, _, _, self.processor = load_vit(
45 | self.model, self.device, self.project_path, return_transform=True
46 | )
47 |
48 | # Get image dimension
49 | if '32' in self.model:
50 | self.img_dim = 7
51 | elif '16' in self.model:
52 | self.img_dim = 14
53 |
54 | return
55 |
56 | def get_array(self, concept, idx):
57 | """Get image as array.
58 |
59 | Parameters
60 | ----------
61 | concept : str
62 | Imagenet id.
63 | idx : int
64 | Image id of Imagenet-S.
65 |
66 | Returns
67 | -------
68 | np.array
69 | Image of concept and idx.
70 | """
71 | file = self.imgs_path / concept / (
72 | self.stim_info.loc[self.stim_info['imagenet_id'] == concept]
73 | .iloc[idx]['img_name']
74 | )
75 | img = Image.open(file).convert('RGB')
76 | img = self.processor(img)
77 | try:
78 | img = np.transpose(img['pixel_values'][0], (1, 2, 0))
79 | except:
80 | img = img
81 | img = (img - img.min()) / (img.max() - img.min())
82 | return img
83 |
84 | def get_pil(self, concept, idx):
85 | """Get image as PIL Image.
86 |
87 | Parameters
88 | ----------
89 | concept : str
90 | Imagenet id.
91 | idx : int
92 | Image id of Imagenet-S.
93 |
94 | Returns
95 | -------
96 | PIL Image
97 | Image of concept and idx.
98 | """
99 | img = self.get_array(concept, idx)
100 | img = Image.fromarray((img * 255).astype(np.uint8))
101 | return img
102 |
103 | def get_token(self, concept, idx, token_idx, as_pil=True):
104 | """Return token from image.
105 |
106 | Parameters
107 | ----------
108 | concept : str
109 | Imagenet id.
110 | idx : int
111 | Image id of Imagenet-S.
112 | token_idx : int
113 | Index of token in image.
114 | as_pil : bool, optional
115 | Whether to return the token as image, by default True.
116 |
117 | Returns
118 | -------
119 | np.array or PIL Image
120 | Token of image.
121 | """
122 | # Get token position
123 | row_idx = token_idx // self.img_dim
124 | col_idx = token_idx % self.img_dim
125 |
126 | # Split image
127 | img = self.get_array(concept, idx)
128 | token = np.split(img, self.img_dim, axis=0)[row_idx]
129 | token = np.split(token, self.img_dim, axis=1)[col_idx]
130 |
131 | # Return token
132 | if as_pil:
133 | return Image.fromarray((token * 255).astype(np.uint8)).resize((100,100))
134 | else:
135 | return token
136 |
137 | def get_segmentation(self, concept, idx):
138 | """Get class segmentation.
139 |
140 | Parameters
141 | ----------
142 | concept : str
143 | Imagenet id.
144 | idx : int
145 | Image id of Imagenet-S.
146 |
147 | Returns
148 | -------
149 | np.array
150 | Mask of class in the image.
151 | """
152 |
153 | # Get segmentation information
154 | f = self.stim_info.loc[self.stim_info['imagenet_id'] == concept]
155 | f = f"{Path(f.iloc[idx]['img_name']).stem}.png"
156 | sgt = Image.open(self.segmentations_path / concept / f)
157 |
158 | # Get mask class
159 | mask = np.array(sgt)
160 | mask = mask[:, :, 1] * 256 + mask[:, :, 0]
161 | mask = (mask == self.seg_map[concept]).astype(int)
162 |
163 | # Resize mask
164 | mask = cv2.resize(
165 | mask, dsize=(self.img_dim, self.img_dim),
166 | interpolation=cv2.INTER_NEAREST
167 | )
168 | return mask
169 |
170 | def mask(self, concept, idx, weights, prepro='normalize_minmax', invert=False):
171 | """Generate heatmap over image from weights.
172 | Code adapted from https://github.com/hila-chefer/Transformer-MM-Explainability
173 |
174 | Parameters
175 | ----------
176 | concept : str
177 | Imagenet id.
178 | idx : int
179 | Image id of Imagenet-S.
180 | weights : torch Tensor
181 | Weights to plot the heatmap. Dimension should be same as the number
182 | of image tokens.
183 | prepro : str, optional
184 | Wheter to preprocess the weights, by default 'normalize_minmax'.
185 | Can be one of the following: 'normalize_minmax' to scale the
186 | weights in the range of 0-1, or 'normalize_decoding' to scale the
187 | weights with respect to the number of classes in Imagenet.
188 | invert : bool, optional
189 | Whether to invert the weights, by default False.
190 |
191 | Returns
192 | -------
193 | np.array
194 | Image with heatmap.
195 | """
196 |
197 | # Preprocess weights if needed
198 | if prepro == 'normalize_minmax':
199 | weights = (weights - weights.min()) / (weights.max() - weights.min())
200 | elif prepro == 'normalize_decoding':
201 | weights = weights / 1000
202 | elif prepro == None:
203 | pass
204 |
205 | if invert == True:
206 | weights = - weights
207 | print('inverted')
208 |
209 | # Transform weights into same size as image
210 | weights = weights.reshape(1, 1, self.img_dim, self.img_dim).float()
211 | weights = torch.nn.functional.interpolate(weights, size=224, mode='bilinear')
212 | weights = torch.squeeze(weights).detach().numpy()
213 |
214 | # Create heatmap
215 | heatmap = cv2.applyColorMap(np.uint8(255 * weights), cv2.COLORMAP_JET)
216 | heatmap = np.float32(heatmap) / 255
217 |
218 | # Get image
219 | img = self.get_array(concept, idx)
220 |
221 | # Apply heatmap to image
222 | vis = heatmap + np.float32(img)
223 | vis = vis / np.max(vis)
224 | vis = np.uint8(255 * vis)
225 | vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
226 |
227 | return vis
228 |
--------------------------------------------------------------------------------
/src/gradients.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import torch
4 | from torch.nn import CrossEntropyLoss
5 | from torch.nn.functional import softmax
6 | from transformers import ViTImageProcessor
7 |
8 | from src.datasets.imagenet import ImagenetDatasetS
9 | from src.models.vit_edited import vit_base_patch32_224, vit_base_patch16_224
10 |
11 |
12 | class ViTGradients():
13 | """Compute feature importance of image tokens.
14 | """
15 | def __init__(self, model_name, project_path, imgs_path, device='cpu'):
16 | self.model_name = model_name
17 | self.device = device
18 |
19 | self.project_path = Path(project_path)
20 | self.imgs_path = Path(imgs_path)
21 | self.res_path = project_path / 'results'
22 | self.grad_path = self.res_path / 'gradients' / self.model_name
23 | self.grad_path.mkdir(parents=True, exist_ok=True)
24 |
25 | self._load_model()
26 |
27 | def _load_model(self):
28 | if self.model_name == 'vit_b_32':
29 | self.model = vit_base_patch32_224(
30 | pretrained=True, pretrained_cfg=True
31 | ).to(self.device)
32 | self.n_tokens = 50
33 | source = 'google/vit-base-patch32-224-in21k'
34 | elif self.model_name == 'vit_b_16':
35 | self.model = vit_base_patch16_224(
36 | pretrained=True, pretrained_cfg=True
37 | ).to(self.device)
38 | source = 'google/vit-base-patch16-224-in21k'
39 | self.n_tokens = 197
40 | self.model.eval()
41 | self.img_transform = ViTImageProcessor.from_pretrained(source)
42 | return
43 |
44 | def _get_dataset(self):
45 | self.dataset = ImagenetDatasetS(self.imgs_path)
46 | return
47 |
48 | def compute(
49 | self, concept, img_idx, cat_idx=None, grad_type='cross_entropy',
50 | input_type='attn_probs'
51 | ):
52 | # Get data
53 | self._get_dataset()
54 | data = self.dataset.get_item_by_concept(concept, img_idx)
55 | self.cat_idx = cat_idx
56 |
57 | # Get image features
58 | img_ft = self.img_transform(data['img'], return_tensors="pt")
59 | img_ft = img_ft['pixel_values'].to(self.device)
60 |
61 | # Compute hidden states
62 | self.model.zero_grad()
63 | for b in range(12):
64 | self.model.blocks[b].attn.attn_probs = None
65 | self.model.blocks[b].attn.key_proj_vals = None
66 | self.model.blocks[b].attn_cls = None
67 | _ = self.model(img_ft)
68 |
69 | # Get gradients
70 | self.input_type = input_type
71 | if grad_type == 'cross_entropy':
72 | grads, attns = self._compute_cross_entropy(data)
73 | elif grad_type == 'cat_prob':
74 | grads, attns = self._compute_cat_prob(data)
75 |
76 | return grads, attns
77 |
78 | def _compute_cross_entropy(self, data):
79 | loss = CrossEntropyLoss(reduction='none')
80 |
81 | grads = {}
82 | attns = {}
83 | for b in range(12):
84 | if self.input_type == 'attn_probs':
85 | inp = self.model.blocks[b].attn.attn_probs
86 | elif self.input_type == 'key_proj_vals':
87 | inp = self.model.blocks[b].attn.key_proj_vals
88 | out = self.model.blocks[b].attn_cls
89 |
90 | # Collect grads of all tokens
91 | out = torch.squeeze(out)
92 | out = softmax(out, dim=1)
93 | target = torch.zeros(1, 1000).to(self.device)
94 | if self.cat_idx != None:
95 | target[0, self.cat_idx] = 1
96 | else:
97 | target[0, data['index']] = 1
98 | target = target.repeat(self.n_tokens, 1)
99 | l = loss(out, target)
100 |
101 | b_grads = []
102 | for t in range(self.n_tokens):
103 | grad = torch.autograd.grad(l[t], inp, retain_graph=True)
104 | if self.input_type == 'attn_probs':
105 | b_grads.append(grad[0][0, :, t, :].detach())
106 | elif self.input_type == 'key_proj_vals':
107 | b_grads.append(grad[0][0, t, :].detach())
108 |
109 | grads[b] = torch.stack(b_grads).transpose(0,1)
110 |
111 | if self.input_type == 'attn_probs':
112 | attns[b] = torch.squeeze(inp)
113 | else:
114 | continue
115 |
116 | return grads, attns
117 |
118 | def _compute_cat_prob(self, data):
119 | grads = {}
120 | attns = {}
121 | for b in range(12):
122 | if self.input_type == 'attn_probs':
123 | inp = self.model.blocks[b].attn.attn_probs
124 | elif self.input_type == 'key_proj_vals':
125 | inp = self.model.blocks[b].attn.key_proj_vals
126 | out = self.model.blocks[b].attn_cls
127 |
128 | # Collect grads of all tokens
129 | b_grads = []
130 | for t in range(self.n_tokens):
131 | cat_prob = torch.zeros(1, self.n_tokens, 1000).to(self.device)
132 | cat_prob[0, t, data['index']] = 1
133 | cat_prob.requires_grad_(True)
134 | cat_prob = torch.sum(cat_prob * out).to(self.device)
135 |
136 | # Compute gradients
137 | grad = torch.autograd.grad(cat_prob, inp, retain_graph=True)
138 |
139 | if self.input_type == 'attn_probs':
140 | b_grads.append(grad[0][0, :, t, :].detach())
141 | elif self.input_type == 'key_proj_vals':
142 | b_grads.append(grad[0][0, t, :].detach())
143 |
144 | grads[b] = torch.stack(b_grads).transpose(0,1)
145 |
146 | if self.input_type == 'attn_probs':
147 | attns[b] = torch.squeeze(inp)
148 | else:
149 | continue
150 |
151 | return grads, attns
152 |
153 |
154 | class CheferGradients(ViTGradients):
155 | """Compute feature importance of image tokens using method from
156 | https://github.com/hila-chefer/Transformer-MM-Explainability
157 | """
158 | def __init__(self, model_name, project_path, imgs_path, device='cpu'):
159 | super().__init__(model_name, project_path, imgs_path, device)
160 | return
161 |
162 | def compute(self, concept, img_idx, cat_idx=None):
163 |
164 | # Get data
165 | self._get_dataset()
166 | data = self.dataset.get_item_by_concept(concept, img_idx)
167 | self.cat_idx = cat_idx
168 |
169 | # Get image features
170 | img_ft = self.img_transform(data['img'], return_tensors="pt")
171 | img_ft = img_ft['pixel_values'].to(self.device)
172 |
173 | # Compute hidden states
174 | self.model.zero_grad()
175 | for b in range(12):
176 | self.model.blocks[b].attn.attn_probs = None
177 | out = self.model(img_ft)
178 | del img_ft
179 |
180 | # Get one hot vector solution
181 | cat_prob = torch.zeros(1, 1000).to(self.device)
182 | if self.cat_idx != None:
183 | cat_prob[0, self.cat_idx] = 1
184 | else:
185 | cat_prob[0, data['index']] = 1
186 | cat_prob.requires_grad_(True)
187 | cat_prob = torch.sum(cat_prob * out).to(self.device)
188 |
189 | # Get gradients
190 | R = torch.eye(self.n_tokens, self.n_tokens).to(self.device)
191 | for b in range(12):
192 | cam = self.model.blocks[b].attn.attn_probs
193 | grad = torch.autograd.grad(cat_prob, cam, retain_graph=True)[0]
194 |
195 | # Average over heads
196 | cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1])
197 | grad = grad.reshape(-1, grad.shape[-2], grad.shape[-1])
198 | cam = grad * cam
199 | cam = cam.clamp(min=0).mean(dim=0).detach()
200 |
201 | # Apply self-attention rules
202 | r_add = torch.matmul(cam, R)
203 | R += r_add
204 |
205 | R = R[0, 1:]
206 |
207 | return R
208 |
--------------------------------------------------------------------------------
/src/plots/linear_probing.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import pandas as pd
4 | import seaborn as sns
5 | import torch
6 |
7 | from src.identifiability import get_class_embed
8 | from src.linear_probing.prober import LAYERS
9 |
10 |
11 | def plot_top1_acc(res_path, dataset_path):
12 | """
13 | Plot top-1 accuracies of linear probing and cls projection.
14 | """
15 | plt.rcParams.update({'font.size': 12})
16 | fig, axes = plt.subplots(ncols=2, figsize=(13, 5))
17 |
18 | # Get linear probing accuracies
19 | accs = []
20 | for layer in LAYERS:
21 | file = res_path / 'linear_probing' / layer / 'acc.npy'
22 | acc = np.load(file)[1:]
23 | df = pd.DataFrame(acc, columns=['top-1 acc'])
24 | df['layer'] = layer.split('-')[0]
25 | df['block'] = int(layer.split('-')[1]) + 1
26 | accs.append(df)
27 | accs = pd.concat(accs)
28 | ## Plot
29 | sns.lineplot(accs, y='top-1 acc', x='block', hue='layer', ax=axes[0])
30 | axes[0].set_title('linear probing')
31 | axes[0].set_xticks(np.arange(1, 13))
32 |
33 | # Get cls projection accuracies
34 | accs = []
35 | for layer in LAYERS:
36 | block = layer.split('-')[1]
37 | layer_type = layer.split('-')[0]
38 | if 'attn' in layer:
39 | layer = f'hs-attn_{block}'
40 | elif 'mlp' in layer:
41 | layer = f'hs-mlp_{block}'
42 | else:
43 | layer = f'hs_{block}'
44 | dec = get_class_embed(
45 | res_path, dataset_path, 'vit_b_32', layer, 'pos', normalize=False
46 | )
47 | dec = dec[:, 1:]
48 | acc = torch.sum((dec == 0), axis=0) / dec.shape[0]
49 | df = pd.DataFrame(acc.detach().numpy(), columns=['top-1 acc'])
50 | df['layer'] = layer_type
51 | df['block'] = int(block) + 1
52 | accs.append(df)
53 | accs = pd.concat(accs)
54 | ## Plot
55 | sns.lineplot(accs, y='top-1 acc', x='block', hue='layer', ax=axes[1])
56 | axes[1].set_title('cls projection')
57 | axes[1].set_xticks(np.arange(1, 13))
58 |
59 | plt.tight_layout()
60 | plt.show()
61 | return
62 |
63 |
64 | def plot_perturbation(res_path, model_name='vit_b_32'):
65 | """
66 | Plot and compare perturbation experiments.
67 | """
68 | plt.rcParams.update({'font.size': 12})
69 |
70 | labels = torch.arange(48) / 49 * 100
71 |
72 | fig, ax = plt.subplots(figsize=(6,4))
73 |
74 | # Plot our method
75 | f = res_path / 'perturbation' / model_name / 'negative_grads.pt'
76 | neg_perturb = torch.load(f, map_location='cpu')
77 | neg_perturb = torch.flip(neg_perturb, dims=(0,)).detach().numpy()
78 | sns.lineplot(x=labels, y=neg_perturb, ax=ax, label='NEG cls-based removal')
79 | print(f'negative AUC emb: {np.sum(neg_perturb) / (np.max(neg_perturb) * 49)}')
80 |
81 | # Plot linear probing
82 | f = res_path / 'perturbation' / model_name / 'negative_linear-probe.pt'
83 | linear_perturb = torch.load(f, map_location='cpu')
84 | linear_perturb = torch.flip(linear_perturb, dims=(0,)).detach().numpy()
85 | sns.lineplot(x=labels, y=linear_perturb, ax=ax, label='NEG probe removal')
86 | print(f'negative AUC linear: {np.sum(linear_perturb) / (np.max(linear_perturb) * 49)}')
87 |
88 | # Plot our method
89 | f = res_path / 'perturbation' / model_name / 'positive_grads.pt'
90 | neg_perturb = torch.load(f, map_location='cpu')
91 | neg_perturb = torch.flip(neg_perturb, dims=(0,)).detach().numpy()
92 | sns.lineplot(x=labels, y=neg_perturb, ax=ax, label='POS cls-based removal')
93 | print(f'positive AUC emb: {np.sum(neg_perturb) / (np.max(neg_perturb) * 49)}')
94 |
95 | # Plot linear probing
96 | f = res_path / 'perturbation' / model_name / 'positive_linear_probe.pt'
97 | linear_perturb = torch.load(f, map_location='cpu')
98 | linear_perturb = torch.flip(linear_perturb, dims=(0,)).detach().numpy()
99 | sns.lineplot(x=labels, y=linear_perturb, ax=ax, label='POS probe removal')
100 | print(f'positive AUC linear: {np.sum(linear_perturb) / (np.max(linear_perturb) * 49)}')
101 |
102 | # Plot random perturbation
103 | f = res_path / 'perturbation' / model_name / 'perturb_random.pt'
104 | rand_perturb = torch.mean(torch.load(f, map_location='cpu'), dim=0)
105 | rand_perturb = torch.flip(rand_perturb, dims=(0,)).detach().numpy()
106 | sns.lineplot(x=labels, y=rand_perturb, ax=ax, label='random removal')
107 | print(f'random AUC: {np.sum(rand_perturb) / (np.max(rand_perturb) * 49)}')
108 |
109 |
110 | ax.hlines(
111 | xmin=-0.5, xmax=100.5, y=neg_perturb[0], colors='dimgray', linestyles='--', lw=2,
112 | label='baseline accuracy'
113 | )
114 | ax.set_xticks(np.arange(0, 100, 10))
115 | ax.set_xlim(-0.5, 95)
116 | ax.set_ylim(0,0.9)
117 | ax.set_xlabel('percentage of tokens removed')
118 | ax.set_ylabel('accuracy')
119 | ax.legend(fontsize=12)
120 |
121 | plt.tight_layout()
122 | plt.show()
123 | return
124 |
125 |
126 | def plot_all_linear_probing(res_path, dataset_path):
127 | plt.rcParams.update({'font.size': 15})
128 | fig, axes = plt.subplots(ncols=3, figsize=(13, 4))
129 |
130 | ## ACCURACY
131 | accs = []
132 | for layer in LAYERS:
133 | file = res_path / 'linear_probing' / layer / 'acc.npy'
134 | acc = np.load(file)[1:]
135 | df = pd.DataFrame(acc, columns=['top-1 acc'])
136 | df['layer'] = layer.split('-')[0]
137 | df['block'] = int(layer.split('-')[1]) + 1
138 | accs.append(df)
139 | accs = pd.concat(accs)
140 | sns.lineplot(accs, y='top-1 acc', x='block', hue='layer', ax=axes[0])
141 | axes[0].set_title('linear probing')
142 | axes[0].set_xticks(np.arange(1, 13))
143 |
144 | accs = []
145 | for layer in LAYERS:
146 | block = layer.split('-')[1]
147 | layer_type = layer.split('-')[0]
148 | if 'attn' in layer:
149 | layer = f'hs-attn_{block}'
150 | elif 'mlp' in layer:
151 | layer = f'hs-mlp_{block}'
152 | else:
153 | layer = f'hs_{block}'
154 | dec = get_class_embed(
155 | res_path, dataset_path, 'vit_b_32', layer, 'pos', normalize=False
156 | )
157 | dec = dec[:, 1:]
158 | acc = torch.sum((dec == 0), axis=0) / dec.shape[0]
159 | df = pd.DataFrame(acc.detach().numpy(), columns=['top-1 acc'])
160 | df['layer'] = layer_type
161 | df['block'] = int(block) + 1
162 | accs.append(df)
163 | accs = pd.concat(accs)
164 | sns.lineplot(accs, y='top-1 acc', x='block', hue='layer', ax=axes[1])
165 | axes[1].set_title('cls projection')
166 | axes[1].set_xticks(np.arange(1, 13))
167 |
168 | ## PERTURBATION
169 | model_name = 'vit_b_32'
170 |
171 | labels = torch.arange(48) / 49 * 100
172 |
173 | # Plot our method
174 | f = res_path / 'perturbation' / model_name / 'negative_grads.pt'
175 | neg_perturb = torch.load(f, map_location='cpu')
176 | neg_perturb = torch.flip(neg_perturb, dims=(0,)).detach().numpy()
177 | sns.lineplot(x=labels, y=neg_perturb, ax=axes[2], label='NEG cls-based removal')
178 |
179 | # Plot linear probing
180 | f = res_path / 'perturbation' / model_name / 'negative_linear-probe.pt'
181 | linear_perturb = torch.load(f, map_location='cpu')
182 | linear_perturb = torch.flip(linear_perturb, dims=(0,)).detach().numpy()
183 | sns.lineplot(x=labels, y=linear_perturb, ax=axes[2], label='NEG probe removal')
184 |
185 | # Plot our method
186 | f = res_path / 'perturbation' / model_name / 'positive_grads.pt'
187 | neg_perturb = torch.load(f, map_location='cpu')
188 | neg_perturb = torch.flip(neg_perturb, dims=(0,)).detach().numpy()
189 | sns.lineplot(x=labels, y=neg_perturb, ax=axes[2], label='POS cls-based removal')
190 |
191 | # Plot linear probing
192 | f = res_path / 'perturbation' / model_name / 'positive_linear_probe.pt'
193 | linear_perturb = torch.load(f, map_location='cpu')
194 | linear_perturb = torch.flip(linear_perturb, dims=(0,)).detach().numpy()
195 | sns.lineplot(x=labels, y=linear_perturb, ax=axes[2], label='POS probe removal')
196 |
197 | # Plot random perturbation
198 | f = res_path / 'perturbation' / model_name / 'perturb_random.pt'
199 | rand_perturb = torch.mean(torch.load(f, map_location='cpu'), dim=0)
200 | rand_perturb = torch.flip(rand_perturb, dims=(0,)).detach().numpy()
201 | sns.lineplot(x=labels, y=rand_perturb, ax=axes[2], label='random removal')
202 |
203 | ax = axes[2]
204 | ax.hlines(
205 | xmin=-0.5, xmax=100.5, y=neg_perturb[0], colors='dimgray', linestyles='--', lw=2,
206 | label='baseline accuracy'
207 | )
208 | ax.set_xticks(np.arange(0, 100, 10))
209 | ax.set_xlim(-0.5, 95)
210 | ax.set_ylim(0,0.9)
211 | ax.set_xlabel('percentage of tokens removed')
212 | ax.set_ylabel('accuracy')
213 | ax.legend(prop={'size': 10})
214 |
215 | plt.tight_layout()
216 | f = res_path / 'figures' / f'compare_linear.png'
217 | plt.savefig(f, dpi=300)
218 | plt.show()
219 | return
--------------------------------------------------------------------------------
/src/perturbation/explain_perturbation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | import numpy as np
5 | import torch
6 | from torch.utils.data import DataLoader
7 | from tqdm import tqdm
8 | from transformers import ViTImageProcessor
9 |
10 | from src.datasets.imagenet import ImagenetDatasetS
11 | from src.gradients import CheferGradients, ViTGradients
12 | from src.models.vit_edited import vit_base_patch32_224, vit_base_patch16_224
13 |
14 |
15 | class ViTPerturb():
16 | def __init__(self, model_name, project_path, imgs_path, device='cpu'):
17 | self.model_name = model_name
18 | self.device = device
19 |
20 | self.project_path = project_path
21 | self.imgs_path = imgs_path
22 | self.perturb_path = project_path / 'results' / 'perturbation' / model
23 | self.perturb_path.mkdir(parents=True, exist_ok=True)
24 | (self.perturb_path / 'masks').mkdir(parents=True, exist_ok=True)
25 |
26 | self._load_model()
27 |
28 | def _load_model(self):
29 | # Get model
30 | if self.model_name == 'vit_b_32':
31 | self.model = vit_base_patch32_224(
32 | pretrained=True, pretrained_cfg=True
33 | ).to(self.device)
34 | source = 'google/vit-base-patch32-224-in21k'
35 | self.n_tokens = 49
36 | elif self.model_name == 'vit_b_16':
37 | self.model = vit_base_patch16_224(
38 | pretrained=True, pretrained_cfg=True
39 | ).to(self.device)
40 | source = 'google/vit-base-patch16-224-in21k'
41 | self.n_tokens = 196
42 | self.model.eval()
43 |
44 | # Get image transform
45 | self.img_transform = ViTImageProcessor.from_pretrained(source)
46 | return
47 |
48 | def _get_dataset(self):
49 | self.dataset = ImagenetDatasetS(self.imgs_path)
50 | self.dataloader = DataLoader(self.dataset, batch_size=1, collate_fn=lambda x: x[0])
51 | return
52 |
53 | def compute(self, perturb_type='negative', source_type='grads'):
54 | self.perturb_type = perturb_type
55 | self.source_type = source_type
56 |
57 | # Get mask
58 | self._get_dataset()
59 | if self.source_type == 'grads':
60 | tokens_mask = self.get_accumulated_grads()
61 | elif self.source_type == 'grads_chefer':
62 | tokens_mask = self.get_grads_chefer()
63 | elif self.source_type == 'linear_probe':
64 | tokens_mask = self.get_accumulated_linear_probe()
65 |
66 | # Compute accuracy
67 | accs = []
68 | for n in range(1, self.n_tokens):
69 | accs.append(self._compute(n, tokens_mask))
70 | accs = torch.stack(accs).to(self.device)
71 |
72 | # Save accuracy
73 | f = self.perturb_path / f'{self.perturb_type}_{self.source_type}.pt'
74 | torch.save(accs, f)
75 |
76 | return
77 |
78 | def _compute(self, n_tokens, mask):
79 | self._get_dataset()
80 | acc = []
81 | for id, data in tqdm(
82 | enumerate(self.dataloader), total=len(self.dataloader)
83 | ):
84 | # Get image features
85 | img_ft = self.img_transform(data['img'], return_tensors="pt")
86 | img_ft = img_ft['pixel_values'].to(self.device)
87 |
88 | # Get decoded tokens and add CLS
89 | tokens_mask = mask[id][:n_tokens].to(self.device)
90 | tokens_mask = torch.cat((torch.tensor([0,]).to(self.device), tokens_mask))
91 |
92 | # Compute hidden states
93 | with torch.no_grad():
94 | out = self.model(img_ft, tokens_mask)
95 |
96 | # Compute accuracy
97 | pred = out.topk(1)[1]
98 | cat_acc = torch.squeeze((pred == data['index']).long())
99 | acc.append(cat_acc)
100 |
101 | # Save accuracy
102 | acc = torch.hstack(acc).to(self.device)
103 | acc = torch.sum(acc) / 4550
104 | print(f'acc {n_tokens}: {acc}', flush=True)
105 |
106 | return acc
107 |
108 | def get_accumulated_grads(self):
109 | f = self.perturb_path / 'masks' / f'{self.source_type}.pt'
110 | if f.is_file():
111 | tokens_mask = torch.load(f).to(self.device)
112 | else:
113 | tokens_mask = []
114 | for data in tqdm(self.dataloader, total=len(self.dataloader)):
115 | concept=data['imagenet_id']
116 | img_idx=data['img_index']
117 | cat_idx=data['index']
118 |
119 | # Compute grads
120 | g = ViTGradients(self.model_name, self.project_path, self.imgs_path)
121 | grads, _ = g.compute(concept, img_idx, cat_idx)
122 |
123 | # Accumulate over blocks
124 | importance = []
125 | for b in range(12):
126 | importance.append(torch.sum(grads[b], dim=0)[0])
127 | mask = torch.sum(torch.stack(importance), dim=0)[1:]
128 | mask = - mask
129 |
130 | # Orden in terms of importance
131 | mask = mask.topk(self.n_tokens, dim=0)[1] + 1
132 | tokens_mask.append(mask)
133 |
134 | tokens_mask = torch.stack(tokens_mask).to(self.device)
135 | torch.save(tokens_mask, f)
136 |
137 | # Flip order if positive perturbation
138 | if self.perturb_type == 'positive':
139 | tokens_mask = torch.flip(tokens_mask, dims=[1])
140 |
141 | return tokens_mask
142 |
143 | def get_grads_chefer(self):
144 | f = self.perturb_path / 'masks' / f'{self.source_type}.pt'
145 | if f.is_file():
146 | tokens_mask = torch.load(f).to(self.device)
147 | else:
148 | tokens_mask = []
149 | for data in tqdm(self.dataloader, total=len(self.dataloader)):
150 | concept=data['imagenet_id']
151 | img_idx=data['img_index']
152 | cat_idx=data['index']
153 |
154 | # Compute grads
155 | g = CheferGradients(self.model_name, self.project_path, self.imgs_path)
156 | R = g.compute(concept, img_idx, cat_idx)
157 | mask = R.topk(self.n_tokens, dim=0)[1] + 1
158 | tokens_mask.append(mask)
159 |
160 | tokens_mask = torch.stack(tokens_mask).to(self.device)
161 | torch.save(tokens_mask, f)
162 |
163 | # Flip order if positive perturbation
164 | if self.perturb_type == 'positive':
165 | tokens_mask = torch.flip(tokens_mask, dims=[1])
166 |
167 | return tokens_mask
168 |
169 | def get_accumulated_linear_probe(self):
170 | # Accumulate linear probing results over blocks
171 | decs = []
172 | for b in range(12):
173 | path = project_path / 'results/linear_probing' / f'hs-{b}'
174 | b_decs = []
175 | for t in range(1, 50):
176 | f = path / f'pos_t{t}.npy'
177 | dec = np.load(f)
178 | b_decs.append(dec)
179 | decs.append(torch.Tensor(np.vstack(b_decs)).to(self.device).T)
180 | decs = torch.sum(torch.stack(decs), dim=0)
181 |
182 | tokens_mask = decs.topk(k=decs.shape[1], dim=1, largest=False)[1]
183 | tokens_mask = tokens_mask + 1
184 |
185 | # Flip order if positive perturbation
186 | if self.perturb_type == 'positive':
187 | tokens_mask = torch.flip(tokens_mask, dims=[1])
188 |
189 | return tokens_mask
190 |
191 | def compute_random(self, perms=10):
192 | random_accs = []
193 | for p in range(perms):
194 | # Get random mask
195 | random_mask = torch.randperm(49, generator=torch.manual_seed(p)) + 1
196 |
197 | # Compute accuracy
198 | p_accs = []
199 | for n in range(1, self.n_tokens):
200 | mask = random_mask[:n]
201 | mask = torch.cat((torch.tensor([0,]).to(device), mask.to(device)))
202 | p_accs.append(self._compute_random(mask))
203 | random_accs.append(torch.stack(p_accs).to(self.device))
204 |
205 | random_accs = torch.stack(random_accs).to(self.device)
206 | f = self.perturb_path / f'perturb_random.pt'
207 | torch.save(random_accs, f)
208 |
209 | return
210 |
211 | def _compute_random(self, tokens_mask):
212 | acc = []
213 | for _, data in tqdm(
214 | enumerate(self.dataloader), total=len(self.dataloader)
215 | ):
216 | # Get image features
217 | img_ft = self.img_transform(data['img'], return_tensors="pt")
218 | img_ft = img_ft['pixel_values'].to(self.device)
219 |
220 | # Compute hidden states
221 | with torch.no_grad():
222 | out = self.model(img_ft, tokens_mask)
223 |
224 | # Get sample accuracy
225 | pred = out.topk(1)[1]
226 | cat_acc = torch.squeeze((pred == data['index']).long())
227 | acc.append(cat_acc)
228 |
229 | # Compute accuracy
230 | acc = torch.hstack(acc).to(self.device)
231 | acc = torch.sum(acc) / 4550
232 | print(f'acc {len(tokens_mask)}: {acc}', flush=True)
233 |
234 | return acc
235 |
236 |
237 | if __name__ == '__main__':
238 | parser = argparse.ArgumentParser()
239 | parser.add_argument(
240 | '-pp', action='store', required=True,
241 | help='Path to the folder containing the source code.'
242 | )
243 | parser.add_argument(
244 | '-dp', action='store', required=True,
245 | help='Path to the folder containing the dataset.'
246 | )
247 | parser.add_argument(
248 | '-m', action='store', required=True,
249 | help='Select which model to run. Can be one of the following options: \
250 | vit_16, vit_32.'
251 | )
252 | parser.add_argument(
253 | '-random', action='store_true',help='Compute random mask'
254 | )
255 | parser.add_argument(
256 | '-pt', action='store', required=True,
257 | help='Perturbation type: positive or negative'
258 | )
259 | parser.add_argument(
260 | '-st', action='store', required=True,
261 | help='Source for the importance weights: grads, grads_chefer or linear_probe'
262 | )
263 |
264 | args = parser.parse_args()
265 | model = args.m
266 | device = "cuda" if torch.cuda.is_available() else "cpu"
267 | random = args.random
268 | perturb_type = args.pt
269 | source_type = args.st
270 |
271 | project_path = Path(args.pp)
272 | data_path = Path(args.dp)
273 |
274 | if random == True:
275 | p = ViTPerturb(model, project_path, data_path, device)
276 | p.compute_random(perturb_type=perturb_type, source_type=source_type)
277 | else:
278 | p = ViTPerturb(model, project_path, data_path, device)
279 | p.compute(perturb_type=perturb_type, source_type=source_type)
--------------------------------------------------------------------------------
/src/plots/class_ident.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import pandas as pd
4 | from scipy.stats import wilcoxon
5 | import seaborn as sns
6 | import torch
7 |
8 | from src.identifiability import MODEL_MAP
9 | from src.identifiability import (
10 | get_class_embed, get_ident_change_rate, get_ident_mean_evolution,
11 | get_ident_segmented, compute_context_diff
12 | )
13 | from src.vis import Vis
14 |
15 |
16 | def print_identifiability(model_name, res_path, dataset_path, layer=11):
17 | """
18 | Compute and print class identifiability measures.
19 | """
20 | # Get class identifiability
21 | dec = get_class_embed(res_path, dataset_path, model_name, f'hs_{layer}', 'pos')[:, 1:]
22 |
23 | # Compute class identifiability rate
24 | avg_percentage = (torch.sum(dec == 0, dim=1) / dec.shape[1]).float().mean() * 100
25 | print(f'- Average percentage of class identifiable image tokens per image: {avg_percentage}')
26 |
27 | # Compute percentage of images with at least one identifiable token
28 | image_percentage = torch.sum((torch.sum((dec == 0), dim=1) > 0)) / dec.shape[0] * 100
29 | print(
30 | f'- Percentage of images with at least one class identifiable token: {image_percentage}'
31 | )
32 |
33 | return
34 |
35 |
36 | def label_token(row):
37 | """
38 | Assign to token its correct label.
39 | """
40 | if row['token'] == 0:
41 | return '[CLS]'
42 | else:
43 | return 'image'
44 |
45 |
46 | def plot_identifiability_evolution(res_path, dataset_path):
47 | """
48 | Plot identifiability evolution over blocks.
49 | """
50 | plt.rcParams.update({'font.size': 16})
51 | fig, axes = plt.subplots(nrows=7, figsize=(13, 30))
52 | for model_name, ax in zip(MODEL_MAP.keys(), axes.flat):
53 | if 'large' in model_name:
54 | n_layers = 24
55 | else:
56 | n_layers = 12
57 |
58 | dfs = []
59 | pvals = []
60 | for b in range(n_layers):
61 | # Get class identifiability of block
62 | dec = get_class_embed(
63 | res_path, dataset_path, model_name, f'hs_{b}', 'pos', normalize=True
64 | )
65 |
66 | # Compare to random model
67 | if 'miil' in model_name:
68 | random_model = f'vit_b_16_random'
69 | else:
70 | random_model = f'{model_name}_random'
71 | rand_dec = get_class_embed(
72 | res_path, dataset_path, random_model, f'hs_{b}', 'pos', normalize=True
73 | )
74 | wx = wilcoxon(dec.flatten(), rand_dec.flatten(), alternative='greater')
75 | pvals.append(np.round(wx.pvalue, 3))
76 |
77 | # Add to df
78 | df = (
79 | pd.DataFrame(pd.DataFrame(dec).stack())
80 | .reset_index(names=['image', 'token'])
81 | .rename(columns={0:'class identifiability'})
82 | )
83 | if 'gap' not in model_name:
84 | df['token label'] = df.apply(lambda row: label_token(row), axis=1)
85 | df['block'] = b + 1
86 | dfs.append(df)
87 | dfs = pd.concat(dfs)
88 |
89 | # Plot
90 | if 'gap' not in model_name:
91 | sns.boxplot(
92 | dfs, x='block', y='class identifiability', hue='token label',
93 | fliersize=0, palette=sns.color_palette("Set2"), ax=ax
94 | )
95 | sns.move_legend(ax, 'lower right')
96 | else:
97 | sns.boxplot(
98 | dfs, x='block', y='class identifiability', fliersize=0,
99 | palette=sns.color_palette(['#fc8d62']), ax=ax
100 | )
101 | ax.hlines(xmin=-1, xmax=n_layers, y=0.5, colors='dimgray', linestyles='--', lw=2)
102 |
103 | # Add significance stars
104 | sig = np.where(np.array(pvals) < 0.05)[0]
105 | ax.scatter(sig, [1.1] * len(sig), marker='*', c='grey', s=45)
106 |
107 | ax.set_xlim(-0.5, int(n_layers))
108 | ax.set_title(f'{MODEL_MAP[model_name]}')
109 |
110 | plt.tight_layout()
111 | f = res_path / 'figures' / f'class_identifiability_evolution.png'
112 | plt.savefig(f, dpi=300)
113 | plt.show()
114 |
115 | return
116 |
117 |
118 | def compare_identifiability_evolution(res_path, dataset_path, cifar_path):
119 | # Get evolution
120 | df = []
121 | for model in MODEL_MAP.keys():
122 | if 'cifar' in model:
123 | ident = get_ident_mean_evolution(model, cifar_path, res_path)
124 | else:
125 | ident = get_ident_mean_evolution(model, dataset_path, res_path)
126 | df.append(ident)
127 | df = pd.concat(df)
128 | df['class identifiability'] = df['class identifiability'].astype('float')
129 |
130 | # Plot
131 | plt.rcParams.update({'font.size': 17})
132 | fig, axes = plt.subplots(ncols=2, figsize=(13, 3.5))
133 |
134 | cls_df = df.loc[df['token type'] == '[CLS]']
135 | sns.lineplot(
136 | cls_df, x='block', y='class identifiability', hue='model',
137 | ax=axes[0], marker='*', markersize=10
138 | )
139 | axes[0].set_title('[CLS] token')
140 |
141 | img_df = df.loc[df['token type'] == 'image']
142 | sns.lineplot(
143 | img_df, x='block', y='class identifiability', hue='model',
144 | ax=axes[1], marker='*', markersize=10
145 | )
146 | axes[1].set_title('image tokens', fontsize=20)
147 |
148 | for ax in axes.flat:
149 | ax.set_ylim(0.4, 1.05)
150 | ax.set_ylabel('class identifiability', fontsize=20)
151 | ax.set_xticks(np.arange(1, 13))
152 | ax.set_xticklabels(np.arange(1, 13))
153 | ax.set_xticklabels([])
154 | ax.set_xlim((0.8, 12.2))
155 | ax.set_xlabel('normalized block', fontsize=20)
156 | ax.hlines(
157 | xmin=0.8, xmax=12.2, y=0.5, colors='dimgray', linestyles='--', lw=2,
158 | label='Chance Level'
159 | )
160 |
161 | # Unify legend
162 | lines, labels = axes[1].get_legend_handles_labels()
163 | lgd = fig.legend(
164 | lines, labels, loc='center right', #nrow=7,
165 | bbox_to_anchor=(1.15, 0.5),
166 | )
167 | for ax in axes.flat:
168 | ax.get_legend().remove()
169 |
170 | f = res_path / 'figures' / f'mean_evolution.png'
171 | plt.savefig(f, dpi=300, bbox_inches='tight')
172 | plt.show()
173 | return
174 |
175 |
176 | def plot_logits_increment(res_path, dataset_path):
177 | """
178 | Plot percentage of image tokens that increment the logits of the correct class.
179 | """
180 |
181 | fig, axes = plt.subplots(nrows=4, ncols=2, figsize=(13, 13))
182 | for model_name, ax in zip(MODEL_MAP.keys(), axes.flat):
183 |
184 | if 'large' in model_name:
185 | n_layers = 24
186 | else:
187 | n_layers = 12
188 |
189 | # Get change rate
190 | pers = get_ident_change_rate(model_name, dataset_path, res_path, n_layers)
191 |
192 | # Plot
193 | sns.lineplot(pers, x='block', y='change rate', marker='o', markersize=9, ax=ax)
194 | ax.set_xlim(1.8, int(n_layers) + 0.2)
195 | ax.set_xticks(np.arange(2, int(n_layers) + 1))
196 | ax.xaxis.set_tick_params(labelsize=10)
197 | ax.set_ylim(0.5, 1)
198 | ax.set_yticks(np.arange(0.6, 1, 0.1))
199 | ax.yaxis.set_tick_params(labelsize=12)
200 | ax.set_title(MODEL_MAP[model_name])
201 |
202 | fig.delaxes(axes[3,1])
203 |
204 | plt.tight_layout()
205 | f = res_path / 'figures' / f'prob_increment.png'
206 | plt.savefig(f, dpi=300)
207 | plt.show()
208 |
209 | return
210 |
211 |
212 | def print_attn_perturb(model_name, perturb_type, res_path, dataset_path):
213 | """
214 | Print class identifiability of attention perturbation studies.
215 | """
216 | # Get perturbed identifiability
217 | f = res_path / 'perturbation' / model_name / f'attn-{perturb_type}_dec.pt'
218 | attn_dec = torch.load(f, map_location='cpu')
219 | attn_dec = 1 - (attn_dec / 1000)
220 | attn_dec_acc = torch.sum(attn_dec == 1) / attn_dec.flatten().shape[0] * 100
221 | print(f'Class identifiability in {perturb_type}: {attn_dec_acc}')
222 |
223 | # Get unperturbed identifiability
224 | dec = get_class_embed(
225 | res_path, dataset_path, model_name, 'hs_11', decoding_type='pos', normalize=True
226 | )[:, 1:]
227 | dec_acc = torch.sum(dec == 1) / dec.flatten().shape[0] * 100
228 | print(f'Class identifiability in unperturbed model: {dec_acc}')
229 |
230 | return
231 |
232 |
233 | def plot_context_diff(model_name, proj_path, dataset_path):
234 | """
235 | Plot identifiability evolution separately for class- and context-labeled tokens.
236 | """
237 | dfs = get_ident_segmented(model_name, proj_path, dataset_path)
238 |
239 | if 'large' in model_name:
240 | nrows = 4
241 | n_layers = 24
242 | height = 10
243 | else:
244 | nrows = 2
245 | n_layers = 12
246 | height = 5
247 |
248 | plt.rcParams.update({'font.size': 12})
249 | fig, axes = plt.subplots(nrows=nrows, ncols=6, figsize=(13, height))
250 | for ax, b in zip(axes.flat, np.arange(n_layers)):
251 | df = dfs.loc[dfs['block'] == b]
252 | _, pval = compute_context_diff(df)
253 | sns.boxplot(
254 | df, y='class identifiability', x='token location',
255 | fliersize=0, ax=ax
256 | )
257 | if pval < 0.05:
258 | ax.set_title(f'block {b+1} *')
259 | else:
260 | ax.set_title(f'block {b+1}')
261 | ax.set_xlabel('')
262 | for ax in axes.flat[1:]:
263 | ax.set_ylabel('')
264 | fig.suptitle(MODEL_MAP[model_name])
265 |
266 | plt.tight_layout()
267 | f = proj_path / 'results/figures' / f'context_{model_name}.png'
268 | plt.savefig(f, dpi=300)
269 | plt.show()
270 |
271 | return
272 |
273 |
274 | def print_context_perturb(model_name, mask_type, proj_path, dataset_path):
275 | """
276 | Print class identifiability of context perturbation studies.
277 | """
278 | # Get identifiability of perturbed model
279 | res_path = proj_path / 'results'
280 | f = res_path / 'perturbation' / model_name / f'no_{mask_type}_tokens_dec.pt'
281 | context_dec = torch.load(f, map_location='cpu')['hs_11']
282 | context_dec_acc = []
283 | for i in context_dec:
284 | i = 1 - (i / 1000)
285 | context_dec_acc.append(torch.sum(i == 1) / i.shape[0])
286 | context_dec_acc = torch.mean(torch.stack(context_dec_acc)) * 100
287 | print(f'Class identifiability with no {mask_type} tokens: {context_dec_acc}')
288 |
289 | # Get identifiability of unperturbed model
290 | im = Vis(proj_path, dataset_path, model_name, device='cpu')
291 | stim_info = im.stim_info
292 | concepts = stim_info['imagenet_id'].unique().tolist()
293 | sgts = []
294 | for c in concepts:
295 | for i in range(5):
296 | gt = im.get_segmentation(c, i).flatten()
297 | sgts.append(gt)
298 | sgts = np.hstack(sgts)
299 |
300 | dec = get_class_embed(
301 | res_path, dataset_path, model_name, 'hs_11', decoding_type='pos', normalize=True
302 | )[:, 1:]
303 | dec = dec.flatten()
304 | if mask_type == 'context':
305 | dec = dec[(sgts == 1).nonzero()]
306 | elif mask_type == 'class_label':
307 | dec = dec[(sgts == 0).nonzero()]
308 | dec_acc = torch.sum(dec == 1) / dec.shape[0] * 100
309 | print(f'Class identifiability in unperturbed model: {dec_acc}')
310 |
311 | return
--------------------------------------------------------------------------------
/src/identifiability.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 | from src.datasets.cifar import MyCIFAR100
7 | from src.datasets.imagenet import ImagenetDatasetS
8 | from src.vis import Vis
9 |
10 |
11 | MODEL_MAP = {
12 | 'vit_b_32': 'ViT-B/32',
13 | 'vit_b_16': 'ViT-B/16',
14 | 'vit_large_16': 'ViT-L/16',
15 | 'vit_miil_16': 'ViT-B/16-MIIL',
16 | 'vit_cifar_16': 'ViT-B/16-CIFAR',
17 | 'deit_ensemble_16': 'ViT-B/16-Refinement',
18 | 'vit_gap_16': 'ViT-B/16-GAP',
19 | }
20 |
21 |
22 | def get_class_embed(
23 | res_path, img_path, model, layer, decoding_type='pos', normalize=False
24 | ):
25 | """Get class projection data.
26 |
27 | Parameters
28 | ----------
29 | res_path : pathlib.Path
30 | Path to results.
31 | img_path : pathlib.Path
32 | Path to dataset.
33 | model : str
34 | Model name. Can be one of the following options: vit_b_16, vit_b_32,
35 | vit_large_16, vit_miil_16, vit_cifar_16, deit_ensemble_16, vit_gap_16.
36 | layer : str
37 | Layer index.
38 | decoding_type : str, optional
39 | Type of projection, by default 'pos'. Can be one of the following:
40 | 'pos', 'probs'.
41 | normalize : bool, optional
42 | Whether to normalize over number of classes, by default False.
43 |
44 | Returns
45 | -------
46 | torch.Tensor
47 | Class projection data.
48 | """
49 | net_path = res_path / 'class_embed'
50 |
51 | # Get dataset info
52 | if 'cifar' in model:
53 | dataset = MyCIFAR100(img_path)
54 | stim_info = dataset.stim_info
55 | concepts = list(stim_info.keys())
56 | else:
57 | dataset = ImagenetDatasetS(img_path)
58 | stim_info = dataset.stim_info
59 | concepts = stim_info['imagenet_id'].unique()
60 |
61 | # Stack decoding info across categories
62 | dec = []
63 | for c in concepts:
64 | f = net_path / model / c / f'{decoding_type}_{layer}.pt'
65 | dec.append(torch.load(f, map_location='cpu'))
66 | dec = torch.vstack(dec)
67 | if normalize == True:
68 | if 'cifar' in model:
69 | dec = 1 - (dec / 100)
70 | else:
71 | dec = 1 - (dec / 1000)
72 | return dec
73 | else:
74 | return dec
75 |
76 |
77 | def get_ident_mean_evolution(model_name, dataset_path, res_path):
78 | """Get mean identifiability evolution.
79 |
80 | Parameters
81 | ----------
82 | model_name : str
83 | Model name. Can be one of the following options: vit_b_16, vit_b_32,
84 | vit_large_16, vit_miil_16, vit_cifar_16, deit_ensemble_16, vit_gap_16.
85 | dataset_path : pathlib.Path
86 | Path to dataset.
87 | res_path : pathlib.Path
88 | Path to results.
89 |
90 | Returns
91 | -------
92 | pandas.DataFrame
93 | Evolution of mean identifiability across layers.
94 | """
95 | if 'large' in model_name:
96 | n_layers = 24
97 | else:
98 | n_layers = 12
99 |
100 | df = []
101 | for b in range(n_layers):
102 | dec = get_class_embed(
103 | res_path, dataset_path, model_name, f'hs_{b}', 'pos', normalize=True
104 | )
105 | if 'gap' in model_name:
106 | dec = dec.mean()
107 | df.append([b+1, 'image', dec])
108 | else:
109 | dec = dec.mean(dim=0)
110 | cls_dec = dec[0].numpy()
111 | df.append([b+1, '[CLS]', cls_dec])
112 | token_dec = torch.mean(dec[1:]).numpy()
113 | df.append([b+1, 'image', token_dec])
114 |
115 | df = pd.DataFrame(df, columns=['block', 'token type', 'class identifiability'])
116 | df['model'] = MODEL_MAP[model_name]
117 |
118 | if model_name == 'vit_large_16':
119 | new_df = []
120 | for idx, b in enumerate(np.arange(1, 25, 2)):
121 | for tt in ['[CLS]', 'image']:
122 | t_df = df.loc[df['token type'] == tt]
123 | mean = t_df.loc[t_df['block'].isin([b, b+1])]['class identifiability'].mean()
124 | new_df.append([idx+1, tt, mean, MODEL_MAP['vit_large_16']])
125 | new_df = pd.DataFrame(
126 | new_df, columns=['block', 'token type', 'class identifiability', 'model']
127 | )
128 | df = new_df
129 |
130 | return df
131 |
132 |
133 | def get_ident_change_rate(model_name, dataset_path, res_path, n_layers=12):
134 | """Get change rate of class identifiability across layers.
135 |
136 | Parameters
137 | ----------
138 | model_name : str
139 | Model name. Can be one of the following options: vit_b_16, vit_b_32,
140 | vit_large_16, vit_miil_16, vit_cifar_16, deit_ensemble_16, vit_gap_16.
141 | dataset_path : pathlib.Path
142 | Path to dataset.
143 | res_path : pathlib.Path
144 | Path to results.
145 | n_layers : int, optional
146 | Number of layers in model, by default 12.
147 |
148 | Returns
149 | -------
150 | pandas.DataFrame
151 | Containing the change rate across layers.
152 | """
153 | pers = []
154 | for b in range(1, n_layers):
155 | i_dec = get_class_embed(res_path, dataset_path, model_name, f'hs_{b-1}', 'probs')[:, 1:]
156 | j_dec = get_class_embed(res_path, dataset_path, model_name, f'hs_{b}', 'probs')[:, 1:]
157 | per = (torch.sum((j_dec - i_dec) > 0) / j_dec.flatten().shape[0]).detach().numpy()
158 | pers.append([b+1, per])
159 | pers = pd.DataFrame(pers, columns=['block', 'change rate'])
160 | pers['change rate'] = pers['change rate'].astype('float')
161 | return pers
162 |
163 |
164 | def get_ident_segmented(model_name, proj_path, dataset_path):
165 | """Get class identifiability separately for class- and context-labeled tokens.
166 |
167 | Parameters
168 | ----------
169 | model_name : str
170 | Model name. Can be one of the following options: vit_b_16, vit_b_32,
171 | vit_large_16, vit_miil_16, vit_cifar_16, deit_ensemble_16, vit_gap_16.
172 | proj_path : pathlib.Path
173 | Path to source code.
174 | dataset_path : pathlib.Path
175 | Path to dataset.
176 |
177 | Returns
178 | -------
179 | pandas DataFrame
180 | Containing class identifiability of class- and context-labeled tokens.
181 | """
182 | # Get segmentation annotations
183 | im = Vis(proj_path, dataset_path, model_name, device='cpu')
184 | stim_info = im.stim_info
185 | concepts = stim_info['imagenet_id'].unique().tolist()
186 | sgts = []
187 | for c in concepts:
188 | for i in range(5):
189 | gt = im.get_segmentation(c, i).flatten()
190 | sgts.append(gt)
191 | sgts = np.hstack(sgts)
192 |
193 | # Get identifiability
194 | if 'large' in model_name:
195 | n_layers = 24
196 | else:
197 | n_layers = 12
198 |
199 | # Save in dataframe
200 | dfs = []
201 | for b in range(n_layers):
202 | dec = get_class_embed(
203 | proj_path / 'results', dataset_path, model_name, f'hs_{b}',
204 | decoding_type='pos', normalize=True
205 | )
206 | if 'gap' not in model_name:
207 | dec = dec[:, 1:] # remove cls token
208 | df = pd.DataFrame(
209 | {'class identifiability': dec.flatten().detach().numpy(), 'token location': sgts}
210 | )
211 | df['block'] = b
212 | dfs.append(df)
213 | dfs = pd.concat(dfs)
214 | dfs['token location'] = dfs['token location'].replace({0: 'context', 1: 'class'})
215 |
216 | return dfs
217 |
218 |
219 | def compute_context_diff(df):
220 | """Compute significant difference in identifiability between class- and context-labeled tokens
221 |
222 | Parameters
223 | ----------
224 | df : pandas DataFrame
225 | Containing class identifiability of class- and context-labeled tokens.
226 |
227 | Returns
228 | -------
229 | tuple
230 | Containing difference value and pvalue,
231 | """
232 | context = df.loc[df['token location'] == 'context']['class identifiability'].to_numpy()
233 | category = df.loc[df['token location'] == 'class']['class identifiability'].to_numpy()
234 |
235 | # Compute true difference
236 | true_diff = np.mean(category) - np.mean(context)
237 |
238 | # Compute difference with shuffled labeles
239 | context_len = context.shape[0]
240 | all_data = np.concatenate((context, category))
241 | random_diffs = []
242 | n_perm = 300
243 | for _ in range(300):
244 | np.random.shuffle(all_data)
245 | context, category = all_data[:context_len], all_data[context_len:]
246 | random_diffs.append(np.mean(category) - np.mean(context))
247 | pval = np.sum(true_diff < random_diffs) / n_perm
248 |
249 | return true_diff, pval
250 |
251 |
252 | def compute_class_similarity_change(model_name, block, layer_type, dataset_path, res_path):
253 | """Compute class similarity change rate of layer.
254 |
255 | Parameters
256 | ----------
257 | model_name : str
258 | Model name. Can be one of the following options: vit_b_16, vit_b_32,
259 | vit_large_16, vit_miil_16, vit_cifar_16, deit_ensemble_16, vit_gap_16.
260 | block : int
261 | Index of block.
262 | layer_type : str
263 | Layer type. Can be one of 'attn' or 'mlp'.
264 | dataset_path : pathlib.Path
265 | Path to dataset.
266 | res_path : pathlib.Path
267 | Path to results.
268 |
269 | Returns
270 | -------
271 | torch.Tensor
272 | Class similarity change rate.
273 | """
274 | dec_layer = get_class_embed(
275 | res_path, dataset_path, model_name, f'hs-{layer_type}_{block}', 'probs'
276 | )
277 | dec = get_class_embed(
278 | res_path, dataset_path, model_name, f'hs_{block-1}', 'probs'
279 | )
280 | return (dec_layer - dec)
281 |
282 |
283 | def compute_residual_match(model_name, dataset_path, res_path, token_type='all'):
284 | """Compute match with the predictions of the residual.
285 |
286 | Parameters
287 | ----------
288 | model_name : str
289 | Model name. Can be one of the following options: vit_b_16, vit_b_32,
290 | vit_large_16, vit_miil_16, vit_cifar_16, deit_ensemble_16, vit_gap_16.
291 | dataset_path : pathlib.Path
292 | Path to dataset.
293 | res_path : pathlib.Path
294 | Path to results.
295 | token_type : str, optional
296 | Compute match with all tokens, or with cls only, by default 'all'.
297 |
298 | Returns
299 | -------
300 | pd.DataFrame
301 | Match with residual
302 | """
303 |
304 | if 'large' in model_name:
305 | n_layers = 24
306 | else:
307 | n_layers = 12
308 |
309 | if 'gap' in model_name:
310 | idxs = torch.arange(196)
311 | elif token_type == 'cls':
312 | idxs = 0
313 | elif (token_type == 'all') & ('32' in model_name):
314 | idxs = torch.arange(50)
315 | elif (token_type == 'all') & ('16' in model_name):
316 | idxs = torch.arange(197)
317 |
318 | data = []
319 | for block in range(1, n_layers):
320 | # Get residual stream prediction
321 | topk_b = get_class_embed(
322 | res_path, dataset_path, model_name, f'hs_{block}', decoding_type='topk'
323 | )[:, idxs, 0]
324 |
325 | # Get attention layer prediction and compute match
326 | topk_attn = get_class_embed(
327 | res_path, dataset_path, model_name, f'hs-attn_{block}', decoding_type='topk'
328 | )[:, idxs, 0]
329 | attn_match = torch.sum(topk_b == topk_attn) / topk_b.flatten().shape[0] * 100
330 | data.append(['attn', block+1, attn_match.detach().numpy()])
331 |
332 | # Get MLP layer prediction and compute match
333 | topk_mlp = get_class_embed(
334 | res_path, dataset_path, model_name, f'hs-mlp_{block}', decoding_type='topk'
335 | )[:, idxs, 0]
336 | mlp_match = torch.sum(topk_b == topk_mlp) / topk_b.flatten().shape[0] * 100
337 | data.append(['mlp', block+1, mlp_match.detach().numpy()])
338 |
339 | # Get previous block residual stream prediction and compute match
340 | topk_prev = get_class_embed(
341 | res_path, dataset_path, model_name, f'hs_{block-1}', decoding_type='topk'
342 | )[:, idxs, 0]
343 | prev_match = torch.sum(topk_b == topk_prev) / topk_b.flatten().shape[0] * 100
344 | data.append(['prev', block+1, prev_match.detach().numpy()])
345 |
346 | # Compute rate of tokens that do not match any prediction of the above
347 | comp_tokens = (topk_b != topk_prev) & (topk_b != topk_attn) & (topk_b != topk_attn)
348 | comp = torch.sum(comp_tokens) / topk_b.flatten().shape[0] * 100
349 | data.append(['comp', block+1, comp.detach().numpy()])
350 |
351 | data = pd.DataFrame(data, columns=['Source', 'block', 'match'])
352 | data['match'] = data['match'].astype('float')
353 | return data
--------------------------------------------------------------------------------
/src/extractor.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from itertools import product
3 | from pathlib import Path
4 | import re
5 |
6 | import numpy as np
7 | import torch
8 | from torch.utils.data import DataLoader
9 | from torchvision.datasets import CIFAR100
10 | from torchvision.models.feature_extraction import create_feature_extractor
11 | from tqdm import tqdm
12 |
13 | from src.datasets.imagenet import ImagenetDatasetS
14 | from src.datasets.cifar import MyCIFAR100
15 | from src.models.load import load_vit
16 |
17 |
18 | class ViTExtractor():
19 | """Project hidden states to class embedding space and save key coefficients.
20 | """
21 | def __init__(
22 | self, model_name, project_path, imgs_path, device='cpu',
23 | pretrained=True
24 | ):
25 | self.pretrained = pretrained
26 | if self.pretrained == True:
27 | self.model_name = model_name
28 | elif self.pretrained == False:
29 | self.model_name = f'{model_name}_random'
30 | self.device = device
31 |
32 | self.project_path = project_path
33 | self.imgs_path = imgs_path
34 | self.hs_path = project_path / 'results' / 'class_embed' / model_name
35 | self.hs_path.mkdir(parents=True, exist_ok=True)
36 |
37 | # super().__init__(self.model_name, project_path, imgs_path, device)
38 |
39 | self._get_model_layers()
40 | self._load_model()
41 |
42 | def _get_model_layers(self):
43 | prefix = 'blocks.'
44 | join = '.'
45 | blocks_layer_types = [
46 | 'add_1', 'attn.getitem_4', 'attn.getitem_5', 'attn.softmax',
47 | 'attn.proj', 'mlp.fc1', 'mlp.fc2'
48 | ]
49 |
50 | if '_large_' in self.model_name:
51 | n_blocks = np.arange(24)
52 | else:
53 | n_blocks = np.arange(12)
54 |
55 | layers = ['head']
56 | for b, l in product(n_blocks, blocks_layer_types):
57 | if l == None:
58 | layers.append(f'{prefix}{b}')
59 | else:
60 | layers.append(f'{prefix}{b}{join}{l}')
61 |
62 | self.layers = layers
63 | return
64 |
65 | def _get_layer_name(self, layer_model):
66 | """Create layer name from layer id."""
67 | if layer_model == 'head':
68 | layer_name = 'cls-out'
69 | else:
70 | block = re.findall(r'\d+', layer_model)[0]
71 | if any(w in layer_model for w in ['.k_', 'getitem_4']):
72 | suffix = f'key'
73 | elif any(w in layer_model for w in ['.q_', 'getitem_3']):
74 | suffix = f'query'
75 | elif any(w in layer_model for w in ['.v_', 'getitem_5']):
76 | suffix = f'value'
77 | elif any(w in layer_model for w in ['attn.softmax']):
78 | suffix = f'attn-w'
79 | elif any(w in layer_model for w in ['attn.proj']):
80 | suffix = f'hs-attn'
81 | elif any(w in layer_model for w in ['mlp.fc1']):
82 | suffix = 'key-mlp'
83 | elif any(w in layer_model for w in ['.mlp', 'mlp.fc2']):
84 | suffix = f'hs-mlp'
85 | else:
86 | suffix = 'hs'
87 | layer_name = f'{suffix}_{block}'
88 | return layer_name
89 |
90 | def _load_model(self):
91 | # Get model
92 | self.model, self.n_tokens, _, self.img_transform = load_vit(
93 | self.model_name, self.device, self.project_path, return_transform=True
94 | )
95 | self.model.eval()
96 |
97 | # Add feature extractor
98 | layers_map = {l: self._get_layer_name(l) for l in self.layers}
99 | self.extractor = create_feature_extractor(self.model, layers_map).to(self.device)
100 |
101 | # Get normalization
102 | if self.model_name == 'vit_gap_16':
103 | self.ln = self.model.fc_norm
104 | else:
105 | self.ln = self.model.norm
106 |
107 | # Get projection matrix
108 | self.cls_proj = self.model.head
109 |
110 | return
111 |
112 | def _get_dataset(self):
113 | self.dataset = ImagenetDatasetS(self.imgs_path)
114 | self.dataloader = DataLoader(
115 | self.dataset, batch_size=5, collate_fn=self._imagenet_collate_batch
116 | )
117 | return
118 |
119 | def _imagenet_collate_batch(self, batch):
120 | assert all(i['imagenet_id']==batch[0]['imagenet_id'] for i in batch)
121 | data = {}
122 | data['imagenet_id'] = batch[0]['imagenet_id']
123 | data['index'] = batch[0]['index']
124 | data['cat'] = batch[0]['cat']
125 | data['imgs'] = [i['img'] for i in batch]
126 | return data
127 |
128 | def extract_hidden_states(self):
129 | self._get_dataset()
130 | acc = []
131 | for _, data in tqdm(
132 | enumerate(self.dataloader), total=len(self.dataloader)
133 | ):
134 | # Prepare concept path
135 | cls_emb_path = self.hs_path / data['imagenet_id']
136 | cls_emb_path.mkdir(parents=True, exist_ok=True)
137 | if self.pretrained == True:
138 | net_ft_path = self.project_path / 'results' / 'net_ft' / self.model_name / data['imagenet_id']
139 | net_ft_path.mkdir(parents=True, exist_ok=True)
140 |
141 | # Get image features
142 | try:
143 | img_ft = self.img_transform(data['imgs'], return_tensors="pt")
144 | img_ft = img_ft['pixel_values'].to(self.device)
145 | except:
146 | img_ft = [self.img_transform(i) for i in data['imgs']]
147 | img_ft = torch.stack(img_ft).to(self.device)
148 |
149 | # Compute hidden states
150 | with torch.no_grad():
151 | out = self.extractor(img_ft)
152 |
153 | # Compute and save projections
154 | for l_name, l_repr in out.items():
155 | if l_name == 'cls-out':
156 | pred = l_repr.topk(1)[1]
157 | cat_acc = torch.squeeze((pred == data['index']).long())
158 | acc.append(cat_acc)
159 |
160 | elif 'hs' in l_name:
161 | # Project to class embedding space
162 | if 'gap' in self.model_name: # add normalization
163 | block = int(l_name.split('_')[-1])
164 | with torch.no_grad():
165 | if 'attn' in l_name:
166 | l_repr = self.model.blocks[block].norm1(l_repr)
167 | elif 'mlp' in l_name:
168 | l_repr = self.model.blocks[block].norm2(l_repr)
169 | preds = self.cls_proj(self.ln(l_repr))
170 | else:
171 | with torch.no_grad():
172 | preds = self.cls_proj(self.ln(l_repr))
173 |
174 | # Get top-5 predictions
175 | top_k = preds.topk(5, dim=-1)[1]
176 | torch.save(top_k, (cls_emb_path / f'topk_{l_name}.pt'))
177 |
178 | # Get correct label position
179 | ordered_idx = torch.argsort(preds, dim=2, descending=True)
180 | label_idx = (ordered_idx == data['index']).nonzero()
181 | pos = label_idx[:, 2].reshape(self.dataset.n_imgs, (self.n_tokens))
182 | torch.save(pos, (cls_emb_path / f'pos_{l_name}.pt'))
183 |
184 | # Get correct label probability
185 | probs = preds[:, :, data['index']].clone()
186 | torch.save(probs, (cls_emb_path / f'probs_{l_name}.pt'))
187 |
188 | # Save key coefficients
189 | elif self.pretrained == True:
190 | l_repr = l_repr.clone()
191 | torch.save(l_repr, (net_ft_path / f'{l_name}.pt'))
192 |
193 | else:
194 | continue
195 |
196 | # Save accuracy
197 | acc = torch.hstack(acc)
198 | file = self.hs_path / 'acc.pt'
199 | torch.save(acc, file)
200 |
201 | return
202 |
203 |
204 | class ExtractorCIFAR100(ViTExtractor):
205 | """Project hidden states to class embedding space and save key coefficients
206 | of the CIFAR model.
207 | """
208 | def __init__(self, model_name, project_path, imgs_path, device, pretrained=True):
209 | self.model_name = model_name
210 | self.pretrained = pretrained
211 | super().__init__(self.model_name, project_path, imgs_path, device, pretrained=pretrained)
212 | self._load_model()
213 | return
214 |
215 | def _get_dataset(self):
216 | self.dataset = CIFAR100(
217 | self.imgs_path, train=False, download=True, transform=self.img_transform
218 | )
219 | self.dataloader = DataLoader(self.dataset, batch_size=1, shuffle=False)
220 | return
221 |
222 | def extract_hidden_states(self):
223 | dataset = MyCIFAR100(self.imgs_path)
224 |
225 | acc = []
226 | for label, data in tqdm(dataset.stim_info.items(), total=len(dataset)):
227 | # Prepare concept path
228 | cls_emb_path = self.hs_path / label
229 | cls_emb_path.mkdir(parents=True, exist_ok=True)
230 |
231 | if self.pretrained == True:
232 | net_ft_path = self.project_path / 'results' / 'net_ft' / self.model_name / label
233 | net_ft_path.mkdir(parents=True, exist_ok=True)
234 |
235 | # Get image features
236 | img_ft = torch.stack([self.img_transform(img) for img in data]).to(self.device)
237 |
238 | # Compute hidden states
239 | with torch.no_grad():
240 | out = self.extractor(img_ft)
241 |
242 | # Save hidden states
243 | for l_name, l_repr in out.items():
244 | if l_name == 'cls-out':
245 | pred = l_repr.topk(1)[1]
246 | cat_acc = torch.squeeze((pred == int(label)).long())
247 | acc.append(cat_acc)
248 |
249 | elif 'hs' in l_name:
250 | with torch.no_grad():
251 | preds = self.cls_proj(self.ln(l_repr))
252 |
253 | top_k = preds.topk(5, dim=-1)[1]
254 | torch.save(top_k, (cls_emb_path / f'topk_{l_name}.pt'))
255 |
256 | ordered_idx = torch.argsort(preds, dim=2, descending=True)
257 | label_idx = (ordered_idx == int(label)).nonzero()
258 | pos = label_idx[:, 2].reshape(dataset.n_imgs, (self.n_tokens))
259 | torch.save(pos, (cls_emb_path / f'pos_{l_name}.pt'))
260 |
261 | probs = preds[:, :, int(label)].clone()
262 | torch.save(probs, (cls_emb_path / f'probs_{l_name}.pt'))
263 |
264 | elif self.pretrained == True:
265 | l_repr = l_repr.clone()
266 | torch.save(l_repr, (net_ft_path / f'{l_name}.pt'))
267 |
268 | else:
269 | continue
270 |
271 | # Save accuracy
272 | acc = torch.hstack(acc)
273 | file = self.hs_path / 'acc.pt'
274 | torch.save(acc, file)
275 | return
276 |
277 |
278 | if __name__ == '__main__':
279 | parser = argparse.ArgumentParser()
280 | parser.add_argument(
281 | '-pp', action='store', required=True,
282 | help='Path to the folder containing the source code.'
283 | )
284 | parser.add_argument(
285 | '-dp', action='store', required=True,
286 | help='Path to the folder containing the dataset.'
287 | )
288 | parser.add_argument(
289 | '-m', action='store', required=True,
290 | help='Select which model to run. Can be one of the following options: \
291 | vit_b_16, vit_b_32, vit_large_16, vit_miil_16, vit_cifar_16, \
292 | deit_ensemble_16, vit_gap_16.'
293 | )
294 | parser.add_argument(
295 | '-pretrained', action='store_true', help='Use pretrained model.'
296 | )
297 |
298 | args = parser.parse_args()
299 | project_path = Path(args.pp)
300 | data_path = Path(args.dp)
301 |
302 | model = args.m
303 | pretrained = args.pretrained
304 |
305 | device = "cuda" if torch.cuda.is_available() else "cpu"
306 |
307 | if model == 'vit_cifar_16':
308 | extractor = ExtractorCIFAR100(
309 | model, project_path, data_path, device=device, pretrained=pretrained
310 | )
311 | extractor.extract_hidden_states()
312 | else:
313 | extractor = ViTExtractor(
314 | model, project_path, data_path, device=device, pretrained=pretrained
315 | )
316 | extractor.extract_hidden_states()
--------------------------------------------------------------------------------
/src/memories.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | import numpy as np
5 | import pandas as pd
6 | from scipy.stats import wilcoxon
7 | import torch
8 |
9 | from src.datasets.cifar import MyCIFAR100
10 | from src.datasets.imagenet import ImagenetDatasetS
11 | from src.models.load import load_vit
12 |
13 |
14 | class Memory:
15 | def __init__(self, proj_path, dataset_path, model, device='cpu'):
16 | self.model_name = model
17 | self.device = device
18 |
19 | self.proj_path = Path(proj_path)
20 | self.dataset_path = Path(dataset_path)
21 | self.net_ft_path = proj_path / 'results' / 'net_ft' / model
22 | self.dec_path = proj_path / 'results' / 'class_embed' / model
23 |
24 | if 'cifar' in model:
25 | self.stim_info = MyCIFAR100(self.dataset_path).stim_info
26 | else:
27 | self.stim_info = ImagenetDatasetS(self.dataset_path).stim_info
28 |
29 | self._load_model()
30 |
31 | def _load_model(self):
32 | # Load model
33 | self.model, self.n_tokens, self.hs_dim, _ = load_vit(
34 | self.model_name, self.device, self.proj_path
35 | )
36 | self.model.eval()
37 |
38 | # Load random model
39 | self.random_model, _, _, _ = load_vit(
40 | self.model_name, self.device, self.proj_path, pretrained=False
41 | )
42 | self.random_model.eval()
43 |
44 | # Define number of layers
45 | if 'large' in self.model_name:
46 | self.n_layers = 24
47 | else:
48 | self.n_layers = 12
49 |
50 | # Define number of classes
51 | if 'cifar' in self.model_name:
52 | self.n_classes = 100
53 | else:
54 | self.n_classes = 1000
55 |
56 | return
57 |
58 | def compute_class_value_agr(self, layer_type):
59 | """
60 | Compute class-value agreement scores.
61 | """
62 | # Normalize class embedding
63 | with torch.no_grad():
64 | self.model.head.weight /= self.model.head.weight.norm(dim=-1, keepdim=True)
65 | self.random_model.head.weight /= self.random_model.head.weight.norm(dim=-1, keepdim=True)
66 |
67 | # Get top-1 logit per class
68 | max_logits = []
69 | pvals = []
70 | random_k = []
71 | for b in range(self.n_layers):
72 | if layer_type == 'mlp':
73 | with torch.no_grad():
74 | self.model.blocks[b].mlp.fc2.weight /= self.model.blocks[b].mlp.fc2.weight.norm(dim=0, keepdim=True)
75 | val_proj = self.model.head.weight @ self.model.blocks[b].mlp.fc2.weight
76 | else:
77 | with torch.no_grad():
78 | self.model.blocks[b].attn.proj.weight /= self.model.blocks[b].attn.proj.weight.norm(dim=0, keepdim=True)
79 | val_proj = self.model.head.weight @ self.model.blocks[b].attn.proj.weight
80 |
81 | logits_top_k = torch.squeeze(val_proj.topk(1, dim=1)[0])
82 | max_logits.append(logits_top_k)
83 |
84 | # Get top-1 logit per class in random model
85 | if layer_type == 'mlp':
86 | with torch.no_grad():
87 | self.random_model.blocks[b].mlp.fc2.weight /= self.random_model.blocks[b].mlp.fc2.weight.norm(dim=0, keepdim=True)
88 | val_proj = self.random_model.head.weight @ self.random_model.blocks[b].mlp.fc2.weight
89 | else:
90 | with torch.no_grad():
91 | self.random_model.blocks[b].attn.proj.weight /= self.random_model.blocks[b].attn.proj.weight.norm(dim=0, keepdim=True)
92 | val_proj = self.random_model.head.weight @ self.random_model.blocks[b].attn.proj.weight
93 |
94 | random_top_k = torch.squeeze(val_proj.topk(1, dim=1)[0]).detach().numpy()
95 | random_k.append(np.mean(random_top_k))
96 |
97 | # Compute statistical difference
98 | wx = wilcoxon(logits_top_k.detach().numpy(), random_top_k, alternative='greater')
99 | pvals.append(wx.pvalue)
100 |
101 | # Save results in dataframe
102 | max_logits = torch.stack(max_logits).flatten().detach().numpy()
103 | blocks = torch.arange(1, self.n_layers+1).repeat_interleave(self.n_classes)
104 | df = pd.DataFrame({'top-1 logits': max_logits, 'block': blocks})
105 |
106 | return df, pvals, random_k
107 |
108 | def compute_key_value_agreement(self, block, layer_type):
109 | """Compute key value agreement rates.
110 | """
111 | res_path = self.proj_path / 'results' / 'memories' / self.model_name
112 | res_path.mkdir(parents=True, exist_ok=True)
113 |
114 | if layer_type == 'mlp':
115 | with torch.no_grad():
116 | val_proj = self.model.head.weight @ self.model.blocks[block].mlp.fc2.weight
117 | elif layer_type == 'attn':
118 | with torch.no_grad():
119 | val_proj = self.model.head.weight @ self.model.blocks[block].attn.proj.weight
120 | key_topk = val_proj.topk(5, dim=1)[1]
121 |
122 | if 'cifar' in self.model_name:
123 | concepts = list(self.stim_info.keys())
124 | indexes = concepts
125 | n_classes = 100
126 | else:
127 | concepts = self.stim_info['imagenet_id'].unique()
128 | indexes = self.stim_info['index'].unique()
129 | n_classes = 1000
130 |
131 | agreement = []
132 | agreement_random = []
133 | logits = []
134 | logits_random = []
135 | for c, c_idx in zip(concepts, indexes):
136 | c_idx = int(c_idx)
137 |
138 | # Get top-5 keys of hidden state
139 | if layer_type == 'mlp':
140 | file = self.net_ft_path / c / f'key-{layer_type}_{block}.pt'
141 | key_val = torch.load(file, map_location=self.device)
142 |
143 | elif layer_type == 'attn':
144 | attn_file = self.net_ft_path / c / f'attn-w_{block}.pt'
145 | attn_data = torch.load(attn_file, map_location=self.device)
146 | val_file = self.net_ft_path / c / f'value_{block}.pt'
147 | val_data = torch.load(val_file, map_location=self.device)
148 | key_val = attn_data @ val_data
149 | key_val = key_val.transpose(1,2).reshape(5, self.n_tokens, self.hs_dim)
150 |
151 | hs_key_topk = key_val.topk(5, dim=-1)[1]
152 |
153 | # Compute top-k value agreement
154 | agreement.append(self._compute_agreement(hs_key_topk, key_topk[c_idx]))
155 | agreement_random.append(self._compute_agreement(
156 | hs_key_topk, key_topk[torch.randint(0, n_classes, (1,))])
157 | )
158 |
159 | # Compute average logits
160 | logits.append(self._get_logits(hs_key_topk, val_proj, c_idx))
161 | logits_random.append(
162 | self._get_logits(hs_key_topk, val_proj, torch.randint(0, n_classes, (1,)))
163 | )
164 |
165 | f_agreement = res_path / f'{layer_type}_{block}_value-class_agreement.pt'
166 | torch.save(torch.stack(agreement), f_agreement)
167 | f_agreement_random = res_path / f'{layer_type}_{block}_value-class_agreement-random.pt'
168 | torch.save(torch.stack(agreement_random), f_agreement_random)
169 |
170 | f_logits = res_path / f'{layer_type}_{block}_value-class_logits.pt'
171 | torch.save(torch.stack(logits), f_logits)
172 | f_logits_random = res_path / f'{layer_type}_{block}_value-class_logits-random.pt'
173 | torch.save(torch.stack(logits_random), f_logits_random)
174 |
175 | return
176 |
177 | def _compute_agreement(self, hs_key_topk, c_key_topk):
178 | agr = torch.isin(hs_key_topk, c_key_topk)
179 | agr = torch.any(agr, dim=-1)
180 | return agr
181 |
182 | def _get_logits(self, hs_key_topk, val_proj, c_idx):
183 | c_logits = []
184 | for img in range(hs_key_topk.shape[0]):
185 | for t in range(hs_key_topk.shape[1]):
186 | ks = hs_key_topk[img, t]
187 | c_logits.append(val_proj[c_idx, ks])
188 | c_logits = torch.stack(c_logits).view(hs_key_topk.shape[0], hs_key_topk.shape[1], -1)
189 | return c_logits
190 |
191 | def compute_composition(self, block, layer_type):
192 | # Create paths
193 | res_path = self.proj_path / 'results' / 'memories' / self.model_name
194 | res_path.mkdir(parents=True, exist_ok=True)
195 |
196 | # Get concepts
197 | if 'cifar' in self.model_name:
198 | concepts = list(self.stim_info.keys())
199 | else:
200 | concepts = self.stim_info['imagenet_id'].unique()
201 |
202 | # Get most activating class per memory value
203 | if layer_type == 'mlp':
204 | with torch.no_grad():
205 | val_proj = self.model.head.weight @ self.model.blocks[block].mlp.fc2.weight
206 | elif layer_type == 'attn':
207 | with torch.no_grad():
208 | val_proj = self.model.head.weight @ self.model.blocks[block].attn.proj.weight
209 | key_topk = torch.squeeze(val_proj.topk(1, dim=0)[1])
210 |
211 | # Get top-5 most activated memories per image and token
212 | match = []
213 | for c in concepts:
214 | if layer_type == 'mlp':
215 | file = self.net_ft_path / c / f'key-{layer_type}_{block}.pt'
216 | key_val = torch.load(file, map_location=self.device)
217 |
218 | elif layer_type == 'attn':
219 | attn_file = self.net_ft_path / c / f'attn-w_{block}.pt'
220 | attn_data = torch.load(attn_file, map_location=self.device)
221 | val_file = self.net_ft_path / c / f'value_{block}.pt'
222 | val_data = torch.load(val_file, map_location=self.device)
223 | key_val = attn_data @ val_data
224 | key_val = key_val.transpose(1,2).reshape(5, self.n_tokens, self.hs_dim)
225 |
226 | hs_key_topk = key_val.topk(5, dim=-1)[1]
227 |
228 | # Get most activating classes for the top-5 activated memories
229 | for img in range(key_val.shape[0]):
230 | for t in range(key_val.shape[1]):
231 | for k in range(5):
232 | k_top = key_topk[hs_key_topk[img, t, k]]
233 | hs_key_topk[img, t, k] = k_top
234 |
235 | # Compute agreement with predictions at output
236 | file = self.dec_path / c / f'topk_hs-{layer_type}_{block}.pt'
237 | preds = torch.load(file, map_location=self.device)[:, :, 0]
238 | img_match = []
239 | for img in range(key_val.shape[0]):
240 | for t in range(key_val.shape[1]):
241 | layer_pred = preds[img, t]
242 | img_match.append(torch.isin(hs_key_topk[img, t], layer_pred))
243 | img_match = torch.stack(img_match)
244 | img_match = img_match.reshape(key_val.shape[0], key_val.shape[1], -1)
245 | match.append(img_match)
246 |
247 | f = res_path / f'{layer_type}_{block}_top5_pred-match.pt'
248 | torch.save(torch.stack(match), f)
249 |
250 | return
251 |
252 |
253 | def compute_key_value_agr_rate(model_name, layer_type, res_path, tokens='all'):
254 | """Compute key-value agreement rate.
255 |
256 | Parameters
257 | ----------
258 | model_name : str
259 | Model name. Can be one of the following options: vit_b_16, vit_b_32,
260 | vit_large_16, vit_miil_16, vit_cifar_16, deit_ensemble_16, vit_gap_16.
261 | layer_type : str
262 | Type of layer. Can be one of 'attn' or 'mlp'.
263 | res_path : pathlib.Path
264 | Path to results.
265 | tokens : str, optional
266 | Use all tokens or cls only, by default 'all'.
267 |
268 | Returns
269 | -------
270 | pd.DataFrame
271 | Key-value agreement rate.
272 | """
273 | if 'large' in model_name:
274 | n_layers = 24
275 | else:
276 | n_layers = 12
277 |
278 | # Get key-value agreement scores
279 | agreement_rate = []
280 | for b in range(n_layers):
281 | f = res_path / 'memories/' / model_name / f'{layer_type}_{b}_value-class_agreement.pt'
282 | agr = torch.load(f, map_location='cpu')
283 | if tokens == 'all':
284 | agr = agr.flatten(start_dim=1)
285 | elif tokens == 'cls':
286 | agr = agr[:,:, 0]
287 | # Compute agreement rate
288 | tokens_agr = torch.sum(agr, dim=1) / agr.shape[1] * 100
289 | agreement_rate.append(tokens_agr)
290 |
291 | # Save results in dataframe
292 | agreement_rate = torch.stack(agreement_rate).flatten().detach().numpy()
293 | agreement_rate = pd.DataFrame(
294 | {'block': torch.arange(1,n_layers+1).repeat_interleave(agr.shape[0]), 'agreement rate': agreement_rate}
295 | )
296 | agreement_rate['agreement rate'] = agreement_rate['agreement rate'].astype('float')
297 |
298 | return agreement_rate
299 |
300 |
301 | if __name__ == '__main__':
302 | parser = argparse.ArgumentParser()
303 | parser.add_argument(
304 | '-pp', action='store', required=True,
305 | help='Path to the folder containing the source code.'
306 | )
307 | parser.add_argument(
308 | '-dp', action='store', required=True,
309 | help='Path to the folder containing the dataset.'
310 | )
311 | parser.add_argument(
312 | '-m', action='store', required=True,
313 | help='Select which model to run. Can be one of the following options: \
314 | vit_16, vit_32'
315 | )
316 | parser.add_argument(
317 | '-lt', action='store', required=True,
318 | help='Hidden state type'
319 | )
320 |
321 | args = parser.parse_args()
322 |
323 | project_path = Path(args.pp)
324 | dataset_path = Path(args.dp)
325 |
326 | model = args.m
327 | layer_type = args.lt
328 | device = "cuda" if torch.cuda.is_available() else "cpu"
329 |
330 | if 'large' in model:
331 | n_layers = 24
332 | else:
333 | n_layers = 12
334 |
335 | for b in range(n_layers):
336 | memory = Memory(project_path, dataset_path, model, device=device)
337 | memory.compute_key_value_agreement(b, layer_type)
338 | memory.compute_composition(b, layer_type)
--------------------------------------------------------------------------------
/src/plots/mech_interp.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import seaborn as sns
4 | import pandas as pd
5 | import torch
6 |
7 | from src.identifiability import MODEL_MAP
8 | from src.identifiability import (
9 | compute_class_similarity_change, compute_residual_match
10 | )
11 | from src.memories import Memory, compute_key_value_agr_rate
12 | from src.vis import Vis
13 |
14 |
15 | def plot_class_building(model_name, res_path, dataset_path):
16 | """
17 | Plot categorical building over blocks and layers.
18 | """
19 | plt.rcParams.update({'font.size': 16})
20 |
21 | if 'large' in model_name:
22 | n_layers = 24
23 | else:
24 | n_layers = 12
25 |
26 | fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(13,5))
27 |
28 | # Plot class similarity change rate
29 | pers = []
30 | for layer_type in ['attn', 'mlp']:
31 | for b in range(1, n_layers):
32 |
33 | diff = compute_class_similarity_change(
34 | model_name, b, layer_type, dataset_path, res_path
35 | )
36 |
37 | if 'gap' not in model_name:
38 | diff_cls = diff[:, 0]
39 | cls_change_rate = torch.sum(diff_cls > 0) / diff_cls.flatten().shape[0]
40 | pers.append([f'{layer_type} - [CLS] ', b+1, cls_change_rate.detach().numpy()])
41 |
42 | diff_img = diff[:, 1:]
43 | diff_change_rate = torch.sum(diff_img > 0) / diff_img.flatten().shape[0]
44 | pers.append([f'{layer_type} - Image', b+1, diff_change_rate.detach().numpy()])
45 |
46 | else:
47 | diff_img = diff
48 | diff_change_rate = torch.sum(diff_img > 0) / diff_img.flatten().shape[0]
49 | pers.append([f'{layer_type} - Image', b+1, diff_change_rate.detach().numpy()])
50 |
51 | pers = pd.DataFrame(pers, columns=['Layer Type', 'block', 'change rate'])
52 | pers['change rate'] = pers['change rate'].astype('float')
53 |
54 | if model_name == 'vit_gap_16':
55 | colors = ["#1f78b4", "#33a02c"]
56 | sns.lineplot(
57 | pers, x='block', y='change rate', hue='Layer Type', marker='*', ax=axes[0],
58 | palette=sns.color_palette(colors)
59 | )
60 | else:
61 | sns.lineplot(
62 | pers, x='block', y='change rate', hue='Layer Type', marker='*', ax=axes[0],
63 | palette=sns.color_palette("Paired", 4)
64 | )
65 | axes[0].hlines(xmin=-1, xmax=n_layers, y=0.5, colors='dimgray', linestyles='--', lw=2)
66 |
67 | axes[0].set_xlim(1.9, int(n_layers) + 0.1)
68 | axes[0].set_xticks(np.arange(2, n_layers + 1))
69 | axes[0].xaxis.set_tick_params(labelsize=10)
70 | axes[0].set_ylim(0, 1)
71 |
72 | axes[0].legend(loc='lower left').set_title('')
73 | axes[0].set_title('class similarity change rate')
74 |
75 | # Plot residual composition
76 | match = compute_residual_match(model_name, dataset_path, res_path)
77 | match = match.pivot(index='block', columns=['Source'])
78 | match.columns = match.columns.droplevel()
79 | match = match[['prev', 'attn', 'mlp', 'comp']]
80 | match.plot(kind='bar', stacked=True, rot=0, ax=axes[1], cmap="Set2")
81 |
82 | axes[1].set_ylabel('% of tokens')
83 | axes[1].xaxis.set_tick_params(labelsize=10)
84 | axes[1].legend(
85 | ['previous', 'attn','mlp', 'composition'], loc='upper left', ncol=2
86 | )
87 | axes[1].set_title('residual stream composition')
88 |
89 | fig.suptitle(MODEL_MAP[model_name])
90 | plt.tight_layout()
91 | f = res_path / 'figures' / f'compositionality_{model_name}.png'
92 | plt.savefig(f, dpi=300)
93 | plt.show()
94 |
95 | return
96 |
97 |
98 | def get_segmented_increases(model_name, layer_type, proj_path, dataset_path):
99 | """
100 | Compute class similarity change separately for class- and context-labeled tokens.
101 | """
102 | # Get segmentations
103 | im = Vis(proj_path, dataset_path, model_name, device='cpu')
104 | stim_info = im.stim_info
105 | concepts = stim_info['imagenet_id'].unique().tolist()
106 | sgts = []
107 | for c in concepts:
108 | for i in range(5):
109 | gt = im.get_segmentation(c, i).flatten()
110 | sgts.append(gt)
111 | sgts = np.hstack(sgts)
112 |
113 | # Compute increases
114 | if 'large' in model_name:
115 | n_layers = 24
116 | else:
117 | n_layers = 12
118 |
119 | dfs = []
120 | for b in range(1, n_layers):
121 | res_path = proj_path / 'results'
122 | dec = compute_class_similarity_change(model_name, b, layer_type, dataset_path, res_path)
123 | if 'gap' not in model_name:
124 | dec = dec[:, 1:]
125 | df = pd.DataFrame({'decoding': dec.flatten().detach().numpy(), 'token location': sgts})
126 | df['block'] = b
127 | dfs.append(df)
128 | dfs = pd.concat(dfs)
129 | dfs['token location'] = dfs['token location'].replace({0: 'context', 1: 'category'})
130 | return dfs
131 |
132 |
133 | def compute_context_diff(model_name, layer_type, proj_path, dataset_path):
134 | """
135 | Compare class similarity change between class- and context-labeled tokens.
136 | """
137 | if 'large' in model_name:
138 | n_layers = 24
139 | else:
140 | n_layers = 12
141 |
142 | dfs = get_segmented_increases(model_name, layer_type, proj_path, dataset_path)
143 | diffs = []
144 | pvals = []
145 | for b in range(1, n_layers):
146 | df = dfs.loc[dfs['block'] == b]
147 | context = df.loc[df['token location'] == 'context']['decoding'].to_numpy()
148 | category = df.loc[df['token location'] == 'category']['decoding'].to_numpy()
149 |
150 | # Compute true difference
151 | true_diff = np.mean(category) - np.mean(context)
152 |
153 | # Compare to random model
154 | context_len = context.shape[0]
155 | all_data = np.concatenate((context, category))
156 | random_diffs = []
157 | n_perm = 300
158 | for p in range(300):
159 | np.random.shuffle(all_data)
160 | context, category = all_data[:context_len], all_data[context_len:]
161 | random_diffs.append(np.mean(category) - np.mean(context))
162 | pval = np.sum(true_diff < random_diffs) / n_perm
163 |
164 | diffs.append(true_diff)
165 | pvals.append(pval)
166 |
167 | df = pd.DataFrame({'diffs': diffs, 'pvals': pvals})
168 | df['model'] = model_name
169 | df['layer'] = layer_type
170 | df['block'] = np.arange(len(diffs))
171 |
172 | return df
173 |
174 |
175 | def plot_categorical_updates(proj_path, dataset_path):
176 | """
177 | Plot categorical updates for all models.
178 | """
179 | plt.rcParams.update({'font.size': 14})
180 | fig, axes = plt.subplots(nrows=7, ncols=2, figsize=(10, 20))
181 | for m_idx, model in enumerate(MODEL_MAP.keys()):
182 | mem = Memory(proj_path, dataset_path, model)
183 |
184 | for l_idx, layer_type in enumerate(['attn', 'mlp']):
185 | df, pvals_l, _ = mem.compute_class_value_agr(layer_type)
186 |
187 | # Plot pvalues
188 | sig = np.where(np.array(pvals_l) < 0.05)[0]
189 | axes[m_idx, l_idx].scatter(sig, [0.8] * len(sig), marker='*', c='grey', s=50)
190 |
191 | # Plot agreement values
192 | sns.boxplot(df, x='block', y='top-1 logits', ax=axes[m_idx, l_idx])
193 | axes[m_idx, l_idx].set_ylim((0,0.85))
194 | axes[m_idx, l_idx].xaxis.set_tick_params(labelsize=9)
195 | axes[m_idx, l_idx].yaxis.set_tick_params(labelsize=12)
196 |
197 | axes[m_idx, l_idx].set_title(f'{MODEL_MAP[model]} - {layer_type}')
198 |
199 | plt.tight_layout()
200 | f = proj_path / 'results/figures' / f'match_score.png'
201 | plt.savefig(f, dpi=300)
202 | plt.show()
203 | return
204 |
205 |
206 | def plot_key_value_agreement_rate(res_path):
207 | """
208 | Plot key-value agreement rate.
209 | """
210 | plt.rcParams.update({'font.size': 14})
211 | fig, axes = plt.subplots(nrows=7, ncols=2, figsize=(10, 20))
212 | for m_idx, model_name in enumerate(MODEL_MAP.keys()):
213 | for l_idx, layer_type in enumerate(['attn', 'mlp']):
214 | df = compute_key_value_agr_rate(model_name, layer_type, res_path)
215 | sns.stripplot(
216 | df, x='block', y='agreement rate', alpha=0.3, zorder=1,
217 | ax=axes[m_idx, l_idx]
218 | )
219 | sns.pointplot(
220 | df, x='block', y='agreement rate', color='red',
221 | markers="d", scale=.75, errorbar=None, ax=axes[m_idx, l_idx]
222 | )
223 | axes[m_idx, l_idx].set_title(f'{MODEL_MAP[model_name]} - {layer_type}')
224 | axes[m_idx, l_idx].set_ylim((0,100))
225 |
226 | axes[m_idx, l_idx].xaxis.set_tick_params(labelsize=9)
227 | axes[m_idx, l_idx].yaxis.set_tick_params(labelsize=12)
228 |
229 | plt.tight_layout()
230 | f = res_path / 'figures' / f'key_val_agr.png'
231 | plt.savefig(f, dpi=300)
232 | plt.show()
233 | return
234 |
235 |
236 | def compare_memory_pairs(proj_path, dataset_path):
237 | """
238 | Compare memory results across ViT variants.
239 | """
240 | plt.rcParams.update({'font.size': 15})
241 | fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 3))
242 |
243 | # Plot class-value agreement
244 | for layer_type, ax in zip(['attn', 'mlp'], axes.flat[:2]):
245 | dfs = []
246 | random_k = []
247 | for model_name in MODEL_MAP.keys():
248 | mem = Memory(proj_path, dataset_path, model_name)
249 | df, _, rk = mem.compute_class_value_agr(layer_type)
250 | random_k.append(rk)
251 |
252 | if model_name == 'vit_large_16':
253 | new_df = []
254 | for idx, b in enumerate(np.arange(1, 25, 2)):
255 | mean = df.loc[df['block'].isin([b, b+1])]['top-1 logits'].mean()
256 | new_df.append([mean, idx+1, MODEL_MAP['vit_large_16']])
257 | new_df = pd.DataFrame(
258 | new_df, columns=['top-1 logits', 'block', 'model']
259 | )
260 | df = new_df
261 | else:
262 | df['model'] = MODEL_MAP[model_name]
263 |
264 | dfs.append(df)
265 |
266 | dfs = pd.concat(dfs)
267 | sns.lineplot(
268 | dfs, x='block', y='top-1 logits', hue='model', ax=ax,
269 | marker='*', markersize=10
270 | )
271 | ax.set_title(layer_type)
272 | ax.set_ylim((0.08, 0.45))
273 | ax.set_xticks(np.arange(1, 13))
274 | ax.set_xticklabels([])
275 | ax.set_xlim((0.8, 12.2))
276 | ax.set_xlabel('normalized layer')
277 |
278 | axes[0].set_ylabel('class-value agreement')
279 | axes[1].set_ylabel('')
280 |
281 | # Plot key-value agreement rate
282 | for layer_type, ax in zip(['attn', 'mlp'], axes[2:].flat):
283 | dfs = []
284 | for model_name in MODEL_MAP.keys():
285 | res_path = proj_path / 'results'
286 | df = compute_key_value_agr_rate(model_name, layer_type, res_path)
287 |
288 | if model_name == 'vit_large_16':
289 | new_df = []
290 | for idx, b in enumerate(np.arange(1, 25, 2)):
291 | mean = df.loc[df['block'].isin([b, b+1])]['agreement rate'].mean()
292 | new_df.append([mean, idx+1, MODEL_MAP['vit_large_16']])
293 | new_df = pd.DataFrame(
294 | new_df, columns=['agreement rate', 'block', 'model']
295 | )
296 | df = new_df
297 | else:
298 | df['model'] = MODEL_MAP[model_name]
299 |
300 | dfs.append(df)
301 |
302 | dfs = pd.concat(dfs)
303 | sns.lineplot(
304 | dfs, x='block', y='agreement rate', hue='model', ax=ax,
305 | marker='*', markersize=10
306 | )
307 | ax.set_title(layer_type)
308 | ax.set_ylim((0,85))
309 | ax.set_xticks(np.arange(1, 13))
310 | ax.set_xticklabels([])
311 | ax.set_xlim((0.8, 12.2))
312 | ax.set_xlabel('normalized layer')
313 |
314 | axes[2].set_ylabel('key-value agreement rate')
315 | axes[3].set_ylabel('')
316 |
317 | lines, labels = axes[1].get_legend_handles_labels()
318 | lgd = fig.legend(
319 | lines, labels, loc='upper center', ncol=4,
320 | bbox_to_anchor=(0.5, 1.25),
321 | )
322 |
323 | for ax in axes.flat:
324 | ax.get_legend().remove()
325 |
326 | plt.tight_layout()
327 | f = proj_path / 'results/figures' / f'compare_all.png'
328 | plt.savefig(f, dpi=300, bbox_inches='tight')
329 | plt.show()
330 | return
331 |
332 |
333 | def plot_agr_rate_diff(res_path, device='cpu'):
334 | """
335 | Plot difference in key-value agreement rate between accurate and non-accurate samples.
336 | """
337 | plt.rcParams.update({'font.size': 14})
338 | fig, axes = plt.subplots(nrows=7, ncols=2, figsize=(10,20))
339 | for m_idx, model_name in enumerate(MODEL_MAP.keys()):
340 | for l_idx, layer_type in enumerate(['attn', 'mlp']):
341 | if 'large' in model_name:
342 | n_layers = 24
343 | else:
344 | n_layers = 12
345 |
346 | # Get accuracy
347 | acc_file = res_path / 'class_embed'/ model_name / 'acc.pt'
348 | acc = torch.load(acc_file, map_location=device)
349 |
350 | diffs = []
351 | pvals = []
352 | for b in range(n_layers):
353 | # Compute difference in accuracy
354 | f = res_path / 'memories/' / model_name / f'{layer_type}_{b}_value-class_agreement.pt'
355 | agr = torch.load(f, map_location='cpu')[:,:, 0].flatten()
356 |
357 | acc_1 = agr[acc==1].float()
358 | agr_acc = torch.sum(acc_1) / acc_1.shape[0] * 100
359 | acc_2 = agr[acc==0].float()
360 | agr_inacc = torch.sum(acc_2) / acc_2.shape[0] * 100
361 |
362 | true_diff = agr_acc - agr_inacc
363 | diffs.append(true_diff.detach().numpy())
364 |
365 | # Compare to random model
366 | random_diffs = []
367 | for p in range(300):
368 | rand_idxs = torch.randperm(len(acc))
369 | r_acc_1 = agr[rand_idxs[:len(acc_1)]].float()
370 | r_agr_acc = torch.sum(r_acc_1) / r_acc_1.shape[0] * 100
371 | r_acc_2 = agr[rand_idxs[len(acc_1):]].float()
372 | r_agr_inacc = torch.sum(r_acc_2) / r_acc_2.shape[0] * 100
373 |
374 | random_diffs.append(r_agr_acc - r_agr_inacc)
375 |
376 | pval = torch.sum(true_diff < torch.stack(random_diffs)) / 300
377 | pvals.append(pval.detach().numpy())
378 |
379 | df = pd.DataFrame({'diff': diffs, 'block': np.arange(1, n_layers+1)})
380 | df['model'] = model_name
381 | df['layer'] = layer_type
382 | df['diff'] = df['diff'].astype('float')
383 |
384 | # Plot pvalues
385 | sig = np.where(np.array(pvals) < 0.05)[0] + 1
386 | axes[m_idx, l_idx].scatter(sig, [69] * len(sig), marker='*', c='grey', s=50)
387 |
388 | sns.lineplot(df, x='block', y='diff', marker='o', ax=axes[m_idx, l_idx])
389 | axes[m_idx, l_idx].set_ylim(-5, 72)
390 | axes[m_idx, l_idx].set_xticks(np.arange(1, n_layers+1))
391 | axes[m_idx, l_idx].set_xlim((0.8, n_layers + 0.2))
392 |
393 | axes[m_idx, l_idx].hlines(
394 | xmin=1, xmax=n_layers+1, y=0, colors='dimgray', linestyles='--', lw=2
395 | )
396 |
397 | axes[m_idx, l_idx].set_ylabel('difference (%)')
398 | axes[m_idx, l_idx].xaxis.set_tick_params(labelsize=9)
399 | axes[m_idx, l_idx].yaxis.set_tick_params(labelsize=12)
400 |
401 | axes[m_idx, l_idx].set_title(f'{MODEL_MAP[model_name]} - {layer_type}')
402 |
403 | plt.tight_layout()
404 | f = res_path / 'figures' / f'acc_agr_rate.png'
405 | plt.savefig(f, dpi=300)
406 | plt.show()
407 |
408 | return
409 |
410 |
411 | def plot_memory_compositionality(res_path):
412 | """
413 | Plot compositionality of memory pairs.
414 | """
415 | plt.rcParams.update({'font.size': 14})
416 | fig, axes = plt.subplots(nrows=4, ncols=2, figsize=(13, 15))
417 | for model_name, ax in zip(MODEL_MAP.keys(), axes.flat):
418 | if 'large' in model_name:
419 | n_layers = 24
420 | else:
421 | n_layers = 12
422 |
423 | # Get compositionality
424 | comps = []
425 | for layer_type in ['attn', 'mlp']:
426 | for b in range(n_layers):
427 | f = res_path / 'memories' / model_name / f'{layer_type}_{b}_top5_pred-match.pt'
428 | comp = torch.load(f, map_location='cpu')#[:, :, 0]
429 | preds = torch.any(comp, dim=-1)
430 | per = torch.sum(torch.any(comp, dim=-1)) / preds.flatten().shape[0] * 100
431 | comps.append([(b+1), layer_type, per.detach().numpy()])
432 |
433 | comps = pd.DataFrame(comps, columns=['block', 'layer type', 'percentage'])
434 | comps['percentage'] = comps['percentage'].astype('float')
435 |
436 | # Plot
437 | sns.lineplot(comps, x='block', y='percentage', hue='layer type', marker='o', ax=ax)
438 | ax.set_xticks(np.arange(1, n_layers+1))
439 | ax.xaxis.set_tick_params(labelsize=10)
440 | ax.set_xlim(1, n_layers)
441 | ax.set_yticks(np.arange(0, 75, 10))
442 | ax.yaxis.set_tick_params(labelsize=12)
443 | ax.set_ylim((0, 75))
444 | ax.set_ylabel('%')
445 |
446 | ax.legend(loc='upper left')
447 | ax.set_title(MODEL_MAP[model_name])
448 |
449 | fig.delaxes(axes[3,1])
450 |
451 | plt.tight_layout()
452 | f = res_path / 'figures' / f'composition.png'
453 | plt.savefig(f, dpi=300)
454 | plt.show()
455 | return
--------------------------------------------------------------------------------
/src/models/vit_edited.py:
--------------------------------------------------------------------------------
1 | """ Vision Transformer (ViT) in PyTorch
2 |
3 | A PyTorch implement of Vision Transformers as described in:
4 |
5 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
6 | - https://arxiv.org/abs/2010.11929
7 |
8 | `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
9 | - https://arxiv.org/abs/2106.10270
10 |
11 | The official jax code is released and available at https://github.com/google-research/vision_transformer
12 |
13 | Acknowledgments:
14 | * The paper authors for releasing code and weights, thanks!
15 | * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
16 | for some einops/einsum fun
17 | * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
18 | * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
19 |
20 | Hacked together by / Copyright 2020, Ross Wightman
21 | """
22 | import math
23 | import logging
24 | from functools import partial
25 | from collections import OrderedDict
26 | from typing import Optional
27 |
28 | import torch
29 | import torch.nn as nn
30 | import torch.nn.functional as F
31 | import torch.utils.checkpoint
32 |
33 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD,\
34 | OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
35 | from timm.models.helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv, checkpoint_seq
36 | from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
37 | from timm.models.registry import register_model
38 |
39 | _logger = logging.getLogger(__name__)
40 |
41 |
42 | def _cfg(url='', **kwargs):
43 | return {
44 | 'url': url,
45 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
46 | 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
47 | 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
48 | 'first_conv': 'patch_embed.proj', 'classifier': 'head',
49 | **kwargs
50 | }
51 |
52 |
53 | default_cfgs = {
54 | # patch models (weights from official Google JAX impl)
55 | 'vit_tiny_patch16_224': _cfg(
56 | url='https://storage.googleapis.com/vit_models/augreg/'
57 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
58 | 'vit_tiny_patch16_384': _cfg(
59 | url='https://storage.googleapis.com/vit_models/augreg/'
60 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
61 | input_size=(3, 384, 384), crop_pct=1.0),
62 | 'vit_small_patch32_224': _cfg(
63 | url='https://storage.googleapis.com/vit_models/augreg/'
64 | 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
65 | 'vit_small_patch32_384': _cfg(
66 | url='https://storage.googleapis.com/vit_models/augreg/'
67 | 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
68 | input_size=(3, 384, 384), crop_pct=1.0),
69 | 'vit_small_patch16_224': _cfg(
70 | url='https://storage.googleapis.com/vit_models/augreg/'
71 | 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
72 | 'vit_small_patch16_384': _cfg(
73 | url='https://storage.googleapis.com/vit_models/augreg/'
74 | 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
75 | input_size=(3, 384, 384), crop_pct=1.0),
76 | 'vit_base_patch32_224': _cfg(
77 | url='https://storage.googleapis.com/vit_models/augreg/'
78 | 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
79 | 'vit_base_patch32_384': _cfg(
80 | url='https://storage.googleapis.com/vit_models/augreg/'
81 | 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
82 | input_size=(3, 384, 384), crop_pct=1.0),
83 | 'vit_base_patch16_224': _cfg(
84 | url='https://storage.googleapis.com/vit_models/augreg/'
85 | 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
86 | 'vit_base_patch16_384': _cfg(
87 | url='https://storage.googleapis.com/vit_models/augreg/'
88 | 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
89 | input_size=(3, 384, 384), crop_pct=1.0),
90 | 'vit_base_patch8_224': _cfg(
91 | url='https://storage.googleapis.com/vit_models/augreg/'
92 | 'B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
93 | 'vit_large_patch32_224': _cfg(
94 | url='', # no official model weights for this combo, only for in21k
95 | ),
96 | 'vit_large_patch32_384': _cfg(
97 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
98 | input_size=(3, 384, 384), crop_pct=1.0),
99 | 'vit_large_patch16_224': _cfg(
100 | url='https://storage.googleapis.com/vit_models/augreg/'
101 | 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
102 | 'vit_large_patch16_384': _cfg(
103 | url='https://storage.googleapis.com/vit_models/augreg/'
104 | 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
105 | input_size=(3, 384, 384), crop_pct=1.0),
106 |
107 | 'vit_large_patch14_224': _cfg(url=''),
108 | 'vit_huge_patch14_224': _cfg(url=''),
109 | 'vit_giant_patch14_224': _cfg(url=''),
110 | 'vit_gigantic_patch14_224': _cfg(url=''),
111 |
112 |
113 | # patch models, imagenet21k (weights from official Google JAX impl)
114 | 'vit_tiny_patch16_224_in21k': _cfg(
115 | url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
116 | num_classes=21843),
117 | 'vit_small_patch32_224_in21k': _cfg(
118 | url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
119 | num_classes=21843),
120 | 'vit_small_patch16_224_in21k': _cfg(
121 | url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
122 | num_classes=21843),
123 | 'vit_base_patch32_224_in21k': _cfg(
124 | url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz',
125 | num_classes=21843),
126 | 'vit_base_patch16_224_in21k': _cfg(
127 | url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
128 | num_classes=21843),
129 | 'vit_base_patch8_224_in21k': _cfg(
130 | url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
131 | num_classes=21843),
132 | 'vit_large_patch32_224_in21k': _cfg(
133 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
134 | num_classes=21843),
135 | 'vit_large_patch16_224_in21k': _cfg(
136 | url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
137 | num_classes=21843),
138 | 'vit_huge_patch14_224_in21k': _cfg(
139 | url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
140 | hf_hub_id='timm/vit_huge_patch14_224_in21k',
141 | num_classes=21843),
142 |
143 | # SAM trained models (https://arxiv.org/abs/2106.01548)
144 | 'vit_base_patch32_224_sam': _cfg(
145 | url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'),
146 | 'vit_base_patch16_224_sam': _cfg(
147 | url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'),
148 |
149 | # DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only)
150 | 'vit_small_patch16_224_dino': _cfg(
151 | url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth',
152 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
153 | 'vit_small_patch8_224_dino': _cfg(
154 | url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth',
155 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
156 | 'vit_base_patch16_224_dino': _cfg(
157 | url='https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth',
158 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
159 | 'vit_base_patch8_224_dino': _cfg(
160 | url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth',
161 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
162 |
163 |
164 | # ViT ImageNet-21K-P pretraining by MILL
165 | 'vit_base_patch16_224_miil_in21k': _cfg(
166 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth',
167 | mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221),
168 | 'vit_base_patch16_224_miil': _cfg(
169 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_1k_miil_84_4-2deb18e3.pth',
170 | mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'),
171 |
172 | 'vit_base_patch16_rpn_224': _cfg(
173 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth'),
174 |
175 | # experimental (may be removed)
176 | 'vit_base_patch32_plus_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95),
177 | 'vit_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95),
178 | 'vit_small_patch16_36x1_224': _cfg(url=''),
179 | 'vit_small_patch16_18x2_224': _cfg(url=''),
180 | 'vit_base_patch16_18x2_224': _cfg(url=''),
181 |
182 | 'vit_base_patch32_224_clip_laion2b': _cfg(
183 | hf_hub_id='laion/CLIP-ViT-B-32-laion2B-s34B-b79K',
184 | hf_hub_filename='open_clip_pytorch_model.bin',
185 | mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
186 | 'vit_large_patch14_224_clip_laion2b': _cfg(
187 | hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K',
188 | hf_hub_filename='open_clip_pytorch_model.bin',
189 | mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=768),
190 | 'vit_huge_patch14_224_clip_laion2b': _cfg(
191 | hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
192 | hf_hub_filename='open_clip_pytorch_model.bin',
193 | mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=1024),
194 | 'vit_giant_patch14_224_clip_laion2b': _cfg(
195 | hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K',
196 | hf_hub_filename='open_clip_pytorch_model.bin',
197 | mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=1024),
198 |
199 | }
200 |
201 |
202 | class Attention(nn.Module):
203 | def __init__(
204 | self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
205 | perturb=False
206 |
207 | ):
208 | super().__init__()
209 | assert dim % num_heads == 0, 'dim should be divisible by num_heads'
210 | self.num_heads = num_heads
211 | head_dim = dim // num_heads
212 | self.scale = head_dim ** -0.5
213 |
214 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
215 | self.attn_drop = nn.Dropout(attn_drop)
216 | self.proj = nn.Linear(dim, dim)
217 | self.proj_drop = nn.Dropout(proj_drop)
218 |
219 | self.perturb = perturb
220 |
221 | def save_attn_gradients(self, attn_gradients):
222 | self.attn_gradients = attn_gradients
223 |
224 | def forward(self, x):
225 | B, N, C = x.shape
226 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
227 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
228 |
229 | attn = (q @ k.transpose(-2, -1)) * self.scale
230 | attn = attn.softmax(dim=-1)
231 | attn = self.attn_drop(attn)
232 |
233 | self.attn_probs = attn
234 |
235 | #attn.register_hook(self.save_attn_gradients)
236 |
237 | if self.perturb == 'self_only':
238 | for t in range(1, attn.shape[-1]):
239 | attn[:, :, t, :] = 0
240 | attn[:, :, t, t] = 1
241 | elif self.perturb == 'no_cls':
242 | for t in range(1, attn.shape[-1]):
243 | attn[:, :, t, 0] = 0
244 |
245 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
246 | self.key_proj_vals = x
247 |
248 | x = self.proj(x)
249 | x = self.proj_drop(x)
250 | return x
251 |
252 |
253 | class LayerScale(nn.Module):
254 | def __init__(self, dim, init_values=1e-5, inplace=False):
255 | super().__init__()
256 | self.inplace = inplace
257 | self.gamma = nn.Parameter(init_values * torch.ones(dim))
258 |
259 | def forward(self, x):
260 | return x.mul_(self.gamma) if self.inplace else x * self.gamma
261 |
262 |
263 | class Block(nn.Module):
264 | def __init__(
265 | self,
266 | dim,
267 | num_heads,
268 | mlp_ratio=4.,
269 | qkv_bias=False,
270 | drop=0.,
271 | attn_drop=0.,
272 | init_values=None,
273 | drop_path=0.,
274 | act_layer=nn.GELU,
275 | norm_layer=nn.LayerNorm,
276 | perturb_type=None,
277 | norm_cls=None,
278 | func_cls=None
279 | ):
280 | super().__init__()
281 | self.norm1 = norm_layer(dim)
282 | self.attn = Attention(
283 | dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop,
284 | proj_drop=drop, perturb=perturb_type
285 | )
286 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
287 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
288 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
289 |
290 | self.norm2 = norm_layer(dim)
291 | self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
292 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
293 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
294 |
295 | self.norm_cls = norm_cls
296 | self.func_cls = func_cls
297 |
298 | def forward(self, x):
299 | attn_x = self.ls1(self.attn(self.norm1(x)))
300 | self.attn_cls = self.func_cls(self.norm_cls(attn_x))
301 | #self.attn_cls = self.func_cls(self.norm_cls(self.attn(self.norm1(x))))
302 | #x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
303 | x = x + self.drop_path1(attn_x)
304 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
305 | return x
306 |
307 |
308 | class ResPostBlock(nn.Module):
309 |
310 | def __init__(
311 | self,
312 | dim,
313 | num_heads,
314 | mlp_ratio=4.,
315 | qkv_bias=False,
316 | drop=0.,
317 | attn_drop=0.,
318 | init_values=None,
319 | drop_path=0.,
320 | act_layer=nn.GELU,
321 | norm_layer=nn.LayerNorm
322 | ):
323 | super().__init__()
324 | self.init_values = init_values
325 |
326 | self.attn = Attention(
327 | dim, num_heads=num_heads, qkv_bias=qkv_bias,
328 | attn_drop=attn_drop, proj_drop=drop
329 | )
330 | self.norm1 = norm_layer(dim)
331 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
332 |
333 | self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
334 | self.norm2 = norm_layer(dim)
335 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
336 |
337 | self.init_weights()
338 |
339 | def init_weights(self):
340 | # NOTE this init overrides that base model init with specific changes for the block type
341 | if self.init_values is not None:
342 | nn.init.constant_(self.norm1.weight, self.init_values)
343 | nn.init.constant_(self.norm2.weight, self.init_values)
344 |
345 | def forward(self, x):
346 | x = x + self.drop_path1(self.norm1(self.attn(x)))
347 | x = x + self.drop_path2(self.norm2(self.mlp(x)))
348 | return x
349 |
350 |
351 | class ParallelBlock(nn.Module):
352 |
353 | def __init__(
354 | self,
355 | dim,
356 | num_heads,
357 | num_parallel=2,
358 | mlp_ratio=4.,
359 | qkv_bias=False,
360 | init_values=None,
361 | drop=0.,
362 | attn_drop=0.,
363 | drop_path=0.,
364 | act_layer=nn.GELU,
365 | norm_layer=nn.LayerNorm
366 | ):
367 | super().__init__()
368 | self.num_parallel = num_parallel
369 | self.attns = nn.ModuleList()
370 | self.ffns = nn.ModuleList()
371 | for _ in range(num_parallel):
372 | self.attns.append(nn.Sequential(OrderedDict([
373 | ('norm', norm_layer(dim)),
374 | ('attn', Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)),
375 | ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
376 | ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
377 | ])))
378 | self.ffns.append(nn.Sequential(OrderedDict([
379 | ('norm', norm_layer(dim)),
380 | ('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)),
381 | ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
382 | ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
383 | ])))
384 |
385 | def _forward_jit(self, x):
386 | x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0)
387 | x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0)
388 | return x
389 |
390 | @torch.jit.ignore
391 | def _forward(self, x):
392 | x = x + sum(attn(x) for attn in self.attns)
393 | x = x + sum(ffn(x) for ffn in self.ffns)
394 | return x
395 |
396 | def forward(self, x):
397 | if torch.jit.is_scripting() or torch.jit.is_tracing():
398 | return self._forward_jit(x)
399 | else:
400 | return self._forward(x)
401 |
402 |
403 | class VisionTransformer(nn.Module):
404 | """ Vision Transformer
405 |
406 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
407 | - https://arxiv.org/abs/2010.11929
408 | """
409 |
410 | def __init__(
411 | self,
412 | img_size=224,
413 | patch_size=16,
414 | in_chans=3,
415 | num_classes=1000,
416 | global_pool='token',
417 | embed_dim=768,
418 | depth=12,
419 | num_heads=12,
420 | mlp_ratio=4.,
421 | qkv_bias=True,
422 | init_values=None,
423 | class_token=True,
424 | no_embed_class=False,
425 | pre_norm=False,
426 | fc_norm=None,
427 | drop_rate=0.,
428 | attn_drop_rate=0.,
429 | drop_path_rate=0.,
430 | weight_init='',
431 | embed_layer=PatchEmbed,
432 | norm_layer=None,
433 | act_layer=None,
434 | block_fn=Block,
435 | perturb_type=None,
436 | block_perturb=None
437 | ):
438 | """
439 | Args:
440 | img_size (int, tuple): input image size
441 | patch_size (int, tuple): patch size
442 | in_chans (int): number of input channels
443 | num_classes (int): number of classes for classification head
444 | global_pool (str): type of global pooling for final sequence (default: 'token')
445 | embed_dim (int): embedding dimension
446 | depth (int): depth of transformer
447 | num_heads (int): number of attention heads
448 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim
449 | qkv_bias (bool): enable bias for qkv if True
450 | init_values: (float): layer-scale init values
451 | class_token (bool): use class token
452 | fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
453 | drop_rate (float): dropout rate
454 | attn_drop_rate (float): attention dropout rate
455 | drop_path_rate (float): stochastic depth rate
456 | weight_init (str): weight init scheme
457 | embed_layer (nn.Module): patch embedding layer
458 | norm_layer: (nn.Module): normalization layer
459 | act_layer: (nn.Module): MLP activation layer
460 | """
461 | super().__init__()
462 | assert global_pool in ('', 'avg', 'token')
463 | assert class_token or global_pool != 'token'
464 | use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
465 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
466 | act_layer = act_layer or nn.GELU
467 |
468 | self.num_classes = num_classes
469 | self.global_pool = global_pool
470 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
471 | self.num_prefix_tokens = 1 if class_token else 0
472 | self.no_embed_class = no_embed_class
473 | self.grad_checkpointing = False
474 |
475 | self.patch_embed = embed_layer(
476 | img_size=img_size,
477 | patch_size=patch_size,
478 | in_chans=in_chans,
479 | embed_dim=embed_dim,
480 | bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
481 | )
482 | num_patches = self.patch_embed.num_patches
483 |
484 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
485 | embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
486 | self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
487 | self.pos_drop = nn.Dropout(p=drop_rate)
488 | self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
489 |
490 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
491 |
492 | self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
493 |
494 | # Classifier Head
495 | self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
496 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
497 |
498 | if block_perturb == 'all':
499 | self.blocks = nn.Sequential(*[
500 | block_fn(
501 | dim=embed_dim,
502 | num_heads=num_heads,
503 | mlp_ratio=mlp_ratio,
504 | qkv_bias=qkv_bias,
505 | init_values=init_values,
506 | drop=drop_rate,
507 | attn_drop=attn_drop_rate,
508 | drop_path=dpr[i],
509 | norm_layer=norm_layer,
510 | act_layer=act_layer,
511 | perturb_type=perturb_type,
512 | norm_cls=self.norm,
513 | func_cls =self.head,
514 | )
515 | for i in range(depth)])
516 | else:
517 | self.blocks = nn.Sequential(*[
518 | block_fn(
519 | dim=embed_dim,
520 | num_heads=num_heads,
521 | mlp_ratio=mlp_ratio,
522 | qkv_bias=qkv_bias,
523 | init_values=init_values,
524 | drop=drop_rate,
525 | attn_drop=attn_drop_rate,
526 | drop_path=dpr[i],
527 | norm_layer=norm_layer,
528 | act_layer=act_layer,
529 | perturb_type=perturb_type,
530 | norm_cls=self.norm,
531 | func_cls =self.head
532 | ) if i == block_perturb
533 | else block_fn(
534 | dim=embed_dim,
535 | num_heads=num_heads,
536 | mlp_ratio=mlp_ratio,
537 | qkv_bias=qkv_bias,
538 | init_values=init_values,
539 | drop=drop_rate,
540 | attn_drop=attn_drop_rate,
541 | drop_path=dpr[i],
542 | norm_layer=norm_layer,
543 | act_layer=act_layer,
544 | norm_cls=self.norm,
545 | func_cls =self.head
546 | )
547 | for i in range(depth)])
548 |
549 | if weight_init != 'skip':
550 | self.init_weights(weight_init)
551 |
552 | def init_weights(self, mode=''):
553 | assert mode in ('jax', 'jax_nlhb', 'moco', '')
554 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
555 | trunc_normal_(self.pos_embed, std=.02)
556 | if self.cls_token is not None:
557 | nn.init.normal_(self.cls_token, std=1e-6)
558 | named_apply(get_init_weights_vit(mode, head_bias), self)
559 |
560 | def _init_weights(self, m):
561 | # this fn left here for compat with downstream users
562 | init_weights_vit_timm(m)
563 |
564 | @torch.jit.ignore()
565 | def load_pretrained(self, checkpoint_path, prefix=''):
566 | _load_weights(self, checkpoint_path, prefix)
567 |
568 | @torch.jit.ignore
569 | def no_weight_decay(self):
570 | return {'pos_embed', 'cls_token', 'dist_token'}
571 |
572 | @torch.jit.ignore
573 | def group_matcher(self, coarse=False):
574 | return dict(
575 | stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
576 | blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
577 | )
578 |
579 | @torch.jit.ignore
580 | def set_grad_checkpointing(self, enable=True):
581 | self.grad_checkpointing = enable
582 |
583 | @torch.jit.ignore
584 | def get_classifier(self):
585 | return self.head
586 |
587 | def reset_classifier(self, num_classes: int, global_pool=None):
588 | self.num_classes = num_classes
589 | if global_pool is not None:
590 | assert global_pool in ('', 'avg', 'token')
591 | self.global_pool = global_pool
592 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
593 |
594 | def _pos_embed(self, x):
595 | if self.no_embed_class:
596 | # deit-3, updated JAX (big vision)
597 | # position embedding does not overlap with class token, add then concat
598 | x = x + self.pos_embed
599 | if self.cls_token is not None:
600 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
601 | else:
602 | # original timm, JAX, and deit vit impl
603 | # pos_embed has entry for class token, concat then add
604 | if self.cls_token is not None:
605 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
606 | x = x + self.pos_embed
607 | return self.pos_drop(x)
608 |
609 | def forward_features(self, x, tokens_mask=None):
610 | x = self.patch_embed(x)
611 | x = self._pos_embed(x)
612 |
613 | if tokens_mask != None:
614 | x = x[:, tokens_mask, :]
615 |
616 | x = self.norm_pre(x)
617 | if self.grad_checkpointing and not torch.jit.is_scripting():
618 | x = checkpoint_seq(self.blocks, x)
619 | else:
620 | x = self.blocks(x)
621 | x = self.norm(x)
622 | return x
623 |
624 | def forward_head(self, x, pre_logits: bool = False):
625 | if self.global_pool:
626 | x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
627 | x = self.fc_norm(x)
628 | return x if pre_logits else self.head(x)
629 |
630 | def forward(self, x, tokens_mask=None):
631 | x = self.forward_features(x, tokens_mask)
632 | x = self.forward_head(x)
633 | return x
634 |
635 |
636 | def init_weights_vit_timm(module: nn.Module, name: str = ''):
637 | """ ViT weight initialization, original timm impl (for reproducibility) """
638 | if isinstance(module, nn.Linear):
639 | trunc_normal_(module.weight, std=.02)
640 | if module.bias is not None:
641 | nn.init.zeros_(module.bias)
642 | elif hasattr(module, 'init_weights'):
643 | module.init_weights()
644 |
645 |
646 | def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.):
647 | """ ViT weight initialization, matching JAX (Flax) impl """
648 | if isinstance(module, nn.Linear):
649 | if name.startswith('head'):
650 | nn.init.zeros_(module.weight)
651 | nn.init.constant_(module.bias, head_bias)
652 | else:
653 | nn.init.xavier_uniform_(module.weight)
654 | if module.bias is not None:
655 | nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias)
656 | elif isinstance(module, nn.Conv2d):
657 | lecun_normal_(module.weight)
658 | if module.bias is not None:
659 | nn.init.zeros_(module.bias)
660 | elif hasattr(module, 'init_weights'):
661 | module.init_weights()
662 |
663 |
664 | def init_weights_vit_moco(module: nn.Module, name: str = ''):
665 | """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """
666 | if isinstance(module, nn.Linear):
667 | if 'qkv' in name:
668 | # treat the weights of Q, K, V separately
669 | val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1]))
670 | nn.init.uniform_(module.weight, -val, val)
671 | else:
672 | nn.init.xavier_uniform_(module.weight)
673 | if module.bias is not None:
674 | nn.init.zeros_(module.bias)
675 | elif hasattr(module, 'init_weights'):
676 | module.init_weights()
677 |
678 |
679 | def get_init_weights_vit(mode='jax', head_bias: float = 0.):
680 | if 'jax' in mode:
681 | return partial(init_weights_vit_jax, head_bias=head_bias)
682 | elif 'moco' in mode:
683 | return init_weights_vit_moco
684 | else:
685 | return init_weights_vit_timm
686 |
687 |
688 | @torch.no_grad()
689 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
690 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation
691 | """
692 | import numpy as np
693 |
694 | def _n2p(w, t=True):
695 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
696 | w = w.flatten()
697 | if t:
698 | if w.ndim == 4:
699 | w = w.transpose([3, 2, 0, 1])
700 | elif w.ndim == 3:
701 | w = w.transpose([2, 0, 1])
702 | elif w.ndim == 2:
703 | w = w.transpose([1, 0])
704 | return torch.from_numpy(w)
705 |
706 | w = np.load(checkpoint_path)
707 | if not prefix and 'opt/target/embedding/kernel' in w:
708 | prefix = 'opt/target/'
709 |
710 | if hasattr(model.patch_embed, 'backbone'):
711 | # hybrid
712 | backbone = model.patch_embed.backbone
713 | stem_only = not hasattr(backbone, 'stem')
714 | stem = backbone if stem_only else backbone.stem
715 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
716 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
717 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
718 | if not stem_only:
719 | for i, stage in enumerate(backbone.stages):
720 | for j, block in enumerate(stage.blocks):
721 | bp = f'{prefix}block{i + 1}/unit{j + 1}/'
722 | for r in range(3):
723 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
724 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
725 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
726 | if block.downsample is not None:
727 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
728 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
729 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
730 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
731 | else:
732 | embed_conv_w = adapt_input_conv(
733 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
734 | model.patch_embed.proj.weight.copy_(embed_conv_w)
735 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
736 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
737 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
738 | if pos_embed_w.shape != model.pos_embed.shape:
739 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
740 | pos_embed_w,
741 | model.pos_embed,
742 | getattr(model, 'num_prefix_tokens', 1),
743 | model.patch_embed.grid_size
744 | )
745 | model.pos_embed.copy_(pos_embed_w)
746 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
747 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
748 | if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
749 | model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
750 | model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
751 | # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights
752 | # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
753 | # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
754 | # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
755 | for i, block in enumerate(model.blocks.children()):
756 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
757 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
758 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
759 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
760 | block.attn.qkv.weight.copy_(torch.cat([
761 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
762 | block.attn.qkv.bias.copy_(torch.cat([
763 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
764 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
765 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
766 | for r in range(2):
767 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
768 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
769 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
770 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
771 |
772 |
773 | def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()):
774 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from
775 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
776 | _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
777 | ntok_new = posemb_new.shape[1]
778 | if num_prefix_tokens:
779 | posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:]
780 | ntok_new -= num_prefix_tokens
781 | else:
782 | posemb_prefix, posemb_grid = posemb[:, :0], posemb[0]
783 | gs_old = int(math.sqrt(len(posemb_grid)))
784 | if not len(gs_new): # backwards compatibility
785 | gs_new = [int(math.sqrt(ntok_new))] * 2
786 | assert len(gs_new) >= 2
787 | _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
788 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
789 | posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
790 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
791 | posemb = torch.cat([posemb_prefix, posemb_grid], dim=1)
792 | return posemb
793 |
794 |
795 | def _convert_openai_clip(state_dict, model):
796 | out_dict = {}
797 | swaps = [
798 | ('visual.', ''), ('conv1', 'patch_embed.proj'), ('positional_embedding', 'pos_embed'),
799 | ('transformer.resblocks.', 'blocks.'), ('ln_pre', 'norm_pre'), ('ln_post', 'norm'), ('ln_', 'norm'),
800 | ('in_proj_', 'qkv.'), ('out_proj', 'proj'), ('mlp.c_fc', 'mlp.fc1'), ('mlp.c_proj', 'mlp.fc2'),
801 | ]
802 | for k, v in state_dict.items():
803 | if not k.startswith('visual.'):
804 | continue
805 | for sp in swaps:
806 | k = k.replace(sp[0], sp[1])
807 |
808 | if k == 'proj':
809 | k = 'head.weight'
810 | v = v.transpose(0, 1)
811 | out_dict['head.bias'] = torch.zeros(v.shape[0])
812 | elif k == 'class_embedding':
813 | k = 'cls_token'
814 | v = v.unsqueeze(0).unsqueeze(1)
815 | elif k == 'pos_embed':
816 | v = v.unsqueeze(0)
817 | if v.shape[1] != model.pos_embed.shape[1]:
818 | # To resize pos embedding when using model at different size from pretrained weights
819 | v = resize_pos_embed(
820 | v,
821 | model.pos_embed,
822 | 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1),
823 | model.patch_embed.grid_size
824 | )
825 | out_dict[k] = v
826 | return out_dict
827 |
828 |
829 | def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False):
830 | """ convert patch embedding weight from manual patchify + linear proj to conv"""
831 | import re
832 | out_dict = {}
833 | if 'model' in state_dict:
834 | # For deit models
835 | state_dict = state_dict['model']
836 |
837 | if 'visual.class_embedding' in state_dict:
838 | return _convert_openai_clip(state_dict, model)
839 |
840 | for k, v in state_dict.items():
841 | if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
842 | # For old models that I trained prior to conv based patchification
843 | O, I, H, W = model.patch_embed.proj.weight.shape
844 | v = v.reshape(O, -1, H, W)
845 | elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
846 | # To resize pos embedding when using model at different size from pretrained weights
847 | v = resize_pos_embed(
848 | v,
849 | model.pos_embed,
850 | 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1),
851 | model.patch_embed.grid_size
852 | )
853 | elif adapt_layer_scale and 'gamma_' in k:
854 | # remap layer-scale gamma into sub-module (deit3 models)
855 | k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
856 | elif 'pre_logits' in k:
857 | # NOTE representation layer removed as not used in latest 21k/1k pretrained weights
858 | continue
859 | out_dict[k] = v
860 | return out_dict
861 |
862 |
863 | def _create_vision_transformer(variant, pretrained=False, **kwargs):
864 | if kwargs.get('features_only', None):
865 | raise RuntimeError('features_only not implemented for Vision Transformer models.')
866 |
867 | pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None))
868 | model = build_model_with_cfg(
869 | VisionTransformer, variant, pretrained,
870 | pretrained_cfg=pretrained_cfg,
871 | pretrained_filter_fn=checkpoint_filter_fn,
872 | pretrained_custom_load='npz' in pretrained_cfg['url'],
873 | **kwargs)
874 | return model
875 |
876 |
877 | @register_model
878 | def vit_tiny_patch16_224(pretrained=False, **kwargs):
879 | """ ViT-Tiny (Vit-Ti/16)
880 | """
881 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
882 | model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
883 | return model
884 |
885 |
886 | @register_model
887 | def vit_tiny_patch16_384(pretrained=False, **kwargs):
888 | """ ViT-Tiny (Vit-Ti/16) @ 384x384.
889 | """
890 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
891 | model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs)
892 | return model
893 |
894 |
895 | @register_model
896 | def vit_small_patch32_224(pretrained=False, **kwargs):
897 | """ ViT-Small (ViT-S/32)
898 | """
899 | model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
900 | model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs)
901 | return model
902 |
903 |
904 | @register_model
905 | def vit_small_patch32_384(pretrained=False, **kwargs):
906 | """ ViT-Small (ViT-S/32) at 384x384.
907 | """
908 | model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
909 | model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs)
910 | return model
911 |
912 |
913 | @register_model
914 | def vit_small_patch16_224(pretrained=False, **kwargs):
915 | """ ViT-Small (ViT-S/16)
916 | NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
917 | """
918 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
919 | model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
920 | return model
921 |
922 |
923 | @register_model
924 | def vit_small_patch16_384(pretrained=False, **kwargs):
925 | """ ViT-Small (ViT-S/16)
926 | NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
927 | """
928 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
929 | model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs)
930 | return model
931 |
932 |
933 | @register_model
934 | def vit_base_patch32_224(pretrained=False, **kwargs):
935 | """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
936 | ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer.
937 | """
938 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
939 | model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs)
940 | return model
941 |
942 |
943 | @register_model
944 | def vit_base_patch32_384(pretrained=False, **kwargs):
945 | """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
946 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
947 | """
948 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
949 | model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
950 | return model
951 |
952 |
953 | @register_model
954 | def vit_base_patch16_224(pretrained=False, **kwargs):
955 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
956 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
957 | """
958 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
959 | model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
960 | return model
961 |
962 |
963 | @register_model
964 | def vit_base_patch16_384(pretrained=False, **kwargs):
965 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
966 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
967 | """
968 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
969 | model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
970 | return model
971 |
972 |
973 | @register_model
974 | def vit_base_patch8_224(pretrained=False, **kwargs):
975 | """ ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929).
976 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
977 | """
978 | model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
979 | model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **model_kwargs)
980 | return model
981 |
982 |
983 | @register_model
984 | def vit_large_patch32_224(pretrained=False, **kwargs):
985 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
986 | """
987 | model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
988 | model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs)
989 | return model
990 |
991 |
992 | @register_model
993 | def vit_large_patch32_384(pretrained=False, **kwargs):
994 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
995 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
996 | """
997 | model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
998 | model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs)
999 | return model
1000 |
1001 |
1002 | @register_model
1003 | def vit_large_patch16_224(pretrained=False, **kwargs):
1004 | """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
1005 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
1006 | """
1007 | model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
1008 | model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
1009 | return model
1010 |
1011 |
1012 | @register_model
1013 | def vit_large_patch16_384(pretrained=False, **kwargs):
1014 | """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
1015 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
1016 | """
1017 | model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
1018 | model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
1019 | return model
1020 |
1021 |
1022 | @register_model
1023 | def vit_large_patch14_224(pretrained=False, **kwargs):
1024 | """ ViT-Large model (ViT-L/14)
1025 | """
1026 | model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, **kwargs)
1027 | model = _create_vision_transformer('vit_large_patch14_224', pretrained=pretrained, **model_kwargs)
1028 | return model
1029 |
1030 |
1031 | @register_model
1032 | def vit_huge_patch14_224(pretrained=False, **kwargs):
1033 | """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
1034 | """
1035 | model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs)
1036 | model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **model_kwargs)
1037 | return model
1038 |
1039 |
1040 | @register_model
1041 | def vit_giant_patch14_224(pretrained=False, **kwargs):
1042 | """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
1043 | """
1044 | model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, **kwargs)
1045 | model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **model_kwargs)
1046 | return model
1047 |
1048 |
1049 | @register_model
1050 | def vit_gigantic_patch14_224(pretrained=False, **kwargs):
1051 | """ ViT-Gigantic (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
1052 | """
1053 | model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs)
1054 | model = _create_vision_transformer('vit_gigantic_patch14_224', pretrained=pretrained, **model_kwargs)
1055 | return model
1056 |
1057 |
1058 | @register_model
1059 | def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
1060 | """ ViT-Tiny (Vit-Ti/16).
1061 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
1062 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
1063 | """
1064 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
1065 | model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
1066 | return model
1067 |
1068 |
1069 | @register_model
1070 | def vit_small_patch32_224_in21k(pretrained=False, **kwargs):
1071 | """ ViT-Small (ViT-S/16)
1072 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
1073 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
1074 | """
1075 | model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
1076 | model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
1077 | return model
1078 |
1079 |
1080 | @register_model
1081 | def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
1082 | """ ViT-Small (ViT-S/16)
1083 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
1084 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
1085 | """
1086 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
1087 | model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
1088 | return model
1089 |
1090 |
1091 | @register_model
1092 | def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
1093 | """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
1094 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
1095 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
1096 | """
1097 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
1098 | model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
1099 | return model
1100 |
1101 |
1102 | @register_model
1103 | def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
1104 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
1105 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
1106 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
1107 | """
1108 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
1109 | model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
1110 | return model
1111 |
1112 |
1113 | @register_model
1114 | def vit_base_patch8_224_in21k(pretrained=False, **kwargs):
1115 | """ ViT-Base model (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929).
1116 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
1117 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
1118 | """
1119 | model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
1120 | model = _create_vision_transformer('vit_base_patch8_224_in21k', pretrained=pretrained, **model_kwargs)
1121 | return model
1122 |
1123 |
1124 | @register_model
1125 | def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
1126 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
1127 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
1128 | NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
1129 | """
1130 | model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
1131 | model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
1132 | return model
1133 |
1134 |
1135 | @register_model
1136 | def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
1137 | """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
1138 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
1139 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
1140 | """
1141 | model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
1142 | model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
1143 | return model
1144 |
1145 |
1146 | @register_model
1147 | def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
1148 | """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
1149 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
1150 | NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
1151 | """
1152 | model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs)
1153 | model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
1154 | return model
1155 |
1156 |
1157 | @register_model
1158 | def vit_base_patch16_224_sam(pretrained=False, **kwargs):
1159 | """ ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
1160 | """
1161 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
1162 | model = _create_vision_transformer('vit_base_patch16_224_sam', pretrained=pretrained, **model_kwargs)
1163 | return model
1164 |
1165 |
1166 | @register_model
1167 | def vit_base_patch32_224_sam(pretrained=False, **kwargs):
1168 | """ ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
1169 | """
1170 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
1171 | model = _create_vision_transformer('vit_base_patch32_224_sam', pretrained=pretrained, **model_kwargs)
1172 | return model
1173 |
1174 |
1175 | @register_model
1176 | def vit_small_patch16_224_dino(pretrained=False, **kwargs):
1177 | """ ViT-Small (ViT-S/16) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294
1178 | """
1179 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
1180 | model = _create_vision_transformer('vit_small_patch16_224_dino', pretrained=pretrained, **model_kwargs)
1181 | return model
1182 |
1183 |
1184 | @register_model
1185 | def vit_small_patch8_224_dino(pretrained=False, **kwargs):
1186 | """ ViT-Small (ViT-S/8) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294
1187 | """
1188 | model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs)
1189 | model = _create_vision_transformer('vit_small_patch8_224_dino', pretrained=pretrained, **model_kwargs)
1190 | return model
1191 |
1192 |
1193 | @register_model
1194 | def vit_base_patch16_224_dino(pretrained=False, **kwargs):
1195 | """ ViT-Base (ViT-B/16) /w DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294
1196 | """
1197 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
1198 | model = _create_vision_transformer('vit_base_patch16_224_dino', pretrained=pretrained, **model_kwargs)
1199 | return model
1200 |
1201 |
1202 | @register_model
1203 | def vit_base_patch8_224_dino(pretrained=False, **kwargs):
1204 | """ ViT-Base (ViT-B/8) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294
1205 | """
1206 | model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
1207 | model = _create_vision_transformer('vit_base_patch8_224_dino', pretrained=pretrained, **model_kwargs)
1208 | return model
1209 |
1210 |
1211 | @register_model
1212 | def vit_base_patch16_224_miil_in21k(pretrained=False, **kwargs):
1213 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
1214 | Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
1215 | """
1216 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
1217 | model = _create_vision_transformer('vit_base_patch16_224_miil_in21k', pretrained=pretrained, **model_kwargs)
1218 | return model
1219 |
1220 |
1221 | @register_model
1222 | def vit_base_patch16_224_miil(pretrained=False, **kwargs):
1223 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
1224 | Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
1225 | """
1226 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
1227 | model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs)
1228 | return model
1229 |
1230 |
1231 | # Experimental models below
1232 |
1233 | @register_model
1234 | def vit_base_patch32_plus_256(pretrained=False, **kwargs):
1235 | """ ViT-Base (ViT-B/32+)
1236 | """
1237 | model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs)
1238 | model = _create_vision_transformer('vit_base_patch32_plus_256', pretrained=pretrained, **model_kwargs)
1239 | return model
1240 |
1241 |
1242 | @register_model
1243 | def vit_base_patch16_plus_240(pretrained=False, **kwargs):
1244 | """ ViT-Base (ViT-B/16+)
1245 | """
1246 | model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs)
1247 | model = _create_vision_transformer('vit_base_patch16_plus_240', pretrained=pretrained, **model_kwargs)
1248 | return model
1249 |
1250 |
1251 | @register_model
1252 | def vit_base_patch16_rpn_224(pretrained=False, **kwargs):
1253 | """ ViT-Base (ViT-B/16) w/ residual post-norm
1254 | """
1255 | model_kwargs = dict(
1256 | patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, class_token=False,
1257 | block_fn=ResPostBlock, global_pool=kwargs.pop('global_pool', 'avg'), **kwargs)
1258 | model = _create_vision_transformer('vit_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs)
1259 | return model
1260 |
1261 |
1262 | @register_model
1263 | def vit_small_patch16_36x1_224(pretrained=False, **kwargs):
1264 | """ ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove.
1265 | Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
1266 | Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
1267 | """
1268 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5, **kwargs)
1269 | model = _create_vision_transformer('vit_small_patch16_36x1_224', pretrained=pretrained, **model_kwargs)
1270 | return model
1271 |
1272 |
1273 | @register_model
1274 | def vit_small_patch16_18x2_224(pretrained=False, **kwargs):
1275 | """ ViT-Small w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
1276 | Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
1277 | Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
1278 | """
1279 | model_kwargs = dict(
1280 | patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock, **kwargs)
1281 | model = _create_vision_transformer('vit_small_patch16_18x2_224', pretrained=pretrained, **model_kwargs)
1282 | return model
1283 |
1284 |
1285 | @register_model
1286 | def vit_base_patch16_18x2_224(pretrained=False, **kwargs):
1287 | """ ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
1288 | Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
1289 | """
1290 | model_kwargs = dict(
1291 | patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs)
1292 | model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs)
1293 | return model
1294 |
1295 |
1296 | @register_model
1297 | def vit_base_patch32_224_clip_laion2b(pretrained=False, **kwargs):
1298 | """ ViT-B/32
1299 | Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
1300 | """
1301 | model_kwargs = dict(
1302 | patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
1303 | model = _create_vision_transformer('vit_base_patch32_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
1304 | return model
1305 |
1306 |
1307 | @register_model
1308 | def vit_large_patch14_224_clip_laion2b(pretrained=False, **kwargs):
1309 | """ ViT-Large model (ViT-L/14)
1310 | Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
1311 | """
1312 | model_kwargs = dict(
1313 | patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
1314 | model = _create_vision_transformer('vit_large_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
1315 | return model
1316 |
1317 |
1318 | @register_model
1319 | def vit_huge_patch14_224_clip_laion2b(pretrained=False, **kwargs):
1320 | """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
1321 | Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
1322 | """
1323 | model_kwargs = dict(
1324 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
1325 | model = _create_vision_transformer('vit_huge_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
1326 | return model
1327 |
1328 |
1329 | @register_model
1330 | def vit_giant_patch14_224_clip_laion2b(pretrained=False, **kwargs):
1331 | """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
1332 | Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
1333 | """
1334 | model_kwargs = dict(
1335 | patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16,
1336 | pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
1337 | model = _create_vision_transformer('vit_giant_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
1338 | return model
1339 |
--------------------------------------------------------------------------------