├── README.md ├── assets └── main.png ├── clip ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── clip.cpython-38.pyc │ ├── clip.cpython-39.pyc │ ├── model.cpython-38.pyc │ ├── model.cpython-39.pyc │ ├── simple_tokenizer.cpython-38.pyc │ └── simple_tokenizer.cpython-39.pyc ├── bpe_simple_vocab_16e6.txt.gz ├── clip.py ├── model.py ├── simple_tokenizer.py └── utils.py ├── configs ├── caltech101.yaml ├── dtd.yaml ├── eurosat.yaml ├── fgvc.yaml ├── imagenet.yaml ├── oxford_pets.yaml ├── sun397.yaml └── ucf101.yaml ├── datasets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── caltech101.cpython-38.pyc │ ├── dtd.cpython-38.pyc │ ├── eurosat.cpython-38.pyc │ ├── fgvc.cpython-38.pyc │ ├── imagenet.cpython-38.pyc │ ├── oxford_pets.cpython-38.pyc │ ├── sun397.cpython-38.pyc │ ├── ucf101.cpython-38.pyc │ └── utils.cpython-38.pyc ├── caltech101.py ├── dtd.py ├── eurosat.py ├── fgvc.py ├── imagenet.py ├── oxford_pets.py ├── sun397.py ├── ucf101.py └── utils.py └── main.py /README.md: -------------------------------------------------------------------------------- 1 | # Meta-Adapter: An Online Few-shot Learner for Vision-Language Model 2 | 3 | Cheng Cheng, [Lin Song](http://linsong.info), Ruoyi Xue, Hang Wang, [Hongbin Sun](https://gr.xjtu.edu.cn/en/web/hsun/home), [Yixiao Ge](https://geyixiao.com), [Ying Shan](https://www.linkedin.com/in/YingShanProfile) 4 | 5 | Meta-Adapter, a new few-shot learning method for CLIP, targets to overcome the limitations of previous methods in terms of poor generalization ability and low efficiency. 6 | The Meta-Adapter, employing a meta-testing mechanism and a lightweight residual-style network, extracts knowledge from few-shot samples without the need for additional fine-tuning, thus alleviating the over-fitting issue while maintaining high efficiency. 7 | 8 | ![Intro](assets/main.png) 9 | [Arxiv Paper](https://arxiv.org/pdf/2311.03774.pdf) 10 | 11 | ## Installation 12 | --- 13 | 1. This code is built on top of the toolbox [Dassl.pytorch](https://github.com/KaiyangZhou/Dassl.pytorch) so you need to install the `dassl` environment first. Simply follow the instructions described [here](https://github.com/KaiyangZhou/Dassl.pytorch#installation) to install `dassl` as well as PyTorch. 14 | 15 | 2. Follow [guidelines](https://github.com/KaiyangZhou/CoOp/blob/main/DATASETS.md) to install the datasets. 16 | 17 | 3. torch>=2.0.0, cuda==11.8 18 | 19 | 20 | ## Model Zoo 21 | ---- 22 | 23 | 1. The pre-trained weights of Meta-Adapter on ImageNet based on RN50, RN101, ViT-B/16 and ViT-B/32 can be downloaded altogether via this [link](https://drive.google.com/drive/folders/1esyFhs4gj9cEZoFo6B45Mp3eMmsuxwW-?usp=drive_link). The weights can be used to reproduce the results in Table 3 of Meta-Adapter's paper. 24 | 25 | 2. Quantitative results on other datasets are as follows: 26 | 27 | | Model | SUN397 | UCF101 | Caltech101 | DTD | FGVCAircarft | EuroSAT | Oxford_Pets | 28 | |:----: |:------:|:------:|:----------:|:---:|:------------:|:-------:|:-----------:| 29 | | Zero-Shot CLIP | 29.0 | 21.1 | 60.6 |10.0 | 0.4 | 4.2 | 84.0 | 30 | | Meta-Adapter | 52.7 | 52.3 | 71.5 |49.2 | 19.6 | 66.7 | 87.0 | 31 | 32 | 33 | ## Getting Start 34 | ---- 35 | 36 | ### Validate 37 | 38 | 1. change `root_path` in `$DATA.yaml`, the default configurations are `shots=16` and `backbone=RN50`; 39 | 2. run `python main.py --config ./configs/$DATA.yaml`; 40 | 41 | 42 | ## Ref 43 | [1]: Zhang, Renrui, Wei Zhang, Rongyao Fang, Peng Gao, Kunchang Li, Jifeng Dai, Yu Qiao, and Hongsheng Li. "Tip-adapter: Training-free adaption of clip for few-shot classification." In European Conference on Computer Vision, pp. 493-510. Cham: Springer Nature Switzerland, 2022. https://arxiv.org/pdf/2207.09519.pdf 44 | 45 | [2]: Gao, Peng, Shijie Geng, Renrui Zhang, Teli Ma, Rongyao Fang, Yongfeng Zhang, Hongsheng Li, and Yu Qiao. "Clip-adapter: Better vision-language models with feature adapters." International Journal of Computer Vision (2023): 1-15. https://arxiv.org/pdf/2110.04544.pdf 46 | 47 | [3]: Zhou, Kaiyang, Jingkang Yang, Chen Change Loy, and Ziwei Liu. "Learning to prompt for vision-language models." International Journal of Computer Vision 130, no. 9 (2022): 2337-2348. https://arxiv.org/pdf/2109.01134 48 | 49 | 50 | ## Acknowledgement 51 | 52 | If you find Meta-Adapter helpful, please cite: 53 | ``` 54 | @inproceedings{cheng2023meta, 55 | title={Meta-Adapter: An Online Few-shot Learner for Vision-Language Model}, 56 | author={Cheng, Cheng and Song, Lin and Xue, Ruoyi and Wang, Hang and Sun, Hongbin and Ge, Yixiao and Shan, Ying}, 57 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, 58 | year={2023} 59 | } 60 | ``` -------------------------------------------------------------------------------- /assets/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArsenalCheng/Meta-Adapter/dc1d9c08d968f30587d94bf25ec731fbf73bc049/assets/main.png -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /clip/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArsenalCheng/Meta-Adapter/dc1d9c08d968f30587d94bf25ec731fbf73bc049/clip/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /clip/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArsenalCheng/Meta-Adapter/dc1d9c08d968f30587d94bf25ec731fbf73bc049/clip/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /clip/__pycache__/clip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArsenalCheng/Meta-Adapter/dc1d9c08d968f30587d94bf25ec731fbf73bc049/clip/__pycache__/clip.cpython-38.pyc -------------------------------------------------------------------------------- /clip/__pycache__/clip.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArsenalCheng/Meta-Adapter/dc1d9c08d968f30587d94bf25ec731fbf73bc049/clip/__pycache__/clip.cpython-39.pyc -------------------------------------------------------------------------------- /clip/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArsenalCheng/Meta-Adapter/dc1d9c08d968f30587d94bf25ec731fbf73bc049/clip/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /clip/__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArsenalCheng/Meta-Adapter/dc1d9c08d968f30587d94bf25ec731fbf73bc049/clip/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /clip/__pycache__/simple_tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArsenalCheng/Meta-Adapter/dc1d9c08d968f30587d94bf25ec731fbf73bc049/clip/__pycache__/simple_tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /clip/__pycache__/simple_tokenizer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArsenalCheng/Meta-Adapter/dc1d9c08d968f30587d94bf25ec731fbf73bc049/clip/__pycache__/simple_tokenizer.cpython-39.pyc -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArsenalCheng/Meta-Adapter/dc1d9c08d968f30587d94bf25ec731fbf73bc049/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | try: 16 | from torchvision.transforms import InterpolationMode 17 | BICUBIC = InterpolationMode.BICUBIC 18 | except ImportError: 19 | BICUBIC = Image.BICUBIC 20 | 21 | 22 | if torch.__version__.split(".") < ["1", "7", "1"]: 23 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 24 | 25 | 26 | __all__ = ["available_models", "load", "tokenize"] 27 | _tokenizer = _Tokenizer() 28 | 29 | _MODELS = { 30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 34 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 35 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 36 | } 37 | 38 | 39 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 40 | os.makedirs(root, exist_ok=True) 41 | filename = os.path.basename(url) 42 | 43 | expected_sha256 = url.split("/")[-2] 44 | download_target = os.path.join(root, filename) 45 | 46 | if os.path.exists(download_target) and not os.path.isfile(download_target): 47 | raise RuntimeError(f"{download_target} exists and is not a regular file") 48 | 49 | if os.path.isfile(download_target): 50 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 51 | return download_target 52 | else: 53 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 54 | 55 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 56 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 57 | while True: 58 | buffer = source.read(8192) 59 | if not buffer: 60 | break 61 | 62 | output.write(buffer) 63 | loop.update(len(buffer)) 64 | 65 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 66 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 67 | 68 | return download_target 69 | 70 | 71 | def _transform(n_px): 72 | return Compose([ 73 | Resize(n_px, interpolation=BICUBIC), 74 | CenterCrop(n_px), 75 | lambda image: image.convert("RGB"), 76 | ToTensor(), 77 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 78 | ]) 79 | 80 | 81 | def available_models() -> List[str]: 82 | """Returns the names of available CLIP models""" 83 | return list(_MODELS.keys()) 84 | 85 | 86 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False): 87 | """Load a CLIP model 88 | 89 | Parameters 90 | ---------- 91 | name : str 92 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 93 | 94 | device : Union[str, torch.device] 95 | The device to put the loaded model 96 | 97 | jit : bool 98 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 99 | 100 | Returns 101 | ------- 102 | model : torch.nn.Module 103 | The CLIP model 104 | 105 | preprocess : Callable[[PIL.Image], torch.Tensor] 106 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 107 | """ 108 | if name in _MODELS: 109 | model_path = _download(_MODELS[name]) 110 | elif os.path.isfile(name): 111 | model_path = name 112 | else: 113 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 114 | 115 | try: 116 | # loading JIT archive 117 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 118 | state_dict = None 119 | except RuntimeError: 120 | # loading saved state dict 121 | if jit: 122 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 123 | jit = False 124 | state_dict = torch.load(model_path, map_location="cpu") 125 | 126 | if not jit: 127 | model = build_model(state_dict or model.state_dict()).to(device) 128 | if str(device) == "cpu": 129 | model.float() 130 | return model, _transform(model.visual.input_resolution) 131 | 132 | # patch the device names 133 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 134 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 135 | 136 | def patch_device(module): 137 | try: 138 | graphs = [module.graph] if hasattr(module, "graph") else [] 139 | except RuntimeError: 140 | graphs = [] 141 | 142 | if hasattr(module, "forward1"): 143 | graphs.append(module.forward1.graph) 144 | 145 | for graph in graphs: 146 | for node in graph.findAllNodes("prim::Constant"): 147 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 148 | node.copyAttributes(device_node) 149 | 150 | model.apply(patch_device) 151 | patch_device(model.encode_image) 152 | patch_device(model.encode_text) 153 | 154 | # patch dtype to float32 on CPU 155 | if str(device) == "cpu": 156 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 157 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 158 | float_node = float_input.node() 159 | 160 | def patch_float(module): 161 | try: 162 | graphs = [module.graph] if hasattr(module, "graph") else [] 163 | except RuntimeError: 164 | graphs = [] 165 | 166 | if hasattr(module, "forward1"): 167 | graphs.append(module.forward1.graph) 168 | 169 | for graph in graphs: 170 | for node in graph.findAllNodes("aten::to"): 171 | inputs = list(node.inputs()) 172 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 173 | if inputs[i].node()["value"] == 5: 174 | inputs[i].node().copyAttributes(float_node) 175 | 176 | model.apply(patch_float) 177 | patch_float(model.encode_image) 178 | patch_float(model.encode_text) 179 | 180 | model.float() 181 | 182 | return model, _transform(model.input_resolution.item()) 183 | 184 | 185 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 186 | """ 187 | Returns the tokenized representation of given input string(s) 188 | 189 | Parameters 190 | ---------- 191 | texts : Union[str, List[str]] 192 | An input string or a list of input strings to tokenize 193 | 194 | context_length : int 195 | The context length to use; all CLIP models use 77 as the context length 196 | 197 | truncate: bool 198 | Whether to truncate the text in case its encoding is longer than the context length 199 | 200 | Returns 201 | ------- 202 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 203 | """ 204 | if isinstance(texts, str): 205 | texts = [texts] 206 | 207 | sot_token = _tokenizer.encoder["<|startoftext|>"] 208 | eot_token = _tokenizer.encoder["<|endoftext|>"] 209 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 210 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 211 | 212 | for i, tokens in enumerate(all_tokens): 213 | if len(tokens) > context_length: 214 | if truncate: 215 | tokens = tokens[:context_length] 216 | tokens[-1] = eot_token 217 | else: 218 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 219 | result[i, :len(tokens)] = torch.tensor(tokens) 220 | 221 | return result 222 | -------------------------------------------------------------------------------- /clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | 20 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 24 | 25 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 26 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 27 | 28 | self.relu = nn.ReLU(inplace=True) 29 | self.downsample = None 30 | self.stride = stride 31 | 32 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 33 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 34 | self.downsample = nn.Sequential(OrderedDict([ 35 | ("-1", nn.AvgPool2d(stride)), 36 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 37 | ("1", nn.BatchNorm2d(planes * self.expansion)) 38 | ])) 39 | 40 | def forward(self, x: torch.Tensor): 41 | identity = x 42 | 43 | out = self.relu(self.bn1(self.conv1(x))) 44 | out = self.relu(self.bn2(self.conv2(out))) 45 | out = self.avgpool(out) 46 | out = self.bn3(self.conv3(out)) 47 | 48 | if self.downsample is not None: 49 | identity = self.downsample(x) 50 | 51 | out += identity 52 | out = self.relu(out) 53 | return out 54 | 55 | 56 | class AttentionPool2d(nn.Module): 57 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 58 | super().__init__() 59 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 60 | self.k_proj = nn.Linear(embed_dim, embed_dim) 61 | self.q_proj = nn.Linear(embed_dim, embed_dim) 62 | self.v_proj = nn.Linear(embed_dim, embed_dim) 63 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 64 | self.num_heads = num_heads 65 | 66 | def forward(self, x): 67 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 68 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 69 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 70 | x, _ = F.multi_head_attention_forward( 71 | query=x, key=x, value=x, 72 | embed_dim_to_check=x.shape[-1], 73 | num_heads=self.num_heads, 74 | q_proj_weight=self.q_proj.weight, 75 | k_proj_weight=self.k_proj.weight, 76 | v_proj_weight=self.v_proj.weight, 77 | in_proj_weight=None, 78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 79 | bias_k=None, 80 | bias_v=None, 81 | add_zero_attn=False, 82 | dropout_p=0, 83 | out_proj_weight=self.c_proj.weight, 84 | out_proj_bias=self.c_proj.bias, 85 | use_separate_proj_weight=True, 86 | training=self.training, 87 | need_weights=False 88 | ) 89 | 90 | return x[0] 91 | 92 | 93 | class ModifiedResNet(nn.Module): 94 | """ 95 | A ResNet class that is similar to torchvision's but contains the following changes: 96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 98 | - The final pooling layer is a QKV attention instead of an average pool 99 | """ 100 | 101 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 102 | super().__init__() 103 | self.output_dim = output_dim 104 | self.input_resolution = input_resolution 105 | 106 | # the 3-layer stem 107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 108 | self.bn1 = nn.BatchNorm2d(width // 2) 109 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 110 | self.bn2 = nn.BatchNorm2d(width // 2) 111 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 112 | self.bn3 = nn.BatchNorm2d(width) 113 | self.avgpool = nn.AvgPool2d(2) 114 | self.relu = nn.ReLU(inplace=True) 115 | 116 | # residual layers 117 | self._inplanes = width # this is a *mutable* variable used during construction 118 | self.layer1 = self._make_layer(width, layers[0]) 119 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 120 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 121 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 122 | 123 | embed_dim = width * 32 # the ResNet feature dimension 124 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 125 | 126 | def _make_layer(self, planes, blocks, stride=1): 127 | layers = [Bottleneck(self._inplanes, planes, stride)] 128 | 129 | self._inplanes = planes * Bottleneck.expansion 130 | for _ in range(1, blocks): 131 | layers.append(Bottleneck(self._inplanes, planes)) 132 | 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | def stem(x): 137 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 138 | x = self.relu(bn(conv(x))) 139 | x = self.avgpool(x) 140 | return x 141 | 142 | x = x.type(self.conv1.weight.dtype) 143 | x = stem(x) 144 | x = self.layer1(x) 145 | x = self.layer2(x) 146 | x = self.layer3(x) 147 | x = self.layer4(x) 148 | x = self.attnpool(x) 149 | 150 | return x 151 | 152 | 153 | class LayerNorm(nn.LayerNorm): 154 | """Subclass torch's LayerNorm to handle fp16.""" 155 | 156 | def forward(self, x: torch.Tensor): 157 | orig_type = x.dtype 158 | ret = super().forward(x.type(torch.float32)) 159 | return ret.type(orig_type) 160 | 161 | 162 | class QuickGELU(nn.Module): 163 | def forward(self, x: torch.Tensor): 164 | return x * torch.sigmoid(1.702 * x) 165 | 166 | 167 | class ResidualAttentionBlock(nn.Module): 168 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 169 | super().__init__() 170 | 171 | self.attn = nn.MultiheadAttention(d_model, n_head) 172 | self.ln_1 = LayerNorm(d_model) 173 | self.mlp = nn.Sequential(OrderedDict([ 174 | ("c_fc", nn.Linear(d_model, d_model * 4)), 175 | ("gelu", QuickGELU()), 176 | ("c_proj", nn.Linear(d_model * 4, d_model)) 177 | ])) 178 | self.ln_2 = LayerNorm(d_model) 179 | self.attn_mask = attn_mask 180 | 181 | def attention(self, x: torch.Tensor): 182 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 183 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 184 | 185 | def forward(self, x: torch.Tensor): 186 | x = x + self.attention(self.ln_1(x)) 187 | x = x + self.mlp(self.ln_2(x)) 188 | return x 189 | 190 | 191 | class Transformer(nn.Module): 192 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 193 | super().__init__() 194 | self.width = width 195 | self.layers = layers 196 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 197 | 198 | def forward(self, x: torch.Tensor): 199 | return self.resblocks(x) 200 | 201 | 202 | class VisionTransformer(nn.Module): 203 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 204 | super().__init__() 205 | self.input_resolution = input_resolution 206 | self.output_dim = output_dim 207 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 208 | 209 | scale = width ** -0.5 210 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 211 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 212 | self.ln_pre = LayerNorm(width) 213 | 214 | self.transformer = Transformer(width, layers, heads) 215 | 216 | self.ln_post = LayerNorm(width) 217 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 218 | 219 | def forward(self, x: torch.Tensor): 220 | x = self.conv1(x) # shape = [*, width, grid, grid] 221 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 222 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 223 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 224 | x = x + self.positional_embedding.to(x.dtype) 225 | x = self.ln_pre(x) 226 | 227 | x = x.permute(1, 0, 2) # NLD -> LND 228 | x = self.transformer(x) 229 | x = x.permute(1, 0, 2) # LND -> NLD 230 | 231 | x = self.ln_post(x[:, 0, :]) 232 | 233 | if self.proj is not None: 234 | x = x @ self.proj 235 | 236 | return x 237 | 238 | 239 | class CLIP(nn.Module): 240 | def __init__(self, 241 | embed_dim: int, 242 | # vision 243 | image_resolution: int, 244 | vision_layers: Union[Tuple[int, int, int, int], int], 245 | vision_width: int, 246 | vision_patch_size: int, 247 | # text 248 | context_length: int, 249 | vocab_size: int, 250 | transformer_width: int, 251 | transformer_heads: int, 252 | transformer_layers: int 253 | ): 254 | super().__init__() 255 | 256 | self.context_length = context_length 257 | 258 | if isinstance(vision_layers, (tuple, list)): 259 | vision_heads = vision_width * 32 // 64 260 | self.visual = ModifiedResNet( 261 | layers=vision_layers, 262 | output_dim=embed_dim, 263 | heads=vision_heads, 264 | input_resolution=image_resolution, 265 | width=vision_width 266 | ) 267 | else: 268 | vision_heads = vision_width // 64 269 | self.visual = VisionTransformer( 270 | input_resolution=image_resolution, 271 | patch_size=vision_patch_size, 272 | width=vision_width, 273 | layers=vision_layers, 274 | heads=vision_heads, 275 | output_dim=embed_dim 276 | ) 277 | 278 | self.transformer = Transformer( 279 | width=transformer_width, 280 | layers=transformer_layers, 281 | heads=transformer_heads, 282 | attn_mask=self.build_attention_mask() 283 | ) 284 | 285 | self.vocab_size = vocab_size 286 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 287 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 288 | self.ln_final = LayerNorm(transformer_width) 289 | 290 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 291 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 292 | 293 | self.initialize_parameters() 294 | 295 | def initialize_parameters(self): 296 | nn.init.normal_(self.token_embedding.weight, std=0.02) 297 | nn.init.normal_(self.positional_embedding, std=0.01) 298 | 299 | if isinstance(self.visual, ModifiedResNet): 300 | if self.visual.attnpool is not None: 301 | std = self.visual.attnpool.c_proj.in_features ** -0.5 302 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 303 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 304 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 305 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 306 | 307 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 308 | for name, param in resnet_block.named_parameters(): 309 | if name.endswith("bn3.weight"): 310 | nn.init.zeros_(param) 311 | 312 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 313 | attn_std = self.transformer.width ** -0.5 314 | fc_std = (2 * self.transformer.width) ** -0.5 315 | for block in self.transformer.resblocks: 316 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 317 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 318 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 319 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 320 | 321 | if self.text_projection is not None: 322 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 323 | 324 | def build_attention_mask(self): 325 | # lazily create causal attention mask, with full attention between the vision tokens 326 | # pytorch uses additive attention mask; fill with -inf 327 | mask = torch.empty(self.context_length, self.context_length) 328 | mask.fill_(float("-inf")) 329 | mask.triu_(1) # zero out the lower diagonal 330 | return mask 331 | 332 | @property 333 | def dtype(self): 334 | return self.visual.conv1.weight.dtype 335 | 336 | def encode_image(self, image): 337 | return self.visual(image.type(self.dtype)) 338 | 339 | def encode_text(self, text): 340 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 341 | 342 | x = x + self.positional_embedding.type(self.dtype) 343 | x = x.permute(1, 0, 2) # NLD -> LND 344 | x = self.transformer(x) 345 | x = x.permute(1, 0, 2) # LND -> NLD 346 | x = self.ln_final(x).type(self.dtype) 347 | 348 | # x.shape = [batch_size, n_ctx, transformer.width] 349 | # take features from the eot embedding (eot_token is the highest number in each sequence) 350 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 351 | 352 | return x 353 | 354 | def forward(self, image, text): 355 | image_features = self.encode_image(image) 356 | text_features = self.encode_text(text) 357 | 358 | # normalized features 359 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 360 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 361 | 362 | # cosine similarity as logits 363 | logit_scale = self.logit_scale.exp() 364 | logits_per_image = logit_scale * image_features @ text_features.t() 365 | logits_per_text = logit_scale * text_features @ image_features.t() 366 | 367 | # shape = [global_batch_size, global_batch_size] 368 | return logits_per_image, logits_per_text 369 | 370 | 371 | def convert_weights(model: nn.Module): 372 | """Convert applicable model parameters to fp16""" 373 | 374 | def _convert_weights_to_fp16(l): 375 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 376 | l.weight.data = l.weight.data.half() 377 | if l.bias is not None: 378 | l.bias.data = l.bias.data.half() 379 | 380 | if isinstance(l, nn.MultiheadAttention): 381 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 382 | tensor = getattr(l, attr) 383 | if tensor is not None: 384 | tensor.data = tensor.data.half() 385 | 386 | for name in ["text_projection", "proj"]: 387 | if hasattr(l, name): 388 | attr = getattr(l, name) 389 | if attr is not None: 390 | attr.data = attr.data.half() 391 | 392 | model.apply(_convert_weights_to_fp16) 393 | 394 | 395 | def build_model(state_dict: dict): 396 | vit = "visual.proj" in state_dict 397 | 398 | if vit: 399 | vision_width = state_dict["visual.conv1.weight"].shape[0] 400 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 401 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 402 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 403 | image_resolution = vision_patch_size * grid_size 404 | else: 405 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 406 | vision_layers = tuple(counts) 407 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 408 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 409 | vision_patch_size = None 410 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 411 | image_resolution = output_width * 32 412 | 413 | embed_dim = state_dict["text_projection"].shape[1] 414 | context_length = state_dict["positional_embedding"].shape[0] 415 | vocab_size = state_dict["token_embedding.weight"].shape[0] 416 | transformer_width = state_dict["ln_final.weight"].shape[0] 417 | transformer_heads = transformer_width // 64 418 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 419 | 420 | model = CLIP( 421 | embed_dim, 422 | image_resolution, vision_layers, vision_width, vision_patch_size, 423 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 424 | ) 425 | 426 | for key in ["input_resolution", "context_length", "vocab_size"]: 427 | if key in state_dict: 428 | del state_dict[key] 429 | 430 | convert_weights(model) 431 | model.load_state_dict(state_dict) 432 | return model.eval() 433 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /clip/utils.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | 7 | import clip 8 | 9 | 10 | def cls_acc(output, target, topk=1): 11 | pred = output.topk(topk, 1, True, True)[1].t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | acc = float(correct[: topk].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) 14 | acc = 100 * acc / target.shape[0] 15 | return acc 16 | 17 | 18 | def clip_classifier(classnames, template, clip_model): 19 | with torch.no_grad(): 20 | clip_weights = [] 21 | 22 | for classname in classnames: 23 | # Tokenize the prompts 24 | classname = classname.replace('_', ' ') 25 | texts = [t.format(classname) for t in template] 26 | texts = clip.tokenize(texts).cuda() 27 | # prompt ensemble for ImageNet 28 | class_embeddings = clip_model.encode_text(texts) 29 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) 30 | class_embedding = class_embeddings.mean(dim=0) 31 | class_embedding /= class_embedding.norm() 32 | clip_weights.append(class_embedding) 33 | 34 | clip_weights = torch.stack(clip_weights, dim=1).cuda() 35 | return clip_weights 36 | 37 | 38 | def build_cache_model(cfg, clip_model, train_loader_cache): 39 | if cfg['load_cache'] == False: 40 | cache_keys = [] 41 | 42 | with torch.no_grad(): 43 | # Data augmentation for the cache model 44 | for augment_idx in range(cfg['augment_epoch']): 45 | train_features = [] 46 | 47 | print('Augment Epoch: {:} / {:}'.format(augment_idx, cfg['augment_epoch'])) 48 | for i, (images, target) in enumerate(tqdm(train_loader_cache)): 49 | images = images.cuda() 50 | image_features = clip_model.encode_image(images) 51 | train_features.append(image_features) 52 | cache_keys.append(torch.cat(train_features, dim=0).unsqueeze(0)) 53 | 54 | cache_keys = torch.cat(cache_keys, dim=0).mean(dim=0) 55 | cache_keys /= cache_keys.norm(dim=-1, keepdim=True) 56 | cache_keys = cache_keys.permute(1, 0) 57 | 58 | torch.save(cache_keys, cfg['cache_dir'] + '/keys_' + str(cfg['shots']) + "shots.pt") 59 | 60 | else: 61 | cache_keys = torch.load(cfg['cache_dir'] + '/keys_' + str(cfg['shots']) + "shots.pt") 62 | 63 | return cache_keys 64 | 65 | 66 | def pre_load_features(cfg, split, clip_model, loader): 67 | if cfg['load_pre_feat'] == False: 68 | features, labels = [], [] 69 | 70 | with torch.no_grad(): 71 | for i, (images, target) in enumerate(tqdm(loader)): 72 | images, target = images.cuda(), target.cuda() 73 | image_features = clip_model.encode_image(images) 74 | image_features /= image_features.norm(dim=-1, keepdim=True) 75 | features.append(image_features) 76 | labels.append(target) 77 | 78 | features, labels = torch.cat(features), torch.cat(labels) 79 | 80 | torch.save(features, cfg['cache_dir'] + "/" + split + "_f.pt") 81 | torch.save(labels, cfg['cache_dir'] + "/" + split + "_l.pt") 82 | 83 | else: 84 | features = torch.load(cfg['cache_dir'] + "/" + split + "_f.pt") 85 | labels = torch.load(cfg['cache_dir'] + "/" + split + "_l.pt") 86 | 87 | return features, labels 88 | 89 | 90 | def search_hp(cfg, cache_keys, cache_values, features, labels, clip_weights, adapter=None): 91 | if cfg['search_hp'] == True: 92 | 93 | beta_list = [i * (cfg['search_scale'][0] - 0.1) / cfg['search_step'][0] + 0.1 for i in 94 | range(cfg['search_step'][0])] 95 | alpha_list = [i * (cfg['search_scale'][1] - 0.1) / cfg['search_step'][1] + 0.1 for i in 96 | range(cfg['search_step'][1])] 97 | 98 | best_acc = 0 99 | best_beta, best_alpha = 0, 0 100 | 101 | for beta in beta_list: 102 | for alpha in alpha_list: 103 | if adapter: 104 | affinity = adapter(features) 105 | else: 106 | affinity = features @ cache_keys 107 | 108 | cache_logits = ((-1) * (beta - beta * affinity)).exp() @ cache_values 109 | clip_logits = 100. * features @ clip_weights 110 | tip_logits = clip_logits + cache_logits * alpha 111 | acc = cls_acc(tip_logits, labels) 112 | 113 | if acc > best_acc: 114 | print("New best setting, beta: {:.2f}, alpha: {:.2f}; accuracy: {:.2f}".format(beta, alpha, acc)) 115 | best_acc = acc 116 | best_beta = beta 117 | best_alpha = alpha 118 | 119 | print("\nAfter searching, the best accuarcy: {:.2f}.\n".format(best_acc)) 120 | 121 | return best_beta, best_alpha -------------------------------------------------------------------------------- /configs/caltech101.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: ~ 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | 10 | # ------ Basic Config ------ 11 | dataset: 'caltech101' 12 | shots: 16 13 | backbone: 'RN50' 14 | lr: 0.0001 15 | augment_epoch: 10 16 | train_epoch: 5 17 | -------------------------------------------------------------------------------- /configs/dtd.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: ~ 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | 10 | # ------ Basic Config ------ 11 | dataset: 'dtd' 12 | shots: 16 13 | backbone: 'RN50' 14 | lr: 0.0001 15 | augment_epoch: 10 16 | train_epoch: 5 17 | -------------------------------------------------------------------------------- /configs/eurosat.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: ~ 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | 10 | # ------ Basic Config ------ 11 | dataset: 'eurosat' 12 | shots: 16 13 | backbone: 'RN50' 14 | lr: 0.0001 15 | augment_epoch: 10 16 | train_epoch: 5 17 | -------------------------------------------------------------------------------- /configs/fgvc.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: ~ 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | 10 | # ------ Basic Config ------ 11 | dataset: 'fgvc' 12 | shots: 16 13 | backbone: 'RN50' 14 | lr: 0.0001 15 | augment_epoch: 10 16 | train_epoch: 5 17 | -------------------------------------------------------------------------------- /configs/imagenet.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: ~ 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | 10 | # ------ Basic Config ------ 11 | dataset: 'imagenet' 12 | shots: 16 13 | backbone: 'RN50' 14 | lr: 0.0001 15 | augment_epoch: 10 16 | train_epoch: 5 17 | -------------------------------------------------------------------------------- /configs/oxford_pets.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: ~ 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | 10 | # ------ Basic Config ------ 11 | dataset: 'oxford_pets' 12 | shots: 16 13 | backbone: 'RN50' 14 | lr: 0.0001 15 | augment_epoch: 10 16 | train_epoch: 5 17 | -------------------------------------------------------------------------------- /configs/sun397.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: ~ 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | 10 | # ------ Basic Config ------ 11 | dataset: 'sun397' 12 | shots: 16 13 | backbone: 'RN50' 14 | lr: 0.0001 15 | augment_epoch: 10 16 | train_epoch: 5 17 | -------------------------------------------------------------------------------- /configs/ucf101.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: ~ 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | 10 | # ------ Basic Config ------ 11 | dataset: 'ucf101' 12 | shots: 16 13 | backbone: 'RN50' 14 | lr: 0.0001 15 | augment_epoch: 10 16 | train_epoch: 5 17 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .oxford_pets import OxfordPets 2 | from .eurosat import EuroSAT 3 | from .ucf101 import UCF101 4 | from .sun397 import SUN397 5 | from .caltech101 import Caltech101 6 | from .dtd import DescribableTextures 7 | from .fgvc import FGVCAircraft 8 | from .imagenet import ImageNet 9 | 10 | 11 | dataset_list = { 12 | "oxford_pets": OxfordPets, 13 | "eurosat": EuroSAT, 14 | "ucf101": UCF101, 15 | "sun397": SUN397, 16 | "caltech101": Caltech101, 17 | "dtd": DescribableTextures, 18 | "fgvc": FGVCAircraft, 19 | "imagenet": ImageNet, 20 | } 21 | 22 | 23 | def build_dataset(dataset, root_path, shots): 24 | return dataset_list[dataset](root_path, shots) -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArsenalCheng/Meta-Adapter/dc1d9c08d968f30587d94bf25ec731fbf73bc049/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/caltech101.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArsenalCheng/Meta-Adapter/dc1d9c08d968f30587d94bf25ec731fbf73bc049/datasets/__pycache__/caltech101.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dtd.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArsenalCheng/Meta-Adapter/dc1d9c08d968f30587d94bf25ec731fbf73bc049/datasets/__pycache__/dtd.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/eurosat.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArsenalCheng/Meta-Adapter/dc1d9c08d968f30587d94bf25ec731fbf73bc049/datasets/__pycache__/eurosat.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/fgvc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArsenalCheng/Meta-Adapter/dc1d9c08d968f30587d94bf25ec731fbf73bc049/datasets/__pycache__/fgvc.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArsenalCheng/Meta-Adapter/dc1d9c08d968f30587d94bf25ec731fbf73bc049/datasets/__pycache__/imagenet.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/oxford_pets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArsenalCheng/Meta-Adapter/dc1d9c08d968f30587d94bf25ec731fbf73bc049/datasets/__pycache__/oxford_pets.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sun397.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArsenalCheng/Meta-Adapter/dc1d9c08d968f30587d94bf25ec731fbf73bc049/datasets/__pycache__/sun397.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/ucf101.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArsenalCheng/Meta-Adapter/dc1d9c08d968f30587d94bf25ec731fbf73bc049/datasets/__pycache__/ucf101.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArsenalCheng/Meta-Adapter/dc1d9c08d968f30587d94bf25ec731fbf73bc049/datasets/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/caltech101.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .utils import Datum, DatasetBase 3 | from .oxford_pets import OxfordPets 4 | 5 | 6 | template = ['a photo of a {}.'] 7 | base_classes = ['windsor_chair', 'trilobite', 'tick', 'sunflower', 'strawberry', 'stop_sign', 'stegosaurus', 'soccer_ball', 'rooster', 'pyramid', 8 | 'pizza', 'panda', 'pagoda', 'okapi', 'motorbike', 'metronome', 'laptop', 'inline_skate', 'headphone', 'gramophone', 'ewer', 9 | 'dollar_bill', 'dalmatian', 'car_side', 'cannon', 'buddha', 'brain', 'bonsai', 'barrel', 'accordion', 'airplane', 'watch', 10 | 'starfish', 'helicopter', 'revolver', 'ferry', 'joshua_tree', 'yin_yang', 'wheelchair', 'nautilus', 'emu', 'grand_piano', 'stapler', 11 | 'pigeon', 'menorah', 'water_lilly', 'saxophone', 'cougar_face', 'platypus', 'garfield', 'binocular', 'sea_horse', 'cup', 'kangaroo', 12 | 'hedgehog', 'bass', 'hawksbill', 'camera', 'umbrella', 'cougar_body', 'dolphin', 'scorpion', 'minaret', 'llama', 'wrench', 13 | 'scissors', 'butterfly', 'snoopy', 'euphonium', 'ceiling_fan'] 14 | novel_classes = ['beaver', 'leopard', 'mayfly', 'ibis', 'brontosaurus', 'elephant', 'schooner', 'flamingo_head', 'gerenuk', 'flamingo', 15 | 'mandolin', 'crocodile', 'chandelier', 'face', 'crayfish', 'anchor', 'rhino', 'lamp', 'lotus', 'dragonfly', 'electric_guitar', 16 | 'wild_cat', 'octopus', 'cellphone', 'lobster', 'ketch', 'ant', 'chair', 'crab', 'crocodile_head'] 17 | 18 | 19 | class Caltech101(DatasetBase): 20 | 21 | 22 | dataset_dir = 'caltech-101' 23 | 24 | def __init__(self, root, num_shots): 25 | self.dataset_dir = os.path.join(root, self.dataset_dir) 26 | self.image_dir = os.path.join(self.dataset_dir, '101_ObjectCategories') 27 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_Caltech101.json') 28 | 29 | self.template = template 30 | 31 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 32 | few_shot_base = [] 33 | for item in train: 34 | if item.classname in base_classes: 35 | few_shot_base.append(item) 36 | few_shot_base = self.generate_fewshot_dataset(few_shot_base, num_shots=num_shots) 37 | few_shot_full = self.generate_fewshot_dataset(val, num_shots=num_shots) 38 | 39 | test_novel = [] 40 | for item in test: 41 | if item.classname in novel_classes: 42 | test_novel.append(item) 43 | test_novel = self.generate_fewshot_dataset(test_novel, num_shots=num_shots) 44 | 45 | super().__init__(train=few_shot_base, full=few_shot_full, val=test_novel) 46 | -------------------------------------------------------------------------------- /datasets/dtd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from .utils import Datum, DatasetBase, listdir_nohidden 4 | from .oxford_pets import OxfordPets 5 | 6 | template = ['{} texture.'] 7 | 8 | base_classes = ['paisley', 'knitted', 'chequered', 'bubbly', 'crystalline', 'cobwebbed', 'striped', 'pleated', 9 | 'cracked', 'studded', 10 | 'waffled', 'polka-dotted', 'freckled', 'perforated', 'honeycombed', 'stratified', 'potholed', 'swirly', 11 | 'porous', 'grid', 12 | 'frilly', 'sprinkled', 'meshed', 'wrinkled', 'spiralled', 'marbled', 'scaly', 'blotchy', 'gauzy', 13 | 'woven', 'veined', 'crosshatched'] 14 | novel_classes = ['braided', 'dotted', 'matted', 'flecked', 'smeared', 'grooved', 'lined', 'banded', 'stained', 15 | 'interlaced', 'fibrous', 16 | 'zigzagged', 'pitted', 'lacelike', 'bumpy'] 17 | 18 | 19 | class DescribableTextures(DatasetBase): 20 | dataset_dir = 'dtd' 21 | 22 | def __init__(self, root, num_shots): 23 | self.dataset_dir = os.path.join(root, self.dataset_dir) 24 | self.image_dir = os.path.join(self.dataset_dir, 'images') 25 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_DescribableTextures.json') 26 | 27 | self.template = template 28 | 29 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 30 | few_shot_base = [] 31 | for item in train: 32 | if item.classname in base_classes: 33 | few_shot_base.append(item) 34 | few_shot_base = self.generate_fewshot_dataset(few_shot_base, num_shots=num_shots) 35 | few_shot_full = self.generate_fewshot_dataset(val, num_shots=num_shots) 36 | 37 | test_novel = [] 38 | for item in test: 39 | if item.classname in novel_classes: 40 | test_novel.append(item) 41 | test_novel = self.generate_fewshot_dataset(test_novel, num_shots=num_shots) 42 | 43 | super().__init__(train=few_shot_base, full=few_shot_full, val=test_novel) 44 | 45 | @staticmethod 46 | def read_and_split_data( 47 | image_dir, 48 | p_trn=0.5, 49 | p_val=0.2, 50 | ignored=[], 51 | new_cnames=None 52 | ): 53 | # The data are supposed to be organized into the following structure 54 | # ============= 55 | # images/ 56 | # dog/ 57 | # cat/ 58 | # horse/ 59 | # ============= 60 | categories = listdir_nohidden(image_dir) 61 | categories = [c for c in categories if c not in ignored] 62 | categories.sort() 63 | 64 | p_tst = 1 - p_trn - p_val 65 | print(f'Splitting into {p_trn:.0%} train, {p_val:.0%} val, and {p_tst:.0%} test') 66 | 67 | def _collate(ims, y, c): 68 | items = [] 69 | for im in ims: 70 | item = Datum( 71 | impath=im, 72 | label=y, # is already 0-based 73 | classname=c 74 | ) 75 | items.append(item) 76 | return items 77 | 78 | train, val, test = [], [], [] 79 | for label, category in enumerate(categories): 80 | category_dir = os.path.join(image_dir, category) 81 | images = listdir_nohidden(category_dir) 82 | images = [os.path.join(category_dir, im) for im in images] 83 | random.shuffle(images) 84 | n_total = len(images) 85 | n_train = round(n_total * p_trn) 86 | n_val = round(n_total * p_val) 87 | n_test = n_total - n_train - n_val 88 | assert n_train > 0 and n_val > 0 and n_test > 0 89 | 90 | if new_cnames is not None and category in new_cnames: 91 | category = new_cnames[category] 92 | 93 | train.extend(_collate(images[:n_train], label, category)) 94 | val.extend(_collate(images[n_train:n_train + n_val], label, category)) 95 | test.extend(_collate(images[n_train + n_val:], label, category)) 96 | 97 | return train, val, test 98 | -------------------------------------------------------------------------------- /datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase 4 | from .oxford_pets import OxfordPets 5 | 6 | template = ['a centered satellite photo of {}.'] 7 | 8 | NEW_CNAMES = { 9 | 'Annual Crop Land': 'AnnualCrop', 10 | 'Forest': 'Forest', 11 | 'Herbaceous Vegetation Land': 'HerbaceousVegetation', 12 | 'Highway or Road': 'Highway', 13 | 'Industrial Buildings': 'Industrial', 14 | 'Pasture Land': 'Pasture', 15 | 'Permanent Crop Land': 'PermanentCrop', 16 | 'Residential Buildings': 'Residential', 17 | 'River': 'River', 18 | 'Sea or Lake': 'SeaLake' 19 | } 20 | base_classes = ['Forest', 'Industrial Buildings', 'Highway or Road', 'Residential Buildings', 'Pasture Land', 21 | 'Permanent Crop Land', 'Sea or Lake'] 22 | novel_classes = ['River', 'Herbaceous Vegetation Land', 'Annual Crop Land'] 23 | 24 | 25 | class EuroSAT(DatasetBase): 26 | dataset_dir = 'eurosat' 27 | 28 | def __init__(self, root, num_shots): 29 | self.dataset_dir = os.path.join(root, self.dataset_dir) 30 | self.image_dir = os.path.join(self.dataset_dir, '2750') 31 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_EuroSAT.json') 32 | 33 | self.template = template 34 | 35 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 36 | few_shot_base = [] 37 | for item in train: 38 | if item.classname in base_classes: 39 | few_shot_base.append(item) 40 | few_shot_base = self.generate_fewshot_dataset(train, num_shots=num_shots) 41 | few_shot_full = self.generate_fewshot_dataset(val, num_shots=num_shots) 42 | 43 | test_novel = [] 44 | for item in test: 45 | if item.classname in novel_classes: 46 | test_novel.append(item) 47 | test_novel = self.generate_fewshot_dataset(test_novel, num_shots=num_shots) 48 | 49 | self.update_classname(few_shot_base) 50 | self.update_classname(test_novel) 51 | self.update_classname(few_shot_full) 52 | self.update_classname(val) 53 | 54 | super().__init__(train=few_shot_base, full=few_shot_full, val=test_novel) 55 | 56 | def update_classname(self, dataset_old): 57 | dataset_new = [] 58 | for item_old in dataset_old: 59 | cname_old = item_old.classname 60 | cname_new = NEW_CNAMES[cname_old] 61 | item_new = Datum( 62 | impath=item_old.impath, 63 | label=item_old.label, 64 | classname=cname_new 65 | ) 66 | dataset_new.append(item_new) 67 | return dataset_new 68 | -------------------------------------------------------------------------------- /datasets/fgvc.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase 4 | 5 | template = ['a photo of a {}, a type of aircraft.'] 6 | 7 | base_classes = ['Eurofighter Typhoon', 'Hawk T1', 'Spitfire', 'F-16A/B', 'DH-82', 'C-130', 'A380', 'F/A-18', 8 | 'Cessna 208', 'Il-76', 'Embraer Legacy 600', 'BAE 146-200', 'ATR-72', 'Global Express', 'DC-3', 'A318', 9 | '777-300', 'A310', 'DC-8', 'DHC-1', 'Challenger 600', 'A340-600', 'A340-200', 'Fokker 50', 10 | 'Falcon 2000', 'MD-11', 'Gulfstream V', 'A319', 'Fokker 70', 'DC-10', 'A330-300', 'A320', '777-200', 11 | 'SR-20', 'DHC-6', 'Cessna 172', 'DHC-8-100', 'DC-6', 'Beechcraft 1900', '707-320', 'Cessna 560', 12 | 'A340-300', 'DC-9-30', 'Fokker 100', 'Cessna 525', '747-300', '727-200', 'Metroliner', 'Yak-42', 13 | 'Tu-134', 'Saab 340', 'Saab 2000', 'PA-28', 'ERJ 145', 'DHC-8-300', 'C-47', 'ATR-42', 'A330-200', 14 | '767-200', 'BAE 146-300', '757-200', 'Model B200', 'MD-90', 'Falcon 900', 'Dornier 328', 'A340-500', 15 | '747-400', '747-100', '737-400', 'MD-80'] 16 | novel_classes = ['Gulfstream IV', 'CRJ-200', 'Boeing 717', '747-200', '737-800', 'Tu-154', 'Tornado', 'MD-87', 'L-1011', 17 | 'ERJ 135', 'EMB-120', 'E-195', 'E-190', 'E-170', 'DR-400', 'CRJ-900', 'CRJ-700', 'BAE-125', 'An-12', 18 | 'A321', 'A300B4', '767-400', '767-300', '757-300', '737-900', '737-700', '737-600', '737-500', 19 | '737-300', '737-200'] 20 | 21 | 22 | class FGVCAircraft(DatasetBase): 23 | dataset_dir = 'fgvc_aircraft' 24 | 25 | def __init__(self, root, num_shots): 26 | 27 | self.dataset_dir = os.path.join(root, self.dataset_dir) 28 | self.image_dir = os.path.join(self.dataset_dir, 'images') 29 | 30 | self.template = template 31 | 32 | classnames = [] 33 | with open(os.path.join(self.dataset_dir, 'variants.txt'), 'r') as f: 34 | lines = f.readlines() 35 | for line in lines: 36 | classnames.append(line.strip()) 37 | cname2lab = {c: i for i, c in enumerate(classnames)} 38 | 39 | train = self.read_data(cname2lab, 'images_variant_train.txt') 40 | val = self.read_data(cname2lab, 'images_variant_val.txt') 41 | test = self.read_data(cname2lab, 'images_variant_test.txt') 42 | 43 | few_shot_base = [] 44 | for item in train: 45 | if item.classname in base_classes: 46 | few_shot_base.append(item) 47 | few_shot_base = self.generate_fewshot_dataset(few_shot_base, num_shots=num_shots) 48 | few_shot_full = self.generate_fewshot_dataset(val, num_shots=num_shots) 49 | 50 | test_novel = [] 51 | for item in test: 52 | if item.classname in novel_classes: 53 | test_novel.append(item) 54 | test_novel = self.generate_fewshot_dataset(test_novel, num_shots=num_shots) 55 | 56 | super().__init__(train=few_shot_base, full=few_shot_full, val=test_novel) 57 | 58 | def read_data(self, cname2lab, split_file): 59 | filepath = os.path.join(self.dataset_dir, split_file) 60 | items = [] 61 | 62 | with open(filepath, 'r') as f: 63 | lines = f.readlines() 64 | for line in lines: 65 | line = line.strip().split(' ') 66 | imname = line[0] + '.jpg' 67 | classname = ' '.join(line[1:]) 68 | impath = os.path.join(self.image_dir, imname) 69 | label = cname2lab[classname] 70 | item = Datum( 71 | impath=impath, 72 | label=label, 73 | classname=classname 74 | ) 75 | items.append(item) 76 | 77 | return items 78 | -------------------------------------------------------------------------------- /datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | import torchvision 4 | import glob 5 | import os 6 | import torch 7 | import torch.utils.data 8 | import torchvision.models 9 | from torchvision import transforms 10 | import numpy as np 11 | from torch.utils.data import Dataset 12 | import PIL.Image as Image 13 | 14 | imagenet_classes = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", 15 | "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", 16 | "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", 17 | "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", 18 | "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", 19 | "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", 20 | "box turtle", "banded gecko", "green iguana", "Carolina anole", 21 | "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", 22 | "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", 23 | "American alligator", "triceratops", "worm snake", "ring-necked snake", 24 | "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", 25 | "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", 26 | "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", 27 | "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", 28 | "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", 29 | "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", 30 | "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", 31 | "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", 32 | "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", 33 | "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", 34 | "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", 35 | "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", 36 | "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", 37 | "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", 38 | "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", 39 | "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", 40 | "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", 41 | "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", 42 | "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", 43 | "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", 44 | "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", 45 | "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", 46 | "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", 47 | "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", 48 | "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", 49 | "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", 50 | "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", 51 | "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", 52 | "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", 53 | "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", 54 | "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", 55 | "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", 56 | "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", 57 | "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", 58 | "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", 59 | "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", 60 | "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", 61 | "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", 62 | "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", 63 | "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", 64 | "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", 65 | "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", 66 | "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", 67 | "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", 68 | "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", 69 | "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", 70 | "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", 71 | "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", 72 | "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", 73 | "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", 74 | "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", 75 | "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", 76 | "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", 77 | "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", 78 | "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", 79 | "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", 80 | "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", 81 | "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", 82 | "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", 83 | "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", 84 | "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", 85 | "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", 86 | "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", 87 | "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", 88 | "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", 89 | "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", 90 | "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", 91 | "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", 92 | "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", 93 | "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", 94 | "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", 95 | "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", 96 | "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", 97 | "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", 98 | "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", 99 | "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", 100 | "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", 101 | "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", 102 | "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", 103 | "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", 104 | "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", 105 | "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", 106 | "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", 107 | "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", 108 | "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", 109 | "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", 110 | "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", 111 | "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", 112 | "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", 113 | "freight car", "French horn", "frying pan", "fur coat", "garbage truck", 114 | "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", 115 | "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", 116 | "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", 117 | "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", 118 | "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", 119 | "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", 120 | "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", 121 | "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", 122 | "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", 123 | "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", 124 | "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", 125 | "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", 126 | "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", 127 | "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", 128 | "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", 129 | "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", 130 | "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", 131 | "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", 132 | "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", 133 | "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", 134 | "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", 135 | "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", 136 | "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", 137 | "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", 138 | "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", 139 | "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", 140 | "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", 141 | "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", 142 | "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", 143 | "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", 144 | "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", 145 | "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", 146 | "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", 147 | "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", 148 | "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", 149 | "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", 150 | "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", 151 | "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", 152 | "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", 153 | "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", 154 | "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", 155 | "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", 156 | "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", 157 | "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", 158 | "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", 159 | "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", 160 | "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", 161 | "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", 162 | "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", 163 | "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", 164 | "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", 165 | "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", 166 | "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", 167 | "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", 168 | "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", 169 | "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", 170 | "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", 171 | "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", 172 | "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", 173 | "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", 174 | "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", 175 | "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", 176 | "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", 177 | "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", 178 | "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"] 179 | imagenet_templates = ["itap of a {}.", 180 | "a bad photo of the {}.", 181 | "a origami {}.", 182 | "a photo of the large {}.", 183 | "a {} in a video game.", 184 | "art of the {}.", 185 | "a photo of the small {}."] 186 | base_idx = [986, 985, 668, 430, 14, 974, 685, 607, 537, 466, 90, 24, 993, 984, 933, 927, 800, 781, 679, 645, 573, 565, 187 | 510, 476, 340, 339, 333, 283, 95, 89, 983, 916, 820, 701, 614, 554, 458, 444, 400, 396, 323, 322, 145, 143, 188 | 69, 13, 0, 996, 959, 895, 890, 874, 802, 779, 425, 404, 399, 388, 284, 275, 139, 137, 98, 87, 989, 982, 189 | 964, 955, 926, 924, 922, 863, 805, 795, 755, 746, 640, 555, 535, 533, 500, 475, 351, 325, 293, 289, 255, 190 | 148, 144, 135, 15, 9, 992, 903, 873, 867, 832, 803, 739, 688, 671, 628, 625, 580, 574, 560, 547, 496, 332, 191 | 330, 321, 320, 292, 195, 149, 138, 76, 19, 11, 10, 980, 957, 937, 936, 917, 900, 878, 763, 687, 576, 564, 192 | 532, 471, 410, 383, 382, 346, 336, 294, 286, 268, 181, 131, 130, 118, 88, 51, 995, 965, 963, 946, 825, 766, 193 | 752, 719, 661, 611, 586, 546, 450, 449, 424, 407, 391, 352, 350, 324, 316, 309, 300, 291, 146, 91, 84, 82, 194 | 80, 48, 25, 18, 12, 8, 991, 953, 886, 822, 780, 736, 732, 682, 627, 557, 528, 524, 498, 486, 477, 474, 437, 195 | 403, 387, 367, 365, 363, 347, 344, 317, 306, 301, 259, 251, 147, 75, 16, 1, 994, 956, 938, 918, 915, 884, 196 | 775, 734, 703, 690, 672, 563, 548, 525, 511, 454, 395, 376, 354, 338, 305, 299, 205, 178, 152, 33, 981, 197 | 962, 958, 948, 945, 944, 934, 850, 847, 791, 783, 727, 605, 603, 568, 562, 520, 518, 467, 401, 393, 386, 198 | 295, 258, 111, 28, 22, 4, 952, 760, 743, 695, 694, 642, 610, 597, 594, 551, 540, 531, 483, 465, 342, 308, 199 | 296, 260, 223, 210, 150, 140, 127, 116, 105, 104, 102, 96, 70, 30, 950, 935, 932, 921, 919, 892, 880, 829, 200 | 768, 761, 713, 712, 711, 654, 639, 621, 595, 592, 561, 448, 440, 439, 420, 406, 398, 335, 307, 279, 254, 201 | 245, 213, 208, 132, 128, 108, 100, 31, 17, 997, 931, 888, 881, 866, 860, 853, 849, 833, 797, 741, 723, 652, 202 | 649, 637, 515, 433, 426, 402, 397, 349, 253, 252, 243, 235, 222, 180, 174, 171, 136, 133, 109, 107, 39, 29, 203 | 20, 951, 897, 865, 796, 759, 509, 443, 384, 355, 288, 276, 274, 247, 217, 194, 183, 141, 123, 92, 77, 68, 204 | 45, 37, 21, 913, 912, 858, 812, 786, 758, 756, 751, 714, 709, 697, 630, 575, 522, 491, 487, 480, 431, 421, 205 | 364, 357, 328, 261, 214, 153, 117, 74, 971, 967, 940, 907, 882, 872, 871, 862, 844, 843, 839, 827, 789, 206 | 726, 720, 646, 613, 570, 517, 495, 453, 392, 337, 329, 297, 287, 270, 249, 230, 203, 182, 161, 156, 142, 207 | 106, 81, 50, 973, 852, 835, 788, 717, 707, 704, 698, 643, 635, 629, 617, 577, 552, 543, 539, 468, 428, 422, 208 | 343, 327, 298, 234, 216, 190, 71, 61, 53, 34, 966, 939, 846, 831, 747, 686, 650, 620, 553, 526, 514, 436, 209 | 429, 366, 318, 273, 241, 209, 169, 115, 113, 72, 954, 941, 929, 901, 889, 883, 879, 877, 823, 798, 777, 210 | 770, 757, 669, 662, 660, 657, 647, 588, 571, 521, 470, 452, 442, 358, 334, 319, 239, 228, 207, 173, 159, 211 | 129, 125, 122, 86, 38, 869, 864, 793, 776, 769, 684, 666, 655, 634, 632, 615, 612, 608, 569, 559, 508, 484, 212 | 432, 378, 375, 370, 362, 348, 313, 302, 263, 262, 256, 237, 199, 176, 168, 120, 57, 56, 49, 41, 7, 2, 969, 213 | 949, 819, 806, 706, 674, 626, 616, 589, 513, 481, 462, 441, 427, 415, 379, 373, 361, 360, 244, 191, 93, 42, 214 | 5, 3, 990, 979, 977, 925, 894, 716, 675, 624, 606, 585, 538, 435, 423, 408, 405, 385, 304, 179, 175, 165, 215 | 164, 94, 85, 65, 961, 920, 857, 830, 826, 774, 738, 724, 692, 581, 534, 490, 485, 478, 472, 369, 311, 290, 216 | 232, 221, 218, 212, 162, 151, 97, 43, 35, 32, 26, 23, 987, 978, 943, 887, 861, 851, 821, 815, 764, 762, 217 | 740, 729, 699, 665, 636, 622, 601, 593, 591, 530, 507, 419, 394, 368, 233, 206, 198, 154, 112, 972, 960, 218 | 928, 909, 898] 219 | novel_idx = [891, 875, 854, 836, 801, 773, 631, 602, 584, 558, 541, 529, 489, 460, 451, 341, 303, 277, 271, 236, 202, 220 | 185, 184, 160, 126, 83, 64, 63, 46, 942, 910, 904, 893, 817, 808, 794, 785, 656, 651, 599, 598, 583, 582, 221 | 579, 572, 544, 497, 492, 417, 414, 380, 331, 281, 224, 196, 188, 103, 99, 44, 845, 834, 814, 811, 809, 222 | 807, 790, 737, 689, 678, 641, 578, 566, 549, 527, 506, 479, 457, 456, 413, 377, 372, 315, 310, 225, 197, 223 | 167, 124, 114, 79, 930, 856, 855, 841, 721, 590, 542, 459, 447, 390, 371, 272, 265, 220, 192, 170, 62, 27, 224 | 975, 914, 765, 735, 710, 708, 683, 503, 502, 501, 463, 455, 314, 264, 200, 121, 119, 52, 908, 896, 870, 225 | 799, 673, 663, 596, 494, 411, 246, 229, 211, 40, 36, 6, 911, 772, 754, 753, 728, 623, 523, 512, 280, 267, 226 | 227, 219, 177, 172, 166, 906, 859, 840, 804, 792, 748, 696, 691, 545, 504, 464, 434, 242, 238, 186, 110, 227 | 999, 976, 923, 745, 715, 567, 482, 473, 285, 266, 204, 187, 78, 988, 970, 968, 838, 667, 659, 644, 619, 228 | 556, 257, 158, 157, 66, 58, 55, 848, 778, 693, 680, 604, 412, 278, 250, 134, 842, 824, 816, 733, 718, 676, 229 | 648, 519, 438, 374, 356, 353, 312, 163, 67, 749, 742, 731, 725, 722, 653, 633, 609, 345, 226, 54, 998, 230 | 828, 505, 101, 947, 705, 670, 536, 418, 269, 813, 488, 445, 409, 201, 155, 59, 47, 905, 771, 677, 664, 469, 231 | 446, 381, 240, 193, 899, 885, 876, 818, 787, 782, 767, 730, 461, 73, 902, 868, 810, 618, 499, 326, 189, 232 | 784, 700, 600, 493, 416, 248, 215, 702, 658, 550, 282, 231, 681, 587, 359, 389, 837, 750, 744, 638, 516, 233 | 60] 234 | 235 | 236 | class imagenet_test_dataset(Dataset): 237 | def __init__(self, valdir, preprocess): 238 | f = open(os.path.join(valdir, 'val.txt')) 239 | target = f.read().splitlines() 240 | self.target = {} 241 | for idx in range(len(target)): 242 | self.target[target[idx].split(' ')[0]] = target[idx].split(' ')[1] 243 | 244 | img_list = glob.glob(valdir + '/*.JPEG') 245 | 246 | self.img_list = [] 247 | 248 | for idx in range(len(img_list)): 249 | current_name = os.path.basename(img_list[idx]) 250 | if int(self.target[current_name]) in novel_idx: 251 | self.img_list.append(img_list[idx]) 252 | 253 | self.size = len(self.img_list) 254 | 255 | self.transform = preprocess 256 | 257 | def __getitem__(self, index): 258 | img = Image.open(self.img_list[index]).convert('RGB') 259 | img = self.transform(img) 260 | img = torch.from_numpy(np.array(img)) 261 | ind = os.path.basename(self.img_list[index]) 262 | target = int(self.target[ind]) 263 | return img, target 264 | 265 | def __len__(self): 266 | return self.size 267 | 268 | 269 | class ImageNet(): 270 | dataset_dir = 'imagenet' 271 | 272 | def __init__(self, root, num_shots): 273 | 274 | self.dataset_dir = os.path.join(root, self.dataset_dir) 275 | self.image_dir = os.path.join(self.dataset_dir, 'images') 276 | 277 | train_preprocess = transforms.Compose([ 278 | transforms.RandomResizedCrop(size=224, scale=(0.5, 1), interpolation=transforms.InterpolationMode.BICUBIC), 279 | transforms.RandomHorizontalFlip(p=0.5), 280 | transforms.ToTensor(), 281 | transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) 282 | ]) 283 | test_preprocess = transforms.Compose([ 284 | transforms.Resize(size=224), 285 | transforms.CenterCrop(size=(224, 224)), 286 | transforms.ToTensor(), 287 | transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) 288 | ]) 289 | 290 | self.train = torchvision.datasets.ImageFolder(os.path.join(self.image_dir, 'train'), transform=train_preprocess) 291 | self.full = torchvision.datasets.ImageFolder(os.path.join(self.image_dir, 'train'), transform=train_preprocess) 292 | self.val = torchvision.datasets.ImageFolder(os.path.join(self.image_dir, 'val'), transform=test_preprocess) 293 | 294 | self.template = imagenet_templates 295 | self.classnames = imagenet_classes 296 | 297 | split_by_label_dict = defaultdict(list) 298 | for i in range(len(self.train.imgs)): 299 | split_by_label_dict[self.train.targets[i]].append(self.train.imgs[i]) 300 | imgs = [] 301 | targets = [] 302 | few_shot_base = defaultdict(list) 303 | for idx in range(len(self.classnames)): 304 | if idx in base_idx: 305 | few_shot_base[idx] = split_by_label_dict[idx] 306 | split_by_label_dict = few_shot_base 307 | for label, items in split_by_label_dict.items(): 308 | if num_shots > 0: 309 | imgs = imgs + random.sample(items, num_shots) 310 | targets = targets + [label for _ in range(num_shots)] 311 | else: 312 | imgs = imgs + items 313 | targets = targets + [label for _ in range(len(items))] 314 | self.train.imgs = imgs 315 | self.train.targets = targets 316 | self.train.samples = imgs 317 | 318 | val_imgs = [] 319 | val_targets = [] 320 | for i in range(len(self.val.imgs)): 321 | if self.val.targets[i] in novel_idx: 322 | val_imgs.append(self.val.imgs[i]) 323 | val_targets.append(self.val.targets[i]) 324 | self.val.imgs = val_imgs 325 | self.val.targets = val_targets 326 | self.val.samples = val_imgs 327 | 328 | split_by_label_dict = defaultdict(list) 329 | for i in range(len(self.full.imgs)): 330 | split_by_label_dict[self.full.targets[i]].append(self.full.imgs[i]) 331 | imgs_full = [] 332 | targets_full = [] 333 | for label, items in split_by_label_dict.items(): 334 | if num_shots > 0: 335 | imgs_full = imgs_full + random.sample(items, num_shots) 336 | targets_full = targets_full + [label for _ in range(num_shots)] 337 | else: 338 | imgs_full = imgs_full + items 339 | targets_full = targets_full + [label for _ in range(len(items))] 340 | self.full.imgs = imgs_full 341 | self.full.targets = targets_full 342 | self.full.samples = imgs_full 343 | -------------------------------------------------------------------------------- /datasets/oxford_pets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from collections import defaultdict 4 | from .utils import Datum, DatasetBase, read_json, write_json 5 | 6 | 7 | template = ['a photo of a {}, a type of pet.'] 8 | base_classes = ['shiba_inu', 'samoyed', 'pug', 'keeshond', 'wheaten_terrier', 'newfoundland', 'german_shorthaired', 'sphynx', 9 | 'pomeranian', 'chihuahua', 'saint_bernard', 'russian_blue', 'basset_hound', 'scottish_terrier', 'yorkshire_terrier', 10 | 'japanese_chin', 'havanese', 'bengal', 'great_pyrenees', 'beagle', 'miniature_pinscher', 'english_cocker_spaniel', 'siamese', 11 | 'leonberger', 'english_setter', 'american_bulldog', 'boxer', 'abyssinian', 'british_shorthair'] 12 | novel_classes = ['maine_coon', 'egyptian_mau', 'american_pit_bull_terrier', 13 | 'staffordshire_bull_terrier', 'ragdoll', 'persian', 'birman', 'bombay'] 14 | 15 | 16 | class OxfordPets(DatasetBase): 17 | 18 | dataset_dir = 'oxford_pets' 19 | 20 | def __init__(self, root, num_shots): 21 | self.dataset_dir = os.path.join(root, self.dataset_dir) 22 | self.image_dir = os.path.join(self.dataset_dir, 'images') 23 | self.anno_dir = os.path.join(self.dataset_dir, 'annotations') 24 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_OxfordPets.json') 25 | 26 | self.template = template 27 | 28 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 29 | few_shot_base = [] 30 | for item in train: 31 | if item.classname in base_classes: 32 | few_shot_base.append(item) 33 | few_shot_base = self.generate_fewshot_dataset(train, num_shots=num_shots) 34 | few_shot_full = self.generate_fewshot_dataset(val, num_shots=num_shots) 35 | 36 | test_novel = [] 37 | for item in test: 38 | if item.classname in novel_classes: 39 | test_novel.append(item) 40 | test_novel = self.generate_fewshot_dataset(test, num_shots=num_shots) 41 | 42 | super().__init__(train=few_shot_base, val=test_novel, full=few_shot_full) 43 | 44 | def read_data(self, split_file): 45 | filepath = os.path.join(self.anno_dir, split_file) 46 | items = [] 47 | 48 | with open(filepath, 'r') as f: 49 | lines = f.readlines() 50 | for line in lines: 51 | line = line.strip() 52 | imname, label, species, _ = line.split(' ') 53 | breed = imname.split('_')[:-1] 54 | breed = '_'.join(breed) 55 | breed = breed.lower() 56 | imname += '.jpg' 57 | impath = os.path.join(self.image_dir, imname) 58 | label = int(label) - 1 # convert to 0-based index 59 | item = Datum( 60 | impath=impath, 61 | label=label, 62 | classname=breed 63 | ) 64 | items.append(item) 65 | 66 | return items 67 | 68 | @staticmethod 69 | def split_trainval(trainval, p_val=0.2): 70 | p_trn = 1 - p_val 71 | print(f'Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val') 72 | tracker = defaultdict(list) 73 | for idx, item in enumerate(trainval): 74 | label = item.label 75 | tracker[label].append(idx) 76 | 77 | train, val = [], [] 78 | for label, idxs in tracker.items(): 79 | n_val = round(len(idxs) * p_val) 80 | assert n_val > 0 81 | random.shuffle(idxs) 82 | for n, idx in enumerate(idxs): 83 | item = trainval[idx] 84 | if n < n_val: 85 | val.append(item) 86 | else: 87 | train.append(item) 88 | 89 | return train, val 90 | 91 | @staticmethod 92 | def save_split(train, val, test, filepath, path_prefix): 93 | def _extract(items): 94 | out = [] 95 | for item in items: 96 | impath = item.impath 97 | label = item.label 98 | classname = item.classname 99 | impath = impath.replace(path_prefix, '') 100 | if impath.startswith('/'): 101 | impath = impath[1:] 102 | out.append((impath, label, classname)) 103 | return out 104 | 105 | train = _extract(train) 106 | val = _extract(val) 107 | test = _extract(test) 108 | 109 | split = { 110 | 'train': train, 111 | 'val': val, 112 | 'test': test 113 | } 114 | 115 | write_json(split, filepath) 116 | print(f'Saved split to {filepath}') 117 | 118 | @staticmethod 119 | def read_split(filepath, path_prefix): 120 | def _convert(items): 121 | out = [] 122 | for impath, label, classname in items: 123 | impath = os.path.join(path_prefix, impath) 124 | item = Datum( 125 | impath=impath, 126 | label=int(label), 127 | classname=classname 128 | ) 129 | out.append(item) 130 | return out 131 | 132 | print(f'Reading split from {filepath}') 133 | split = read_json(filepath) 134 | train = _convert(split['train']) 135 | val = _convert(split['val']) 136 | test = _convert(split['test']) 137 | 138 | return train, val, test -------------------------------------------------------------------------------- /datasets/sun397.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase 4 | 5 | from .oxford_pets import OxfordPets 6 | 7 | template = ['a photo of a {}.'] 8 | base_classes = ['indoor florist_shop', 'skatepark', 'raft', 'oilrig', 'ball_pit', 'martial_arts_gym', 'courtroom', 9 | 'cockpit', 'airplane_cabin', 'volcano', 'sauna', 'music_studio', 'indoor volleyball_court', 10 | 'batters_box', 'wind_farm', 'wave', 'rock_arch', 'raceway', 'outdoor track', 'oast_house', 11 | 'limousine_interior', 'indoor cloister', 'cemetery', 'carrousel', 'baseball stadium', 12 | 'auto_factory', 'vineyard', 'toll_plaza', 'television_studio', 'outdoor tennis_court', 13 | 'outdoor oil_refinery', 'manufactured_home', 'lift_bridge', 'indoor pilothouse', 'forest_road', 14 | 'exterior covered_bridge', 'coral_reef underwater', 'bowling_alley', 'bamboo_forest', 'aquarium', 15 | 'veterinarians_office', 'vegetation desert', 'outdoor hangar', 'dining_car', 'control_room', 16 | 'barrel_storage wine_cellar', 'squash_court', 'sky', 'promenade_deck', 'playground', 17 | 'platform train_station', 'pantry', 'outdoor lido_deck', 'outdoor ice_skating_rink', 18 | 'outdoor control_tower', 'kindergarden_classroom', 'kasbah', 'islet', 'indoor brewery', 'igloo', 19 | 'heliport', 'courthouse', 'rope_bridge', 'rice_paddy', 'racecourse', 'pulpit', 'landing_deck', 20 | 'indoor gymnasium', 'indoor cavern', 'indoor casino', 'ice_floe', 'crevasse', 'butte', 'bus_interior', 21 | 'boxing_ring', 'topiary_garden', 'ski_resort', 'pharmacy', 'outdoor greenhouse', 22 | 'outdoor athletic_field', 'orchard', 'lighthouse', 'indoor wrestling_ring', 'indoor tennis_court', 23 | 'indoor swimming_pool', 'fire_station', 'closet', 'bottle_storage wine_cellar', 'boardwalk', 24 | 'outdoor labyrinth', 'landfill', 'indoor jail', 'iceberg', 'bullring', 'art_gallery', 25 | 'anechoic_chamber', 'amusement_park', 'videostore', 'throne_room', 'slum', 'sandbox', 26 | 'picnic_area', 'outdoor tent', 'laundromat', 'indoor warehouse', 'indoor ice_skating_rink', 27 | 'hot_spring', 'exterior gazebo', 'dam', 'campus', 'aqueduct', 'windmill', 28 | 'water_tower', 'subway_interior', 'phone_booth', 'pagoda', 'indoor escalator', 'indoor badminton_court', 29 | 'establishment poolroom', 'discotheque', 'childs_room', 'archive', 'amphitheater', 'shop bakery', 30 | 'riding_arena', 'residential_neighborhood', 'outdoor volleyball_court', 'outdoor general_store', 31 | 'outdoor basketball_court', 'interior elevator', 'indoor synagogue', 'indoor firing_range', 32 | 'gas_station', 'electrical_substation', 'driveway', 'classroom', 'basilica', 'schoolhouse', 33 | 'physics_laboratory', 'outdoor podium', 'mausoleum', 'fountain', 'excavation', 'dorm_room', 34 | 'cheese_factory', 'viaduct', 'utility_room', 'outdoor outhouse', 'outdoor driving_range', 35 | 'outdoor doorway', 'music_store', 'marsh', 'locker_room', 'kitchenette', 'kitchen', 36 | 'indoor shopping_mall', 'indoor booth', 'canyon', 'badlands', 'south_asia temple', 'shoe_shop', 37 | 'sandbar', 'sand desert', 'restaurant_kitchen', 'outdoor bazaar', 'indoor market', 'conference_room', 38 | 'butchers_shop', 'banquet_hall', 'vegetable_garden', 'railroad_track', 'patio', 'outdoor hot_tub', 39 | 'medina', 'hospital_room', 'harbor', 'frontseat car_interior', 'creek', 'chalet', 'campsite', 40 | 'boathouse', 'biology_laboratory', 'barn', 'tree_farm', 'snowfield', 'outdoor observatory', 41 | 'indoor parking_garage', 'indoor bow_window', 'fishpond', 'elevator_shaft', 'cafeteria', 42 | 'broadleaf forest', 'beach', 'train_railway', 'server_room', 'pasture', 'outdoor market', 43 | 'indoor hangar', 'golf_course', 'food_court', 'corridor', 'bedroom', 'valley', 'urban canal', 44 | 'restaurant_patio', 'public atrium', 'outdoor nuclear_power_plant', 'office cubicle', 'indoor pub', 45 | 'highway', 'engine_room', 'dining_room', 'crosswalk', 'computer_room', 'tree_house', 'rainforest', 46 | 'outdoor bow_window', 'outdoor apartment_building', 'lecture_room', 'indoor stage', 'indoor library', 47 | 'indoor jacuzzi', 'indoor chicken_coop', 'indoor bazaar', 'hospital', 'hayfield', 'football stadium', 48 | 'beauty_salon', 'skyscraper', 'putting_green', 'operating_room', 'indoor bistro', 'garbage_dump', 49 | 'formal_garden', 'dock', 'corn_field', 'construction_site', 'ballroom', 'baggage_claim', 'art_studio', 50 | 'wheat_field', 'sushi_bar', 'supermarket', 'ski_lodge', 'runway', 'park', 'outdoor kennel', 51 | 'outdoor diner', 'lobby', 'indoor general_store', 'exterior balcony', 'watering_hole', 'van_interior', 52 | 'plaza', 'outdoor arrival_gate', 'fire_escape', 'fairway', 'water moat', 'village', 'street', 'shower', 53 | 'outdoor planetarium', 'outdoor church', 'jail_cell', 'indoor church', 'indoor cathedral', 54 | 'candy_store', 'ticket_booth', 'staircase', 'outdoor power_plant', 'office_building', 'indoor garage', 55 | 'catacomb', 'amusement_arcade', 'plunge waterfall', 'jewelry_shop', 'forest_path'] 56 | novel_classes = ['east_asia temple', 'dentists_office', 'castle', 'bookstore', 'arch', 'alley', 'toyshop', 'pond', 57 | 'platform subway_station', 58 | 'palace', 'outdoor chicken_coop', 'motel', 'ice_cream_parlor', 'home_office', 'clothing_store', 59 | 'auditorium', 'wet_bar', 60 | 'tower', 'swamp', 'shopfront', 'parlor', 'outdoor swimming_pool', 'outdoor mosque', 61 | 'outdoor cathedral', 'mountain_snowy', 62 | 'indoor diner', 'fastfood_restaurant', 'cultivated field', 'parking_lot', 'natural lake', 63 | 'herb_garden', 'basement', 64 | 'sea_cliff', 'indoor kennel', 'home poolroom', 'game_room', 'fan waterfall', 'conference_center', 65 | 'coast', 'bathroom', 66 | 'barndoor', 'office', 'indoor factory', 'ice_shelf', 'delicatessen', 'courtyard', 'bridge', 'abbey', 67 | 'veranda', 'ski_slope', 68 | 'shed', 'indoor mosque', 'indoor greenhouse', 'gift_shop', 'cottage_garden', 'playroom', 69 | 'outdoor monastery', 'indoor museum', 70 | 'outdoor cabin', 'indoor apse', 'hill', 'burial_chamber', 'berth', 'bar', 'airport_terminal', 'yard', 71 | 'stable', 'recreation_room', 72 | 'outdoor parking_garage', 'corral', 'thriftshop', 'natural canal', 'indoor movie_theater', 'house', 73 | 'attic', 'trench', 'ruin', 74 | 'outdoor hunting_lodge', 'interior balcony', 'home dinette', 'building_facade', 'boat_deck', 'river', 75 | 'ocean', 'hotel_room', 76 | 'baseball_field', 'cliff', 'botanical_garden', 'waiting_room', 'mountain', 'lock_chamber', 77 | 'indoor podium', 'door elevator', 'coffee_shop', 'bayou', 'chemistry_lab', 'assembly_line', 78 | 'youth_hostel', 'pavilion', 'industrial_area', 'galley', 79 | 'art_school', 'reception', 'outdoor hotel', 'living_room', 'wild field', 'outdoor inn', 80 | 'outdoor synagogue', 'indoor_procenium theater', 'restaurant', 'nursery', 'needleleaf forest', 81 | 'mansion', 'indoor_seats theater', 'drugstore', 'block waterfall', 'vehicle dinette', 82 | 'outdoor library', 'clean_room', 'backseat car_interior' 83 | ] 84 | 85 | class SUN397(DatasetBase): 86 | dataset_dir = 'sun397' 87 | 88 | def __init__(self, root, num_shots): 89 | self.dataset_dir = os.path.join(root, self.dataset_dir) 90 | self.image_dir = os.path.join(self.dataset_dir, 'SUN397') 91 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_SUN397.json') 92 | 93 | self.template = template 94 | 95 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 96 | few_shot_base = [] 97 | for item in train: 98 | if item.classname in base_classes: 99 | few_shot_base.append(item) 100 | few_shot_base = self.generate_fewshot_dataset(few_shot_base, num_shots=num_shots) 101 | few_shot_full = self.generate_fewshot_dataset(val, num_shots=num_shots) 102 | 103 | test_novel = [] 104 | for item in test: 105 | if item.classname in novel_classes: 106 | test_novel.append(item) 107 | test_novel = self.generate_fewshot_dataset(test_novel, num_shots=num_shots) 108 | 109 | super().__init__(train=few_shot_base, val=test_novel, full=few_shot_full) 110 | 111 | def read_data(self, cname2lab, text_file): 112 | text_file = os.path.join(self.dataset_dir, text_file) 113 | items = [] 114 | 115 | with open(text_file, 'r') as f: 116 | lines = f.readlines() 117 | for line in lines: 118 | imname = line.strip()[1:] # remove / 119 | classname = os.path.dirname(imname) 120 | label = cname2lab[classname] 121 | impath = os.path.join(self.image_dir, imname) 122 | 123 | names = classname.split('/')[1:] # remove 1st letter 124 | names = names[::-1] # put words like indoor/outdoor at first 125 | classname = ' '.join(names) 126 | 127 | item = Datum( 128 | impath=impath, 129 | label=label, 130 | classname=classname 131 | ) 132 | items.append(item) 133 | 134 | return items 135 | -------------------------------------------------------------------------------- /datasets/ucf101.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase 4 | 5 | from .oxford_pets import OxfordPets 6 | 7 | 8 | template = ['a photo of a person doing {}.'] 9 | base_classes = ['Typing', 'Table_Tennis_Shot', 'Soccer_Penalty', 'Playing_Guitar', 'Military_Parade', 'Ice_Dancing', 'Bowling', 10 | 'Blowing_Candles', 'Billiards', 'Bench_Press', 'Field_Hockey_Penalty', 'Baby_Crawling', 'Writing_On_Board', 11 | 'Basketball_Dunk', 'Horse_Race', 'Sumo_Wrestling', 'Surfing', 'Clean_And_Jerk', 'Pull_Ups', 'Rock_Climbing_Indoor', 12 | 'Playing_Violin', 'Playing_Piano', 'Apply_Eye_Makeup', 'Horse_Riding', 'Sky_Diving', 'Tai_Chi', 'Rafting', 'Playing_Dhol', 13 | 'Breast_Stroke', 'Fencing', 'Cutting_In_Kitchen', 'Punch', 'Golf_Swing', 'Playing_Sitar', 'Band_Marching', 'Biking', 14 | 'Mopping_Floor', 'Shaving_Beard', 'Uneven_Bars', 'Handstand_Pushups', 'Brushing_Teeth', 'Baseball_Pitch', 'Rowing', 15 | 'Blow_Dry_Hair', 'Tennis_Swing', 'Drumming', 'Diving', 'Archery', 'Playing_Flute', 'Walking_With_Dog', 'Skate_Boarding', 16 | 'Cliff_Diving', 'Boxing_Punching_Bag', 'Knitting', 'Cricket_Shot', 'Playing_Cello', 'Skiing', 'Playing_Tabla', 'Hula_Hoop', 17 | 'Haircut', 'Pommel_Horse', 'Trampoline_Jumping', 'Skijet', 'Basketball', 'Salsa_Spin', 'Long_Jump', 'Apply_Lipstick', 18 | 'Volleyball_Spiking', 'Juggling_Balls', 'Floor_Gymnastics'] 19 | novel_classes = ['High_Jump', 'Front_Crawl', 'Pole_Vault', 'Hammer_Throw', 'Pizza_Tossing', 'Swing', 'Yo_Yo', 'Shotput', 'Head_Massage', 20 | 'Jump_Rope', 'Soccer_Juggling', 'Hammering', 'Mixing', 'Kayaking', 'Cricket_Bowling', 'Jumping_Jack', 'Boxing_Speed_Bag', 21 | 'Javelin_Throw', 'Handstand_Walking', 'Lunges', 'Push_Ups', 'Throw_Discus', 'Wall_Pushups', 'Nunchucks', 'Frisbee_Catch', 22 | 'Body_Weight_Squats', 'Rope_Climbing', 'Parallel_Bars', 'Still_Rings', 'Playing_Daf', 'Balance_Beam'] 23 | 24 | 25 | class UCF101(DatasetBase): 26 | 27 | dataset_dir = 'ucf101' 28 | 29 | def __init__(self, root, num_shots): 30 | self.dataset_dir = os.path.join(root, self.dataset_dir) 31 | self.image_dir = os.path.join(self.dataset_dir, 'UCF-101-midframes') 32 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_UCF101.json') 33 | 34 | self.template = template 35 | 36 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 37 | few_shot_base = [] 38 | for item in train: 39 | if item.classname in base_classes: 40 | few_shot_base.append(item) 41 | few_shot_base = self.generate_fewshot_dataset(few_shot_base, num_shots=num_shots) 42 | few_shot_full = self.generate_fewshot_dataset(val, num_shots=16) 43 | 44 | test_novel = [] 45 | for item in test: 46 | if item.classname in novel_classes: 47 | test_novel.append(item) 48 | test_novel = self.generate_fewshot_dataset(test_novel, num_shots=num_shots) 49 | 50 | super().__init__(train=few_shot_base, val=test_novel, full=few_shot_full) 51 | 52 | def read_data(self, cname2lab, text_file): 53 | text_file = os.path.join(self.dataset_dir, text_file) 54 | items = [] 55 | 56 | with open(text_file, 'r') as f: 57 | lines = f.readlines() 58 | for line in lines: 59 | line = line.strip().split(' ')[0] # trainlist: filename, label 60 | action, filename = line.split('/') 61 | label = cname2lab[action] 62 | 63 | elements = re.findall('[A-Z][^A-Z]*', action) 64 | renamed_action = '_'.join(elements) 65 | 66 | filename = filename.replace('.avi', '.jpg') 67 | impath = os.path.join(self.image_dir, renamed_action, filename) 68 | 69 | item = Datum( 70 | impath=impath, 71 | label=label, 72 | classname=renamed_action 73 | ) 74 | items.append(item) 75 | 76 | return items 77 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import os.path as osp 4 | import tarfile 5 | import zipfile 6 | from collections import defaultdict 7 | import gdown 8 | import json 9 | import torch 10 | from torch.utils.data import Dataset as TorchDataset 11 | import torchvision.transforms as T 12 | from PIL import Image 13 | 14 | 15 | def read_json(fpath): 16 | """Read json file from a path.""" 17 | with open(fpath, 'r') as f: 18 | obj = json.load(f) 19 | return obj 20 | 21 | 22 | def write_json(obj, fpath): 23 | """Writes to a json file.""" 24 | if not osp.exists(osp.dirname(fpath)): 25 | os.makedirs(osp.dirname(fpath)) 26 | with open(fpath, 'w') as f: 27 | json.dump(obj, f, indent=4, separators=(',', ': ')) 28 | 29 | 30 | def read_image(path): 31 | """Read image from path using ``PIL.Image``. 32 | 33 | Args: 34 | path (str): path to an image. 35 | 36 | Returns: 37 | PIL image 38 | """ 39 | if not osp.exists(path): 40 | raise IOError('No file exists at {}'.format(path)) 41 | 42 | while True: 43 | try: 44 | img = Image.open(path).convert('RGB') 45 | return img 46 | except IOError: 47 | print( 48 | 'Cannot read image from {}, ' 49 | 'probably due to heavy IO. Will re-try'.format(path) 50 | ) 51 | 52 | 53 | def listdir_nohidden(path, sort=False): 54 | """List non-hidden items in a directory. 55 | 56 | Args: 57 | path (str): directory path. 58 | sort (bool): sort the items. 59 | """ 60 | items = [f for f in os.listdir(path) if not f.startswith('.') and 'sh' not in f] 61 | if sort: 62 | items.sort() 63 | return items 64 | 65 | 66 | class Datum: 67 | """Data instance which defines the basic attributes. 68 | 69 | Args: 70 | impath (str): image path. 71 | label (int): class label. 72 | domain (int): domain label. 73 | classname (str): class name. 74 | """ 75 | 76 | def __init__(self, impath='', label=0, domain=-1, classname=''): 77 | assert isinstance(impath, str) 78 | assert isinstance(label, int) 79 | assert isinstance(domain, int) 80 | assert isinstance(classname, str) 81 | 82 | self._impath = impath 83 | self._label = label 84 | self._domain = domain 85 | self._classname = classname 86 | 87 | @property 88 | def impath(self): 89 | return self._impath 90 | 91 | @property 92 | def label(self): 93 | return self._label 94 | 95 | @property 96 | def domain(self): 97 | return self._domain 98 | 99 | @property 100 | def classname(self): 101 | return self._classname 102 | 103 | 104 | class DatasetBase: 105 | """A unified dataset class for 106 | 1) domain adaptation 107 | 2) domain generalization 108 | 3) semi-supervised learning 109 | """ 110 | dataset_dir = '' # the directory where the dataset is stored 111 | domains = [] # string names of all domains 112 | 113 | def __init__(self, train=None, val=None, full=None): 114 | self._train = train 115 | self._val = val 116 | self._full = full 117 | 118 | self._num_classes = self.get_num_classes(full) 119 | self._lab2cname, self._classnames = self.get_lab2cname(full) 120 | 121 | @property 122 | def train(self): 123 | return self._train 124 | 125 | @property 126 | def val(self): 127 | return self._val 128 | 129 | @property 130 | def full(self): 131 | return self._full 132 | 133 | @property 134 | def lab2cname(self): 135 | return self._lab2cname 136 | 137 | @property 138 | def classnames(self): 139 | return self._classnames 140 | 141 | @property 142 | def num_classes(self): 143 | return self._num_classes 144 | 145 | def get_num_classes(self, data_source): 146 | """Count number of classes. 147 | 148 | Args: 149 | data_source (list): a list of Datum objects. 150 | """ 151 | label_set = set() 152 | for item in data_source: 153 | label_set.add(item.label) 154 | return max(label_set) + 1 155 | 156 | def get_lab2cname(self, data_source): 157 | """Get a label-to-classname mapping (dict). 158 | 159 | Args: 160 | data_source (list): a list of Datum objects. 161 | """ 162 | container = set() 163 | for item in data_source: 164 | container.add((item.label, item.classname)) 165 | mapping = {label: classname for label, classname in container} 166 | labels = list(mapping.keys()) 167 | labels.sort() 168 | classnames = [mapping[label] for label in labels] 169 | return mapping, classnames 170 | 171 | def check_input_domains(self, source_domains, target_domains): 172 | self.is_input_domain_valid(source_domains) 173 | self.is_input_domain_valid(target_domains) 174 | 175 | def is_input_domain_valid(self, input_domains): 176 | for domain in input_domains: 177 | if domain not in self.domains: 178 | raise ValueError( 179 | 'Input domain must belong to {}, ' 180 | 'but got [{}]'.format(self.domains, domain) 181 | ) 182 | 183 | def download_data(self, url, dst, from_gdrive=True): 184 | if not osp.exists(osp.dirname(dst)): 185 | os.makedirs(osp.dirname(dst)) 186 | 187 | if from_gdrive: 188 | gdown.download(url, dst, quiet=False) 189 | else: 190 | raise NotImplementedError 191 | 192 | print('Extracting file ...') 193 | 194 | try: 195 | tar = tarfile.open(dst) 196 | tar.extractall(path=osp.dirname(dst)) 197 | tar.close() 198 | except: 199 | zip_ref = zipfile.ZipFile(dst, 'r') 200 | zip_ref.extractall(osp.dirname(dst)) 201 | zip_ref.close() 202 | 203 | print('File extracted to {}'.format(osp.dirname(dst))) 204 | 205 | def generate_fewshot_dataset( 206 | self, *data_sources, num_shots=-1, repeat=True 207 | ): 208 | """Generate a few-shot dataset (typically for the training set). 209 | 210 | This function is useful when one wants to evaluate a model 211 | in a few-shot learning setting where each class only contains 212 | a few number of images. 213 | 214 | Args: 215 | data_sources: each individual is a list containing Datum objects. 216 | num_shots (int): number of instances per class to sample. 217 | repeat (bool): repeat images if needed. 218 | """ 219 | if num_shots < 1: 220 | if len(data_sources) == 1: 221 | return data_sources[0] 222 | return data_sources 223 | 224 | print(f'Creating a {num_shots}-shot dataset') 225 | 226 | output = [] 227 | 228 | for data_source in data_sources: 229 | tracker = self.split_dataset_by_label(data_source) 230 | dataset = [] 231 | 232 | for label, items in tracker.items(): 233 | if len(items) >= num_shots: 234 | sampled_items = random.sample(items, num_shots) 235 | else: 236 | if repeat: 237 | sampled_items = random.choices(items, k=num_shots) 238 | else: 239 | sampled_items = items 240 | dataset.extend(sampled_items) 241 | 242 | output.append(dataset) 243 | 244 | if len(output) == 1: 245 | return output[0] 246 | 247 | return output 248 | 249 | def split_dataset_by_label(self, data_source): 250 | """Split a dataset, i.e. a list of Datum objects, 251 | into class-specific groups stored in a dictionary. 252 | 253 | Args: 254 | data_source (list): a list of Datum objects. 255 | """ 256 | output = defaultdict(list) 257 | 258 | for item in data_source: 259 | output[item.label].append(item) 260 | 261 | return output 262 | 263 | def split_dataset_by_domain(self, data_source): 264 | """Split a dataset, i.e. a list of Datum objects, 265 | into domain-specific groups stored in a dictionary. 266 | 267 | Args: 268 | data_source (list): a list of Datum objects. 269 | """ 270 | output = defaultdict(list) 271 | 272 | for item in data_source: 273 | output[item.domain].append(item) 274 | 275 | return output 276 | 277 | 278 | class DatasetWrapper(TorchDataset): 279 | def __init__(self, data_source, input_size, transform=None, is_train=False, 280 | return_img0=False, k_tfm=1): 281 | self.data_source = data_source 282 | self.transform = transform # accept list (tuple) as input 283 | self.is_train = is_train 284 | # Augmenting an image K>1 times is only allowed during training 285 | self.k_tfm = k_tfm if is_train else 1 286 | self.return_img0 = return_img0 287 | 288 | if self.k_tfm > 1 and transform is None: 289 | raise ValueError( 290 | 'Cannot augment the image {} times ' 291 | 'because transform is None'.format(self.k_tfm) 292 | ) 293 | 294 | # Build transform that doesn't apply any data augmentation 295 | interp_mode = T.InterpolationMode.BICUBIC 296 | to_tensor = [] 297 | to_tensor += [T.Resize(input_size, interpolation=interp_mode)] 298 | to_tensor += [T.ToTensor()] 299 | normalize = T.Normalize( 300 | mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) 301 | ) 302 | to_tensor += [normalize] 303 | self.to_tensor = T.Compose(to_tensor) 304 | 305 | def __len__(self): 306 | return len(self.data_source) 307 | 308 | def __getitem__(self, idx): 309 | item = self.data_source[idx] 310 | 311 | output = { 312 | 'label': item.label, 313 | 'domain': item.domain, 314 | 'impath': item.impath 315 | } 316 | 317 | img0 = read_image(item.impath) 318 | 319 | if self.transform is not None: 320 | if isinstance(self.transform, (list, tuple)): 321 | for i, tfm in enumerate(self.transform): 322 | img = self._transform_image(tfm, img0) 323 | keyname = 'img' 324 | if (i + 1) > 1: 325 | keyname += str(i + 1) 326 | output[keyname] = img 327 | else: 328 | img = self._transform_image(self.transform, img0) 329 | output['img'] = img 330 | 331 | if self.return_img0: 332 | output['img0'] = self.to_tensor(img0) 333 | 334 | return output['img'], output['label'] 335 | 336 | def _transform_image(self, tfm, img0): 337 | img_list = [] 338 | 339 | for k in range(self.k_tfm): 340 | img_list.append(tfm(img0)) 341 | 342 | img = img_list 343 | if len(img) == 1: 344 | img = img[0] 345 | 346 | return img 347 | 348 | 349 | def build_data_loader( 350 | data_source=None, 351 | batch_size=64, 352 | input_size=224, 353 | tfm=None, 354 | is_train=True, 355 | shuffle=False, 356 | dataset_wrapper=None 357 | ): 358 | 359 | if dataset_wrapper is None: 360 | dataset_wrapper = DatasetWrapper 361 | 362 | # Build data loader 363 | data_loader = torch.utils.data.DataLoader( 364 | dataset_wrapper(data_source, input_size=input_size, transform=tfm, is_train=is_train), 365 | batch_size=batch_size, 366 | num_workers=8, 367 | shuffle=shuffle, 368 | drop_last=False, 369 | pin_memory=(torch.cuda.is_available()) 370 | ) 371 | assert len(data_loader) > 0 372 | 373 | return data_loader 374 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import argparse 4 | import yaml 5 | from tqdm import tqdm 6 | import torch 7 | import torch.nn.functional as F 8 | import torch.nn as nn 9 | from datasets import build_dataset 10 | from datasets.utils import build_data_loader 11 | import torchvision.transforms as transforms 12 | import clip 13 | from clip.utils import (cls_acc, 14 | clip_classifier, 15 | build_cache_model, 16 | pre_load_features) 17 | 18 | 19 | def get_arguments(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--config', dest='config', help='settings of Meta-Adapter in yaml format') 22 | args = parser.parse_args() 23 | 24 | return args 25 | 26 | 27 | class MetaAdapter(nn.Module): 28 | def __init__(self, dim=1024, num_heads=1): 29 | super().__init__() 30 | self.dim = dim 31 | self.num_heads = num_heads 32 | self.q_proj = nn.Linear(dim, dim, bias=False) 33 | self.k_proj = nn.Linear(dim, dim, bias=False) 34 | self.alpha_proj = nn.Linear(dim, 1, bias=True) 35 | self._reset_parameters() 36 | 37 | def _reset_parameters(self): 38 | nn.init.xavier_uniform_(self.q_proj.weight) 39 | nn.init.xavier_uniform_(self.k_proj.weight) 40 | nn.init.xavier_uniform_(self.alpha_proj.weight) 41 | nn.init.constant_(self.alpha_proj.bias, 1) 42 | 43 | def forward(self, query, key, value): 44 | B, K, C = key.shape 45 | res = query 46 | 47 | query = query.reshape(B, 1, C) 48 | key = torch.cat([query, key], dim=1) 49 | value = torch.cat([query, value], dim=1) 50 | query = self.q_proj(query).reshape(B, self.num_heads, C) 51 | key = self.k_proj(key) 52 | 53 | query = query.reshape(B, self.num_heads, 1, -1).permute(0, 2, 1, 3) 54 | key = key.reshape(B, K + 1, 1, -1).permute(0, 2, 1, 3) 55 | value = value.reshape(B, K + 1, 1, -1).permute(0, 2, 1, 3) 56 | 57 | attn_weight = (query @ key.transpose(-1, -2) / torch.sqrt(torch.tensor(self.dim, dtype=torch.float))).softmax(-1) 58 | attn = attn_weight @ value 59 | 60 | alpha = torch.nn.functional.sigmoid(self.alpha_proj(res).reshape(B, -1, 1, 1)) 61 | attn = (alpha * attn).squeeze() 62 | 63 | attn = res + attn 64 | attn = F.normalize(attn, p=2, dim=-1) 65 | return attn 66 | 67 | 68 | def run_meta_adapter(cfg, cache_keys, test_features, test_labels, clip_weights, clip_model, 69 | train_loader_image): 70 | # Zero-shot CLIP 71 | clip_logits = 100. * test_features @ clip_weights 72 | acc = cls_acc(clip_logits, test_labels) 73 | print("**** Zero-shot CLIP's test accuracy on novel classes: {:.2f}. ****".format(acc)) 74 | 75 | adapter = MetaAdapter(dim=cache_keys.shape[0]).to(clip_model.dtype).cuda() 76 | 77 | optimizer = torch.optim.AdamW(adapter.parameters(), lr=cfg['lr'], eps=1e-4) 78 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, cfg['train_epoch'] * len(train_loader_image)) 79 | 80 | best_acc, best_epoch = 0.0, 0 81 | 82 | query = clip_weights.T 83 | key = cache_keys.T.reshape(query.shape[0], -1, query.shape[1]) 84 | 85 | for train_idx in range(cfg['train_epoch']): 86 | # Train 87 | adapter.train() 88 | correct_samples, all_samples = 0, 0 89 | loss_list = [] 90 | print('Train Epoch: {:} / {:}'.format(train_idx, cfg['train_epoch'])) 91 | 92 | for i, (images, target) in enumerate(tqdm(train_loader_image)): 93 | images, target = images.cuda(), target.cuda() 94 | with torch.no_grad(): 95 | image_features = clip_model.encode_image(images) 96 | image_features /= image_features.norm(dim=-1, keepdim=True) 97 | 98 | weights = adapter(query, key, key) 99 | tip_logits = 100. * image_features @ weights.T 100 | 101 | loss = F.cross_entropy(tip_logits, target) 102 | 103 | acc = cls_acc(tip_logits, target) 104 | correct_samples += acc / 100 * len(tip_logits) 105 | all_samples += len(tip_logits) 106 | loss_list.append(loss.item()) 107 | 108 | optimizer.zero_grad() 109 | loss.backward() 110 | optimizer.step() 111 | scheduler.step() 112 | 113 | # update cache_keys 114 | with torch.no_grad(): 115 | for tar, feat in zip(target, image_features): 116 | key[tar] = torch.cat([feat[None, :], key[tar][:key.shape[1] - 1]], dim=0) 117 | 118 | current_lr = scheduler.get_last_lr()[0] 119 | print('LR: {:.6f}, Acc: {:.4f} ({:}/{:}), Loss: {:.4f}'.format(current_lr, correct_samples / all_samples, 120 | correct_samples, all_samples, 121 | sum(loss_list) / len(loss_list))) 122 | 123 | # Eval 124 | adapter.eval() 125 | 126 | query_test = clip_weights.T 127 | key_test = cache_keys.T.reshape(query_test.shape[0], -1, query_test.shape[1]) 128 | weights = adapter(query_test, key_test, key_test) 129 | tip_logits = 100. * test_features @ weights.T 130 | acc = cls_acc(tip_logits, test_labels) 131 | 132 | if acc > best_acc: 133 | best_acc = acc 134 | torch.save(adapter.state_dict(), cfg['cache_dir'] + "/best_meta_" + str(cfg['shots']) + "shots.pt") 135 | torch.save(key, cfg['cache_dir'] + "/keys" + str(cfg['shots']) + "shots.pt") 136 | 137 | print("**** Meta-Adapter's best accuracy: {:.2f}. ****".format(best_acc)) 138 | 139 | 140 | def main(): 141 | # Load config file 142 | args = get_arguments() 143 | assert (os.path.exists(args.config)) 144 | 145 | cfg = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 146 | 147 | cache_dir = os.path.join('./caches', cfg['dataset']) 148 | os.makedirs(cache_dir, exist_ok=True) 149 | cfg['cache_dir'] = cache_dir 150 | 151 | # CLIP 152 | clip_model, preprocess = clip.load(cfg['backbone']) 153 | clip_model.eval() 154 | 155 | # ImageNet dataset 156 | random.seed(1) 157 | torch.manual_seed(1) 158 | 159 | print("Preparing dataset.") 160 | dataset = build_dataset(cfg['dataset'], cfg['root_path'], cfg['shots']) 161 | 162 | if cfg['dataset'] == 'imagenet': 163 | test_loader = torch.utils.data.DataLoader(dataset.val, batch_size=64, num_workers=8, shuffle=False) 164 | train_loader_cache = torch.utils.data.DataLoader(dataset.full, batch_size=64, num_workers=8, shuffle=False) 165 | train_loader_F = torch.utils.data.DataLoader(dataset.train, batch_size=64, num_workers=8, shuffle=True) 166 | else: 167 | test_loader = build_data_loader(data_source=dataset.val, batch_size=64, is_train=False, tfm=preprocess, 168 | shuffle=False) 169 | train_tranform = transforms.Compose([ 170 | transforms.RandomResizedCrop(size=224, scale=(0.5, 1), interpolation=transforms.InterpolationMode.BICUBIC), 171 | transforms.RandomHorizontalFlip(p=0.5), 172 | transforms.ToTensor(), 173 | transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) 174 | ]) 175 | train_loader_cache = build_data_loader(data_source=dataset.full, batch_size=64, tfm=train_tranform, 176 | is_train=True, shuffle=False) 177 | train_loader_F = build_data_loader(data_source=dataset.train, batch_size=64, tfm=train_tranform, 178 | is_train=True, shuffle=True) 179 | 180 | # Textual features 181 | print("Getting textual features as CLIP's classifier.") 182 | clip_weights = clip_classifier(dataset.classnames, dataset.template, clip_model) 183 | 184 | # Construct the cache model by few-shot training set 185 | print("Constructing cache model by few-shot visual features and labels.") 186 | cache_keys = build_cache_model(cfg, clip_model, train_loader_cache) 187 | 188 | # Pre-load test features 189 | print("Loading visual features and labels from test set.") 190 | test_features, test_labels = pre_load_features(cfg, "test", clip_model, test_loader) 191 | 192 | run_meta_adapter(cfg, cache_keys, test_features, test_labels, clip_weights, clip_model, train_loader_F) 193 | 194 | 195 | if __name__ == '__main__': 196 | main() 197 | --------------------------------------------------------------------------------