├── .gitignore ├── README.md ├── clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── clip.py ├── model.py └── simple_tokenizer.py ├── data ├── 10k.txt ├── 20k.txt ├── 3k.txt ├── CLIP_Dissect_results │ ├── resnet18_places_imagenet_broden.csv │ └── resnet50_imagenet_broden.csv ├── MILAN_results │ ├── m_base_resnet18_places365.csv │ ├── m_base_resnet50_imagenet.csv │ ├── m_imagenet_resnet18_places365.csv │ └── m_places365_resnet50_imagenet.csv ├── NetDissect_results │ ├── resnet18_places365_fc.csv │ ├── resnet18_places365_layer1.csv │ ├── resnet18_places365_layer4.csv │ ├── resnet50_imagenet_fc.csv │ ├── resnet50_imagenet_layer1.csv │ ├── resnet50_imagenet_layer2.csv │ ├── resnet50_imagenet_layer3.csv │ └── resnet50_imagenet_layer4.csv ├── broden_labels_clean.txt ├── categories_places365.txt ├── github_overview_figure.png └── imagenet_labels.txt ├── data_utils.py ├── describe_neurons.py ├── dlbroden.sh ├── dlzoo_example.sh ├── experiments ├── appendix_a6_predict_class_from_desc.ipynb ├── fig10_compositional.ipynb ├── fig11_vit_qualitative.ipynb ├── fig12_13_larger_range_images.ipynb ├── fig14_similarity_comp_qual.ipynb ├── fig3_quantitative_example.ipynb ├── fig4_detect_missing_concept.ipynb ├── fig5_use_case.ipynb ├── fig8_low_level_comparison.ipynb ├── fig_1_6_7_9_qualitative_comparison.ipynb ├── table1_quantitative_rn50.ipynb ├── table2_quantitative_rn18.ipynb ├── table3_similarity_comparison.ipynb └── text_colorings.py ├── qualitative.ipynb ├── requirements.txt ├── similarity.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | *.egg-info 5 | .pytest_cache 6 | .ipynb_checkpoints 7 | 8 | thumbs.db 9 | .DS_Store 10 | .idea 11 | saved_activations/ 12 | data/broden1_224 13 | data/resnet18_places365.pth.tar 14 | results/ 15 | *.zip 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## CLIP-Dissect 2 | 3 | An automatic and efficient tool to describe functionalities of individual neurons in DNNs. 4 | 5 | This is the official repository for our paper: [CLIP-Dissect: Automatic Description of Neuron Representations in Deep Vision Networks](https://arxiv.org/abs/2204.10965) published at ICLR 2023. 6 | 7 | **Update 6/5/23**: We have conducted a crowdsourced evaluation of our description quality, results are available on [arxiv](https://arxiv.org/abs/2204.10965) (Appendix B). 8 | 9 | ![Overview](data/github_overview_figure.png) 10 | 11 | ## Installation 12 | 13 | 1. Install Python (3.10) 14 | 1. Install Pytorch (tested with 1.12.0, also works with 2.0) and Torchvision >= 0.13 following instructions from https://pytorch.org/get-started/previous-versions/ 15 | 3. Install remaining requirements using `pip install -r requirements.txt` 16 | 4. Download the Broden dataset (images only) using `bash dlbroden.sh` 17 | 5. (Optional) Download ResNet-18 pretrained on Places-365: `bash dlzoo_example.sh` 18 | 19 | We do not provide download instructions for ImageNet data, to evaluate using your own copy of ImageNet validation set you must set 20 | the correct path in `DATASET_ROOTS["imagenet_val"]` variable in `data_utils.py`. 21 | 22 | ## Quickstart: 23 | 24 | This will dissect 5 layers of ResNet-50(ImageNet) using Broden as the probing dataset. Results will be saved in `results/resnet50_{datetime}/descriptions.csv`. 25 | 26 | ``` 27 | python describe_neurons.py 28 | ``` 29 | 30 | ## Recreating experiments 31 | 32 | The results used for figures and tables of our paper can be recreated by running the corresponding notebook in the `experiments` folder, for example to reproduce Table 1 run `experiments/table1.ipynb`. 33 | 34 | ## How to modify: 35 | 36 | ### Dissecting your own model 37 | 38 | 1. Implement the code to load your model(in eval mode) and a preprocess function to correctly load images for your model in `get_target_model` function of `data_utils.py` under an if statement for target_name of you choice. 39 | 2. Dissect the model by running `python describe_neurons.py --target_model {model_name}` 40 | 41 | ### Using your own probing dataset 42 | 43 | 1. Implement code to load your dataset as a torchvision DataSet uin the `get_data` function of `dataset_utils.py` 44 | 2. Add your dataset name into the choices of `--d_probe` argument in `describe_neurons.py` 45 | 3. Dissect the model by running `python describe_neurons.py --d_probe {dataset_name}` 46 | 47 | ### Using your own concept set 48 | 49 | 1. Create/download a .txt file containing your concept set, which each concept on a separate line 50 | 2. Dissect the model by running `python describe_neurons.py --concept_set {path_to_conceptset}` 51 | 52 | ### Specifying device 53 | 54 | You can specify which device is used with the `--device` argument, which defaults to `cuda`, i.e. `python describe_neurons.py --device cpu` 55 | 56 | ## Sources: 57 | 58 | - CLIP: https://github.com/openai/CLIP 59 | - Text datasets(10k and 20k): https://github.com/first20hours/google-10000-english 60 | - Text dataset(3k): https://www.ef.edu/english-resources/english-vocabulary/top-3000-words/ 61 | - Broden download script based on: https://github.com/CSAILVision/NetDissect-Lite 62 | 63 | ## Common errors 64 | 65 | **Incorrect activations cached:** 66 | 67 | The code automatically caches the saved activations of target model and CLIP in `saved_activations`, and if a file already exists with the same save name the code will load these activations instead of recalculating. However sometimes you may wish to modify the pipeline in a way that doesn't change the name of the saved activations and want to recalculate the activations. In this case you need to manually delete the relevant files from saved_activations before rerunning CLIP-Dissect, as using incorrect activations will give incorrect results. 68 | 69 | ## Cite this work 70 | 71 | T. Oikarinen and T.-W. Weng, CLIP-Dissect: Automatic Description of Neuron Representations in Deep Vision Networks, ICLR 2023. 72 | 73 | ``` 74 | @article{oikarinen2023clip, 75 | title={CLIP-Dissect: Automatic Description of Neuron Representations in Deep Vision Networks}, 76 | author={Oikarinen, Tuomas and Weng, Tsui-Wei}, 77 | journal={International Conference on Learning Representations}, 78 | year={2023} 79 | } 80 | ``` 81 | -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trustworthy-ML-Lab/CLIP-dissect/21e7697feaea3bf7d7bc2d2cc8e4047d5f5fd502/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 39 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 40 | } 41 | 42 | 43 | def _download(url: str, root: str): 44 | os.makedirs(root, exist_ok=True) 45 | filename = os.path.basename(url) 46 | 47 | expected_sha256 = url.split("/")[-2] 48 | download_target = os.path.join(root, filename) 49 | 50 | if os.path.exists(download_target) and not os.path.isfile(download_target): 51 | raise RuntimeError(f"{download_target} exists and is not a regular file") 52 | 53 | if os.path.isfile(download_target): 54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 55 | return download_target 56 | else: 57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 58 | 59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 61 | while True: 62 | buffer = source.read(8192) 63 | if not buffer: 64 | break 65 | 66 | output.write(buffer) 67 | loop.update(len(buffer)) 68 | 69 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 70 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 71 | 72 | return download_target 73 | 74 | 75 | def _convert_image_to_rgb(image): 76 | return image.convert("RGB") 77 | 78 | 79 | def _transform(n_px): 80 | return Compose([ 81 | Resize(n_px, interpolation=BICUBIC), 82 | CenterCrop(n_px), 83 | _convert_image_to_rgb, 84 | ToTensor(), 85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 86 | ]) 87 | 88 | 89 | def available_models() -> List[str]: 90 | """Returns the names of available CLIP models""" 91 | return list(_MODELS.keys()) 92 | 93 | 94 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 95 | """Load a CLIP model 96 | 97 | Parameters 98 | ---------- 99 | name : str 100 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 101 | 102 | device : Union[str, torch.device] 103 | The device to put the loaded model 104 | 105 | jit : bool 106 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 107 | 108 | download_root: str 109 | path to download the model files; by default, it uses "~/.cache/clip" 110 | 111 | Returns 112 | ------- 113 | model : torch.nn.Module 114 | The CLIP model 115 | 116 | preprocess : Callable[[PIL.Image], torch.Tensor] 117 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 118 | """ 119 | if name in _MODELS: 120 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 121 | elif os.path.isfile(name): 122 | model_path = name 123 | else: 124 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 125 | 126 | try: 127 | # loading JIT archive 128 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 129 | state_dict = None 130 | except RuntimeError: 131 | # loading saved state dict 132 | if jit: 133 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 134 | jit = False 135 | state_dict = torch.load(model_path, map_location="cpu") 136 | 137 | if not jit: 138 | model = build_model(state_dict or model.state_dict()).to(device) 139 | if str(device) == "cpu": 140 | model.float() 141 | return model, _transform(model.visual.input_resolution) 142 | 143 | # patch the device names 144 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 145 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 146 | 147 | def patch_device(module): 148 | try: 149 | graphs = [module.graph] if hasattr(module, "graph") else [] 150 | except RuntimeError: 151 | graphs = [] 152 | 153 | if hasattr(module, "forward1"): 154 | graphs.append(module.forward1.graph) 155 | 156 | for graph in graphs: 157 | for node in graph.findAllNodes("prim::Constant"): 158 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 159 | node.copyAttributes(device_node) 160 | 161 | model.apply(patch_device) 162 | patch_device(model.encode_image) 163 | patch_device(model.encode_text) 164 | 165 | # patch dtype to float32 on CPU 166 | if str(device) == "cpu": 167 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 168 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 169 | float_node = float_input.node() 170 | 171 | def patch_float(module): 172 | try: 173 | graphs = [module.graph] if hasattr(module, "graph") else [] 174 | except RuntimeError: 175 | graphs = [] 176 | 177 | if hasattr(module, "forward1"): 178 | graphs.append(module.forward1.graph) 179 | 180 | for graph in graphs: 181 | for node in graph.findAllNodes("aten::to"): 182 | inputs = list(node.inputs()) 183 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 184 | if inputs[i].node()["value"] == 5: 185 | inputs[i].node().copyAttributes(float_node) 186 | 187 | model.apply(patch_float) 188 | patch_float(model.encode_image) 189 | patch_float(model.encode_text) 190 | 191 | model.float() 192 | 193 | return model, _transform(model.input_resolution.item()) 194 | 195 | 196 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 197 | """ 198 | Returns the tokenized representation of given input string(s) 199 | 200 | Parameters 201 | ---------- 202 | texts : Union[str, List[str]] 203 | An input string or a list of input strings to tokenize 204 | 205 | context_length : int 206 | The context length to use; all CLIP models use 77 as the context length 207 | 208 | truncate: bool 209 | Whether to truncate the text in case its encoding is longer than the context length 210 | 211 | Returns 212 | ------- 213 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 214 | """ 215 | if isinstance(texts, str): 216 | texts = [texts] 217 | 218 | sot_token = _tokenizer.encoder["<|startoftext|>"] 219 | eot_token = _tokenizer.encoder["<|endoftext|>"] 220 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 221 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 222 | 223 | for i, tokens in enumerate(all_tokens): 224 | if len(tokens) > context_length: 225 | if truncate: 226 | tokens = tokens[:context_length] 227 | tokens[-1] = eot_token 228 | else: 229 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 230 | result[i, :len(tokens)] = torch.tensor(tokens) 231 | 232 | return result 233 | -------------------------------------------------------------------------------- /clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | 20 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 24 | 25 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 26 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 27 | 28 | self.relu = nn.ReLU(inplace=True) 29 | self.downsample = None 30 | self.stride = stride 31 | 32 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 33 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 34 | self.downsample = nn.Sequential(OrderedDict([ 35 | ("-1", nn.AvgPool2d(stride)), 36 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 37 | ("1", nn.BatchNorm2d(planes * self.expansion)) 38 | ])) 39 | 40 | def forward(self, x: torch.Tensor): 41 | identity = x 42 | 43 | out = self.relu(self.bn1(self.conv1(x))) 44 | out = self.relu(self.bn2(self.conv2(out))) 45 | out = self.avgpool(out) 46 | out = self.bn3(self.conv3(out)) 47 | 48 | if self.downsample is not None: 49 | identity = self.downsample(x) 50 | 51 | out += identity 52 | out = self.relu(out) 53 | return out 54 | 55 | 56 | class AttentionPool2d(nn.Module): 57 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 58 | super().__init__() 59 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 60 | self.k_proj = nn.Linear(embed_dim, embed_dim) 61 | self.q_proj = nn.Linear(embed_dim, embed_dim) 62 | self.v_proj = nn.Linear(embed_dim, embed_dim) 63 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 64 | self.num_heads = num_heads 65 | 66 | def forward(self, x): 67 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 68 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 69 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 70 | x, _ = F.multi_head_attention_forward( 71 | query=x, key=x, value=x, 72 | embed_dim_to_check=x.shape[-1], 73 | num_heads=self.num_heads, 74 | q_proj_weight=self.q_proj.weight, 75 | k_proj_weight=self.k_proj.weight, 76 | v_proj_weight=self.v_proj.weight, 77 | in_proj_weight=None, 78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 79 | bias_k=None, 80 | bias_v=None, 81 | add_zero_attn=False, 82 | dropout_p=0, 83 | out_proj_weight=self.c_proj.weight, 84 | out_proj_bias=self.c_proj.bias, 85 | use_separate_proj_weight=True, 86 | training=self.training, 87 | need_weights=False 88 | ) 89 | 90 | return x[0] 91 | 92 | 93 | class ModifiedResNet(nn.Module): 94 | """ 95 | A ResNet class that is similar to torchvision's but contains the following changes: 96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 98 | - The final pooling layer is a QKV attention instead of an average pool 99 | """ 100 | 101 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 102 | super().__init__() 103 | self.output_dim = output_dim 104 | self.input_resolution = input_resolution 105 | 106 | # the 3-layer stem 107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 108 | self.bn1 = nn.BatchNorm2d(width // 2) 109 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 110 | self.bn2 = nn.BatchNorm2d(width // 2) 111 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 112 | self.bn3 = nn.BatchNorm2d(width) 113 | self.avgpool = nn.AvgPool2d(2) 114 | self.relu = nn.ReLU(inplace=True) 115 | 116 | # residual layers 117 | self._inplanes = width # this is a *mutable* variable used during construction 118 | self.layer1 = self._make_layer(width, layers[0]) 119 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 120 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 121 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 122 | 123 | embed_dim = width * 32 # the ResNet feature dimension 124 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 125 | 126 | def _make_layer(self, planes, blocks, stride=1): 127 | layers = [Bottleneck(self._inplanes, planes, stride)] 128 | 129 | self._inplanes = planes * Bottleneck.expansion 130 | for _ in range(1, blocks): 131 | layers.append(Bottleneck(self._inplanes, planes)) 132 | 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | def stem(x): 137 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 138 | x = self.relu(bn(conv(x))) 139 | x = self.avgpool(x) 140 | return x 141 | 142 | x = x.type(self.conv1.weight.dtype) 143 | x = stem(x) 144 | x = self.layer1(x) 145 | x = self.layer2(x) 146 | x = self.layer3(x) 147 | x = self.layer4(x) 148 | x = self.attnpool(x) 149 | 150 | return x 151 | 152 | 153 | class LayerNorm(nn.LayerNorm): 154 | """Subclass torch's LayerNorm to handle fp16.""" 155 | 156 | def forward(self, x: torch.Tensor): 157 | orig_type = x.dtype 158 | ret = super().forward(x.type(torch.float32)) 159 | return ret.type(orig_type) 160 | 161 | 162 | class QuickGELU(nn.Module): 163 | def forward(self, x: torch.Tensor): 164 | return x * torch.sigmoid(1.702 * x) 165 | 166 | 167 | class ResidualAttentionBlock(nn.Module): 168 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 169 | super().__init__() 170 | 171 | self.attn = nn.MultiheadAttention(d_model, n_head) 172 | self.ln_1 = LayerNorm(d_model) 173 | self.mlp = nn.Sequential(OrderedDict([ 174 | ("c_fc", nn.Linear(d_model, d_model * 4)), 175 | ("gelu", QuickGELU()), 176 | ("c_proj", nn.Linear(d_model * 4, d_model)) 177 | ])) 178 | self.ln_2 = LayerNorm(d_model) 179 | self.attn_mask = attn_mask 180 | 181 | def attention(self, x: torch.Tensor): 182 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 183 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 184 | 185 | def forward(self, x: torch.Tensor): 186 | x = x + self.attention(self.ln_1(x)) 187 | x = x + self.mlp(self.ln_2(x)) 188 | return x 189 | 190 | 191 | class Transformer(nn.Module): 192 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 193 | super().__init__() 194 | self.width = width 195 | self.layers = layers 196 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 197 | 198 | def forward(self, x: torch.Tensor): 199 | return self.resblocks(x) 200 | 201 | 202 | class VisionTransformer(nn.Module): 203 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 204 | super().__init__() 205 | self.input_resolution = input_resolution 206 | self.output_dim = output_dim 207 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 208 | 209 | scale = width ** -0.5 210 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 211 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 212 | self.ln_pre = LayerNorm(width) 213 | 214 | self.transformer = Transformer(width, layers, heads) 215 | 216 | self.ln_post = LayerNorm(width) 217 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 218 | 219 | def forward(self, x: torch.Tensor): 220 | x = self.conv1(x) # shape = [*, width, grid, grid] 221 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 222 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 223 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 224 | x = x + self.positional_embedding.to(x.dtype) 225 | x = self.ln_pre(x) 226 | 227 | x = x.permute(1, 0, 2) # NLD -> LND 228 | x = self.transformer(x) 229 | x = x.permute(1, 0, 2) # LND -> NLD 230 | 231 | x = self.ln_post(x[:, 0, :]) 232 | 233 | if self.proj is not None: 234 | x = x @ self.proj 235 | 236 | return x 237 | 238 | 239 | class CLIP(nn.Module): 240 | def __init__(self, 241 | embed_dim: int, 242 | # vision 243 | image_resolution: int, 244 | vision_layers: Union[Tuple[int, int, int, int], int], 245 | vision_width: int, 246 | vision_patch_size: int, 247 | # text 248 | context_length: int, 249 | vocab_size: int, 250 | transformer_width: int, 251 | transformer_heads: int, 252 | transformer_layers: int 253 | ): 254 | super().__init__() 255 | 256 | self.context_length = context_length 257 | 258 | if isinstance(vision_layers, (tuple, list)): 259 | vision_heads = vision_width * 32 // 64 260 | self.visual = ModifiedResNet( 261 | layers=vision_layers, 262 | output_dim=embed_dim, 263 | heads=vision_heads, 264 | input_resolution=image_resolution, 265 | width=vision_width 266 | ) 267 | else: 268 | vision_heads = vision_width // 64 269 | self.visual = VisionTransformer( 270 | input_resolution=image_resolution, 271 | patch_size=vision_patch_size, 272 | width=vision_width, 273 | layers=vision_layers, 274 | heads=vision_heads, 275 | output_dim=embed_dim 276 | ) 277 | 278 | self.transformer = Transformer( 279 | width=transformer_width, 280 | layers=transformer_layers, 281 | heads=transformer_heads, 282 | attn_mask=self.build_attention_mask() 283 | ) 284 | 285 | self.vocab_size = vocab_size 286 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 287 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 288 | self.ln_final = LayerNorm(transformer_width) 289 | 290 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 291 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 292 | 293 | self.initialize_parameters() 294 | 295 | def initialize_parameters(self): 296 | nn.init.normal_(self.token_embedding.weight, std=0.02) 297 | nn.init.normal_(self.positional_embedding, std=0.01) 298 | 299 | if isinstance(self.visual, ModifiedResNet): 300 | if self.visual.attnpool is not None: 301 | std = self.visual.attnpool.c_proj.in_features ** -0.5 302 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 303 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 304 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 305 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 306 | 307 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 308 | for name, param in resnet_block.named_parameters(): 309 | if name.endswith("bn3.weight"): 310 | nn.init.zeros_(param) 311 | 312 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 313 | attn_std = self.transformer.width ** -0.5 314 | fc_std = (2 * self.transformer.width) ** -0.5 315 | for block in self.transformer.resblocks: 316 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 317 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 318 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 319 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 320 | 321 | if self.text_projection is not None: 322 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 323 | 324 | def build_attention_mask(self): 325 | # lazily create causal attention mask, with full attention between the vision tokens 326 | # pytorch uses additive attention mask; fill with -inf 327 | mask = torch.empty(self.context_length, self.context_length) 328 | mask.fill_(float("-inf")) 329 | mask.triu_(1) # zero out the lower diagonal 330 | return mask 331 | 332 | @property 333 | def dtype(self): 334 | return self.visual.conv1.weight.dtype 335 | 336 | def encode_image(self, image): 337 | return self.visual(image.type(self.dtype)) 338 | 339 | def encode_text(self, text): 340 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 341 | 342 | x = x + self.positional_embedding.type(self.dtype) 343 | x = x.permute(1, 0, 2) # NLD -> LND 344 | x = self.transformer(x) 345 | x = x.permute(1, 0, 2) # LND -> NLD 346 | x = self.ln_final(x).type(self.dtype) 347 | 348 | # x.shape = [batch_size, n_ctx, transformer.width] 349 | # take features from the eot embedding (eot_token is the highest number in each sequence) 350 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 351 | 352 | return x 353 | 354 | def encode_text_embed(self, emb, last_token): 355 | #x = text.type(self.dtype) # [batch_size, n_ctx, d_model] 356 | x = emb 357 | x = x + self.positional_embedding.type(self.dtype) 358 | x = x.permute(1, 0, 2) # NLD -> LND 359 | x = self.transformer(x) 360 | x = x.permute(1, 0, 2) # LND -> NLD 361 | x = self.ln_final(x).type(self.dtype) 362 | 363 | # x.shape = [batch_size, n_ctx, transformer.width] 364 | # take features from the eot embedding (eot_token is the highest number in each sequence) 365 | x = x[torch.arange(x.shape[0]), last_token] @ self.text_projection 366 | 367 | return x 368 | 369 | def forward(self, image, text): 370 | image_features = self.encode_image(image) 371 | text_features = self.encode_text(text) 372 | 373 | # normalized features 374 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 375 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 376 | 377 | # cosine similarity as logits 378 | logit_scale = self.logit_scale.exp() 379 | logits_per_image = logit_scale * image_features @ text_features.t() 380 | logits_per_text = logits_per_image.t() 381 | 382 | # shape = [global_batch_size, global_batch_size] 383 | return logits_per_image, logits_per_text 384 | 385 | 386 | def convert_weights(model: nn.Module): 387 | """Convert applicable model parameters to fp16""" 388 | 389 | def _convert_weights_to_fp16(l): 390 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 391 | l.weight.data = l.weight.data.half() 392 | if l.bias is not None: 393 | l.bias.data = l.bias.data.half() 394 | 395 | if isinstance(l, nn.MultiheadAttention): 396 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 397 | tensor = getattr(l, attr) 398 | if tensor is not None: 399 | tensor.data = tensor.data.half() 400 | 401 | for name in ["text_projection", "proj"]: 402 | if hasattr(l, name): 403 | attr = getattr(l, name) 404 | if attr is not None: 405 | attr.data = attr.data.half() 406 | 407 | model.apply(_convert_weights_to_fp16) 408 | 409 | 410 | def build_model(state_dict: dict): 411 | vit = "visual.proj" in state_dict 412 | 413 | if vit: 414 | vision_width = state_dict["visual.conv1.weight"].shape[0] 415 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 416 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 417 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 418 | image_resolution = vision_patch_size * grid_size 419 | else: 420 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 421 | vision_layers = tuple(counts) 422 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 423 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 424 | vision_patch_size = None 425 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 426 | image_resolution = output_width * 32 427 | 428 | embed_dim = state_dict["text_projection"].shape[1] 429 | context_length = state_dict["positional_embedding"].shape[0] 430 | vocab_size = state_dict["token_embedding.weight"].shape[0] 431 | transformer_width = state_dict["ln_final.weight"].shape[0] 432 | transformer_heads = transformer_width // 64 433 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 434 | 435 | model = CLIP( 436 | embed_dim, 437 | image_resolution, vision_layers, vision_width, vision_patch_size, 438 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 439 | ) 440 | 441 | for key in ["input_resolution", "context_length", "vocab_size"]: 442 | if key in state_dict: 443 | del state_dict[key] 444 | 445 | convert_weights(model) 446 | model.load_state_dict(state_dict) 447 | return model.eval() 448 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /data/NetDissect_results/resnet18_places365_layer1.csv: -------------------------------------------------------------------------------- 1 | unit,category,label,score,color-label,color-truth,color-activation,color-intersect,color-iou,object-label,object-truth,object-activation,object-intersect,object-iou,part-label,part-truth,part-activation,part-intersect,part-iou,material-label,material-truth,material-activation,material-intersect,material-iou,scene-label,scene-truth,scene-activation,scene-intersect,scene-iou,texture-label,texture-truth,texture-activation,texture-intersect,texture-iou 2 | 1,color,green-c,0.010368466,green-c,63326720,3349727,684238,0.010368466,plant,3760075,1619175,51047,0.009580528,neck,812426,603813,6345,0.004500338,food,1104928,938351,12475,0.006142887,bar-s,765184,837481,8603,0.005396904,dotted,2935296,345067,18848,0.005778909 3 | 2,texture,perforated,0.005495854,black-c,139916835,1976000,485463,0.003433081,person,14523849,982899,56723,0.003671386,leg,2786636,375934,10428,0.003308227,food,1104928,445181,3449,0.002229967,street-s,28111104,524056,72912,0.00255274,perforated,1793792,279092,11330,0.005495854 4 | 3,texture,dotted,0.025160554,green-c,63326720,3756979,832819,0.012570686,plant,3760075,2158711,87221,0.014956705,arm,2299112,887633,20966,0.006622699,food,1104928,683782,11801,0.006641308,ball_pit-s,275968,1031383,18947,0.014705791,dotted,2935296,656354,88150,0.025160554 5 | 4,texture,studded,0.025190781,black-c,139916835,2890212,686535,0.004830654,car,5354795,1469155,42162,0.006216945,wheel,1513439,623160,13735,0.006470033,glass,2246777,661086,10218,0.003526312,casino-indoor-s,652288,792984,7762,0.005399615,studded,1593088,393824,48822,0.025190781 6 | 5,texture,chequered,0.014940479,red-c,21004512,2621088,128077,0.00545066,wall,48917895,1234181,264799,0.005307947,arm,2299112,418247,16030,0.005934116,painted,13470129,675596,72553,0.005155412,conference_room-s,2308096,682503,11265,0.003781046,chequered,1530368,295787,26882,0.014940479 7 | 6,texture,striped,0.025505639,white-c,72511065,3105929,1028883,0.013794196,windowpane,8841352,1546091,115456,0.01123989,column,193649,503979,6292,0.009101219,glass,2246777,503247,15269,0.005583316,building_facade-s,3148544,921971,30570,0.007566935,striped,2433536,480950,72487,0.025505639 8 | 7,texture,zigzagged,0.036835752,white-c,72511065,3022611,469344,0.006252557,motorbike,1440771,1335569,28050,0.010206346,wheel,1513439,612148,30047,0.014338548,paper,471596,573279,4241,0.0040754,greenhouse-indoor-s,363776,645482,3891,0.003870228,zigzagged,1618176,809324,86242,0.036835752 9 | 8,color,red-c,0.033144792,red-c,21004512,3130862,774298,0.033144792,flower,591549,1515261,28243,0.013587727,hand,398768,558264,10706,0.011313226,food,1104928,647924,8766,0.005026128,ball_pit-s,275968,808396,8882,0.008258623,dotted,2935296,632191,60945,0.017380371 10 | 9,color,red-c,0.056066255,red-c,21004512,3319748,1291368,0.056066255,person,14523849,1436303,160294,0.010145281,hand,398768,635020,9909,0.009677901,food,1104928,855065,14277,0.007337659,ball_pit-s,275968,723569,8523,0.008600282,dotted,2935296,642054,77020,0.02200364 11 | 10,scene,ball_pit-s,0.017725015,blue-c,61494563,3381326,908270,0.014198903,bus,1628191,1747892,49524,0.014887456,body,1291612,727296,16596,0.008288419,fabric,8548746,823590,44738,0.004796304,ball_pit-s,275968,888630,20283,0.017725015,zigzagged,1618176,468628,36172,0.01763944 12 | 11,texture,perforated,0.011919323,white-c,72511065,1981672,316594,0.004268138,car,5354795,993411,31070,0.004918368,wheel,1513439,438887,10733,0.005527935,paper,471596,420351,2325,0.00261347,shoe_shop-s,577024,522239,3198,0.00291771,perforated,1793792,345622,25200,0.011919323 13 | 12,color,orange-c,0.019005461,orange-c,30348501,3384667,629157,0.019005461,person,14523849,1583712,81978,0.005115446,hand,398768,593915,5228,0.005294418,food,1104928,907671,18525,0.009290026,ball_pit-s,275968,896801,11748,0.01011868,sprinkled,1718528,477233,24454,0.011262341 14 | 13,texture,dotted,0.013120715,white-c,72511065,1988136,264926,0.003568783,painting,2207098,970667,15535,0.004912672,wheel,1513439,440818,7306,0.003752534,glass,2246777,517916,10231,0.003714337,street-s,28111104,557056,85453,0.002989675,dotted,2935296,240649,41131,0.013120715 15 | 14,texture,chequered,0.031014157,white-c,72511065,3521719,835316,0.0111083,car,5354795,1621025,56029,0.008096921,screen,677188,695569,9457,0.006936844,glass,2246777,874887,22452,0.007244422,cockpit-s,539392,771582,6113,0.00468479,chequered,1530368,487403,60697,0.031014157 16 | 15,color,white-c,0.037134467,white-c,72511065,3384445,2717429,0.037134467,sky,36650987,1530364,510257,0.013545054,shade,197478,386818,11392,0.019884658,glass,2246777,521487,46300,0.017009777,ball_pit-s,275968,796929,8907,0.008371319,grid,1682240,627182,60339,0.026828267 17 | 16,color,blue-c,0.016141138,blue-c,61494563,2919144,1023195,0.016141138,windowpane,8841352,1605613,71916,0.00693163,screen,677188,469309,5568,0.004880234,glass,2246777,493001,14136,0.005186301,ball_pit-s,275968,869920,7976,0.007009329,dotted,2935296,455668,36468,0.01087138 18 | 17,texture,chequered,0.034109456,black-c,139916835,3062366,1552992,0.010980935,person,14523849,1119843,111167,0.007157046,body,1291612,509124,13449,0.007524813,leather,470954,849464,6591,0.005016642,conference_room-s,2308096,521064,13494,0.004792472,chequered,1530368,524999,67795,0.034109456 19 | 18,color,black-c,0.019669133,black-c,139916835,3293428,2762486,0.019669133,car,5354795,1425912,59136,0.008797943,wheel,1513439,685827,27192,0.01251891,glass,2246777,725063,8917,0.003009528,building_facade-s,3148544,696237,16267,0.004248907,chequered,1530368,458986,31538,0.016108766 20 | 19,texture,chequered,0.028129665,white-c,72511065,3003629,509839,0.006797413,windowpane,8841352,1375331,72948,0.007191434,leg,2786636,637905,19777,0.005808626,glass,2246777,821516,15506,0.005079293,living_room-s,9621248,754700,42807,0.00414269,chequered,1530368,386223,52438,0.028129665 21 | 20,texture,perforated,0.00949374,black-c,139916835,1689558,427993,0.003031576,building,35246027,784515,109980,0.003061756,wheel,1513439,315748,6121,0.003357531,metal,3462561,405650,9084,0.0023539,building_facade-s,3148544,426555,9223,0.002586461,perforated,1793792,250558,19226,0.00949374 22 | 21,color,pink-c,0.023997571,pink-c,18491188,3217283,508742,0.023997571,person,14523849,1757194,203230,0.012640401,hand,398768,760223,11905,0.010378472,fabric,8548746,600772,42983,0.004720017,ball_pit-s,275968,885716,9590,0.008323974,dotted,2935296,598484,64829,0.018688359 23 | 22,part,head,0.00713357,white-c,72511065,2654895,371518,0.004967187,floor,20378155,1233337,118448,0.005510992,head,4946761,516993,38700,0.00713357,metal,3462561,731179,18308,0.004384696,airport_terminal-s,1480192,643643,7876,0.003722189,frilly,1618176,304721,11191,0.005853934 24 | 23,texture,veined,0.055712737,green-c,63326720,3582980,2473359,0.038384535,grass,12961339,2123498,625295,0.043244454,body,1291612,575322,10811,0.005824506,food,1104928,352940,17491,0.012143349,park-s,1179136,958966,33987,0.016152634,veined,1806336,857507,140578,0.055712737 25 | 24,color,blue-c,0.048437225,blue-c,61494563,3716449,3012713,0.048437225,sky,36650987,2122872,628956,0.016488599,screen,677188,575932,32907,0.026968242,plastic-opaque,577397,317768,7810,0.008801438,ball_pit-s,275968,1087820,17990,0.013367534,bubbly,1517824,1042911,79577,0.032072524 26 | 25,texture,banded,0.084041506,white-c,72511065,3052837,869439,0.011639939,curtain,3110939,1288605,64423,0.014860716,door frame,221497,265384,6632,0.013809503,painted,13470129,576817,51578,0.003685362,kitchen-s,9006592,891614,66153,0.0067283,banded,2571520,449050,234173,0.084041506 27 | 26,texture,lined,0.0289984,white-c,72511065,2969623,736857,0.009858432,car,5354795,1665745,80606,0.011614808,crosswalk,241792,564145,5717,0.007144285,plastic-opaque,577397,446455,6785,0.006671144,building_facade-s,3148544,924847,37947,0.009403426,lined,1568000,472439,57502,0.0289984 28 | 27,texture,perforated,0.005204549,orange-c,30348501,2012254,106329,0.003296571,person,14523849,911895,54045,0.003513591,arm,2299112,366852,8885,0.003343898,metal,3462561,567758,11503,0.002862286,conference_room-s,2308096,463891,7240,0.002618684,perforated,1793792,286901,10773,0.005204549 29 | 28,texture,dotted,0.016122383,red-c,21004512,3467681,254848,0.010523367,person,14523849,1684934,181139,0.011301661,arm,2299112,786178,27279,0.008920504,fabric,8548746,942432,51421,0.00544728,ball_pit-s,275968,829217,13755,0.012602732,dotted,2935296,480690,54200,0.016122383 30 | 29,texture,chequered,0.00998476,white-c,72511065,1734890,331610,0.004486409,car,5354795,781653,30072,0.004924689,body,1291612,297417,5687,0.00359177,paper,471596,483782,2886,0.003029947,conference_room-s,2308096,425199,8230,0.003020111,chequered,1530368,258314,17683,0.00998476 31 | 30,texture,polka-dotted,0.008900034,yellow-c,34166024,2350621,134590,0.003699351,painting,2207098,1108397,20935,0.006354415,drawer,540514,515329,3517,0.00334212,glass,2246777,770701,11903,0.003960307,conference_room-s,2308096,664099,13003,0.004394105,polka-dotted,1705984,203437,16844,0.008900034 32 | 31,texture,grid,0.019467454,white-c,72511065,3316770,453345,0.006014568,skyscraper,1270392,1897183,48119,0.015425446,balcony,317196,595073,7030,0.007765905,glass,2246777,404005,12098,0.004584861,skyscraper-s,4402944,1140409,69838,0.01275926,grid,1682240,550094,42628,0.019467454 33 | 32,color,white-c,0.007407126,white-c,72511065,1890405,547049,0.007407126,windowpane,8841352,869149,34235,0.003538038,shade,197478,342955,1919,0.00356351,glass,2246777,496259,11002,0.004027036,airport_terminal-s,1480192,460473,6133,0.003170276,striped,2433536,265062,17407,0.006492264 34 | 33,texture,grid,0.011435681,white-c,72511065,2595144,334703,0.004476344,windowpane,8841352,1263827,96209,0.009612278,frame,136341,532067,3841,0.005779703,glass,2246777,761066,16980,0.005677291,building_facade-s,3148544,786740,21799,0.005570227,grid,1682240,244459,21784,0.011435681 35 | 34,object,sky,0.015049989,white-c,72511065,2900270,714314,0.009562818,sky,36650987,1529623,566098,0.015049989,body,1291612,335643,7242,0.004470334,painted,13470129,597123,108705,0.007787702,conference_room-s,2308096,884184,16437,0.005175634,dotted,2935296,247429,32952,0.010461706 36 | 35,color,red-c,0.049208602,red-c,21004512,3405238,1144834,0.049208602,ball,122639,1332399,19305,0.013446093,body,1291612,530885,20329,0.011280302,food,1104928,968087,37740,0.018542949,ball_pit-s,275968,669078,25678,0.027930056,zigzagged,1618176,651632,55897,0.025248079 37 | 36,texture,waffled,0.009607726,yellow-c,34166024,3202648,228809,0.006160739,ceiling,12122565,1363717,84975,0.006340799,head,4946761,600821,38246,0.006942034,metal,3462561,974835,40987,0.009322836,conference_room-s,2308096,687920,11912,0.003991818,waffled,1555456,347070,18105,0.009607726 38 | 37,texture,perforated,0.014671237,black-c,139916835,2687872,1933006,0.013741257,car,5354795,1084174,34764,0.005428308,wheel,1513439,467264,12608,0.006406195,glass,2246777,742617,6861,0.002300394,airport_terminal-s,1480192,552129,8124,0.004013443,perforated,1793792,346797,30951,0.014671237 39 | 38,texture,zigzagged,0.013076401,white-c,72511065,2258153,300274,0.004032204,motorbike,1440771,1003791,9758,0.004007715,wheel,1513439,390242,8227,0.004340385,paper,471596,570836,2559,0.002460877,conference_room-s,2308096,542225,8844,0.003112466,zigzagged,1618176,349422,25397,0.013076401 40 | 39,texture,zigzagged,0.016418372,blue-c,61494563,3504577,759645,0.011825202,bus,1628191,1897946,30988,0.008866003,arm,2299112,734562,22911,0.007609699,glass,2246777,837466,16161,0.00526746,ball_pit-s,275968,1041448,14583,0.0111933,zigzagged,1618176,390347,32444,0.016418372 41 | 40,texture,grid,0.016601402,black-c,139916835,2480945,935207,0.006610985,bus,1628191,1335554,22251,0.007564523,frame,136341,467726,3403,0.005665397,glass,2246777,473026,11590,0.004279575,building_facade-s,3148544,782461,31158,0.007989544,grid,1682240,299108,32356,0.016601402 42 | 41,color,white-c,0.0368031,white-c,72511065,3336160,2692327,0.0368031,light,350013,1633189,45325,0.023388997,shade,197478,580311,13843,0.018120391,glass,2246777,650451,36726,0.012839005,airport_terminal-s,1480192,850696,10924,0.004708694,zigzagged,1618176,464682,42170,0.020664599 43 | 42,texture,chequered,0.014681321,white-c,72511065,2626213,403790,0.005403066,car,5354795,1306148,36052,0.005441901,screen,677188,489632,4616,0.003971764,glass,2246777,775491,15617,0.005194151,dining_room-s,5694976,736147,27228,0.004251787,chequered,1530368,270322,26054,0.014681321 44 | 43,color,blue-c,0.037889995,blue-c,61494563,3306541,2365678,0.037889995,person,14523849,1954429,153345,0.0093933,screen,677188,687235,15585,0.01155439,glass,2246777,528210,20161,0.00731843,skyscraper-s,4402944,1071344,37734,0.006940794,sprinkled,1718528,448393,27563,0.012883772 45 | 44,texture,grid,0.016173589,white-c,72511065,2650905,573807,0.007693004,windowpane,8841352,1433532,101866,0.010013351,wheel,1513439,594316,16086,0.007690509,glass,2246777,503808,17742,0.00649214,building_facade-s,3148544,842963,28283,0.007136362,grid,1682240,362599,32546,0.016173589 46 | 45,texture,chequered,0.028812301,white-c,72511065,2875412,592357,0.007919834,windowpane,8841352,1269817,64826,0.006452696,screen,677188,580425,5744,0.00458834,glass,2246777,820522,18428,0.006044205,bow_window-outdoor-s,288512,672790,3768,0.003935108,chequered,1530368,353590,52761,0.028812301 47 | 46,color,red-c,0.05488432,red-c,21004512,3738516,1287349,0.05488432,flower,591549,1444514,28264,0.014077106,arm,2299112,628744,36309,0.012556946,fabric,8548746,469829,105433,0.011828938,ball_pit-s,275968,657200,12305,0.013362465,dotted,2935296,1562293,197761,0.04599277 48 | 47,texture,grid,0.026031349,white-c,72511065,3226145,1016234,0.013600384,car,5354795,1882694,94723,0.013261389,crosswalk,241792,619852,9416,0.011048687,plastic-opaque,577397,405985,7147,0.007320983,building_facade-s,3148544,1024907,46221,0.011199037,grid,1682240,547479,56570,0.026031349 49 | 48,texture,zigzagged,0.031775073,black-c,139916835,2985165,995902,0.007018035,motorbike,1440771,1248291,29318,0.011022865,wheel,1513439,618102,32816,0.01563616,metal,3462561,566286,14532,0.003620045,shoe_shop-s,577024,589628,4593,0.003952467,zigzagged,1618176,832134,75461,0.031775073 50 | 49,texture,grid,0.017045592,black-c,139916835,2957714,921045,0.006488357,bus,1628191,1506955,38701,0.012498527,frame,136341,552930,6282,0.009197806,glass,2246777,714607,14251,0.004835547,building_facade-s,3148544,829628,28252,0.00715255,grid,1682240,295754,33151,0.017045592 51 | 50,texture,zigzagged,0.025497807,white-c,72511065,3137126,507582,0.006755096,motorbike,1440771,1468255,42540,0.01484047,wheel,1513439,634840,40668,0.019295781,metal,3462561,428090,22312,0.00576785,forest-broadleaf-s,1216768,679023,12348,0.006556078,zigzagged,1618176,925961,63257,0.025497807 52 | 51,texture,perforated,0.009108289,white-c,72511065,1664117,294862,0.003991076,car,5354795,796339,27008,0.004410099,shade,197478,348794,1556,0.002856534,paper,471596,380309,1964,0.002310749,shoe_shop-s,577024,406555,2188,0.002229489,perforated,1793792,296485,18867,0.009108289 53 | 52,texture,zigzagged,0.018276552,blue-c,61494563,3173669,901177,0.014132329,sky,36650987,1954219,602459,0.015853038,body,1291612,555371,10275,0.005594248,glass,2246777,530997,13703,0.004957543,field-cultivated-s,639744,1069784,16225,0.009581865,zigzagged,1618176,424110,36656,0.018276552 54 | 53,texture,lacelike,0.032456866,white-c,72511065,3077592,610736,0.008145545,motorbike,1440771,1447935,39166,0.013744675,wheel,1513439,622114,38121,0.018175083,metal,3462561,383252,13381,0.003491517,forest-broadleaf-s,1216768,649484,14494,0.007827157,lacelike,1530368,979896,78914,0.032456866 55 | 54,texture,chequered,0.016836849,white-c,72511065,2622642,1093034,0.014762616,car,5354795,1282897,40784,0.00618229,body,1291612,453364,8737,0.005032141,glass,2246777,614808,15842,0.005566912,living_room-s,9621248,689032,39608,0.003856418,chequered,1530368,350951,31151,0.016836849 56 | 55,texture,perforated,0.004673532,yellow-c,34166024,1802127,119512,0.003333795,person,14523849,875791,40345,0.002626748,arm,2299112,333490,6585,0.0025076,glass,2246777,437871,5279,0.00197024,ball_pit-s,275968,485906,2597,0.003420359,perforated,1793792,255311,9532,0.004673532 57 | 56,texture,banded,0.020416211,white-c,72511065,3242336,1264881,0.016980885,sky,36650987,1691063,432621,0.011411963,body,1291612,335624,9044,0.005588954,painted,13470129,674572,89030,0.006334098,living_room-s,9621248,968221,58144,0.005521053,banded,2571520,331958,58092,0.020416211 58 | 57,texture,perforated,0.005180185,black-c,139916835,1830569,530471,0.003756426,car,5354795,833691,20365,0.003301654,top,381885,297272,1658,0.002447236,metal,3462561,472437,10039,0.002557734,conference_room-s,2308096,442061,7006,0.002553997,perforated,1793792,268306,10627,0.005180185 59 | 58,texture,perforated,0.015953563,white-c,72511065,1718627,234147,0.003164339,light,350013,816356,8710,0.007523805,wheel,1513439,364744,9722,0.005203213,paper,471596,356537,1945,0.002354186,shoe_shop-s,577024,422013,2560,0.002569051,perforated,1793792,346430,33608,0.015953563 60 | 59,texture,polka-dotted,0.008008401,red-c,21004512,1882411,72114,0.003160842,painting,2207098,920394,19827,0.006380031,body,1291612,414605,5837,0.003432762,glass,2246777,552834,8524,0.003054007,building_facade-s,3148544,559162,12781,0.003459069,polka-dotted,1705984,199293,15137,0.008008401 61 | 60,color,yellow-c,0.029376227,yellow-c,34166024,3150188,1064926,0.029376227,flower,591549,1423012,13317,0.006654361,shade,197478,468819,5743,0.008694217,food,1104928,897070,22426,0.011328711,ball_pit-s,275968,799619,8418,0.00788816,sprinkled,1718528,432731,24366,0.011456148 62 | 61,texture,perforated,0.026451757,white-c,72511065,2839145,527673,0.007052327,car,5354795,1224496,53638,0.008219561,wheel,1513439,574592,16171,0.007805064,metal,3462561,666894,15880,0.003860389,conference_room-s,2308096,627291,10701,0.003658854,perforated,1793792,634488,62577,0.026451757 63 | 62,texture,grid,0.017094888,white-c,72511065,2521440,351313,0.00470417,windowpane,8841352,1404103,90271,0.008889155,body,1291612,526794,9088,0.005022887,glass,2246777,452453,10040,0.003733466,building_facade-s,3148544,834058,23027,0.005815523,grid,1682240,343574,34049,0.017094888 64 | 63,texture,zigzagged,0.020125863,purple-c,10857154,2664287,63112,0.004689438,motorbike,1440771,1294987,25573,0.009435887,wheel,1513439,564066,18791,0.009127543,food,1104928,507693,6015,0.003743917,ball_pit-s,275968,651759,4150,0.004493399,zigzagged,1618176,608112,43922,0.020125863 65 | 64,texture,chequered,0.031945789,white-c,72511065,3338147,632045,0.008402935,car,5354795,1626244,63927,0.009241863,wing,323329,685734,7171,0.007157458,ceramic,679072,777124,9521,0.006581298,conference_room-s,2308096,847767,16746,0.005334621,chequered,1530368,444930,61149,0.031945789 66 | -------------------------------------------------------------------------------- /data/broden_labels_clean.txt: -------------------------------------------------------------------------------- 1 | wine cellar-bottle storage 2 | bird 3 | brewery-outdoor 4 | grill 5 | nuclear power plant-outdoor 6 | convenience store-indoor 7 | ashcan 8 | railroad train 9 | canyon 10 | flecked 11 | crystalline 12 | mountain snowy 13 | shed 14 | paper towel 15 | elevator-freight elevator 16 | waterfall-cascade 17 | excavation 18 | earth 19 | hot spring 20 | scaly 21 | table cloth 22 | tap 23 | ski resort 24 | plane 25 | badlands 26 | kiosk-outdoor 27 | dam 28 | embankment 29 | streetlight 30 | television camera 31 | calendar 32 | candy store 33 | arcades 34 | railing 35 | basketball hoop 36 | pottedplant 37 | pitcher 38 | covered bridge-exterior 39 | exhibitor 40 | tank 41 | toilet tissue 42 | flag 43 | game room 44 | hand cart 45 | trunk 46 | shipyard 47 | console table 48 | crosswalk 49 | shed 50 | corral 51 | ice cream parlor 52 | theater-indoor procenium 53 | joss house 54 | auto factory 55 | clock 56 | fountain 57 | traveling bag 58 | pictures 59 | elevator lobby 60 | toilet 61 | hot tub 62 | menu 63 | skeleton 64 | poolroom-establishment 65 | place mat 66 | workshop 67 | labyrinth-indoor 68 | corn field 69 | rubble 70 | horn 71 | dirt track 72 | snow 73 | trench 74 | saddle 75 | sky 76 | fitting room-interior 77 | subway station-corridor 78 | swimming pool 79 | red-c 80 | shopping mall-indoor 81 | desert-vegetation 82 | checkout counter 83 | deck 84 | library-outdoor 85 | track 86 | guitar 87 | rudder 88 | formal garden 89 | plastic-opaque 90 | ranch house 91 | capital 92 | downtown 93 | bleachers-indoor 94 | grille door 95 | bus interior 96 | garage door 97 | loudspeaker 98 | island 99 | bus 100 | display window 101 | moat-water 102 | flood 103 | cabinet 104 | footboard 105 | wood 106 | shop 107 | skittle alley 108 | aquarium 109 | video player 110 | pool 111 | screen 112 | screen door 113 | fairway 114 | forest road 115 | dashboard 116 | mountain path 117 | orchard 118 | temple-east asia 119 | forecourt 120 | pillar 121 | mountain road 122 | hotel-outdoor 123 | irrigation ditch 124 | hand 125 | ottoman 126 | grooved 127 | armchair 128 | operating room 129 | shoe shop 130 | newspaper 131 | paisley 132 | kettle 133 | newsstand-outdoor 134 | skyscraper 135 | sprinkled 136 | fence 137 | palm 138 | tire 139 | amusement arcade 140 | cardroom 141 | bar 142 | chest 143 | dome 144 | foot 145 | air conditioner 146 | manufactured home 147 | granite 148 | wing 149 | leaf 150 | jar 151 | tile 152 | hotel room 153 | dummy 154 | boxing ring 155 | field-cultivated 156 | window seat 157 | baggage claim 158 | sacristy 159 | beach house 160 | bedchamber 161 | drinking glass 162 | cabin 163 | youth hostel 164 | videostore 165 | hoodoo 166 | cabana 167 | alley 168 | trade name 169 | canal-urban 170 | jail-outdoor 171 | shelter 172 | conveyer belt 173 | carport 174 | handle bar 175 | machine 176 | watchtower 177 | videos 178 | curtain 179 | campsite 180 | tvmonitor 181 | heliport 182 | stove 183 | gymnasium-indoor 184 | field-wild 185 | spiralled 186 | body 187 | lighthouse 188 | handle 189 | dolmen 190 | gauzy 191 | ball 192 | shower 193 | zigzagged 194 | terraces 195 | side rail 196 | leaves 197 | barn 198 | stage-indoor 199 | lean-to 200 | bus depot-outdoor 201 | covered bridge-interior 202 | cockpit 203 | floor 204 | canister 205 | windows 206 | gym shoe 207 | ice 208 | pulpit 209 | cubicle-library 210 | pitted 211 | wet bar 212 | topiary garden 213 | shirt 214 | wheel 215 | temple 216 | keyboard 217 | creek 218 | sauna 219 | beach 220 | arch 221 | plaza 222 | balcony-exterior 223 | beer garden 224 | rocking chair 225 | shaft 226 | stands 227 | honeycombed 228 | bulletin board 229 | arm 230 | controls 231 | cemetery 232 | parterre 233 | player 234 | monastery-outdoor 235 | handbag 236 | beak 237 | windmill 238 | oast house 239 | painting 240 | candle 241 | roundabout 242 | painted 243 | cradle 244 | swirly 245 | inflatable bounce game 246 | flowerpot 247 | water wheel 248 | rubble 249 | bedroom 250 | bench 251 | shade 252 | booklet 253 | iceberg 254 | terrace 255 | gravestone 256 | frilly 257 | waiting room 258 | lake-natural 259 | towel 260 | golf course 261 | vegetable garden 262 | switch 263 | blinds 264 | semidesert ground 265 | natural history museum 266 | telephone booth 267 | snowfield 268 | desert-sand 269 | structure 270 | mill 271 | recycling bin 272 | pane of glass 273 | cracked 274 | inn-outdoor 275 | shoe 276 | soap dispenser 277 | bedpost 278 | jacuzzi-indoor 279 | hen 280 | water tower 281 | purple-c 282 | box 283 | drawer 284 | barrel 285 | coach roof 286 | estuary 287 | tomb 288 | dishwasher 289 | mission 290 | pallet 291 | train station-outdoor 292 | amusement park 293 | highway 294 | aqueduct 295 | chest of drawers 296 | canteen 297 | blotchy 298 | faucet 299 | license plate 300 | building complex 301 | gift shop 302 | bumpy 303 | interlaced 304 | easel 305 | chandelier 306 | seat base 307 | alcove 308 | waterfall-block 309 | bathtub 310 | ocean 311 | berth 312 | straw 313 | knife 314 | casino-outdoor 315 | fitting room-exterior 316 | synagogue-outdoor 317 | bottle 318 | fish 319 | yard 320 | slats 321 | savanna 322 | display board 323 | foliage 324 | steam shovel 325 | stained 326 | mouth 327 | television 328 | greenhouse-indoor 329 | jacket 330 | vending machine 331 | office building 332 | green-c 333 | eye 334 | basketball court-outdoor 335 | mosque-outdoor 336 | jersey 337 | moor 338 | water mill 339 | industrial park 340 | driving range-outdoor 341 | mouse pad 342 | smeared 343 | lockers 344 | granary 345 | courtyard 346 | tennis court 347 | pink-c 348 | canvas 349 | telescope 350 | excavator 351 | lecture room 352 | pleated 353 | sandbox 354 | hat 355 | apparel 356 | mat 357 | set of instruments 358 | hangar-indoor 359 | hill 360 | concrete 361 | utility room 362 | barbecue 363 | roof 364 | step 365 | elephant 366 | sculpture 367 | water tank 368 | mirror 369 | ear 370 | restaurant 371 | motorbike 372 | skin 373 | hunting lodge-indoor 374 | television stand 375 | streetcar 376 | hedge maze 377 | camera 378 | doorway-outdoor 379 | casing 380 | stool 381 | covered bridge 382 | cafeteria 383 | monument 384 | lamp 385 | museum-indoor 386 | pier 387 | fluorescent 388 | ballroom 389 | polka-dotted 390 | taillight 391 | brown-c 392 | bulldozer 393 | drawing 394 | double door 395 | postbox 396 | workbench 397 | windshield 398 | television studio 399 | briefcase 400 | water tower 401 | desert 402 | trestle 403 | doorway-indoor 404 | crane 405 | martial arts gym 406 | train 407 | lighthouse 408 | slum 409 | grey-c 410 | horse-drawn carriage 411 | spindle 412 | landing deck 413 | counter 414 | galley 415 | arcade 416 | shelves 417 | shelf 418 | saucepan 419 | vault 420 | ad 421 | niche 422 | tower 423 | mattress 424 | skirt 425 | top 426 | bandstand 427 | eyebrow 428 | classroom 429 | bathrobe 430 | pantry 431 | wall 432 | coffee shop 433 | hut 434 | baptistry-outdoor 435 | toyshop 436 | parasol 437 | convenience store-outdoor 438 | loading dock 439 | volleyball court-outdoor 440 | escalator-indoor 441 | sheep 442 | dishrag 443 | pipe 444 | altarpiece 445 | ring 446 | wrestling ring-indoor 447 | knitted 448 | fork 449 | refrigerator 450 | cat 451 | cathedral-outdoor 452 | patty 453 | pitch 454 | wheat field 455 | court 456 | basketball court-indoor 457 | mug 458 | minibike 459 | lower sash 460 | pane 461 | gate 462 | witness stand 463 | restaurant patio 464 | choir loft-exterior 465 | dormer 466 | elevator-interior 467 | duck 468 | shanties 469 | lake-artificial 470 | roller coaster 471 | weighbridge 472 | brush 473 | outside arm 474 | stile 475 | driveway 476 | toll booth 477 | swimming pool-outdoor 478 | basement 479 | curb 480 | airport ticket counter 481 | water 482 | building facade 483 | veined 484 | chimney 485 | subway interior 486 | sidewalk 487 | home office 488 | viaduct 489 | nursery 490 | bubbly 491 | corridor 492 | blind 493 | kasbah 494 | fibrous 495 | runway 496 | gulch 497 | home theater 498 | movie theater-indoor 499 | fog bank 500 | balcony-interior 501 | manhole 502 | wine cellar-barrel storage 503 | eiderdown 504 | folding screen 505 | bread rolls 506 | carrousel 507 | witness stand 508 | synthesizer 509 | equipment 510 | lido deck-outdoor 511 | plinth 512 | gazebo-exterior 513 | pillow 514 | bird cage 515 | tree 516 | crevasse 517 | japanese garden 518 | catwalk 519 | corner pocket 520 | wardrobe 521 | bidet 522 | cup 523 | fur 524 | rope bridge 525 | botanical garden 526 | drainage ditch 527 | plastic 528 | headboard 529 | magazine 530 | check-in-desk 531 | stratified 532 | dorm room 533 | candelabrum 534 | bathroom 535 | quay 536 | inn-indoor 537 | stairway 538 | general store-outdoor 539 | pond 540 | binder 541 | tail 542 | junkyard 543 | assembly line 544 | can 545 | cactus 546 | head 547 | control tower-indoor 548 | backstairs 549 | oar 550 | church-outdoor 551 | lagoon 552 | bazaar-outdoor 553 | moat-dry 554 | fire station 555 | market-outdoor 556 | braided 557 | circus tent-outdoor 558 | garage-indoor 559 | tables 560 | apse-indoor 561 | system 562 | moon bounce 563 | bush 564 | slot machine 565 | metal shutters 566 | figurine 567 | ranch 568 | forest 569 | ball pit 570 | glass 571 | cart 572 | subway station-platform 573 | shopfront 574 | fort 575 | military hut 576 | sash 577 | hotel breakfast area 578 | village 579 | berth 580 | arm panel 581 | observatory-outdoor 582 | office 583 | coach 584 | rubber 585 | airport-entrance 586 | perforated 587 | mountain 588 | dining hall 589 | laundromat 590 | bird feeder 591 | village 592 | catwalk 593 | watering hole 594 | studded 595 | sand trap 596 | fuselage 597 | booth 598 | sink 599 | coffee table 600 | courthouse 601 | panel 602 | helmet 603 | ship 604 | balcony 605 | waterfall 606 | elevator-door 607 | junk pile 608 | control tower 609 | mountain pass 610 | table football 611 | dam 612 | hill 613 | food court 614 | kindergarden classroom 615 | day care center 616 | bus stop 617 | parlor 618 | backpack 619 | helicopter 620 | sales booth 621 | watchtower 622 | playground 623 | baptismal font 624 | plaything 625 | knob 626 | ceiling 627 | booth-indoor 628 | limousine interior 629 | bed 630 | hayfield 631 | vegetables 632 | tearoom 633 | billboard 634 | balloon 635 | river 636 | access road 637 | flower 638 | candies 639 | tray 640 | lake 641 | chain wheel 642 | bowl 643 | buffet 644 | hot tub-indoor 645 | canopy 646 | porch 647 | cow 648 | torso 649 | pantry 650 | mausoleum 651 | town house 652 | baseboard 653 | anechoic chamber 654 | hovel 655 | folding door 656 | shower curtain 657 | service station 658 | living room 659 | swimming pool-indoor 660 | geodesic dome-outdoor 661 | firing range-outdoor 662 | hunting lodge-outdoor 663 | teapot 664 | backplate 665 | church-indoor 666 | viaduct 667 | boathouse 668 | rock 669 | dentists office 670 | washer 671 | kitchenette 672 | guardhouse 673 | outhouse-outdoor 674 | bank-indoor 675 | bleachers-outdoor 676 | conference center 677 | ski slope 678 | shore 679 | street 680 | reception 681 | cannon 682 | computer 683 | tapestry 684 | railway 685 | lid 686 | towel rack 687 | hoof 688 | pilothouse-indoor 689 | weighbridge 690 | dacha 691 | remote control 692 | windmill 693 | bowling alley 694 | boot 695 | bridge 696 | staircase 697 | wineglass 698 | sandbox 699 | computer room 700 | table tennis 701 | dirt track 702 | park 703 | ruin 704 | chequered 705 | ice skating rink-indoor 706 | hangar-outdoor 707 | fireplace 708 | lobby 709 | traffic light 710 | planetarium-outdoor 711 | acropolis 712 | coffee maker 713 | river 714 | guardrail 715 | muzzle 716 | bowling alley 717 | mouse 718 | fire place 719 | sea 720 | dental chair 721 | animal 722 | mosque 723 | fire escape 724 | podium-indoor 725 | cottage 726 | banquet hall 727 | stalls 728 | jail cell 729 | grass 730 | scaffolding 731 | bog 732 | sofa 733 | autobus 734 | palace 735 | industrial area 736 | patio 737 | bottle rack 738 | board 739 | tower 740 | rifle 741 | yellow-c 742 | seat cushion 743 | baseball field 744 | kitchen island 745 | light 746 | rubbish 747 | carousel 748 | bullpen 749 | parking 750 | pavilion 751 | back 752 | doors 753 | price tag 754 | head roof 755 | hot tub-outdoor 756 | wicker 757 | plate 758 | valley 759 | stage 760 | jail-indoor 761 | rope 762 | pot 763 | doorframe 764 | washing machines 765 | shops 766 | revolving door 767 | apron 768 | cottage garden 769 | mine 770 | cardboard 771 | carport-freestanding 772 | scale 773 | cash register 774 | bus shelter 775 | bakery-kitchen 776 | bouquet 777 | podium 778 | shower 779 | fridge 780 | mezzanine 781 | frame 782 | landing 783 | riser 784 | bakery-shop 785 | black-c 786 | earmuffs 787 | diner-outdoor 788 | oven 789 | ticket window 790 | hat shop 791 | file cabinet 792 | heater 793 | bumper 794 | pedestal 795 | blade 796 | castle 797 | imaret 798 | quadrangle 799 | column 800 | engine 801 | potholed 802 | freckled 803 | music studio 804 | muntin 805 | mosque-indoor 806 | parking lot 807 | vineyard 808 | courtroom 809 | oasis 810 | hay 811 | ladder 812 | clouds 813 | fastfood restaurant 814 | parking lot 815 | dinette-vehicle 816 | shutter 817 | door frame 818 | hedge 819 | net 820 | seat 821 | bullring 822 | cushion 823 | kennel-outdoor 824 | airplane 825 | atrium-public 826 | matted 827 | planks 828 | laptop 829 | mansion 830 | rim 831 | trailer 832 | carport-outdoor 833 | henhouse 834 | ramp 835 | garage-outdoor 836 | organ 837 | wrinkled 838 | rack 839 | playroom 840 | art school 841 | waffled 842 | spoon 843 | elevator 844 | forest-broadleaf 845 | ceramic 846 | shower stall 847 | leg 848 | dog 849 | tread 850 | meat 851 | fruit 852 | book stand 853 | lift bridge 854 | pole 855 | post 856 | bookcase 857 | fire escape 858 | stadium-baseball 859 | fountain 860 | slope 861 | ride 862 | artists loft 863 | airport terminal 864 | vineyard 865 | beam 866 | herb garden 867 | campus 868 | monitor 869 | fjord 870 | cathedral-indoor 871 | fabric 872 | bedclothes 873 | paper 874 | map 875 | windscreen 876 | signal box 877 | auto showroom 878 | goal 879 | conference room 880 | spotlight 881 | stabilizer 882 | leather 883 | cheese factory 884 | van 885 | forklift 886 | front 887 | napkin 888 | playground 889 | hospital 890 | motel 891 | brick 892 | boat 893 | safety side 894 | auditorium 895 | caravan 896 | base 897 | embassy 898 | pool table 899 | candlestick 900 | chapel 901 | barrels 902 | work surface 903 | field 904 | door 905 | art studio 906 | dining room 907 | marsh 908 | archaelogical excavation 909 | catacomb 910 | valley 911 | blue-c 912 | car interior-backseat 913 | basket 914 | gravel 915 | book 916 | desk 917 | nose 918 | barnyard 919 | locker room 920 | ice skating rink-outdoor 921 | aircraft carrier 922 | florist shop-indoor 923 | nunnery 924 | cobwebbed 925 | horse 926 | signboard 927 | pack 928 | exhaust hood 929 | pasture 930 | baptistry-indoor 931 | cloister-indoor 932 | cloud 933 | road cut 934 | carpet 935 | radio 936 | wheelchair 937 | crosswalk 938 | medina 939 | tent 940 | bank vault 941 | bank-outdoor 942 | forest path 943 | airplane cabin 944 | dinette-home 945 | badminton court-indoor 946 | arrival gate-outdoor 947 | bread 948 | telephone 949 | amphitheater 950 | manhole 951 | delicatessen 952 | table 953 | sand 954 | wall socket 955 | blanket 956 | diffusor 957 | throne room 958 | forest-needleleaf 959 | elevator door 960 | document 961 | harbor 962 | bistro-indoor 963 | person 964 | aquatic theater 965 | dotted 966 | orange-c 967 | tumble dryer 968 | baby buggy 969 | porous 970 | coast 971 | text 972 | tracks 973 | windowpane 974 | islet 975 | freeway 976 | dugout 977 | parking garage-outdoor 978 | back pillow 979 | wallpaper 980 | countertop 981 | footbridge 982 | labyrinth 983 | kiosk-indoor 984 | grid 985 | market-indoor 986 | earth fissure 987 | statue 988 | movie theater-outdoor 989 | hacienda 990 | crosshatched 991 | bow window-indoor 992 | notebook 993 | cross 994 | sweater 995 | escalator-outdoor 996 | disc case 997 | grand piano 998 | pagoda 999 | steering wheel 1000 | bell 1001 | butte 1002 | batters box 1003 | face 1004 | wave 1005 | construction site 1006 | slide 1007 | microphone 1008 | cliff 1009 | silver screen 1010 | vent 1011 | central reservation 1012 | museum-outdoor 1013 | hospital room 1014 | pub-outdoor 1015 | stern 1016 | badlands 1017 | donjon 1018 | student residence 1019 | ruins 1020 | runway 1021 | poolroom-home 1022 | neck 1023 | headlight 1024 | banner 1025 | curtains 1026 | button panel 1027 | cockpit 1028 | hair 1029 | liquor store-indoor 1030 | lined 1031 | ticket counter 1032 | umbrella 1033 | waterfall-fan 1034 | skyscraper 1035 | cubicle-office 1036 | chicken coop-outdoor 1037 | pigeonhole 1038 | house 1039 | microwave 1040 | jewelry shop 1041 | plastic-clear 1042 | side 1043 | barbershop 1044 | crate 1045 | zen garden 1046 | sconce 1047 | bridge 1048 | land 1049 | finger 1050 | radiator 1051 | arch 1052 | fire 1053 | drum 1054 | partition 1055 | coat 1056 | deck chair 1057 | instrument panel 1058 | cabin-outdoor 1059 | cavern-indoor 1060 | fence 1061 | butchers shop 1062 | warehouse-indoor 1063 | football field 1064 | entrance 1065 | grandstand 1066 | wire 1067 | computer case 1068 | sandbar 1069 | smoke 1070 | library-indoor 1071 | kennel-indoor 1072 | vase 1073 | paw 1074 | bannister 1075 | apartment building-outdoor 1076 | building 1077 | planter 1078 | canal-natural 1079 | car dealership 1080 | laminate 1081 | brushes 1082 | gas station 1083 | cap 1084 | dining car 1085 | attic 1086 | monitoring device 1087 | road 1088 | gas pump 1089 | metal 1090 | clean room 1091 | aqueduct 1092 | clothing store 1093 | stretcher 1094 | ground 1095 | sewing machine 1096 | skylight 1097 | ice rink 1098 | trouser 1099 | chalet 1100 | parking garage-indoor 1101 | control tower-outdoor 1102 | bazaar-indoor 1103 | blackboard 1104 | arcade machine 1105 | striped 1106 | casino-indoor 1107 | pulpit 1108 | bar 1109 | shop window 1110 | bicycle 1111 | platform 1112 | badminton court-outdoor 1113 | white-c 1114 | metal shutter 1115 | bag 1116 | marbled 1117 | case 1118 | spa-massage room 1119 | upper sash 1120 | gas station 1121 | printer 1122 | kitchen 1123 | path 1124 | house 1125 | abbey 1126 | table game 1127 | ghost town 1128 | elevator shaft 1129 | woven 1130 | cloister-outdoor 1131 | auto mechanics-indoor 1132 | altar 1133 | trellis 1134 | machinery 1135 | sun deck 1136 | cd 1137 | big top 1138 | escalator 1139 | car 1140 | call center 1141 | florist shop-outdoor 1142 | supermarket 1143 | bullring 1144 | childs room 1145 | food 1146 | piano 1147 | airport 1148 | flight of stairs-natural 1149 | merchandise 1150 | plant 1151 | breads 1152 | beauty salon 1153 | roundabout 1154 | closet 1155 | container 1156 | decoration 1157 | meter 1158 | lawn 1159 | awning 1160 | bookstore 1161 | chair 1162 | cargo container interior 1163 | scoreboard 1164 | sill 1165 | flight of stairs-urban 1166 | crt screen 1167 | inside arm 1168 | jacuzzi-outdoor 1169 | truck 1170 | lacelike 1171 | fan 1172 | banded 1173 | funeral chapel 1174 | linoleum 1175 | hallway 1176 | art gallery 1177 | reading room 1178 | farm 1179 | stairs 1180 | archive 1181 | greenhouse 1182 | field road 1183 | greenhouse-outdoor 1184 | air base 1185 | bow window-outdoor 1186 | swivel chair 1187 | labyrinth-outdoor 1188 | meshed 1189 | bucket 1190 | cage 1191 | bayou 1192 | mountain 1193 | fishpond 1194 | poster 1195 | tunnel 1196 | liquor store-outdoor 1197 | box office -------------------------------------------------------------------------------- /data/categories_places365.txt: -------------------------------------------------------------------------------- 1 | /a/airfield 0 2 | /a/airplane_cabin 1 3 | /a/airport_terminal 2 4 | /a/alcove 3 5 | /a/alley 4 6 | /a/amphitheater 5 7 | /a/amusement_arcade 6 8 | /a/amusement_park 7 9 | /a/apartment_building/outdoor 8 10 | /a/aquarium 9 11 | /a/aqueduct 10 12 | /a/arcade 11 13 | /a/arch 12 14 | /a/archaelogical_excavation 13 15 | /a/archive 14 16 | /a/arena/hockey 15 17 | /a/arena/performance 16 18 | /a/arena/rodeo 17 19 | /a/army_base 18 20 | /a/art_gallery 19 21 | /a/art_school 20 22 | /a/art_studio 21 23 | /a/artists_loft 22 24 | /a/assembly_line 23 25 | /a/athletic_field/outdoor 24 26 | /a/atrium/public 25 27 | /a/attic 26 28 | /a/auditorium 27 29 | /a/auto_factory 28 30 | /a/auto_showroom 29 31 | /b/badlands 30 32 | /b/bakery/shop 31 33 | /b/balcony/exterior 32 34 | /b/balcony/interior 33 35 | /b/ball_pit 34 36 | /b/ballroom 35 37 | /b/bamboo_forest 36 38 | /b/bank_vault 37 39 | /b/banquet_hall 38 40 | /b/bar 39 41 | /b/barn 40 42 | /b/barndoor 41 43 | /b/baseball_field 42 44 | /b/basement 43 45 | /b/basketball_court/indoor 44 46 | /b/bathroom 45 47 | /b/bazaar/indoor 46 48 | /b/bazaar/outdoor 47 49 | /b/beach 48 50 | /b/beach_house 49 51 | /b/beauty_salon 50 52 | /b/bedchamber 51 53 | /b/bedroom 52 54 | /b/beer_garden 53 55 | /b/beer_hall 54 56 | /b/berth 55 57 | /b/biology_laboratory 56 58 | /b/boardwalk 57 59 | /b/boat_deck 58 60 | /b/boathouse 59 61 | /b/bookstore 60 62 | /b/booth/indoor 61 63 | /b/botanical_garden 62 64 | /b/bow_window/indoor 63 65 | /b/bowling_alley 64 66 | /b/boxing_ring 65 67 | /b/bridge 66 68 | /b/building_facade 67 69 | /b/bullring 68 70 | /b/burial_chamber 69 71 | /b/bus_interior 70 72 | /b/bus_station/indoor 71 73 | /b/butchers_shop 72 74 | /b/butte 73 75 | /c/cabin/outdoor 74 76 | /c/cafeteria 75 77 | /c/campsite 76 78 | /c/campus 77 79 | /c/canal/natural 78 80 | /c/canal/urban 79 81 | /c/candy_store 80 82 | /c/canyon 81 83 | /c/car_interior 82 84 | /c/carrousel 83 85 | /c/castle 84 86 | /c/catacomb 85 87 | /c/cemetery 86 88 | /c/chalet 87 89 | /c/chemistry_lab 88 90 | /c/childs_room 89 91 | /c/church/indoor 90 92 | /c/church/outdoor 91 93 | /c/classroom 92 94 | /c/clean_room 93 95 | /c/cliff 94 96 | /c/closet 95 97 | /c/clothing_store 96 98 | /c/coast 97 99 | /c/cockpit 98 100 | /c/coffee_shop 99 101 | /c/computer_room 100 102 | /c/conference_center 101 103 | /c/conference_room 102 104 | /c/construction_site 103 105 | /c/corn_field 104 106 | /c/corral 105 107 | /c/corridor 106 108 | /c/cottage 107 109 | /c/courthouse 108 110 | /c/courtyard 109 111 | /c/creek 110 112 | /c/crevasse 111 113 | /c/crosswalk 112 114 | /d/dam 113 115 | /d/delicatessen 114 116 | /d/department_store 115 117 | /d/desert/sand 116 118 | /d/desert/vegetation 117 119 | /d/desert_road 118 120 | /d/diner/outdoor 119 121 | /d/dining_hall 120 122 | /d/dining_room 121 123 | /d/discotheque 122 124 | /d/doorway/outdoor 123 125 | /d/dorm_room 124 126 | /d/downtown 125 127 | /d/dressing_room 126 128 | /d/driveway 127 129 | /d/drugstore 128 130 | /e/elevator/door 129 131 | /e/elevator_lobby 130 132 | /e/elevator_shaft 131 133 | /e/embassy 132 134 | /e/engine_room 133 135 | /e/entrance_hall 134 136 | /e/escalator/indoor 135 137 | /e/excavation 136 138 | /f/fabric_store 137 139 | /f/farm 138 140 | /f/fastfood_restaurant 139 141 | /f/field/cultivated 140 142 | /f/field/wild 141 143 | /f/field_road 142 144 | /f/fire_escape 143 145 | /f/fire_station 144 146 | /f/fishpond 145 147 | /f/flea_market/indoor 146 148 | /f/florist_shop/indoor 147 149 | /f/food_court 148 150 | /f/football_field 149 151 | /f/forest/broadleaf 150 152 | /f/forest_path 151 153 | /f/forest_road 152 154 | /f/formal_garden 153 155 | /f/fountain 154 156 | /g/galley 155 157 | /g/garage/indoor 156 158 | /g/garage/outdoor 157 159 | /g/gas_station 158 160 | /g/gazebo/exterior 159 161 | /g/general_store/indoor 160 162 | /g/general_store/outdoor 161 163 | /g/gift_shop 162 164 | /g/glacier 163 165 | /g/golf_course 164 166 | /g/greenhouse/indoor 165 167 | /g/greenhouse/outdoor 166 168 | /g/grotto 167 169 | /g/gymnasium/indoor 168 170 | /h/hangar/indoor 169 171 | /h/hangar/outdoor 170 172 | /h/harbor 171 173 | /h/hardware_store 172 174 | /h/hayfield 173 175 | /h/heliport 174 176 | /h/highway 175 177 | /h/home_office 176 178 | /h/home_theater 177 179 | /h/hospital 178 180 | /h/hospital_room 179 181 | /h/hot_spring 180 182 | /h/hotel/outdoor 181 183 | /h/hotel_room 182 184 | /h/house 183 185 | /h/hunting_lodge/outdoor 184 186 | /i/ice_cream_parlor 185 187 | /i/ice_floe 186 188 | /i/ice_shelf 187 189 | /i/ice_skating_rink/indoor 188 190 | /i/ice_skating_rink/outdoor 189 191 | /i/iceberg 190 192 | /i/igloo 191 193 | /i/industrial_area 192 194 | /i/inn/outdoor 193 195 | /i/islet 194 196 | /j/jacuzzi/indoor 195 197 | /j/jail_cell 196 198 | /j/japanese_garden 197 199 | /j/jewelry_shop 198 200 | /j/junkyard 199 201 | /k/kasbah 200 202 | /k/kennel/outdoor 201 203 | /k/kindergarden_classroom 202 204 | /k/kitchen 203 205 | /l/lagoon 204 206 | /l/lake/natural 205 207 | /l/landfill 206 208 | /l/landing_deck 207 209 | /l/laundromat 208 210 | /l/lawn 209 211 | /l/lecture_room 210 212 | /l/legislative_chamber 211 213 | /l/library/indoor 212 214 | /l/library/outdoor 213 215 | /l/lighthouse 214 216 | /l/living_room 215 217 | /l/loading_dock 216 218 | /l/lobby 217 219 | /l/lock_chamber 218 220 | /l/locker_room 219 221 | /m/mansion 220 222 | /m/manufactured_home 221 223 | /m/market/indoor 222 224 | /m/market/outdoor 223 225 | /m/marsh 224 226 | /m/martial_arts_gym 225 227 | /m/mausoleum 226 228 | /m/medina 227 229 | /m/mezzanine 228 230 | /m/moat/water 229 231 | /m/mosque/outdoor 230 232 | /m/motel 231 233 | /m/mountain 232 234 | /m/mountain_path 233 235 | /m/mountain_snowy 234 236 | /m/movie_theater/indoor 235 237 | /m/museum/indoor 236 238 | /m/museum/outdoor 237 239 | /m/music_studio 238 240 | /n/natural_history_museum 239 241 | /n/nursery 240 242 | /n/nursing_home 241 243 | /o/oast_house 242 244 | /o/ocean 243 245 | /o/office 244 246 | /o/office_building 245 247 | /o/office_cubicles 246 248 | /o/oilrig 247 249 | /o/operating_room 248 250 | /o/orchard 249 251 | /o/orchestra_pit 250 252 | /p/pagoda 251 253 | /p/palace 252 254 | /p/pantry 253 255 | /p/park 254 256 | /p/parking_garage/indoor 255 257 | /p/parking_garage/outdoor 256 258 | /p/parking_lot 257 259 | /p/pasture 258 260 | /p/patio 259 261 | /p/pavilion 260 262 | /p/pet_shop 261 263 | /p/pharmacy 262 264 | /p/phone_booth 263 265 | /p/physics_laboratory 264 266 | /p/picnic_area 265 267 | /p/pier 266 268 | /p/pizzeria 267 269 | /p/playground 268 270 | /p/playroom 269 271 | /p/plaza 270 272 | /p/pond 271 273 | /p/porch 272 274 | /p/promenade 273 275 | /p/pub/indoor 274 276 | /r/racecourse 275 277 | /r/raceway 276 278 | /r/raft 277 279 | /r/railroad_track 278 280 | /r/rainforest 279 281 | /r/reception 280 282 | /r/recreation_room 281 283 | /r/repair_shop 282 284 | /r/residential_neighborhood 283 285 | /r/restaurant 284 286 | /r/restaurant_kitchen 285 287 | /r/restaurant_patio 286 288 | /r/rice_paddy 287 289 | /r/river 288 290 | /r/rock_arch 289 291 | /r/roof_garden 290 292 | /r/rope_bridge 291 293 | /r/ruin 292 294 | /r/runway 293 295 | /s/sandbox 294 296 | /s/sauna 295 297 | /s/schoolhouse 296 298 | /s/science_museum 297 299 | /s/server_room 298 300 | /s/shed 299 301 | /s/shoe_shop 300 302 | /s/shopfront 301 303 | /s/shopping_mall/indoor 302 304 | /s/shower 303 305 | /s/ski_resort 304 306 | /s/ski_slope 305 307 | /s/sky 306 308 | /s/skyscraper 307 309 | /s/slum 308 310 | /s/snowfield 309 311 | /s/soccer_field 310 312 | /s/stable 311 313 | /s/stadium/baseball 312 314 | /s/stadium/football 313 315 | /s/stadium/soccer 314 316 | /s/stage/indoor 315 317 | /s/stage/outdoor 316 318 | /s/staircase 317 319 | /s/storage_room 318 320 | /s/street 319 321 | /s/subway_station/platform 320 322 | /s/supermarket 321 323 | /s/sushi_bar 322 324 | /s/swamp 323 325 | /s/swimming_hole 324 326 | /s/swimming_pool/indoor 325 327 | /s/swimming_pool/outdoor 326 328 | /s/synagogue/outdoor 327 329 | /t/television_room 328 330 | /t/television_studio 329 331 | /t/temple/asia 330 332 | /t/throne_room 331 333 | /t/ticket_booth 332 334 | /t/topiary_garden 333 335 | /t/tower 334 336 | /t/toyshop 335 337 | /t/train_interior 336 338 | /t/train_station/platform 337 339 | /t/tree_farm 338 340 | /t/tree_house 339 341 | /t/trench 340 342 | /t/tundra 341 343 | /u/underwater/ocean_deep 342 344 | /u/utility_room 343 345 | /v/valley 344 346 | /v/vegetable_garden 345 347 | /v/veterinarians_office 346 348 | /v/viaduct 347 349 | /v/village 348 350 | /v/vineyard 349 351 | /v/volcano 350 352 | /v/volleyball_court/outdoor 351 353 | /w/waiting_room 352 354 | /w/water_park 353 355 | /w/water_tower 354 356 | /w/waterfall 355 357 | /w/watering_hole 356 358 | /w/wave 357 359 | /w/wet_bar 358 360 | /w/wheat_field 359 361 | /w/wind_farm 360 362 | /w/windmill 361 363 | /y/yard 362 364 | /y/youth_hostel 363 365 | /z/zen_garden 364 -------------------------------------------------------------------------------- /data/github_overview_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trustworthy-ML-Lab/CLIP-dissect/21e7697feaea3bf7d7bc2d2cc8e4047d5f5fd502/data/github_overview_figure.png -------------------------------------------------------------------------------- /data/imagenet_labels.txt: -------------------------------------------------------------------------------- 1 | tench, Tinca tinca 2 | goldfish, Carassius auratus 3 | great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias 4 | tiger shark, Galeocerdo cuvieri 5 | hammerhead, hammerhead shark 6 | electric ray, crampfish, numbfish, torpedo 7 | stingray 8 | cock 9 | hen 10 | ostrich, Struthio camelus 11 | brambling, Fringilla montifringilla 12 | goldfinch, Carduelis carduelis 13 | house finch, linnet, Carpodacus mexicanus 14 | junco, snowbird 15 | indigo bunting, indigo finch, indigo bird, Passerina cyanea 16 | robin, American robin, Turdus migratorius 17 | bulbul 18 | jay 19 | magpie 20 | chickadee 21 | water ouzel, dipper 22 | kite 23 | bald eagle, American eagle, Haliaeetus leucocephalus 24 | vulture 25 | great grey owl, great gray owl, Strix nebulosa 26 | European fire salamander, Salamandra salamandra 27 | common newt, Triturus vulgaris 28 | eft 29 | spotted salamander, Ambystoma maculatum 30 | axolotl, mud puppy, Ambystoma mexicanum 31 | bullfrog, Rana catesbeiana 32 | tree frog, tree-frog 33 | tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui 34 | loggerhead, loggerhead turtle, Caretta caretta 35 | leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea 36 | mud turtle 37 | terrapin 38 | box turtle, box tortoise 39 | banded gecko 40 | common iguana, iguana, Iguana iguana 41 | American chameleon, anole, Anolis carolinensis 42 | whiptail, whiptail lizard 43 | agama 44 | frilled lizard, Chlamydosaurus kingi 45 | alligator lizard 46 | Gila monster, Heloderma suspectum 47 | green lizard, Lacerta viridis 48 | African chameleon, Chamaeleo chamaeleon 49 | Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis 50 | African crocodile, Nile crocodile, Crocodylus niloticus 51 | American alligator, Alligator mississipiensis 52 | triceratops 53 | thunder snake, worm snake, Carphophis amoenus 54 | ringneck snake, ring-necked snake, ring snake 55 | hognose snake, puff adder, sand viper 56 | green snake, grass snake 57 | king snake, kingsnake 58 | garter snake, grass snake 59 | water snake 60 | vine snake 61 | night snake, Hypsiglena torquata 62 | boa constrictor, Constrictor constrictor 63 | rock python, rock snake, Python sebae 64 | Indian cobra, Naja naja 65 | green mamba 66 | sea snake 67 | horned viper, cerastes, sand viper, horned asp, Cerastes cornutus 68 | diamondback, diamondback rattlesnake, Crotalus adamanteus 69 | sidewinder, horned rattlesnake, Crotalus cerastes 70 | trilobite 71 | harvestman, daddy longlegs, Phalangium opilio 72 | scorpion 73 | black and gold garden spider, Argiope aurantia 74 | barn spider, Araneus cavaticus 75 | garden spider, Aranea diademata 76 | black widow, Latrodectus mactans 77 | tarantula 78 | wolf spider, hunting spider 79 | tick 80 | centipede 81 | black grouse 82 | ptarmigan 83 | ruffed grouse, partridge, Bonasa umbellus 84 | prairie chicken, prairie grouse, prairie fowl 85 | peacock 86 | quail 87 | partridge 88 | African grey, African gray, Psittacus erithacus 89 | macaw 90 | sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita 91 | lorikeet 92 | coucal 93 | bee eater 94 | hornbill 95 | hummingbird 96 | jacamar 97 | toucan 98 | drake 99 | red-breasted merganser, Mergus serrator 100 | goose 101 | black swan, Cygnus atratus 102 | tusker 103 | echidna, spiny anteater, anteater 104 | platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus 105 | wallaby, brush kangaroo 106 | koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus 107 | wombat 108 | jellyfish 109 | sea anemone, anemone 110 | brain coral 111 | flatworm, platyhelminth 112 | nematode, nematode worm, roundworm 113 | conch 114 | snail 115 | slug 116 | sea slug, nudibranch 117 | chiton, coat-of-mail shell, sea cradle, polyplacophore 118 | chambered nautilus, pearly nautilus, nautilus 119 | Dungeness crab, Cancer magister 120 | rock crab, Cancer irroratus 121 | fiddler crab 122 | king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica 123 | American lobster, Northern lobster, Maine lobster, Homarus americanus 124 | spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish 125 | crayfish, crawfish, crawdad, crawdaddy 126 | hermit crab 127 | isopod 128 | white stork, Ciconia ciconia 129 | black stork, Ciconia nigra 130 | spoonbill 131 | flamingo 132 | little blue heron, Egretta caerulea 133 | American egret, great white heron, Egretta albus 134 | bittern 135 | crane, bird 136 | limpkin, Aramus pictus 137 | European gallinule, Porphyrio porphyrio 138 | American coot, marsh hen, mud hen, water hen, Fulica americana 139 | bustard 140 | ruddy turnstone, Arenaria interpres 141 | red-backed sandpiper, dunlin, Erolia alpina 142 | redshank, Tringa totanus 143 | dowitcher 144 | oystercatcher, oyster catcher 145 | pelican 146 | king penguin, Aptenodytes patagonica 147 | albatross, mollymawk 148 | grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus 149 | killer whale, killer, orca, grampus, sea wolf, Orcinus orca 150 | dugong, Dugong dugon 151 | sea lion 152 | Chihuahua 153 | Japanese spaniel 154 | Maltese dog, Maltese terrier, Maltese 155 | Pekinese, Pekingese, Peke 156 | Shih-Tzu 157 | Blenheim spaniel 158 | papillon 159 | toy terrier 160 | Rhodesian ridgeback 161 | Afghan hound, Afghan 162 | basset, basset hound 163 | beagle 164 | bloodhound, sleuthhound 165 | bluetick 166 | black-and-tan coonhound 167 | Walker hound, Walker foxhound 168 | English foxhound 169 | redbone 170 | borzoi, Russian wolfhound 171 | Irish wolfhound 172 | Italian greyhound 173 | whippet 174 | Ibizan hound, Ibizan Podenco 175 | Norwegian elkhound, elkhound 176 | otterhound, otter hound 177 | Saluki, gazelle hound 178 | Scottish deerhound, deerhound 179 | Weimaraner 180 | Staffordshire bullterrier, Staffordshire bull terrier 181 | American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier 182 | Bedlington terrier 183 | Border terrier 184 | Kerry blue terrier 185 | Irish terrier 186 | Norfolk terrier 187 | Norwich terrier 188 | Yorkshire terrier 189 | wire-haired fox terrier 190 | Lakeland terrier 191 | Sealyham terrier, Sealyham 192 | Airedale, Airedale terrier 193 | cairn, cairn terrier 194 | Australian terrier 195 | Dandie Dinmont, Dandie Dinmont terrier 196 | Boston bull, Boston terrier 197 | miniature schnauzer 198 | giant schnauzer 199 | standard schnauzer 200 | Scotch terrier, Scottish terrier, Scottie 201 | Tibetan terrier, chrysanthemum dog 202 | silky terrier, Sydney silky 203 | soft-coated wheaten terrier 204 | West Highland white terrier 205 | Lhasa, Lhasa apso 206 | flat-coated retriever 207 | curly-coated retriever 208 | golden retriever 209 | Labrador retriever 210 | Chesapeake Bay retriever 211 | German short-haired pointer 212 | vizsla, Hungarian pointer 213 | English setter 214 | Irish setter, red setter 215 | Gordon setter 216 | Brittany spaniel 217 | clumber, clumber spaniel 218 | English springer, English springer spaniel 219 | Welsh springer spaniel 220 | cocker spaniel, English cocker spaniel, cocker 221 | Sussex spaniel 222 | Irish water spaniel 223 | kuvasz 224 | schipperke 225 | groenendael 226 | malinois 227 | briard 228 | kelpie 229 | komondor 230 | Old English sheepdog, bobtail 231 | Shetland sheepdog, Shetland sheep dog, Shetland 232 | collie 233 | Border collie 234 | Bouvier des Flandres, Bouviers des Flandres 235 | Rottweiler 236 | German shepherd, German shepherd dog, German police dog, alsatian 237 | Doberman, Doberman pinscher 238 | miniature pinscher 239 | Greater Swiss Mountain dog 240 | Bernese mountain dog 241 | Appenzeller 242 | EntleBucher 243 | boxer 244 | bull mastiff 245 | Tibetan mastiff 246 | French bulldog 247 | Great Dane 248 | Saint Bernard, St Bernard 249 | Eskimo dog, husky 250 | malamute, malemute, Alaskan malamute 251 | Siberian husky 252 | dalmatian, coach dog, carriage dog 253 | affenpinscher, monkey pinscher, monkey dog 254 | basenji 255 | pug, pug-dog 256 | Leonberg 257 | Newfoundland, Newfoundland dog 258 | Great Pyrenees 259 | Samoyed, Samoyede 260 | Pomeranian 261 | chow, chow chow 262 | keeshond 263 | Brabancon griffon 264 | Pembroke, Pembroke Welsh corgi 265 | Cardigan, Cardigan Welsh corgi 266 | toy poodle 267 | miniature poodle 268 | standard poodle 269 | Mexican hairless 270 | timber wolf, grey wolf, gray wolf, Canis lupus 271 | white wolf, Arctic wolf, Canis lupus tundrarum 272 | red wolf, maned wolf, Canis rufus, Canis niger 273 | coyote, prairie wolf, brush wolf, Canis latrans 274 | dingo, warrigal, warragal, Canis dingo 275 | dhole, Cuon alpinus 276 | African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus 277 | hyena, hyaena 278 | red fox, Vulpes vulpes 279 | kit fox, Vulpes macrotis 280 | Arctic fox, white fox, Alopex lagopus 281 | grey fox, gray fox, Urocyon cinereoargenteus 282 | tabby, tabby cat 283 | tiger cat 284 | Persian cat 285 | Siamese cat, Siamese 286 | Egyptian cat 287 | cougar, puma, catamount, mountain lion, painter, panther, Felis concolor 288 | lynx, catamount 289 | leopard, Panthera pardus 290 | snow leopard, ounce, Panthera uncia 291 | jaguar, panther, Panthera onca, Felis onca 292 | lion, king of beasts, Panthera leo 293 | tiger, Panthera tigris 294 | cheetah, chetah, Acinonyx jubatus 295 | brown bear, bruin, Ursus arctos 296 | American black bear, black bear, Ursus americanus, Euarctos americanus 297 | ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus 298 | sloth bear, Melursus ursinus, Ursus ursinus 299 | mongoose 300 | meerkat, mierkat 301 | tiger beetle 302 | ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle 303 | ground beetle, carabid beetle 304 | long-horned beetle, longicorn, longicorn beetle 305 | leaf beetle, chrysomelid 306 | dung beetle 307 | rhinoceros beetle 308 | weevil 309 | fly 310 | bee 311 | ant, emmet, pismire 312 | grasshopper, hopper 313 | cricket 314 | walking stick, walkingstick, stick insect 315 | cockroach, roach 316 | mantis, mantid 317 | cicada, cicala 318 | leafhopper 319 | lacewing, lacewing fly 320 | dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk 321 | damselfly 322 | admiral 323 | ringlet, ringlet butterfly 324 | monarch, monarch butterfly, milkweed butterfly, Danaus plexippus 325 | cabbage butterfly 326 | sulphur butterfly, sulfur butterfly 327 | lycaenid, lycaenid butterfly 328 | starfish, sea star 329 | sea urchin 330 | sea cucumber, holothurian 331 | wood rabbit, cottontail, cottontail rabbit 332 | hare 333 | Angora, Angora rabbit 334 | hamster 335 | porcupine, hedgehog 336 | fox squirrel, eastern fox squirrel, Sciurus niger 337 | marmot 338 | beaver 339 | guinea pig, Cavia cobaya 340 | sorrel 341 | zebra 342 | hog, pig, grunter, squealer, Sus scrofa 343 | wild boar, boar, Sus scrofa 344 | warthog 345 | hippopotamus, hippo, river horse, Hippopotamus amphibius 346 | ox 347 | water buffalo, water ox, Asiatic buffalo, Bubalus bubalis 348 | bison 349 | ram, tup 350 | bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis 351 | ibex, Capra ibex 352 | hartebeest 353 | impala, Aepyceros melampus 354 | gazelle 355 | Arabian camel, dromedary, Camelus dromedarius 356 | llama 357 | weasel 358 | mink 359 | polecat, fitch, foulmart, foumart, Mustela putorius 360 | black-footed ferret, ferret, Mustela nigripes 361 | otter 362 | skunk, polecat, wood pussy 363 | badger 364 | armadillo 365 | three-toed sloth, ai, Bradypus tridactylus 366 | orangutan, orang, orangutang, Pongo pygmaeus 367 | gorilla, Gorilla gorilla 368 | chimpanzee, chimp, Pan troglodytes 369 | gibbon, Hylobates lar 370 | siamang, Hylobates syndactylus, Symphalangus syndactylus 371 | guenon, guenon monkey 372 | patas, hussar monkey, Erythrocebus patas 373 | baboon 374 | macaque 375 | langur 376 | colobus, colobus monkey 377 | proboscis monkey, Nasalis larvatus 378 | marmoset 379 | capuchin, ringtail, Cebus capucinus 380 | howler monkey, howler 381 | titi, titi monkey 382 | spider monkey, Ateles geoffroyi 383 | squirrel monkey, Saimiri sciureus 384 | Madagascar cat, ring-tailed lemur, Lemur catta 385 | indri, indris, Indri indri, Indri brevicaudatus 386 | Indian elephant, Elephas maximus 387 | African elephant, Loxodonta africana 388 | lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens 389 | giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca 390 | barracouta, snoek 391 | eel 392 | coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch 393 | rock beauty, Holocanthus tricolor 394 | anemone fish 395 | sturgeon 396 | gar, garfish, garpike, billfish, Lepisosteus osseus 397 | lionfish 398 | puffer, pufferfish, blowfish, globefish 399 | abacus 400 | abaya 401 | academic gown, academic robe, judge's robe 402 | accordion, piano accordion, squeeze box 403 | acoustic guitar 404 | aircraft carrier, carrier, flattop, attack aircraft carrier 405 | airliner 406 | airship, dirigible 407 | altar 408 | ambulance 409 | amphibian, amphibious vehicle 410 | analog clock 411 | apiary, bee house 412 | apron 413 | ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin 414 | assault rifle, assault gun 415 | backpack, back pack, knapsack, packsack, rucksack, haversack 416 | bakery, bakeshop, bakehouse 417 | balance beam, beam 418 | balloon 419 | ballpoint, ballpoint pen, ballpen, Biro 420 | Band Aid 421 | banjo 422 | bannister, banister, balustrade, balusters, handrail 423 | barbell 424 | barber chair 425 | barbershop 426 | barn 427 | barometer 428 | barrel, cask 429 | barrow, garden cart, lawn cart, wheelbarrow 430 | baseball 431 | basketball 432 | bassinet 433 | bassoon 434 | bathing cap, swimming cap 435 | bath towel 436 | bathtub, bathing tub, bath, tub 437 | beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon 438 | beacon, lighthouse, beacon light, pharos 439 | beaker 440 | bearskin, busby, shako 441 | beer bottle 442 | beer glass 443 | bell cote, bell cot 444 | bib 445 | bicycle-built-for-two, tandem bicycle, tandem 446 | bikini, two-piece 447 | binder, ring-binder 448 | binoculars, field glasses, opera glasses 449 | birdhouse 450 | boathouse 451 | bobsled, bobsleigh, bob 452 | bolo tie, bolo, bola tie, bola 453 | bonnet, poke bonnet 454 | bookcase 455 | bookshop, bookstore, bookstall 456 | bottlecap 457 | bow 458 | bow tie, bow-tie, bowtie 459 | brass, memorial tablet, plaque 460 | brassiere, bra, bandeau 461 | breakwater, groin, groyne, mole, bulwark, seawall, jetty 462 | breastplate, aegis, egis 463 | broom 464 | bucket, pail 465 | buckle 466 | bulletproof vest 467 | bullet train, bullet 468 | butcher shop, meat market 469 | cab, hack, taxi, taxicab 470 | caldron, cauldron 471 | candle, taper, wax light 472 | cannon 473 | canoe 474 | can opener, tin opener 475 | cardigan 476 | car mirror 477 | carousel, carrousel, merry-go-round, roundabout, whirligig 478 | carpenter's kit, tool kit 479 | carton 480 | car wheel 481 | cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM 482 | cassette 483 | cassette player 484 | castle 485 | catamaran 486 | CD player 487 | cello, violoncello 488 | cellular telephone, cellular phone, cellphone, cell, mobile phone 489 | chain 490 | chainlink fence 491 | chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour 492 | chain saw, chainsaw 493 | chest 494 | chiffonier, commode 495 | chime, bell, gong 496 | china cabinet, china closet 497 | Christmas stocking 498 | church, church building 499 | cinema, movie theater, movie theatre, movie house, picture palace 500 | cleaver, meat cleaver, chopper 501 | cliff dwelling 502 | cloak 503 | clog, geta, patten, sabot 504 | cocktail shaker 505 | coffee mug 506 | coffeepot 507 | coil, spiral, volute, whorl, helix 508 | combination lock 509 | computer keyboard, keypad 510 | confectionery, confectionary, candy store 511 | container ship, containership, container vessel 512 | convertible 513 | corkscrew, bottle screw 514 | cornet, horn, trumpet, trump 515 | cowboy boot 516 | cowboy hat, ten-gallon hat 517 | cradle 518 | crane, machine 519 | crash helmet 520 | crate 521 | crib, cot 522 | Crock Pot 523 | croquet ball 524 | crutch 525 | cuirass 526 | dam, dike, dyke 527 | desk 528 | desktop computer 529 | dial telephone, dial phone 530 | diaper, nappy, napkin 531 | digital clock 532 | digital watch 533 | dining table, board 534 | dishrag, dishcloth 535 | dishwasher, dish washer, dishwashing machine 536 | disk brake, disc brake 537 | dock, dockage, docking facility 538 | dogsled, dog sled, dog sleigh 539 | dome 540 | doormat, welcome mat 541 | drilling platform, offshore rig 542 | drum, membranophone, tympan 543 | drumstick 544 | dumbbell 545 | Dutch oven 546 | electric fan, blower 547 | electric guitar 548 | electric locomotive 549 | entertainment center 550 | envelope 551 | espresso maker 552 | face powder 553 | feather boa, boa 554 | file, file cabinet, filing cabinet 555 | fireboat 556 | fire engine, fire truck 557 | fire screen, fireguard 558 | flagpole, flagstaff 559 | flute, transverse flute 560 | folding chair 561 | football helmet 562 | forklift 563 | fountain 564 | fountain pen 565 | four-poster 566 | freight car 567 | French horn, horn 568 | frying pan, frypan, skillet 569 | fur coat 570 | garbage truck, dustcart 571 | gasmask, respirator, gas helmet 572 | gas pump, gasoline pump, petrol pump, island dispenser 573 | goblet 574 | go-kart 575 | golf ball 576 | golfcart, golf cart 577 | gondola 578 | gong, tam-tam 579 | gown 580 | grand piano, grand 581 | greenhouse, nursery, glasshouse 582 | grille, radiator grille 583 | grocery store, grocery, food market, market 584 | guillotine 585 | hair slide 586 | hair spray 587 | half track 588 | hammer 589 | hamper 590 | hand blower, blow dryer, blow drier, hair dryer, hair drier 591 | hand-held computer, hand-held microcomputer 592 | handkerchief, hankie, hanky, hankey 593 | hard disc, hard disk, fixed disk 594 | harmonica, mouth organ, harp, mouth harp 595 | harp 596 | harvester, reaper 597 | hatchet 598 | holster 599 | home theater, home theatre 600 | honeycomb 601 | hook, claw 602 | hoopskirt, crinoline 603 | horizontal bar, high bar 604 | horse cart, horse-cart 605 | hourglass 606 | iPod 607 | iron, smoothing iron 608 | jack-o'-lantern 609 | jean, blue jean, denim 610 | jeep, landrover 611 | jersey, T-shirt, tee shirt 612 | jigsaw puzzle 613 | jinrikisha, ricksha, rickshaw 614 | joystick 615 | kimono 616 | knee pad 617 | knot 618 | lab coat, laboratory coat 619 | ladle 620 | lampshade, lamp shade 621 | laptop, laptop computer 622 | lawn mower, mower 623 | lens cap, lens cover 624 | letter opener, paper knife, paperknife 625 | library 626 | lifeboat 627 | lighter, light, igniter, ignitor 628 | limousine, limo 629 | liner, ocean liner 630 | lipstick, lip rouge 631 | Loafer 632 | lotion 633 | loudspeaker, speaker, speaker unit, loudspeaker system, speaker system 634 | loupe, jeweler's loupe 635 | lumbermill, sawmill 636 | magnetic compass 637 | mailbag, postbag 638 | mailbox, letter box 639 | maillot 640 | maillot, tank suit 641 | manhole cover 642 | maraca 643 | marimba, xylophone 644 | mask 645 | matchstick 646 | maypole 647 | maze, labyrinth 648 | measuring cup 649 | medicine chest, medicine cabinet 650 | megalith, megalithic structure 651 | microphone, mike 652 | microwave, microwave oven 653 | military uniform 654 | milk can 655 | minibus 656 | miniskirt, mini 657 | minivan 658 | missile 659 | mitten 660 | mixing bowl 661 | mobile home, manufactured home 662 | Model T 663 | modem 664 | monastery 665 | monitor 666 | moped 667 | mortar 668 | mortarboard 669 | mosque 670 | mosquito net 671 | motor scooter, scooter 672 | mountain bike, all-terrain bike, off-roader 673 | mountain tent 674 | mouse, computer mouse 675 | mousetrap 676 | moving van 677 | muzzle 678 | nail 679 | neck brace 680 | necklace 681 | nipple 682 | notebook, notebook computer 683 | obelisk 684 | oboe, hautboy, hautbois 685 | ocarina, sweet potato 686 | odometer, hodometer, mileometer, milometer 687 | oil filter 688 | organ, pipe organ 689 | oscilloscope, scope, cathode-ray oscilloscope, CRO 690 | overskirt 691 | oxcart 692 | oxygen mask 693 | packet 694 | paddle, boat paddle 695 | paddlewheel, paddle wheel 696 | padlock 697 | paintbrush 698 | pajama, pyjama, pj's, jammies 699 | palace 700 | panpipe, pandean pipe, syrinx 701 | paper towel 702 | parachute, chute 703 | parallel bars, bars 704 | park bench 705 | parking meter 706 | passenger car, coach, carriage 707 | patio, terrace 708 | pay-phone, pay-station 709 | pedestal, plinth, footstall 710 | pencil box, pencil case 711 | pencil sharpener 712 | perfume, essence 713 | Petri dish 714 | photocopier 715 | pick, plectrum, plectron 716 | pickelhaube 717 | picket fence, paling 718 | pickup, pickup truck 719 | pier 720 | piggy bank, penny bank 721 | pill bottle 722 | pillow 723 | ping-pong ball 724 | pinwheel 725 | pirate, pirate ship 726 | pitcher, ewer 727 | plane, carpenter's plane, woodworking plane 728 | planetarium 729 | plastic bag 730 | plate rack 731 | plow, plough 732 | plunger, plumber's helper 733 | Polaroid camera, Polaroid Land camera 734 | pole 735 | police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria 736 | poncho 737 | pool table, billiard table, snooker table 738 | pop bottle, soda bottle 739 | pot, flowerpot 740 | potter's wheel 741 | power drill 742 | prayer rug, prayer mat 743 | printer 744 | prison, prison house 745 | projectile, missile 746 | projector 747 | puck, hockey puck 748 | punching bag, punch bag, punching ball, punchball 749 | purse 750 | quill, quill pen 751 | quilt, comforter, comfort, puff 752 | racer, race car, racing car 753 | racket, racquet 754 | radiator 755 | radio, wireless 756 | radio telescope, radio reflector 757 | rain barrel 758 | recreational vehicle, RV, R.V. 759 | reel 760 | reflex camera 761 | refrigerator, icebox 762 | remote control, remote 763 | restaurant, eating house, eating place, eatery 764 | revolver, six-gun, six-shooter 765 | rifle 766 | rocking chair, rocker 767 | rotisserie 768 | rubber eraser, rubber, pencil eraser 769 | rugby ball 770 | rule, ruler 771 | running shoe 772 | safe 773 | safety pin 774 | saltshaker, salt shaker 775 | sandal 776 | sarong 777 | sax, saxophone 778 | scabbard 779 | scale, weighing machine 780 | school bus 781 | schooner 782 | scoreboard 783 | screen, CRT screen 784 | screw 785 | screwdriver 786 | seat belt, seatbelt 787 | sewing machine 788 | shield, buckler 789 | shoe shop, shoe-shop, shoe store 790 | shoji 791 | shopping basket 792 | shopping cart 793 | shovel 794 | shower cap 795 | shower curtain 796 | ski 797 | ski mask 798 | sleeping bag 799 | slide rule, slipstick 800 | sliding door 801 | slot, one-armed bandit 802 | snorkel 803 | snowmobile 804 | snowplow, snowplough 805 | soap dispenser 806 | soccer ball 807 | sock 808 | solar dish, solar collector, solar furnace 809 | sombrero 810 | soup bowl 811 | space bar 812 | space heater 813 | space shuttle 814 | spatula 815 | speedboat 816 | spider web, spider's web 817 | spindle 818 | sports car, sport car 819 | spotlight, spot 820 | stage 821 | steam locomotive 822 | steel arch bridge 823 | steel drum 824 | stethoscope 825 | stole 826 | stone wall 827 | stopwatch, stop watch 828 | stove 829 | strainer 830 | streetcar, tram, tramcar, trolley, trolley car 831 | stretcher 832 | studio couch, day bed 833 | stupa, tope 834 | submarine, pigboat, sub, U-boat 835 | suit, suit of clothes 836 | sundial 837 | sunglass 838 | sunglasses, dark glasses, shades 839 | sunscreen, sunblock, sun blocker 840 | suspension bridge 841 | swab, swob, mop 842 | sweatshirt 843 | swimming trunks, bathing trunks 844 | swing 845 | switch, electric switch, electrical switch 846 | syringe 847 | table lamp 848 | tank, army tank, armored combat vehicle, armoured combat vehicle 849 | tape player 850 | teapot 851 | teddy, teddy bear 852 | television, television system 853 | tennis ball 854 | thatch, thatched roof 855 | theater curtain, theatre curtain 856 | thimble 857 | thresher, thrasher, threshing machine 858 | throne 859 | tile roof 860 | toaster 861 | tobacco shop, tobacconist shop, tobacconist 862 | toilet seat 863 | torch 864 | totem pole 865 | tow truck, tow car, wrecker 866 | toyshop 867 | tractor 868 | trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi 869 | tray 870 | trench coat 871 | tricycle, trike, velocipede 872 | trimaran 873 | tripod 874 | triumphal arch 875 | trolleybus, trolley coach, trackless trolley 876 | trombone 877 | tub, vat 878 | turnstile 879 | typewriter keyboard 880 | umbrella 881 | unicycle, monocycle 882 | upright, upright piano 883 | vacuum, vacuum cleaner 884 | vase 885 | vault 886 | velvet 887 | vending machine 888 | vestment 889 | viaduct 890 | violin, fiddle 891 | volleyball 892 | waffle iron 893 | wall clock 894 | wallet, billfold, notecase, pocketbook 895 | wardrobe, closet, press 896 | warplane, military plane 897 | washbasin, handbasin, washbowl, lavabo, wash-hand basin 898 | washer, automatic washer, washing machine 899 | water bottle 900 | water jug 901 | water tower 902 | whiskey jug 903 | whistle 904 | wig 905 | window screen 906 | window shade 907 | Windsor tie 908 | wine bottle 909 | wing 910 | wok 911 | wooden spoon 912 | wool, woolen, woollen 913 | worm fence, snake fence, snake-rail fence, Virginia fence 914 | wreck 915 | yawl 916 | yurt 917 | web site, website, internet site, site 918 | comic book 919 | crossword puzzle, crossword 920 | street sign 921 | traffic light, traffic signal, stoplight 922 | book jacket, dust cover, dust jacket, dust wrapper 923 | menu 924 | plate 925 | guacamole 926 | consomme 927 | hot pot, hotpot 928 | trifle 929 | ice cream, icecream 930 | ice lolly, lolly, lollipop, popsicle 931 | French loaf 932 | bagel, beigel 933 | pretzel 934 | cheeseburger 935 | hotdog, hot dog, red hot 936 | mashed potato 937 | head cabbage 938 | broccoli 939 | cauliflower 940 | zucchini, courgette 941 | spaghetti squash 942 | acorn squash 943 | butternut squash 944 | cucumber, cuke 945 | artichoke, globe artichoke 946 | bell pepper 947 | cardoon 948 | mushroom 949 | Granny Smith 950 | strawberry 951 | orange 952 | lemon 953 | fig 954 | pineapple, ananas 955 | banana 956 | jackfruit, jak, jack 957 | custard apple 958 | pomegranate 959 | hay 960 | carbonara 961 | chocolate sauce, chocolate syrup 962 | dough 963 | meat loaf, meatloaf 964 | pizza, pizza pie 965 | potpie 966 | burrito 967 | red wine 968 | espresso 969 | cup 970 | eggnog 971 | alp 972 | bubble 973 | cliff, drop, drop-off 974 | coral reef 975 | geyser 976 | lakeside, lakeshore 977 | promontory, headland, head, foreland 978 | sandbar, sand bar 979 | seashore, coast, seacoast, sea-coast 980 | valley, vale 981 | volcano 982 | ballplayer, baseball player 983 | groom, bridegroom 984 | scuba diver 985 | rapeseed 986 | daisy 987 | yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum 988 | corn 989 | acorn 990 | hip, rose hip, rosehip 991 | buckeye, horse chestnut, conker 992 | coral fungus 993 | agaric 994 | gyromitra 995 | stinkhorn, carrion fungus 996 | earthstar 997 | hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa 998 | bolete 999 | ear, spike, capitulum 1000 | toilet tissue, toilet paper, bathroom tissue -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pandas as pd 4 | from torchvision import datasets, transforms, models 5 | 6 | DATASET_ROOTS = {"imagenet_val": "YOUR_PATH/ImageNet_val/", 7 | "broden": "data/broden1_224/images/"} 8 | 9 | 10 | def get_target_model(target_name, device): 11 | """ 12 | returns target model in eval mode and its preprocess function 13 | target_name: supported options - {resnet18_places, resnet18, resnet34, resnet50, resnet101, resnet152} 14 | except for resnet18_places this will return a model trained on ImageNet from torchvision 15 | 16 | To Dissect a different model implement its loading and preprocessing function here 17 | """ 18 | if target_name == 'resnet18_places': 19 | target_model = models.resnet18(num_classes=365).to(device) 20 | state_dict = torch.load('data/resnet18_places365.pth.tar')['state_dict'] 21 | new_state_dict = {} 22 | for key in state_dict: 23 | if key.startswith('module.'): 24 | new_state_dict[key[7:]] = state_dict[key] 25 | target_model.load_state_dict(new_state_dict) 26 | target_model.eval() 27 | preprocess = get_resnet_imagenet_preprocess() 28 | elif "vit_b" in target_name: 29 | target_name_cap = target_name.replace("vit_b", "ViT_B") 30 | weights = eval("models.{}_Weights.IMAGENET1K_V1".format(target_name_cap)) 31 | preprocess = weights.transforms() 32 | target_model = eval("models.{}(weights=weights).to(device)".format(target_name)) 33 | elif "resnet" in target_name: 34 | target_name_cap = target_name.replace("resnet", "ResNet") 35 | weights = eval("models.{}_Weights.IMAGENET1K_V1".format(target_name_cap)) 36 | preprocess = weights.transforms() 37 | target_model = eval("models.{}(weights=weights).to(device)".format(target_name)) 38 | 39 | target_model.eval() 40 | return target_model, preprocess 41 | 42 | def get_resnet_imagenet_preprocess(): 43 | target_mean = [0.485, 0.456, 0.406] 44 | target_std = [0.229, 0.224, 0.225] 45 | preprocess = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), 46 | transforms.ToTensor(), transforms.Normalize(mean=target_mean, std=target_std)]) 47 | return preprocess 48 | 49 | def get_data(dataset_name, preprocess=None): 50 | if dataset_name == "cifar100_train": 51 | data = datasets.CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=True, 52 | transform=preprocess) 53 | 54 | elif dataset_name == "cifar100_val": 55 | data = datasets.CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False, 56 | transform=preprocess) 57 | 58 | elif dataset_name in DATASET_ROOTS.keys(): 59 | data = datasets.ImageFolder(DATASET_ROOTS[dataset_name], preprocess) 60 | 61 | elif dataset_name == "imagenet_broden": 62 | data = torch.utils.data.ConcatDataset([datasets.ImageFolder(DATASET_ROOTS["imagenet_val"], preprocess), 63 | datasets.ImageFolder(DATASET_ROOTS["broden"], preprocess)]) 64 | 65 | return data 66 | 67 | 68 | def get_places_id_to_broden_label(): 69 | with open("data/categories_places365.txt", "r") as f: 70 | places365_classes = f.read().split("\n") 71 | 72 | broden_scenes = pd.read_csv('data/broden1_224/c_scene.csv') 73 | id_to_broden_label = {} 74 | for i, cls in enumerate(places365_classes): 75 | name = cls[3:].split(' ')[0] 76 | name = name.replace('/', '-') 77 | 78 | found = (name+'-s' in broden_scenes['name'].values) 79 | 80 | if found: 81 | id_to_broden_label[i] = name.replace('-', '/')+'-s' 82 | if not found: 83 | id_to_broden_label[i] = None 84 | return id_to_broden_label 85 | 86 | def get_cifar_superclass(): 87 | cifar100_has_superclass = [i for i in range(7)] 88 | cifar100_has_superclass.extend([i for i in range(33, 69)]) 89 | cifar100_has_superclass.append(70) 90 | cifar100_has_superclass.extend([i for i in range(72, 78)]) 91 | cifar100_has_superclass.extend([101, 104, 110, 111, 113, 114]) 92 | cifar100_has_superclass.extend([i for i in range(118, 126)]) 93 | cifar100_has_superclass.extend([i for i in range(147, 151)]) 94 | cifar100_has_superclass.extend([i for i in range(269, 281)]) 95 | cifar100_has_superclass.extend([i for i in range(286, 298)]) 96 | cifar100_has_superclass.extend([i for i in range(300, 308)]) 97 | cifar100_has_superclass.extend([309, 314]) 98 | cifar100_has_superclass.extend([i for i in range(321, 327)]) 99 | cifar100_has_superclass.extend([i for i in range(330, 339)]) 100 | cifar100_has_superclass.extend([345, 354, 355, 360, 361]) 101 | cifar100_has_superclass.extend([i for i in range(385, 398)]) 102 | cifar100_has_superclass.extend([409, 438, 440, 441, 455, 463, 466, 483, 487]) 103 | cifar100_doesnt_have_superclass = [i for i in range(500) if (i not in cifar100_has_superclass)] 104 | 105 | return cifar100_has_superclass, cifar100_doesnt_have_superclass -------------------------------------------------------------------------------- /describe_neurons.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import datetime 4 | import json 5 | import pandas as pd 6 | import torch 7 | 8 | import utils 9 | import similarity 10 | 11 | 12 | parser = argparse.ArgumentParser(description='CLIP-Dissect') 13 | 14 | parser.add_argument("--clip_model", type=str, default="ViT-B/16", 15 | choices=['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14'], 16 | help="Which CLIP-model to use") 17 | parser.add_argument("--target_model", type=str, default="resnet50", 18 | help=""""Which model to dissect, supported options are pretrained imagenet models from 19 | torchvision and resnet18_places""") 20 | parser.add_argument("--target_layers", type=str, default="conv1,layer1,layer2,layer3,layer4", 21 | help="""Which layer neurons to describe. String list of layer names to describe, separated by comma(no spaces). 22 | Follows the naming scheme of the Pytorch module used""") 23 | parser.add_argument("--d_probe", type=str, default="broden", 24 | choices = ["imagenet_broden", "cifar100_val", "imagenet_val", "broden", "imagenet_broden"]) 25 | parser.add_argument("--concept_set", type=str, default="data/20k.txt", help="Path to txt file containing concept set") 26 | parser.add_argument("--batch_size", type=int, default=200, help="Batch size when running CLIP/target model") 27 | parser.add_argument("--device", type=str, default="cuda", help="whether to use GPU/which gpu") 28 | parser.add_argument("--activation_dir", type=str, default="saved_activations", help="where to save activations") 29 | parser.add_argument("--result_dir", type=str, default="results", help="where to save results") 30 | parser.add_argument("--pool_mode", type=str, default="avg", help="Aggregation function for channels, max or avg") 31 | parser.add_argument("--similarity_fn", type=str, default="soft_wpmi", choices=["soft_wpmi", "wpmi", "rank_reorder", 32 | "cos_similarity", "cos_similarity_cubed"]) 33 | 34 | parser.parse_args() 35 | 36 | if __name__ == '__main__': 37 | args = parser.parse_args() 38 | args.target_layers = args.target_layers.split(",") 39 | 40 | similarity_fn = eval("similarity.{}".format(args.similarity_fn)) 41 | 42 | utils.save_activations(clip_name = args.clip_model, target_name = args.target_model, 43 | target_layers = args.target_layers, d_probe = args.d_probe, 44 | concept_set = args.concept_set, batch_size = args.batch_size, 45 | device = args.device, pool_mode=args.pool_mode, 46 | save_dir = args.activation_dir) 47 | 48 | outputs = {"layer":[], "unit":[], "description":[], "similarity":[]} 49 | with open(args.concept_set, 'r') as f: 50 | words = (f.read()).split('\n') 51 | 52 | for target_layer in args.target_layers: 53 | save_names = utils.get_save_names(clip_name = args.clip_model, target_name = args.target_model, 54 | target_layer = target_layer, d_probe = args.d_probe, 55 | concept_set = args.concept_set, pool_mode = args.pool_mode, 56 | save_dir = args.activation_dir) 57 | target_save_name, clip_save_name, text_save_name = save_names 58 | 59 | similarities = utils.get_similarity_from_activations( 60 | target_save_name, clip_save_name, text_save_name, similarity_fn, return_target_feats=False, device=args.device 61 | ) 62 | vals, ids = torch.max(similarities, dim=1) 63 | 64 | del similarities 65 | torch.cuda.empty_cache() 66 | 67 | descriptions = [words[int(idx)] for idx in ids] 68 | 69 | outputs["unit"].extend([i for i in range(len(vals))]) 70 | outputs["layer"].extend([target_layer]*len(vals)) 71 | outputs["description"].extend(descriptions) 72 | outputs["similarity"].extend(vals.cpu().numpy()) 73 | 74 | df = pd.DataFrame(outputs) 75 | if not os.path.exists(args.result_dir): 76 | os.mkdir(args.result_dir) 77 | save_path = "{}/{}_{}".format(args.result_dir, args.target_model, datetime.datetime.now().strftime("%y_%m_%d_%H_%M")) 78 | os.mkdir(save_path) 79 | df.to_csv(os.path.join(save_path,"descriptions.csv"), index=False) 80 | with open(os.path.join(save_path, "args.txt"), 'w') as f: 81 | json.dump(args.__dict__, f, indent=2) -------------------------------------------------------------------------------- /dlbroden.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | # Start from parent directory of script 5 | #cd "$(dirname "$(dirname "$(readlink -f "$0")")")" 6 | 7 | # Download broden1_224 8 | if [ ! -f data/broden1_224/index.csv ] 9 | then 10 | 11 | echo "Downloading broden1_224" 12 | mkdir -p data 13 | pushd data 14 | wget --progress=bar \ 15 | http://netdissect.csail.mit.edu/data/broden1_224.zip \ 16 | -O broden1_224.zip 17 | unzip -q broden1_224.zip 18 | rm broden1_224.zip 19 | #remove unneeded files 20 | pushd broden1_224 21 | #rm *.csv 22 | rm *.txt 23 | rm images/ade20k/*object.png 24 | rm images/ade20k/*color.png 25 | rm images/ade20k/*.png 26 | rm images/dtd/*.png 27 | rm images/opensurfaces/*color.png 28 | rm images/opensurfaces/*.png 29 | rm images/pascal/*.png 30 | 31 | popd 32 | 33 | fi -------------------------------------------------------------------------------- /dlzoo_example.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | echo "Download resnet18 trained on Places365" 5 | echo "Downloading $MODEL" 6 | mkdir -p data 7 | pushd data 8 | wget --progress=bar \ 9 | http://places2.csail.mit.edu/models_places365/resnet18_places365.pth.tar 10 | popd 11 | 12 | echo "done" -------------------------------------------------------------------------------- /experiments/appendix_a6_predict_class_from_desc.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "8e52ee82-e7e5-4bb9-a1a8-efc6e009c98f", 6 | "metadata": {}, 7 | "source": [ 8 | "## Predicting input class from descriptions of higly activating images" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "e29c7be9-8a6a-44dd-89cb-877dffe8d4db", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import os\n", 19 | "#virtually move to parent directory\n", 20 | "os.chdir(\"..\")\n", 21 | "\n", 22 | "import math\n", 23 | "import torch\n", 24 | "import pandas as pd\n", 25 | "\n", 26 | "import matplotlib\n", 27 | "from matplotlib import pyplot as plt\n", 28 | "from sentence_transformers import SentenceTransformer\n", 29 | "\n", 30 | "import clip\n", 31 | "import utils\n", 32 | "import data_utils\n", 33 | "import similarity" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "id": "0e693e37-a3ae-4aad-a03d-492e2a08eb5d", 39 | "metadata": {}, 40 | "source": [ 41 | "## Settings" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 2, 47 | "id": "1724590a-2333-4daa-9948-6be1dfc60c25", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "target_name = 'resnet50'\n", 52 | "target_layer = 'layer4'\n", 53 | "\n", 54 | "clip_name = 'ViT-B/16'\n", 55 | "d_probe = 'imagenet_broden'\n", 56 | "concept_set = 'data/20k.txt'\n", 57 | "batch_size = 200\n", 58 | "device = 'cuda'\n", 59 | "pool_mode = 'avg'\n", 60 | "\n", 61 | "save_dir = 'saved_activations'\n", 62 | "similarity_fn = similarity.soft_wpmi" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "id": "89f1eb6d-87cc-4430-b8e4-cd48d2643c7d", 68 | "metadata": {}, 69 | "source": [ 70 | "## Run CLIP-Dissect" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 3, 76 | "id": "4b6a1e91-5363-43a3-8f0b-4a034515923e", 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "utils.save_activations(clip_name = clip_name, target_name = target_name, target_layers = [target_layer], \n", 81 | " d_probe = d_probe, concept_set = concept_set, batch_size = batch_size, \n", 82 | " device = device, pool_mode=pool_mode, save_dir = save_dir)" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 4, 88 | "id": "edd0e205-0b81-4d59-80ac-16b321e56949", 89 | "metadata": {}, 90 | "outputs": [ 91 | { 92 | "name": "stderr", 93 | "output_type": "stream", 94 | "text": [ 95 | "100%|██████████| 2048/2048 [00:14<00:00, 145.32it/s]\n" 96 | ] 97 | }, 98 | { 99 | "name": "stdout", 100 | "output_type": "stream", 101 | "text": [ 102 | "torch.Size([2048, 20000])\n" 103 | ] 104 | } 105 | ], 106 | "source": [ 107 | "save_names = utils.get_save_names(clip_name = clip_name, target_name = target_name,\n", 108 | " target_layer = target_layer, d_probe = d_probe,\n", 109 | " concept_set = concept_set, pool_mode=pool_mode,\n", 110 | " save_dir = save_dir)\n", 111 | "\n", 112 | "target_save_name, clip_save_name, text_save_name = save_names\n", 113 | "\n", 114 | "similarities, target_feats = utils.get_similarity_from_activations(target_save_name, clip_save_name, \n", 115 | " text_save_name, similarity_fn, device=device)\n", 116 | "\n", 117 | "with open(concept_set, 'r') as f: \n", 118 | " words = (f.read()).split('\\n')\n", 119 | " \n", 120 | "vals, ids = torch.max(similarities, dim=1)\n", 121 | "descriptions = {\"CLIP-Dissect\":[words[int(idx)] for idx in ids]}" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "id": "ac805683-52a8-4c01-8ca6-0fabbaf19434", 127 | "metadata": {}, 128 | "source": [ 129 | "## Calculate standard accuracy" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 5, 135 | "id": "ede1fcd1-8f41-4f99-926b-1fdd92e4f7e4", 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "#only use imagenet val for this part\n", 140 | "pil_data = data_utils.get_data('imagenet_val')\n", 141 | "target_model, target_preprocess = data_utils.get_target_model(target_name, device)\n", 142 | "\n", 143 | "save_names = utils.get_save_names(clip_name = clip_name, target_name = target_name,\n", 144 | " target_layer = target_layer, d_probe = 'imagenet_val',\n", 145 | " concept_set = concept_set, pool_mode=pool_mode,\n", 146 | " save_dir = save_dir)\n", 147 | "target_save_name, clip_save_name, text_save_name = save_names\n", 148 | "\n", 149 | "dataset = data_utils.get_data('imagenet_val', target_preprocess)\n", 150 | "utils.save_target_activations(target_model, dataset, target_save_name, target_layers = [target_layer], batch_size = batch_size,\n", 151 | " device = device, pool_mode = pool_mode)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 6, 157 | "id": "d6792ebc-46cf-4f13-9a00-b62da0b1292e", 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "activations = torch.load(target_save_name, map_location='cpu')\n", 162 | "W_f = target_model.fc.weight\n", 163 | "b_f = target_model.fc.bias\n", 164 | "\n", 165 | "targets = torch.LongTensor(pil_data.targets).to(device)\n", 166 | "with open('data/imagenet_labels.txt', 'r') as f:\n", 167 | " classes = f.read().split('\\n')" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 7, 173 | "id": "eac48766-e06e-4b4b-bfed-d91452739e21", 174 | "metadata": {}, 175 | "outputs": [ 176 | { 177 | "name": "stdout", 178 | "output_type": "stream", 179 | "text": [ 180 | "Standard Accuracy:76.13%\n" 181 | ] 182 | } 183 | ], 184 | "source": [ 185 | "correct = 0\n", 186 | "with torch.no_grad():\n", 187 | " for i in range(math.ceil(len(targets)/batch_size)):\n", 188 | " targ = targets[i*batch_size:(i+1)*batch_size]\n", 189 | " act = activations[i*batch_size:(i+1)*batch_size].to(device)\n", 190 | " out = act@W_f.T + b_f\n", 191 | " pred = torch.max(out, dim=1)[1]\n", 192 | " correct += torch.sum(pred==targ)\n", 193 | "print(\"Standard Accuracy:{:.2f}%\".format(correct/len(targets)*100))" 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "id": "0b02d9dc-14e9-4868-994e-bcb777f692a9", 199 | "metadata": {}, 200 | "source": [ 201 | "## Measure how often most contributing neuron description matches target class" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 8, 207 | "id": "e8121976-9234-414c-957d-b97ef41adc04", 208 | "metadata": {}, 209 | "outputs": [ 210 | { 211 | "name": "stdout", 212 | "output_type": "stream", 213 | "text": [ 214 | "torch.Size([1000, 512]) torch.Size([1000, 768])\n" 215 | ] 216 | } 217 | ], 218 | "source": [ 219 | "mpnet_model = SentenceTransformer('all-mpnet-base-v2')\n", 220 | "clip_model, _ = clip.load(clip_name, device=device)\n", 221 | "\n", 222 | "with torch.no_grad():\n", 223 | " tokens = clip.tokenize(classes).to(device)\n", 224 | " class_clip = clip_model.encode_text(tokens)\n", 225 | " class_clip /= class_clip.norm(dim=-1, keepdim=True)\n", 226 | "\n", 227 | "class_mpnet = mpnet_model.encode(classes)\n", 228 | "class_mpnet = torch.tensor(class_mpnet).to(device)\n", 229 | "print(class_clip.shape, class_mpnet.shape)" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 9, 235 | "id": "4cdedb99-435f-45f1-a7ff-12489988123a", 236 | "metadata": {}, 237 | "outputs": [], 238 | "source": [ 239 | "name_conversion = {'resnet50':'resnet50_imagenet', 'resnet18_places':'resnet18_places365'}\n", 240 | "\n", 241 | "netdissect_res = pd.read_csv('data/NetDissect_results/{}_{}.csv'.format(name_conversion[target_name],\n", 242 | " target_layer))\n", 243 | "descriptions[\"Network Dissection\"] = netdissect_res['label'].values\n", 244 | "\n", 245 | "milan_base = pd.read_csv('data/MILAN_results/m_base_{}.csv'.format(name_conversion[target_name]))\n", 246 | "milan_base = milan_base[milan_base['layer']==target_layer]\n", 247 | "milan_base = milan_base.sort_values(by=['unit'])\n", 248 | "descriptions[\"MILAN base\"] = list(milan_base['description'])" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 10, 254 | "id": "19ccf797-3ff5-44c1-8982-c1a6f1e8531d", 255 | "metadata": {}, 256 | "outputs": [ 257 | { 258 | "name": "stdout", 259 | "output_type": "stream", 260 | "text": [ 261 | "CLIP-Dissect\n", 262 | "Same as gt:9.87%\n", 263 | "Same as pred:11.83% \n", 264 | "\n", 265 | "Network Dissection\n", 266 | "Same as gt:3.04%\n", 267 | "Same as pred:3.68% \n", 268 | "\n", 269 | "MILAN base\n", 270 | "Same as gt:2.30%\n", 271 | "Same as pred:2.63% \n", 272 | "\n" 273 | ] 274 | } 275 | ], 276 | "source": [ 277 | "for key in descriptions:\n", 278 | " print(key)\n", 279 | " with torch.no_grad():\n", 280 | " tokens = clip.tokenize(descriptions[key]).to(device)\n", 281 | " desc_clip = clip_model.encode_text(tokens)\n", 282 | " desc_clip /= desc_clip.norm(dim=-1, keepdim=True)\n", 283 | "\n", 284 | " desc_mpnet = mpnet_model.encode(descriptions[key])\n", 285 | " desc_mpnet = torch.tensor(desc_mpnet).to(device)\n", 286 | "\n", 287 | " correct_gt = 0\n", 288 | " correct_pred = 0\n", 289 | "\n", 290 | " with torch.no_grad():\n", 291 | " for i in range(math.ceil(len(targets)/batch_size)):\n", 292 | " targ = targets[i*batch_size:(i+1)*batch_size]\n", 293 | " act = activations[i*batch_size:(i+1)*batch_size].to(device)\n", 294 | "\n", 295 | " out = act@W_f.T + b_f\n", 296 | " pred = torch.max(out, dim=1)[1]\n", 297 | "\n", 298 | " contrib = W_f[pred]*act\n", 299 | " max_contrib = torch.max(contrib, dim=1)[1]\n", 300 | "\n", 301 | " clip_cos = desc_clip[max_contrib]@class_clip.T\n", 302 | " mpnet_cos = desc_mpnet[max_contrib]@class_mpnet.T\n", 303 | " \n", 304 | " cos = 3*clip_cos.detach() + mpnet_cos\n", 305 | " most_sim = torch.max(cos, dim=1)[1]\n", 306 | " \n", 307 | " correct_gt += torch.sum(most_sim==targ)\n", 308 | " correct_pred += torch.sum(most_sim==pred)\n", 309 | "\n", 310 | " print(\"Same as gt:{:.2f}%\".format(100*correct_gt/len(targets)))\n", 311 | " print(\"Same as pred:{:.2f}% \\n\".format(100*correct_pred/len(targets)))" 312 | ] 313 | } 314 | ], 315 | "metadata": { 316 | "kernelspec": { 317 | "display_name": "Python [conda env:jovyan-clip_dissect]", 318 | "language": "python", 319 | "name": "conda-env-jovyan-clip_dissect-py" 320 | }, 321 | "language_info": { 322 | "codemirror_mode": { 323 | "name": "ipython", 324 | "version": 3 325 | }, 326 | "file_extension": ".py", 327 | "mimetype": "text/x-python", 328 | "name": "python", 329 | "nbconvert_exporter": "python", 330 | "pygments_lexer": "ipython3", 331 | "version": "3.10.9" 332 | } 333 | }, 334 | "nbformat": 4, 335 | "nbformat_minor": 5 336 | } 337 | -------------------------------------------------------------------------------- /experiments/table1_quantitative_rn50.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "5c8e4da1-1970-47dc-bab4-35e3b2d9f2e3", 6 | "metadata": {}, 7 | "source": [ 8 | "## Performance on describing final layer neurons of ResNet-50 (ImageNet)" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "7926c513-eeff-48e1-b72d-572a1a313c26", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import os\n", 19 | "#virtually move to parent directory\n", 20 | "os.chdir(\"..\")\n", 21 | "\n", 22 | "import torch\n", 23 | "import pandas as pd\n", 24 | "from sentence_transformers import SentenceTransformer\n", 25 | "\n", 26 | "import clip\n", 27 | "import utils\n", 28 | "import similarity" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "id": "5b5c3c8f-e7d4-4409-a061-7995483dfe21", 34 | "metadata": {}, 35 | "source": [ 36 | "## Arguments for CLIP-Dissect" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 2, 42 | "id": "1724590a-2333-4daa-9948-6be1dfc60c25", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "clip_name = 'ViT-B/16'\n", 47 | "target_name = 'resnet50'\n", 48 | "target_layer = 'fc'\n", 49 | "batch_size = 200\n", 50 | "device = 'cuda'\n", 51 | "pool_mode = 'avg'\n", 52 | "\n", 53 | "save_dir = 'saved_activations'\n", 54 | "similarity_fn = similarity.soft_wpmi" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 3, 60 | "id": "ea9b1413-bfce-4941-a09b-bf45606de62c", 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "model = SentenceTransformer('all-mpnet-base-v2')\n", 65 | "clip_model, _ = clip.load(clip_name, device=device)\n", 66 | "\n", 67 | "with open('data/imagenet_labels.txt', 'r') as f: \n", 68 | " imagenet_classnames = (f.read()).split('\\n')" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "id": "20097db4-7a7b-461a-9cf0-beed037c4e6f", 74 | "metadata": {}, 75 | "source": [ 76 | "## Run CLIP-Dissect" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 4, 82 | "id": "4b6a1e91-5363-43a3-8f0b-4a034515923e", 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "rows = [(\"imagenet_val\", \"data/broden_labels_clean.txt\"),\n", 87 | " (\"imagenet_val\", \"data/3k.txt\"),\n", 88 | " (\"imagenet_val\", \"data/10k.txt\"),\n", 89 | " (\"imagenet_val\", \"data/20k.txt\"),\n", 90 | " (\"imagenet_val\", \"data/imagenet_labels.txt\"),\n", 91 | " (\"cifar100_train\", \"data/20k.txt\"),\n", 92 | " (\"broden\", \"data/20k.txt\"),\n", 93 | " (\"imagenet_val\", \"data/20k.txt\"),\n", 94 | " (\"imagenet_broden\", \"data/20k.txt\"),]" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "id": "edd0e205-0b81-4d59-80ac-16b321e56949", 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "for d_probe, concept_set in rows:\n", 105 | " with open(concept_set, 'r') as f: \n", 106 | " words = (f.read()).split('\\n')\n", 107 | " utils.save_activations(clip_name = clip_name, target_name = target_name, target_layers = [target_layer], \n", 108 | " d_probe = d_probe, concept_set = concept_set, batch_size = batch_size, \n", 109 | " device = device, pool_mode=pool_mode, save_dir = save_dir)\n", 110 | "\n", 111 | " save_names = utils.get_save_names(clip_name = clip_name, target_name = target_name,\n", 112 | " target_layer = target_layer, d_probe = d_probe,\n", 113 | " concept_set = concept_set, pool_mode=pool_mode,\n", 114 | " save_dir = save_dir)\n", 115 | "\n", 116 | " target_save_name, clip_save_name, text_save_name = save_names\n", 117 | "\n", 118 | " similarities, target_feats = utils.get_similarity_from_activations(target_save_name, clip_save_name, \n", 119 | " text_save_name, similarity_fn, device=device)\n", 120 | "\n", 121 | " clip_preds = torch.argmax(similarities, dim=1)\n", 122 | " clip_preds = [words[int(pred)] for pred in clip_preds]\n", 123 | "\n", 124 | " clip_cos, mpnet_cos = utils.get_cos_similarity(clip_preds, imagenet_classnames, clip_model, model, device, batch_size)\n", 125 | " print(\"D_probe:{}, Concept set:{}\".format(d_probe, concept_set))\n", 126 | " print(\"CLIP-Dissect - Clip similarity: {:.4f}, mpnet similarity: {:.4f}\".format(clip_cos, mpnet_cos))" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "id": "0689b84a-cfdc-4ee3-a4b2-5b251396fb64", 132 | "metadata": {}, 133 | "source": [ 134 | "## Baselines" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "id": "422ba0dd-909d-4c78-a439-0e95c98662af", 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "netdissect_res = pd.read_csv('data/NetDissect_results/resnet50_imagenet_fc.csv')\n", 145 | "nd_preds = netdissect_res['label'].values\n", 146 | "\n", 147 | "clip_cos, mpnet_cos = utils.get_cos_similarity(nd_preds, imagenet_classnames, clip_model, model, device, batch_size)\n", 148 | "print(\"Network Dissection - Clip similarity: {:.4f}, mpnet similarity: {:.4f}\".format(clip_cos, mpnet_cos))" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "id": "083f8e02-4eb0-422d-97bc-ea3914d36d6b", 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "milan_preds = pd.read_csv('data/MILAN_results/m_base_resnet50_imagenet.csv')\n", 159 | "milan_preds = milan_preds[milan_preds['layer']=='fc']\n", 160 | "milan_preds = milan_preds.sort_values(by=['unit'])\n", 161 | "milan_preds = list(milan_preds['description'])\n", 162 | "\n", 163 | "clip_cos, mpnet_cos = utils.get_cos_similarity(milan_preds, imagenet_classnames, clip_model, model, device, batch_size)\n", 164 | "print(\"MILAN - Clip similarity: {:.4f}, mpnet similarity: {:.4f}\".format(clip_cos, mpnet_cos))" 165 | ] 166 | } 167 | ], 168 | "metadata": { 169 | "kernelspec": { 170 | "display_name": "Python [conda env:jovyan-clip]", 171 | "language": "python", 172 | "name": "conda-env-jovyan-clip-py" 173 | }, 174 | "language_info": { 175 | "codemirror_mode": { 176 | "name": "ipython", 177 | "version": 3 178 | }, 179 | "file_extension": ".py", 180 | "mimetype": "text/x-python", 181 | "name": "python", 182 | "nbconvert_exporter": "python", 183 | "pygments_lexer": "ipython3", 184 | "version": "3.9.9" 185 | } 186 | }, 187 | "nbformat": 4, 188 | "nbformat_minor": 5 189 | } 190 | -------------------------------------------------------------------------------- /experiments/table2_quantitative_rn18.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "ef9877fe-5d0d-4ad1-950a-8cfaa5e733fe", 6 | "metadata": {}, 7 | "source": [ 8 | "## Describing final layer neurons of ResNet-18 (Places)" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "29cd5fcc-cc5c-4cca-8708-491e6f904ee5", 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stderr", 19 | "output_type": "stream", 20 | "text": [ 21 | "2023-03-22 07:19:55.947759: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "import os\n", 27 | "#virtually move to parent directory\n", 28 | "os.chdir(\"..\")\n", 29 | "\n", 30 | "import torch\n", 31 | "import pandas as pd\n", 32 | "\n", 33 | "from sentence_transformers import SentenceTransformer\n", 34 | "\n", 35 | "import clip\n", 36 | "import utils\n", 37 | "import data_utils\n", 38 | "import similarity" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 2, 44 | "id": "1724590a-2333-4daa-9948-6be1dfc60c25", 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "#Arguments\n", 49 | "clip_name = 'ViT-B/16'\n", 50 | "target_name = 'resnet18_places'\n", 51 | "target_layer = 'fc'\n", 52 | "d_probe = 'broden'\n", 53 | "concept_set = 'data/broden_labels_clean.txt'\n", 54 | "batch_size = 200\n", 55 | "device = 'cuda'\n", 56 | "pool_mode = 'avg'\n", 57 | "\n", 58 | "save_dir = 'saved_activations'\n", 59 | "similarity_fn = similarity.soft_wpmi" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 3, 65 | "id": "4b6a1e91-5363-43a3-8f0b-4a034515923e", 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "utils.save_activations(clip_name = clip_name, target_name = target_name, target_layers = [target_layer], \n", 70 | " d_probe = d_probe, concept_set = concept_set, batch_size = batch_size, \n", 71 | " device = device, pool_mode=pool_mode, save_dir = save_dir)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 4, 77 | "id": "edd0e205-0b81-4d59-80ac-16b321e56949", 78 | "metadata": {}, 79 | "outputs": [ 80 | { 81 | "name": "stderr", 82 | "output_type": "stream", 83 | "text": [ 84 | "100%|██████████| 365/365 [00:00<00:00, 1611.75it/s]" 85 | ] 86 | }, 87 | { 88 | "name": "stdout", 89 | "output_type": "stream", 90 | "text": [ 91 | "torch.Size([365, 1197])\n" 92 | ] 93 | }, 94 | { 95 | "name": "stderr", 96 | "output_type": "stream", 97 | "text": [ 98 | "\n" 99 | ] 100 | } 101 | ], 102 | "source": [ 103 | "save_names = utils.get_save_names(clip_name = clip_name, target_name = target_name,\n", 104 | " target_layer = target_layer, d_probe = d_probe,\n", 105 | " concept_set = concept_set, pool_mode=pool_mode,\n", 106 | " save_dir = save_dir)\n", 107 | "\n", 108 | "target_save_name, clip_save_name, text_save_name = save_names\n", 109 | "\n", 110 | "similarities, target_feats = utils.get_similarity_from_activations(target_save_name, clip_save_name, \n", 111 | " text_save_name, similarity_fn, device=device)\n", 112 | "\n", 113 | "with open(concept_set, 'r') as f: \n", 114 | " words = (f.read()).split('\\n')" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 5, 120 | "id": "d39dfe00-ff36-4d18-adcb-b348220dca6a", 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "#Clean up names of target classes\n", 125 | "with open('data/categories_places365.txt', 'r') as f:\n", 126 | " cls_id_to_name = f.read().split('\\n')\n", 127 | " cls_id_to_name = [(cls[3:]).split(' ')[0] for cls in cls_id_to_name]\n", 128 | "\n", 129 | "def process_word(word):\n", 130 | " if concept_set.startswith('data/broden_labels'):\n", 131 | " if word.endswith(\"-s\"):\n", 132 | " word = word[:-2]\n", 133 | " word = word.replace('_', ' ')\n", 134 | " return \"{}\".format(word)\n", 135 | " elif concept_set == 'data/categories_places365.txt':\n", 136 | " \n", 137 | " word = word[3:].split(' ')[0]\n", 138 | " word = word.replace('/', '-')\n", 139 | " word = word.replace('_', ' ')\n", 140 | " \n", 141 | " return \"{}\".format(word)" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "id": "81c91edd-85ec-4b37-b7d6-2df042c1f622", 147 | "metadata": {}, 148 | "source": [ 149 | "# Accuracies" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 6, 155 | "id": "a71c2a95-1536-4fbc-8bc8-95798fd59d45", 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "id_to_label = data_utils.get_places_id_to_broden_label()\n", 160 | "\n", 161 | "def clean_label(label):\n", 162 | " if label.startswith('/'):\n", 163 | " label = label[3:]\n", 164 | " label = label.split(' ')[0]\n", 165 | " if label.endswith('-s'):\n", 166 | " label = label[:-2]\n", 167 | " return label\n", 168 | " \n", 169 | "def get_topk_acc(similarities, k=5):\n", 170 | " correct = 0\n", 171 | " total = 0\n", 172 | " for orig_id in range(len(similarities)):\n", 173 | " #skip classes not in Broden\n", 174 | " if id_to_label[orig_id]==None:\n", 175 | " continue\n", 176 | " else:\n", 177 | " vals, ids = torch.topk(similarities[orig_id], k, largest=True)\n", 178 | " for idx in ids[:k]:\n", 179 | " if (process_word(words[idx])==process_word(id_to_label[orig_id])):\n", 180 | " correct += 1\n", 181 | " continue\n", 182 | " total += 1\n", 183 | " return (correct/total)*100" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 7, 189 | "id": "f60b4f49-6829-4da6-9add-94ff0e96a98a", 190 | "metadata": {}, 191 | "outputs": [ 192 | { 193 | "name": "stdout", 194 | "output_type": "stream", 195 | "text": [ 196 | "CLIP-Dissect Top 1 acc:58.0524\n", 197 | "CLIP-Dissect Top 5 acc:86.1423\n" 198 | ] 199 | } 200 | ], 201 | "source": [ 202 | "print(\"CLIP-Dissect Top 1 acc:{:.4f}\".format(get_topk_acc(similarities, k=1)))\n", 203 | "print(\"CLIP-Dissect Top 5 acc:{:.4f}\".format(get_topk_acc(similarities, k=5)))" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 8, 209 | "id": "f5ecd688-007f-4efa-89e8-7fbca8dd206b", 210 | "metadata": {}, 211 | "outputs": [ 212 | { 213 | "name": "stdout", 214 | "output_type": "stream", 215 | "text": [ 216 | "Network Dissection Top 1 acc:43.8202\n" 217 | ] 218 | } 219 | ], 220 | "source": [ 221 | "df = pd.read_csv('data/NetDissect_results/resnet18_places365_fc.csv')\n", 222 | "correct = 0\n", 223 | "total = 0\n", 224 | "for i, label in enumerate(df['label']):\n", 225 | " if id_to_label[i]==None:\n", 226 | " continue\n", 227 | " else:\n", 228 | " correct += (clean_label(label)==clean_label(id_to_label[i]))\n", 229 | " total += 1\n", 230 | "\n", 231 | "print(\"Network Dissection Top 1 acc:{:.4f}\".format(correct/total*100))" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "id": "cbfa97eb-1e6a-45c5-96e5-3d5724be0c13", 237 | "metadata": {}, 238 | "source": [ 239 | "# Cos similarities" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": 9, 245 | "id": "f2ba548c-59fc-4112-8e93-9bb034c0adb0", 246 | "metadata": {}, 247 | "outputs": [], 248 | "source": [ 249 | "model = SentenceTransformer('all-mpnet-base-v2')\n", 250 | "clip_model, _ = clip.load(clip_name, device=device)" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 10, 256 | "id": "0a33a7c1-0acc-444b-a530-baedc39ffe45", 257 | "metadata": {}, 258 | "outputs": [ 259 | { 260 | "name": "stdout", 261 | "output_type": "stream", 262 | "text": [ 263 | "CLIP-Dissect - Clip similarity: 0.9106, mpnet similarity: 0.7024\n" 264 | ] 265 | } 266 | ], 267 | "source": [ 268 | "clip_preds = torch.argmax(similarities, dim=1)\n", 269 | "clip_preds = [words[int(pred)] for pred in clip_preds]\n", 270 | "\n", 271 | "clip_cos, mpnet_cos = utils.get_cos_similarity(clip_preds, cls_id_to_name, clip_model, model, device, batch_size)\n", 272 | "print(\"CLIP-Dissect - Clip similarity: {:.4f}, mpnet similarity: {:.4f}\".format(clip_cos, mpnet_cos))" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 11, 278 | "id": "314b1e83-93fb-4617-9a2a-55d6d2c95f3a", 279 | "metadata": {}, 280 | "outputs": [ 281 | { 282 | "name": "stdout", 283 | "output_type": "stream", 284 | "text": [ 285 | "Network Dissection - Clip similarity: 0.8887, mpnet similarity: 0.6697\n" 286 | ] 287 | } 288 | ], 289 | "source": [ 290 | "netdissect_res = pd.read_csv('data/NetDissect_results/resnet18_places365_fc.csv')\n", 291 | "nd_preds = netdissect_res['label'].values\n", 292 | "nd_preds = [clean_label(pred) for pred in nd_preds]\n", 293 | "\n", 294 | "clip_cos, mpnet_cos = utils.get_cos_similarity(nd_preds, cls_id_to_name, clip_model, model, device, batch_size)\n", 295 | "print(\"Network Dissection - Clip similarity: {:.4f}, mpnet similarity: {:.4f}\".format(clip_cos, mpnet_cos))" 296 | ] 297 | } 298 | ], 299 | "metadata": { 300 | "kernelspec": { 301 | "display_name": "Python 3 (ipykernel)", 302 | "language": "python", 303 | "name": "python3" 304 | }, 305 | "language_info": { 306 | "codemirror_mode": { 307 | "name": "ipython", 308 | "version": 3 309 | }, 310 | "file_extension": ".py", 311 | "mimetype": "text/x-python", 312 | "name": "python", 313 | "nbconvert_exporter": "python", 314 | "pygments_lexer": "ipython3", 315 | "version": "3.8.12" 316 | } 317 | }, 318 | "nbformat": 4, 319 | "nbformat_minor": 5 320 | } 321 | -------------------------------------------------------------------------------- /experiments/table3_similarity_comparison.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "10471e16-5f1a-486a-a053-f32e1faeff3f", 6 | "metadata": {}, 7 | "source": [ 8 | "## Similarity function comparison" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "4b33deec-4d03-46d8-8ac5-b7b3e5b2f6c9", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import os\n", 19 | "#virtually move to parent directory\n", 20 | "os.chdir(\"..\")\n", 21 | "\n", 22 | "import torch\n", 23 | "from sentence_transformers import SentenceTransformer\n", 24 | "from sklearn import metrics\n", 25 | "\n", 26 | "import clip\n", 27 | "import utils\n", 28 | "import similarity" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "id": "c5eb399e-2c95-4609-bda4-cc5a5fc7f268", 34 | "metadata": {}, 35 | "source": [ 36 | "## Settings" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 2, 42 | "id": "1724590a-2333-4daa-9948-6be1dfc60c25", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "similarity_fns = [\"cos_similarity\", \"rank_reorder\", \"wpmi\", \"soft_wpmi\"]\n", 47 | "d_probes = ['cifar100_train', 'broden', 'imagenet_val', 'imagenet_broden']\n", 48 | "\n", 49 | "clip_name = 'ViT-B/16'\n", 50 | "target_name = 'resnet50'\n", 51 | "target_layer = 'fc'\n", 52 | "batch_size = 200\n", 53 | "device = 'cuda'\n", 54 | "pool_mode = 'avg'\n", 55 | "save_dir = 'saved_activations'" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 3, 61 | "id": "5ed87e72-f472-488d-89bd-c08bee7657d5", 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "model = SentenceTransformer('all-mpnet-base-v2')\n", 66 | "clip_model, _ = clip.load(clip_name, device=device)\n", 67 | "\n", 68 | "with open(\"data/imagenet_labels.txt\", \"r\") as f:\n", 69 | " cls_id_to_name = f.read().split(\"\\n\")" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "id": "cbfa97eb-1e6a-45c5-96e5-3d5724be0c13", 75 | "metadata": {}, 76 | "source": [ 77 | "# Cos similarities" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 4, 83 | "id": "4b6a1e91-5363-43a3-8f0b-4a034515923e", 84 | "metadata": {}, 85 | "outputs": [ 86 | { 87 | "name": "stdout", 88 | "output_type": "stream", 89 | "text": [ 90 | "Files already downloaded and verified\n", 91 | "Files already downloaded and verified\n" 92 | ] 93 | }, 94 | { 95 | "name": "stderr", 96 | "output_type": "stream", 97 | "text": [ 98 | "100%|██████████| 1/1 [00:01<00:00, 1.65s/it]\n" 99 | ] 100 | }, 101 | { 102 | "name": "stdout", 103 | "output_type": "stream", 104 | "text": [ 105 | "Similarity fn: cos_similarity, D_probe: cifar100_train\n", 106 | "Clip similarity: 0.6484, mpnet similarity: 0.2756\n" 107 | ] 108 | }, 109 | { 110 | "name": "stderr", 111 | "output_type": "stream", 112 | "text": [ 113 | "100%|██████████| 1/1 [00:01<00:00, 1.81s/it]\n" 114 | ] 115 | }, 116 | { 117 | "name": "stdout", 118 | "output_type": "stream", 119 | "text": [ 120 | "Similarity fn: cos_similarity, D_probe: broden\n", 121 | "Clip similarity: 0.6235, mpnet similarity: 0.2153\n" 122 | ] 123 | }, 124 | { 125 | "name": "stderr", 126 | "output_type": "stream", 127 | "text": [ 128 | "100%|██████████| 1/1 [00:01<00:00, 1.48s/it]\n" 129 | ] 130 | }, 131 | { 132 | "name": "stdout", 133 | "output_type": "stream", 134 | "text": [ 135 | "Similarity fn: cos_similarity, D_probe: imagenet_val\n", 136 | "Clip similarity: 0.6216, mpnet similarity: 0.2829\n" 137 | ] 138 | }, 139 | { 140 | "name": "stderr", 141 | "output_type": "stream", 142 | "text": [ 143 | "100%|██████████| 1/1 [00:03<00:00, 3.17s/it]\n" 144 | ] 145 | }, 146 | { 147 | "name": "stdout", 148 | "output_type": "stream", 149 | "text": [ 150 | "Similarity fn: cos_similarity, D_probe: imagenet_broden\n", 151 | "Clip similarity: 0.6421, mpnet similarity: 0.2587\n", 152 | "Files already downloaded and verified\n", 153 | "Files already downloaded and verified\n" 154 | ] 155 | }, 156 | { 157 | "name": "stderr", 158 | "output_type": "stream", 159 | "text": [ 160 | "100%|██████████| 1000/1000 [02:49<00:00, 5.91it/s]\n" 161 | ] 162 | }, 163 | { 164 | "name": "stdout", 165 | "output_type": "stream", 166 | "text": [ 167 | "Similarity fn: rank_reorder, D_probe: cifar100_train\n", 168 | "Clip similarity: 0.7227, mpnet similarity: 0.3247\n" 169 | ] 170 | }, 171 | { 172 | "name": "stderr", 173 | "output_type": "stream", 174 | "text": [ 175 | "100%|██████████| 1000/1000 [03:39<00:00, 4.55it/s]\n" 176 | ] 177 | }, 178 | { 179 | "name": "stdout", 180 | "output_type": "stream", 181 | "text": [ 182 | "Similarity fn: rank_reorder, D_probe: broden\n", 183 | "Clip similarity: 0.7471, mpnet similarity: 0.3856\n" 184 | ] 185 | }, 186 | { 187 | "name": "stderr", 188 | "output_type": "stream", 189 | "text": [ 190 | "100%|██████████| 1000/1000 [02:46<00:00, 6.00it/s]\n" 191 | ] 192 | }, 193 | { 194 | "name": "stdout", 195 | "output_type": "stream", 196 | "text": [ 197 | "Similarity fn: rank_reorder, D_probe: imagenet_val\n", 198 | "Clip similarity: 0.7832, mpnet similarity: 0.4911\n" 199 | ] 200 | }, 201 | { 202 | "name": "stderr", 203 | "output_type": "stream", 204 | "text": [ 205 | "100%|██████████| 1000/1000 [06:54<00:00, 2.41it/s]\n" 206 | ] 207 | }, 208 | { 209 | "name": "stdout", 210 | "output_type": "stream", 211 | "text": [ 212 | "Similarity fn: rank_reorder, D_probe: imagenet_broden\n", 213 | "Clip similarity: 0.7866, mpnet similarity: 0.5035\n", 214 | "Files already downloaded and verified\n", 215 | "Files already downloaded and verified\n" 216 | ] 217 | }, 218 | { 219 | "name": "stderr", 220 | "output_type": "stream", 221 | "text": [ 222 | "100%|██████████| 1000/1000 [00:01<00:00, 622.31it/s]\n" 223 | ] 224 | }, 225 | { 226 | "name": "stdout", 227 | "output_type": "stream", 228 | "text": [ 229 | "Similarity fn: wpmi, D_probe: cifar100_train\n", 230 | "Clip similarity: 0.7192, mpnet similarity: 0.3457\n" 231 | ] 232 | }, 233 | { 234 | "name": "stderr", 235 | "output_type": "stream", 236 | "text": [ 237 | "100%|██████████| 1000/1000 [00:01<00:00, 597.84it/s]\n" 238 | ] 239 | }, 240 | { 241 | "name": "stdout", 242 | "output_type": "stream", 243 | "text": [ 244 | "Similarity fn: wpmi, D_probe: broden\n", 245 | "Clip similarity: 0.7427, mpnet similarity: 0.3886\n" 246 | ] 247 | }, 248 | { 249 | "name": "stderr", 250 | "output_type": "stream", 251 | "text": [ 252 | "100%|██████████| 1000/1000 [00:01<00:00, 553.30it/s]\n" 253 | ] 254 | }, 255 | { 256 | "name": "stdout", 257 | "output_type": "stream", 258 | "text": [ 259 | "Similarity fn: wpmi, D_probe: imagenet_val\n", 260 | "Clip similarity: 0.7944, mpnet similarity: 0.5301\n" 261 | ] 262 | }, 263 | { 264 | "name": "stderr", 265 | "output_type": "stream", 266 | "text": [ 267 | "100%|██████████| 1000/1000 [00:01<00:00, 553.67it/s]\n" 268 | ] 269 | }, 270 | { 271 | "name": "stdout", 272 | "output_type": "stream", 273 | "text": [ 274 | "Similarity fn: wpmi, D_probe: imagenet_broden\n", 275 | "Clip similarity: 0.7930, mpnet similarity: 0.5266\n", 276 | "Files already downloaded and verified\n", 277 | "Files already downloaded and verified\n" 278 | ] 279 | }, 280 | { 281 | "name": "stderr", 282 | "output_type": "stream", 283 | "text": [ 284 | "100%|██████████| 1000/1000 [00:05<00:00, 185.16it/s]\n" 285 | ] 286 | }, 287 | { 288 | "name": "stdout", 289 | "output_type": "stream", 290 | "text": [ 291 | "torch.Size([1000, 20000])\n", 292 | "Similarity fn: soft_wpmi, D_probe: cifar100_train\n", 293 | "Clip similarity: 0.7300, mpnet similarity: 0.3671\n" 294 | ] 295 | }, 296 | { 297 | "name": "stderr", 298 | "output_type": "stream", 299 | "text": [ 300 | "100%|██████████| 1000/1000 [00:04<00:00, 203.97it/s]\n" 301 | ] 302 | }, 303 | { 304 | "name": "stdout", 305 | "output_type": "stream", 306 | "text": [ 307 | "torch.Size([1000, 20000])\n", 308 | "Similarity fn: soft_wpmi, D_probe: broden\n", 309 | "Clip similarity: 0.7412, mpnet similarity: 0.3946\n" 310 | ] 311 | }, 312 | { 313 | "name": "stderr", 314 | "output_type": "stream", 315 | "text": [ 316 | "100%|██████████| 1000/1000 [00:04<00:00, 209.43it/s]\n" 317 | ] 318 | }, 319 | { 320 | "name": "stdout", 321 | "output_type": "stream", 322 | "text": [ 323 | "torch.Size([1000, 20000])\n", 324 | "Similarity fn: soft_wpmi, D_probe: imagenet_val\n", 325 | "Clip similarity: 0.7900, mpnet similarity: 0.5262\n" 326 | ] 327 | }, 328 | { 329 | "name": "stderr", 330 | "output_type": "stream", 331 | "text": [ 332 | "100%|██████████| 1000/1000 [00:04<00:00, 209.87it/s]\n" 333 | ] 334 | }, 335 | { 336 | "name": "stdout", 337 | "output_type": "stream", 338 | "text": [ 339 | "torch.Size([1000, 20000])\n", 340 | "Similarity fn: soft_wpmi, D_probe: imagenet_broden\n", 341 | "Clip similarity: 0.7900, mpnet similarity: 0.5239\n" 342 | ] 343 | } 344 | ], 345 | "source": [ 346 | "concept_set = 'data/20k.txt'\n", 347 | "\n", 348 | "with open(concept_set, 'r') as f:\n", 349 | " words = f.read().split('\\n')\n", 350 | "\n", 351 | "for similarity_fn in similarity_fns:\n", 352 | " for d_probe in d_probes:\n", 353 | " utils.save_activations(clip_name = clip_name, target_name = target_name, target_layers = [target_layer], \n", 354 | " d_probe = d_probe, concept_set = concept_set, batch_size = batch_size, \n", 355 | " device = device, pool_mode=pool_mode, save_dir = save_dir)\n", 356 | "\n", 357 | " save_names = utils.get_save_names(clip_name = clip_name, target_name = target_name,\n", 358 | " target_layer = target_layer, d_probe = d_probe,\n", 359 | " concept_set = concept_set, pool_mode=pool_mode,\n", 360 | " save_dir = save_dir)\n", 361 | "\n", 362 | " target_save_name, clip_save_name, text_save_name = save_names\n", 363 | "\n", 364 | " similarities, target_feats = utils.get_similarity_from_activations(target_save_name, clip_save_name, \n", 365 | " text_save_name, \n", 366 | " eval(\"similarity.{}\".format(similarity_fn)),\n", 367 | " device=device)\n", 368 | "\n", 369 | " clip_preds = torch.argmax(similarities, dim=1)\n", 370 | " clip_preds = [words[int(pred)] for pred in clip_preds]\n", 371 | "\n", 372 | " clip_cos, mpnet_cos = utils.get_cos_similarity(clip_preds, cls_id_to_name, clip_model, model, device, batch_size)\n", 373 | " print(\"Similarity fn: {}, D_probe: {}\".format(similarity_fn, d_probe))\n", 374 | " print(\"Clip similarity: {:.4f}, mpnet similarity: {:.4f}\".format(clip_cos, mpnet_cos))" 375 | ] 376 | }, 377 | { 378 | "cell_type": "markdown", 379 | "id": "81c91edd-85ec-4b37-b7d6-2df042c1f622", 380 | "metadata": {}, 381 | "source": [ 382 | "# Accuracies" 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": 5, 388 | "id": "1ebc5fec-0878-4f80-b3f1-ca7c1f6a2486", 389 | "metadata": {}, 390 | "outputs": [], 391 | "source": [ 392 | "def get_topk_acc(sim, k=5):\n", 393 | " correct = 0\n", 394 | " for orig_id in range(1000):\n", 395 | " vals, ids = torch.topk(sim[orig_id], k=k)\n", 396 | " for idx in ids[:k]:\n", 397 | " correct += (int(idx)==orig_id)\n", 398 | " return (correct/1000)*100\n", 399 | "\n", 400 | "def get_correct_rank_mean_median(sim):\n", 401 | " ranks = []\n", 402 | " for orig_id in range(1000):\n", 403 | " vals, ids = torch.sort(sim[orig_id], descending=True)\n", 404 | " \n", 405 | " ranks.append(list(ids).index(orig_id)+1)\n", 406 | " \n", 407 | " mean = sum(ranks)/len(ranks)\n", 408 | " median = sorted(ranks)[500]\n", 409 | " return mean, median\n", 410 | "\n", 411 | "def get_auc(sim):\n", 412 | " max_sim, preds = torch.max(sim.cpu(), dim=1)\n", 413 | " gtruth = torch.arange(0, 1000)\n", 414 | " correct = (preds==gtruth)\n", 415 | " fpr, tpr, thresholds = metrics.roc_curve(correct, max_sim)\n", 416 | " auc = metrics.roc_auc_score(correct, max_sim)\n", 417 | " return auc" 418 | ] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "execution_count": 9, 423 | "id": "3ae886f7-6e33-4226-8d5e-9e6ccd1699ee", 424 | "metadata": {}, 425 | "outputs": [ 426 | { 427 | "name": "stdout", 428 | "output_type": "stream", 429 | "text": [ 430 | "Files already downloaded and verified\n", 431 | "Files already downloaded and verified\n" 432 | ] 433 | }, 434 | { 435 | "name": "stderr", 436 | "output_type": "stream", 437 | "text": [ 438 | "100%|██████████| 1/1 [00:00<00:00, 9.83it/s]\n" 439 | ] 440 | }, 441 | { 442 | "name": "stdout", 443 | "output_type": "stream", 444 | "text": [ 445 | "Similarity fn: cos_similarity, D_probe: cifar100_train\n", 446 | "Top 1 acc: 8.60%, Top 5 acc: 25.10%\n", 447 | "Mean rank of correct class: 53.94, Median rank of correct class: 21\n", 448 | "AUC: 0.5926\n" 449 | ] 450 | }, 451 | { 452 | "name": "stderr", 453 | "output_type": "stream", 454 | "text": [ 455 | "100%|██████████| 1/1 [00:00<00:00, 8.83it/s]\n" 456 | ] 457 | }, 458 | { 459 | "name": "stdout", 460 | "output_type": "stream", 461 | "text": [ 462 | "Similarity fn: cos_similarity, D_probe: broden\n", 463 | "Top 1 acc: 5.70%, Top 5 acc: 21.30%\n", 464 | "Mean rank of correct class: 63.92, Median rank of correct class: 24\n", 465 | "AUC: 0.5710\n" 466 | ] 467 | }, 468 | { 469 | "name": "stderr", 470 | "output_type": "stream", 471 | "text": [ 472 | "100%|██████████| 1/1 [00:00<00:00, 11.84it/s]\n" 473 | ] 474 | }, 475 | { 476 | "name": "stdout", 477 | "output_type": "stream", 478 | "text": [ 479 | "Similarity fn: cos_similarity, D_probe: imagenet_val\n", 480 | "Top 1 acc: 15.90%, Top 5 acc: 43.80%\n", 481 | "Mean rank of correct class: 22.56, Median rank of correct class: 7\n", 482 | "AUC: 0.4849\n" 483 | ] 484 | }, 485 | { 486 | "name": "stderr", 487 | "output_type": "stream", 488 | "text": [ 489 | "100%|██████████| 1/1 [00:00<00:00, 5.21it/s]\n" 490 | ] 491 | }, 492 | { 493 | "name": "stdout", 494 | "output_type": "stream", 495 | "text": [ 496 | "Similarity fn: cos_similarity, D_probe: imagenet_broden\n", 497 | "Top 1 acc: 11.30%, Top 5 acc: 34.60%\n", 498 | "Mean rank of correct class: 32.64, Median rank of correct class: 11\n", 499 | "AUC: 0.5003\n", 500 | "Files already downloaded and verified\n", 501 | "Files already downloaded and verified\n" 502 | ] 503 | }, 504 | { 505 | "name": "stderr", 506 | "output_type": "stream", 507 | "text": [ 508 | "100%|██████████| 1000/1000 [00:12<00:00, 81.02it/s]\n" 509 | ] 510 | }, 511 | { 512 | "name": "stdout", 513 | "output_type": "stream", 514 | "text": [ 515 | "Similarity fn: rank_reorder, D_probe: cifar100_train\n", 516 | "Top 1 acc: 36.60%, Top 5 acc: 67.50%\n", 517 | "Mean rank of correct class: 13.63, Median rank of correct class: 3\n", 518 | "AUC: 0.6338\n" 519 | ] 520 | }, 521 | { 522 | "name": "stderr", 523 | "output_type": "stream", 524 | "text": [ 525 | "100%|██████████| 1000/1000 [00:10<00:00, 93.23it/s]\n" 526 | ] 527 | }, 528 | { 529 | "name": "stdout", 530 | "output_type": "stream", 531 | "text": [ 532 | "Similarity fn: rank_reorder, D_probe: broden\n", 533 | "Top 1 acc: 57.70%, Top 5 acc: 83.70%\n", 534 | "Mean rank of correct class: 6.69, Median rank of correct class: 1\n", 535 | "AUC: 0.6853\n" 536 | ] 537 | }, 538 | { 539 | "name": "stderr", 540 | "output_type": "stream", 541 | "text": [ 542 | "100%|██████████| 1000/1000 [00:13<00:00, 75.74it/s]\n" 543 | ] 544 | }, 545 | { 546 | "name": "stdout", 547 | "output_type": "stream", 548 | "text": [ 549 | "Similarity fn: rank_reorder, D_probe: imagenet_val\n", 550 | "Top 1 acc: 89.80%, Top 5 acc: 98.60%\n", 551 | "Mean rank of correct class: 2.28, Median rank of correct class: 1\n", 552 | "AUC: 0.6434\n" 553 | ] 554 | }, 555 | { 556 | "name": "stderr", 557 | "output_type": "stream", 558 | "text": [ 559 | "100%|██████████| 1000/1000 [00:14<00:00, 67.79it/s]\n" 560 | ] 561 | }, 562 | { 563 | "name": "stdout", 564 | "output_type": "stream", 565 | "text": [ 566 | "Similarity fn: rank_reorder, D_probe: imagenet_broden\n", 567 | "Top 1 acc: 89.90%, Top 5 acc: 98.20%\n", 568 | "Mean rank of correct class: 2.12, Median rank of correct class: 1\n", 569 | "AUC: 0.5993\n", 570 | "Files already downloaded and verified\n", 571 | "Files already downloaded and verified\n" 572 | ] 573 | }, 574 | { 575 | "name": "stderr", 576 | "output_type": "stream", 577 | "text": [ 578 | "100%|██████████| 1000/1000 [00:00<00:00, 7502.62it/s]\n" 579 | ] 580 | }, 581 | { 582 | "name": "stdout", 583 | "output_type": "stream", 584 | "text": [ 585 | "Similarity fn: wpmi, D_probe: cifar100_train\n", 586 | "Top 1 acc: 24.00%, Top 5 acc: 55.00%\n", 587 | "Mean rank of correct class: 20.46, Median rank of correct class: 4\n", 588 | "AUC: 0.6355\n" 589 | ] 590 | }, 591 | { 592 | "name": "stderr", 593 | "output_type": "stream", 594 | "text": [ 595 | "100%|██████████| 1000/1000 [00:00<00:00, 6698.21it/s]\n" 596 | ] 597 | }, 598 | { 599 | "name": "stdout", 600 | "output_type": "stream", 601 | "text": [ 602 | "Similarity fn: wpmi, D_probe: broden\n", 603 | "Top 1 acc: 47.10%, Top 5 acc: 79.40%\n", 604 | "Mean rank of correct class: 7.58, Median rank of correct class: 2\n", 605 | "AUC: 0.7118\n" 606 | ] 607 | }, 608 | { 609 | "name": "stderr", 610 | "output_type": "stream", 611 | "text": [ 612 | "100%|██████████| 1000/1000 [00:00<00:00, 6421.12it/s]\n" 613 | ] 614 | }, 615 | { 616 | "name": "stdout", 617 | "output_type": "stream", 618 | "text": [ 619 | "Similarity fn: wpmi, D_probe: imagenet_val\n", 620 | "Top 1 acc: 86.90%, Top 5 acc: 98.10%\n", 621 | "Mean rank of correct class: 2.00, Median rank of correct class: 1\n", 622 | "AUC: 0.7176\n" 623 | ] 624 | }, 625 | { 626 | "name": "stderr", 627 | "output_type": "stream", 628 | "text": [ 629 | "100%|██████████| 1000/1000 [00:00<00:00, 6964.22it/s]\n" 630 | ] 631 | }, 632 | { 633 | "name": "stdout", 634 | "output_type": "stream", 635 | "text": [ 636 | "Similarity fn: wpmi, D_probe: imagenet_broden\n", 637 | "Top 1 acc: 86.90%, Top 5 acc: 98.10%\n", 638 | "Mean rank of correct class: 1.99, Median rank of correct class: 1\n", 639 | "AUC: 0.7270\n", 640 | "Files already downloaded and verified\n", 641 | "Files already downloaded and verified\n" 642 | ] 643 | }, 644 | { 645 | "name": "stderr", 646 | "output_type": "stream", 647 | "text": [ 648 | "100%|██████████| 1000/1000 [00:00<00:00, 1393.59it/s]\n" 649 | ] 650 | }, 651 | { 652 | "name": "stdout", 653 | "output_type": "stream", 654 | "text": [ 655 | "torch.Size([1000, 1000])\n", 656 | "Similarity fn: soft_wpmi, D_probe: cifar100_train\n", 657 | "Top 1 acc: 46.30%, Top 5 acc: 79.40%\n", 658 | "Mean rank of correct class: 8.61, Median rank of correct class: 2\n", 659 | "AUC: 0.6673\n" 660 | ] 661 | }, 662 | { 663 | "name": "stderr", 664 | "output_type": "stream", 665 | "text": [ 666 | "100%|██████████| 1000/1000 [00:00<00:00, 1180.32it/s]\n" 667 | ] 668 | }, 669 | { 670 | "name": "stdout", 671 | "output_type": "stream", 672 | "text": [ 673 | "torch.Size([1000, 1000])\n", 674 | "Similarity fn: soft_wpmi, D_probe: broden\n", 675 | "Top 1 acc: 70.70%, Top 5 acc: 90.00%\n", 676 | "Mean rank of correct class: 4.80, Median rank of correct class: 1\n", 677 | "AUC: 0.7856\n" 678 | ] 679 | }, 680 | { 681 | "name": "stderr", 682 | "output_type": "stream", 683 | "text": [ 684 | "100%|██████████| 1000/1000 [00:00<00:00, 1344.09it/s]\n" 685 | ] 686 | }, 687 | { 688 | "name": "stdout", 689 | "output_type": "stream", 690 | "text": [ 691 | "torch.Size([1000, 1000])\n", 692 | "Similarity fn: soft_wpmi, D_probe: imagenet_val\n", 693 | "Top 1 acc: 95.50%, Top 5 acc: 98.90%\n", 694 | "Mean rank of correct class: 1.18, Median rank of correct class: 1\n", 695 | "AUC: 0.9208\n" 696 | ] 697 | }, 698 | { 699 | "name": "stderr", 700 | "output_type": "stream", 701 | "text": [ 702 | "100%|██████████| 1000/1000 [00:00<00:00, 1253.33it/s]\n" 703 | ] 704 | }, 705 | { 706 | "name": "stdout", 707 | "output_type": "stream", 708 | "text": [ 709 | "torch.Size([1000, 1000])\n", 710 | "Similarity fn: soft_wpmi, D_probe: imagenet_broden\n", 711 | "Top 1 acc: 95.40%, Top 5 acc: 99.00%\n", 712 | "Mean rank of correct class: 1.19, Median rank of correct class: 1\n", 713 | "AUC: 0.9166\n" 714 | ] 715 | } 716 | ], 717 | "source": [ 718 | "concept_set = 'data/imagenet_labels.txt'\n", 719 | "with open(concept_set, 'r') as f: \n", 720 | " words = (f.read()).split('\\n')\n", 721 | " \n", 722 | "\n", 723 | "for similarity_fn in similarity_fns:\n", 724 | " for d_probe in d_probes:\n", 725 | " utils.save_activations(clip_name = clip_name, target_name = target_name, target_layers = [target_layer], \n", 726 | " d_probe = d_probe, concept_set = concept_set, batch_size = batch_size, \n", 727 | " device = device, pool_mode=pool_mode, save_dir = save_dir)\n", 728 | "\n", 729 | " save_names = utils.get_save_names(clip_name = clip_name, target_name = target_name,\n", 730 | " target_layer = target_layer, d_probe = d_probe,\n", 731 | " concept_set = concept_set, pool_mode=pool_mode,\n", 732 | " \n", 733 | " save_dir = save_dir)\n", 734 | "\n", 735 | " target_save_name, clip_save_name, text_save_name = save_names\n", 736 | "\n", 737 | " similarities, target_feats = utils.get_similarity_from_activations(target_save_name, clip_save_name, \n", 738 | " text_save_name, \n", 739 | " eval(\"similarity.{}\".format(similarity_fn)),\n", 740 | " device=device)\n", 741 | " \n", 742 | " print(\"Similarity fn: {}, D_probe: {}\".format(similarity_fn, d_probe))\n", 743 | " print(\"Top 1 acc: {:.2f}%, Top 5 acc: {:.2f}%\".format(get_topk_acc(similarities, k=1),\n", 744 | " get_topk_acc(similarities, k=5)))\n", 745 | " \n", 746 | " mean, median = get_correct_rank_mean_median(similarities)\n", 747 | " print(\"Mean rank of correct class: {:.2f}, Median rank of correct class: {}\".format(mean, median))\n", 748 | " print(\"AUC: {:.4f}\".format(get_auc(similarities)))\n", 749 | "\n" 750 | ] 751 | } 752 | ], 753 | "metadata": { 754 | "kernelspec": { 755 | "display_name": "Python [conda env:jovyan-clip]", 756 | "language": "python", 757 | "name": "conda-env-jovyan-clip-py" 758 | }, 759 | "language_info": { 760 | "codemirror_mode": { 761 | "name": "ipython", 762 | "version": 3 763 | }, 764 | "file_extension": ".py", 765 | "mimetype": "text/x-python", 766 | "name": "python", 767 | "nbconvert_exporter": "python", 768 | "pygments_lexer": "ipython3", 769 | "version": "3.9.9" 770 | } 771 | }, 772 | "nbformat": 4, 773 | "nbformat_minor": 5 774 | } 775 | -------------------------------------------------------------------------------- /experiments/text_colorings.py: -------------------------------------------------------------------------------- 1 | def get_coloring(fig_name): 2 | 3 | if fig_name=='fig1a': 4 | def get_color(method, i): 5 | if method=="clip": 6 | if i in []: 7 | return "orange" 8 | elif i in []: 9 | return "red" 10 | else: 11 | return "green" 12 | elif method=="nd": 13 | if i in []: 14 | return "orange" 15 | elif i in [0]: 16 | return "red" 17 | else: 18 | return "green" 19 | elif method=="milan_b": 20 | if i in []: 21 | return "orange" 22 | elif i in [0, 3]: 23 | return "red" 24 | else: 25 | return "green" 26 | elif method=="milan_ood": 27 | if i in []: 28 | return "orange" 29 | elif i in [0, 3]: 30 | return "red" 31 | else: 32 | return "green" 33 | 34 | elif fig_name=="fig1b": 35 | def get_color(method, i): 36 | if method=="clip": 37 | if i in []: 38 | return "orange" 39 | elif i in []: 40 | return "red" 41 | else: 42 | return "green" 43 | elif method=="nd": 44 | if i in [0, 1]: 45 | return "orange" 46 | elif i in [2]: 47 | return "red" 48 | else: 49 | return "green" 50 | elif method=="milan_b": 51 | if i in [2]: 52 | return "orange" 53 | elif i in [0, 1]: 54 | return "red" 55 | else: 56 | return "green" 57 | elif method=="milan_ood": 58 | if i in [1]: 59 | return "orange" 60 | elif i in [0, 2, 3]: 61 | return "red" 62 | else: 63 | return "green" 64 | 65 | elif fig_name == 'fig6a': 66 | def get_color(method, i): 67 | if method=="clip": 68 | if i in []: 69 | return "orange" 70 | elif i in []: 71 | return "red" 72 | else: 73 | return "green" 74 | elif method=="nd": 75 | if i in []: 76 | return "orange" 77 | elif i in []: 78 | return "red" 79 | else: 80 | return "green" 81 | elif method=="milan_b": 82 | if i in [0, 5]: 83 | return "orange" 84 | elif i in []: 85 | return "red" 86 | else: 87 | return "green" 88 | elif method=="milan_ood": 89 | if i in [0, 1, 3, 5, 6]: 90 | return "orange" 91 | elif i in []: 92 | return "red" 93 | else: 94 | return "green" 95 | 96 | elif fig_name == 'fig6b': 97 | def get_color(method, i): 98 | if method=="clip": 99 | if i in []: 100 | return "orange" 101 | elif i in []: 102 | return "red" 103 | else: 104 | return "green" 105 | elif method=="nd": 106 | if i in []: 107 | return "orange" 108 | elif i in []: 109 | return "red" 110 | else: 111 | return "green" 112 | elif method=="milan_b": 113 | if i in [0, 3, 5]: 114 | return "orange" 115 | elif i in [2, 4, 6, 7, 8]: 116 | return "red" 117 | else: 118 | return "green" 119 | elif method=="milan_ood": 120 | if i in [5]: 121 | return "orange" 122 | elif i in [0, 2, 3, 4, 6, 7, 8]: 123 | return "red" 124 | else: 125 | return "green" 126 | 127 | elif fig_name == 'fig7a': 128 | def get_color(method, i): 129 | if method=="clip": 130 | if i in [2]: 131 | return "orange" 132 | elif i in []: 133 | return "red" 134 | else: 135 | return "green" 136 | elif method=="nd": 137 | if i in []: 138 | return "orange" 139 | elif i in []: 140 | return "red" 141 | else: 142 | return "green" 143 | elif method=="milan_b": 144 | if i in [0, 2, 4, 7]: 145 | return "orange" 146 | elif i in [1, 5, 6, 8, 9]: 147 | return "red" 148 | else: 149 | return "green" 150 | elif method=="milan_ood": 151 | if i in [0, 6, 8]: 152 | return "orange" 153 | elif i in [1, 2, 3, 4, 5, 7, 9]: 154 | return "red" 155 | else: 156 | return "green" 157 | 158 | elif fig_name == 'fig7b': 159 | def get_color(method, i): 160 | if method=="clip": 161 | if i in []: 162 | return "orange" 163 | elif i in []: 164 | return "red" 165 | else: 166 | return "green" 167 | elif method=="nd": 168 | if i in [2,6]: 169 | return "orange" 170 | elif i in [3,4]: 171 | return "red" 172 | else: 173 | return "green" 174 | elif method=="milan_b": 175 | if i in [0, 2, 3, 4, 6, 8, 9]: 176 | return "orange" 177 | elif i in []: 178 | return "red" 179 | else: 180 | return "green" 181 | elif method=="milan_ood": 182 | if i in [3, 6, 8]: 183 | return "orange" 184 | elif i in [0, 2, 4, 7, 9]: 185 | return "red" 186 | else: 187 | return "green" 188 | 189 | elif fig_name=='fig14a': 190 | def get_color(method, i): 191 | if method=="cos": 192 | if i in []: 193 | return "orange" 194 | elif i in [0,1,3]: 195 | return "red" 196 | else: 197 | return "green" 198 | elif method=="soft_wpmi": 199 | return "green" 200 | 201 | 202 | elif fig_name=="fig14b": 203 | def get_color(method, i): 204 | if method=="cos": 205 | return "red" 206 | elif method=="soft_wpmi": 207 | return "green" 208 | 209 | else: 210 | def get_color(method, i): 211 | return "black" 212 | 213 | return get_color -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision >= 0.13 3 | ftfy >= 6.1.1 4 | regex >= 2023.3.23 5 | tqdm >= 4.65.0 6 | black >= 23.1.0 7 | isort >= 5.12.0 8 | pandas >= 1.5.3 9 | matplotlib == 3.5.1 10 | scipy >= 1.10.1 11 | huggingface-hub==0.4.0 12 | sentence-transformers == 2.2.0 13 | jupyter 14 | notebook 15 | -------------------------------------------------------------------------------- /similarity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from tqdm import tqdm 4 | from scipy import stats 5 | from matplotlib import pyplot as plt 6 | 7 | def cos_similarity_cubed(clip_feats, target_feats, device='cuda', batch_size=10000, min_norm=1e-3): 8 | """ 9 | Substract mean from each vector, then raises to third power and compares cos similarity 10 | Does not modify any tensors in place 11 | """ 12 | with torch.no_grad(): 13 | torch.cuda.empty_cache() 14 | 15 | clip_feats = clip_feats - torch.mean(clip_feats, dim=0, keepdim=True) 16 | target_feats = target_feats - torch.mean(target_feats, dim=0, keepdim=True) 17 | 18 | clip_feats = clip_feats**3 19 | target_feats = target_feats**3 20 | 21 | clip_feats = clip_feats/torch.clip(torch.norm(clip_feats, p=2, dim=0, keepdim=True), min_norm) 22 | target_feats = target_feats/torch.clip(torch.norm(target_feats, p=2, dim=0, keepdim=True), min_norm) 23 | 24 | similarities = [] 25 | for t_i in tqdm(range(math.ceil(target_feats.shape[1]/batch_size))): 26 | curr_similarities = [] 27 | curr_target = target_feats[:, t_i*batch_size:(t_i+1)*batch_size].to(device).T 28 | for c_i in range(math.ceil(clip_feats.shape[1]/batch_size)): 29 | curr_similarities.append(curr_target @ clip_feats[:, c_i*batch_size:(c_i+1)*batch_size].to(device)) 30 | similarities.append(torch.cat(curr_similarities, dim=1)) 31 | return torch.cat(similarities, dim=0) 32 | 33 | def cos_similarity(clip_feats, target_feats, device='cuda'): 34 | with torch.no_grad(): 35 | clip_feats = clip_feats / torch.norm(clip_feats, p=2, dim=0, keepdim=True) 36 | target_feats = target_feats / torch.norm(target_feats, p=2, dim=0, keepdim=True) 37 | 38 | batch_size = 10000 39 | 40 | similarities = [] 41 | for t_i in tqdm(range(math.ceil(target_feats.shape[1]/batch_size))): 42 | curr_similarities = [] 43 | curr_target = target_feats[:, t_i*batch_size:(t_i+1)*batch_size].to(device).T 44 | for c_i in range(math.ceil(clip_feats.shape[1]/batch_size)): 45 | curr_similarities.append(curr_target @ clip_feats[:, c_i*batch_size:(c_i+1)*batch_size].to(device)) 46 | similarities.append(torch.cat(curr_similarities, dim=1)) 47 | return torch.cat(similarities, dim=0) 48 | 49 | def soft_wpmi(clip_feats, target_feats, top_k=100, a=10, lam=1, device='cuda', 50 | min_prob=1e-7, p_start=0.998, p_end=0.97): 51 | 52 | with torch.no_grad(): 53 | torch.cuda.empty_cache() 54 | clip_feats = torch.nn.functional.softmax(a*clip_feats, dim=1) 55 | 56 | inds = torch.topk(target_feats, dim=0, k=top_k)[1] 57 | prob_d_given_e = [] 58 | 59 | p_in_examples = p_start-(torch.arange(start=0, end=top_k)/top_k*(p_start-p_end)).unsqueeze(1).to(device) 60 | for orig_id in tqdm(range(target_feats.shape[1])): 61 | 62 | curr_clip_feats = clip_feats.gather(0, inds[:,orig_id:orig_id+1].expand(-1,clip_feats.shape[1])).to(device) 63 | 64 | curr_p_d_given_e = 1+p_in_examples*(curr_clip_feats-1) 65 | curr_p_d_given_e = torch.sum(torch.log(curr_p_d_given_e+min_prob), dim=0, keepdim=True) 66 | prob_d_given_e.append(curr_p_d_given_e) 67 | torch.cuda.empty_cache() 68 | 69 | prob_d_given_e = torch.cat(prob_d_given_e, dim=0) 70 | print(prob_d_given_e.shape) 71 | #logsumexp trick to avoid underflow 72 | prob_d = (torch.logsumexp(prob_d_given_e, dim=0, keepdim=True) - 73 | torch.log(prob_d_given_e.shape[0]*torch.ones([1]).to(device))) 74 | mutual_info = prob_d_given_e - lam*prob_d 75 | return mutual_info 76 | 77 | def wpmi(clip_feats, target_feats, top_k=28, a=2, lam=0.6, device='cuda', min_prob=1e-7): 78 | 79 | with torch.no_grad(): 80 | torch.cuda.empty_cache() 81 | 82 | clip_feats = torch.nn.functional.softmax(a*clip_feats, dim=1) 83 | 84 | inds = torch.topk(target_feats, dim=0, k=top_k)[1] 85 | prob_d_given_e = [] 86 | 87 | for orig_id in tqdm(range(target_feats.shape[1])): 88 | torch.cuda.empty_cache() 89 | curr_clip_feats = clip_feats.gather(0, inds[:,orig_id:orig_id+1].expand(-1,clip_feats.shape[1])).to(device) 90 | curr_p_d_given_e = torch.sum(torch.log(curr_clip_feats+min_prob), dim=0, keepdim=True) 91 | prob_d_given_e.append(curr_p_d_given_e) 92 | 93 | prob_d_given_e = torch.cat(prob_d_given_e, dim=0) 94 | #logsumexp trick to avoid underflow 95 | prob_d = (torch.logsumexp(prob_d_given_e, dim=0, keepdim=True) - 96 | torch.log(prob_d_given_e.shape[0]*torch.ones([1]).to(device))) 97 | 98 | mutual_info = prob_d_given_e - lam*prob_d 99 | return mutual_info 100 | 101 | def rank_reorder(clip_feats, target_feats, device="cuda", p=3, top_fraction=0.05, scale_p=0.5): 102 | """ 103 | top fraction: percentage of mostly highly activating target images to use for eval. Between 0 and 1 104 | """ 105 | with torch.no_grad(): 106 | batch = 1500 107 | errors = [] 108 | top_n = int(target_feats.shape[0]*top_fraction) 109 | target_feats, inds = torch.topk(target_feats, dim=0, k=top_n) 110 | 111 | for orig_id in tqdm(range(target_feats.shape[1])): 112 | clip_indices = clip_feats.gather(0, inds[:, orig_id:orig_id+1].expand([-1,clip_feats.shape[1]])).to(device) 113 | #calculate the average probability score of the top neurons for each caption 114 | avg_clip = torch.mean(clip_indices, dim=0, keepdim=True) 115 | clip_indices = torch.argsort(clip_indices, dim=0) 116 | clip_indices = torch.argsort(clip_indices, dim=0) 117 | curr_errors = [] 118 | target = target_feats[:, orig_id:orig_id+1].to(device) 119 | sorted_target = torch.flip(target, dims=[0]) 120 | 121 | baseline_diff = sorted_target - torch.cat([sorted_target[torch.randperm(len(sorted_target))] for _ in range(5)], dim=1) 122 | baseline_diff = torch.mean(torch.abs(baseline_diff)**p) 123 | torch.cuda.empty_cache() 124 | 125 | for i in range(math.ceil(clip_indices.shape[1]/batch)): 126 | 127 | clip_id = (clip_indices[:, i*batch:(i+1)*batch]) 128 | reorg = sorted_target.expand(-1, batch).gather(dim=0, index=clip_id) 129 | diff = (target-reorg) 130 | curr_errors.append(torch.mean(torch.abs(diff)**p, dim=0, keepdim=True)/baseline_diff) 131 | errors.append(torch.cat(curr_errors, dim=1)/(avg_clip)**scale_p) 132 | 133 | errors = torch.cat(errors, dim=0) 134 | return -errors 135 | 136 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import numpy as np 4 | import torch 5 | import clip 6 | from tqdm import tqdm 7 | from torch.utils.data import DataLoader 8 | import data_utils 9 | 10 | PM_SUFFIX = {"max":"_max", "avg":""} 11 | 12 | def get_activation(outputs, mode): 13 | ''' 14 | mode: how to pool activations: one of avg, max 15 | for fc or ViT neurons does no pooling 16 | ''' 17 | if mode=='avg': 18 | def hook(model, input, output): 19 | if len(output.shape)==4: #CNN layers 20 | outputs.append(output.mean(dim=[2,3]).detach()) 21 | elif len(output.shape)==3: #ViT 22 | outputs.append(output[:, 0].clone()) 23 | elif len(output.shape)==2: #FC layers 24 | outputs.append(output.detach()) 25 | elif mode=='max': 26 | def hook(model, input, output): 27 | if len(output.shape)==4: #CNN layers 28 | outputs.append(output.amax(dim=[2,3]).detach()) 29 | elif len(output.shape)==3: #ViT 30 | outputs.append(output[:, 0].clone()) 31 | elif len(output.shape)==2: #FC layers 32 | outputs.append(output.detach()) 33 | return hook 34 | 35 | def get_save_names(clip_name, target_name, target_layer, d_probe, concept_set, pool_mode, save_dir): 36 | 37 | target_save_name = "{}/{}_{}_{}{}.pt".format(save_dir, d_probe, target_name, target_layer, 38 | PM_SUFFIX[pool_mode]) 39 | clip_save_name = "{}/{}_{}.pt".format(save_dir, d_probe, clip_name.replace('/', '')) 40 | concept_set_name = (concept_set.split("/")[-1]).split(".")[0] 41 | text_save_name = "{}/{}_{}.pt".format(save_dir, concept_set_name, clip_name.replace('/', '')) 42 | 43 | return target_save_name, clip_save_name, text_save_name 44 | 45 | def save_target_activations(target_model, dataset, save_name, target_layers = ["layer4"], batch_size = 1000, 46 | device = "cuda", pool_mode='avg'): 47 | """ 48 | save_name: save_file path, should include {} which will be formatted by layer names 49 | """ 50 | _make_save_dir(save_name) 51 | save_names = {} 52 | for target_layer in target_layers: 53 | save_names[target_layer] = save_name.format(target_layer) 54 | 55 | if _all_saved(save_names): 56 | return 57 | 58 | all_features = {target_layer:[] for target_layer in target_layers} 59 | 60 | hooks = {} 61 | for target_layer in target_layers: 62 | command = "target_model.{}.register_forward_hook(get_activation(all_features[target_layer], pool_mode))".format(target_layer) 63 | hooks[target_layer] = eval(command) 64 | 65 | with torch.no_grad(): 66 | for images, labels in tqdm(DataLoader(dataset, batch_size, num_workers=8, pin_memory=True)): 67 | features = target_model(images.to(device)) 68 | 69 | for target_layer in target_layers: 70 | torch.save(torch.cat(all_features[target_layer]), save_names[target_layer]) 71 | hooks[target_layer].remove() 72 | #free memory 73 | del all_features 74 | torch.cuda.empty_cache() 75 | return 76 | 77 | 78 | def save_clip_image_features(model, dataset, save_name, batch_size=1000 , device = "cuda"): 79 | _make_save_dir(save_name) 80 | all_features = [] 81 | 82 | if os.path.exists(save_name): 83 | return 84 | 85 | save_dir = save_name[:save_name.rfind("/")] 86 | if not os.path.exists(save_dir): 87 | os.makedirs(save_dir) 88 | with torch.no_grad(): 89 | for images, labels in tqdm(DataLoader(dataset, batch_size, num_workers=8, pin_memory=True)): 90 | features = model.encode_image(images.to(device)) 91 | all_features.append(features) 92 | torch.save(torch.cat(all_features), save_name) 93 | #free memory 94 | del all_features 95 | torch.cuda.empty_cache() 96 | return 97 | 98 | def save_clip_text_features(model, text, save_name, batch_size=1000): 99 | if os.path.exists(save_name): 100 | return 101 | _make_save_dir(save_name) 102 | text_features = [] 103 | with torch.no_grad(): 104 | for i in tqdm(range(math.ceil(len(text)/batch_size))): 105 | text_features.append(model.encode_text(text[batch_size*i:batch_size*(i+1)])) 106 | text_features = torch.cat(text_features, dim=0) 107 | torch.save(text_features, save_name) 108 | del text_features 109 | torch.cuda.empty_cache() 110 | return 111 | 112 | def get_clip_text_features(model, text, batch_size=1000): 113 | """ 114 | gets text features without saving, useful with dynamic concept sets 115 | """ 116 | text_features = [] 117 | with torch.no_grad(): 118 | for i in tqdm(range(math.ceil(len(text)/batch_size))): 119 | text_features.append(model.encode_text(text[batch_size*i:batch_size*(i+1)])) 120 | text_features = torch.cat(text_features, dim=0) 121 | return text_features 122 | 123 | def save_activations(clip_name, target_name, target_layers, d_probe, 124 | concept_set, batch_size, device, pool_mode, save_dir): 125 | 126 | clip_model, clip_preprocess = clip.load(clip_name, device=device) 127 | target_model, target_preprocess = data_utils.get_target_model(target_name, device) 128 | #setup data 129 | data_c = data_utils.get_data(d_probe, clip_preprocess) 130 | data_t = data_utils.get_data(d_probe, target_preprocess) 131 | 132 | with open(concept_set, 'r') as f: 133 | words = (f.read()).split('\n') 134 | #ignore empty lines 135 | words = [i for i in words if i!=""] 136 | 137 | text = clip.tokenize(["{}".format(word) for word in words]).to(device) 138 | 139 | save_names = get_save_names(clip_name = clip_name, target_name = target_name, 140 | target_layer = '{}', d_probe = d_probe, concept_set = concept_set, 141 | pool_mode=pool_mode, save_dir = save_dir) 142 | target_save_name, clip_save_name, text_save_name = save_names 143 | 144 | save_clip_text_features(clip_model, text, text_save_name, batch_size) 145 | save_clip_image_features(clip_model, data_c, clip_save_name, batch_size, device) 146 | save_target_activations(target_model, data_t, target_save_name, target_layers, 147 | batch_size, device, pool_mode) 148 | return 149 | 150 | def get_similarity_from_activations(target_save_name, clip_save_name, text_save_name, similarity_fn, 151 | return_target_feats=True, device="cuda"): 152 | 153 | image_features = torch.load(clip_save_name, map_location='cpu').float() 154 | text_features = torch.load(text_save_name, map_location='cpu').float() 155 | with torch.no_grad(): 156 | image_features /= image_features.norm(dim=-1, keepdim=True) 157 | text_features /= text_features.norm(dim=-1, keepdim=True) 158 | clip_feats = (image_features @ text_features.T) 159 | del image_features, text_features 160 | torch.cuda.empty_cache() 161 | 162 | target_feats = torch.load(target_save_name, map_location='cpu') 163 | similarity = similarity_fn(clip_feats, target_feats, device=device) 164 | 165 | del clip_feats 166 | torch.cuda.empty_cache() 167 | 168 | if return_target_feats: 169 | return similarity, target_feats 170 | else: 171 | del target_feats 172 | torch.cuda.empty_cache() 173 | return similarity 174 | 175 | def get_cos_similarity(preds, gt, clip_model, mpnet_model, device="cuda", batch_size=200): 176 | """ 177 | preds: predicted concepts, list of strings 178 | gt: correct concepts, list of strings 179 | """ 180 | pred_tokens = clip.tokenize(preds).to(device) 181 | gt_tokens = clip.tokenize(gt).to(device) 182 | pred_embeds = [] 183 | gt_embeds = [] 184 | 185 | #print(preds) 186 | with torch.no_grad(): 187 | for i in range(math.ceil(len(pred_tokens)/batch_size)): 188 | pred_embeds.append(clip_model.encode_text(pred_tokens[batch_size*i:batch_size*(i+1)])) 189 | gt_embeds.append(clip_model.encode_text(gt_tokens[batch_size*i:batch_size*(i+1)])) 190 | 191 | pred_embeds = torch.cat(pred_embeds, dim=0) 192 | pred_embeds /= pred_embeds.norm(dim=-1, keepdim=True) 193 | gt_embeds = torch.cat(gt_embeds, dim=0) 194 | gt_embeds /= gt_embeds.norm(dim=-1, keepdim=True) 195 | 196 | #l2_norm_pred = torch.norm(pred_embeds-gt_embeds, dim=1) 197 | cos_sim_clip = torch.sum(pred_embeds*gt_embeds, dim=1) 198 | 199 | gt_embeds = mpnet_model.encode([gt_x for gt_x in gt]) 200 | pred_embeds = mpnet_model.encode(preds) 201 | cos_sim_mpnet = np.sum(pred_embeds*gt_embeds, axis=1) 202 | 203 | return float(torch.mean(cos_sim_clip)), float(np.mean(cos_sim_mpnet)) 204 | 205 | def _all_saved(save_names): 206 | """ 207 | save_names: {layer_name:save_path} dict 208 | Returns True if there is a file corresponding to each one of the values in save_names, 209 | else Returns False 210 | """ 211 | for save_name in save_names.values(): 212 | if not os.path.exists(save_name): 213 | return False 214 | return True 215 | 216 | def _make_save_dir(save_name): 217 | """ 218 | creates save directory if one does not exist 219 | save_name: full save path 220 | """ 221 | save_dir = save_name[:save_name.rfind("/")] 222 | if not os.path.exists(save_dir): 223 | os.makedirs(save_dir) 224 | return 225 | 226 | 227 | --------------------------------------------------------------------------------