├── LICENSE ├── README.md ├── clip ├── .DS_Store ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-311.pyc │ ├── __init__.cpython-38.pyc │ ├── clip.cpython-311.pyc │ ├── clip.cpython-38.pyc │ ├── model.cpython-311.pyc │ ├── model.cpython-38.pyc │ ├── simple_tokenizer.cpython-311.pyc │ └── simple_tokenizer.cpython-38.pyc ├── bpe_simple_vocab_16e6.txt.gz ├── clip.py ├── model.py └── simple_tokenizer.py ├── configs ├── caltech101.yaml ├── dtd.yaml ├── eurosat.yaml ├── fgvc.yaml ├── food101.yaml ├── imagenet.yaml ├── imagenet_a.yaml ├── imagenet_r.yaml ├── imagenet_s.yaml ├── imagenet_v.yaml ├── oxford_flowers.yaml ├── oxford_pets.yaml ├── stanford_cars.yaml ├── sun397.yaml └── ucf101.yaml ├── datasets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-311.pyc │ ├── __init__.cpython-38.pyc │ ├── augmix_ops.cpython-38.pyc │ ├── caltech101.cpython-38.pyc │ ├── dtd.cpython-38.pyc │ ├── eurosat.cpython-38.pyc │ ├── fgvc.cpython-38.pyc │ ├── food101.cpython-38.pyc │ ├── imagenet.cpython-38.pyc │ ├── imagenet_a.cpython-38.pyc │ ├── imagenet_r.cpython-38.pyc │ ├── imagenet_sketch.cpython-38.pyc │ ├── imagenetv2.cpython-38.pyc │ ├── oxford_flowers.cpython-38.pyc │ ├── oxford_pets.cpython-311.pyc │ ├── oxford_pets.cpython-38.pyc │ ├── stanford_cars.cpython-38.pyc │ ├── sun397.cpython-38.pyc │ ├── ucf101.cpython-38.pyc │ ├── utils.cpython-311.pyc │ └── utils.cpython-38.pyc ├── augmix_ops.py ├── caltech101.py ├── dtd.py ├── eurosat.py ├── fgvc.py ├── food101.py ├── imagenet.py ├── imagenet_a.py ├── imagenet_r.py ├── imagenet_sketch.py ├── imagenetv2.py ├── oxford_flowers.py ├── oxford_pets.py ├── stanford_cars.py ├── sun397.py ├── ucf101.py └── utils.py ├── docs ├── DATASETS.md ├── comparison.png └── overview.png ├── gpt3_prompts ├── CuPL_prompts_caltech101.json ├── CuPL_prompts_dtd.json ├── CuPL_prompts_eurosat.json ├── CuPL_prompts_fgvcaircraft.json ├── CuPL_prompts_flowers102.json ├── CuPL_prompts_food101.json ├── CuPL_prompts_imagenet.json ├── CuPL_prompts_oxfordpets.json ├── CuPL_prompts_stanfordcars.json ├── CuPL_prompts_sun397.json └── CuPL_prompts_ucf101.json ├── main_dpe.py ├── requirements.txt ├── scripts ├── run_cd_benchmark_rn50.sh ├── run_cd_benchmark_vit.sh ├── run_ood_benchmark_rn50.sh └── run_ood_benchmark_vit.sh └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Ce Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [NeurIPS 2024] DPE-CLIP 2 | [![Website](https://img.shields.io/badge/Project-Website-green)](https://zhangce01.github.io/DPE-CLIP/) 3 | [![arXiv](https://img.shields.io/badge/arXiv-2410.12790-red)](http://arxiv.org/abs/2410.12790) 4 | [![Conference](https://img.shields.io/badge/NeurIPS-2024-blue)](https://neurips.cc/) 5 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 6 | 7 | ## 👀Introduction 8 | This repository contains the code for our NeurIPS 2024 paper `Dual Prototype Evolving for Test-Time Generalization of Vision-Language Models`. [[Paper](https://arxiv.org/abs/2410.12790)] 9 | 10 | ![](docs/comparison.png) 11 | 12 | ## ⏳Setup 13 | 14 | #### 1. Environment 15 | We test our codebase with PyTorch 2.1.1 with CUDA 12.1. Please install corresponding PyTorch and CUDA versions according to your computational resources. Then install the rest of required packages by running `pip install -r requirements.txt`. Please install the info-nce-pytorch package following https://github.com/RElbers/info-nce-pytorch. 16 | 17 | #### 2. Dataset 18 | To set up all required datasets, kindly refer to the guidance in [DATASETS.md](docs/DATASETS.md), which incorporates steps for installing two benchmarks. 19 | 20 | ## 📦Usage 21 | To run the code, you can execute the following 4 bash scripts: 22 | 23 | #### Robustness to Natural Distribution Shifts 24 | * **ResNet50**: Run DPE on the OOD Benchmark using the ResNet-50 model: 25 | ``` 26 | bash ./scripts/run_ood_benchmark_rn50.sh 27 | ``` 28 | * **ViT/B-16**: Run DPE on the OOD Benchmark using the ViT/B-16 model. 29 | ``` 30 | bash ./scripts/run_ood_benchmark_vit.sh 31 | ``` 32 | 33 | #### Cross-Datasets Generalization 34 | * **ResNet50**: Run DPE on the Cross-Domain Benchmark using the ResNet-50 model: 35 | ``` 36 | bash ./scripts/run_cd_benchmark_rn50.sh 37 | ``` 38 | * **ViT/B-16**: Run DPE on the Cross-Domain Benchmark using the ViT/B-16 model. 39 | ``` 40 | bash ./scripts/run_cd_benchmark_vit.sh 41 | ``` 42 | 43 | #### Arguments 44 | In each bash script, you can modify the following arguments: (1) `--datasets` to specify the datasets, (2) `--backbone` to specify the backbone model (RN50 and ViT-B/16), and (3) `--coop` to enable the learned prompts by CoOp. We use `wandb` to track the results. If you wish to deactivate this feature, simply omit the `--wandb-log` argument. 45 | 46 | 47 | ![](docs/overview.png) 48 | 49 | ## 🙏Acknowledgements 50 | 51 | Our codebase is adapted from [Tip-Adapter](https://github.com/gaopengcuhk/Tip-Adapter/), [CLIP](https://github.com/openai/CLIP/tree/main/clip), [TDA](https://github.com/kdiAAA/TDA), [TPT](https://github.com/azshue/TPT), and [CuPL](https://github.com/sarahpratt/CuPL). We thank the authors for releasing their code! 52 | 53 | ## 📧Contact 54 | 55 | If you have any questions, please contact at [cezhang@cs.cmu.edu](mailto:cezhang@cs.cmu.edu). 56 | 57 | ## 📌 BibTeX & Citation 58 | 59 | If you find this code useful, please consider citing our work: 60 | 61 | ```bibtex 62 | @article{zhang2024dual, 63 | title={Dual prototype evolving for test-time generalization of vision-language models}, 64 | author={Zhang, Ce and Stepputtis, Simon and Sycara, Katia and Xie, Yaqi}, 65 | journal={Advances in Neural Information Processing Systems}, 66 | volume={37}, 67 | pages={32111--32136}, 68 | year={2024} 69 | } 70 | ``` 71 | -------------------------------------------------------------------------------- /clip/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/clip/.DS_Store -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /clip/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/clip/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /clip/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/clip/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /clip/__pycache__/clip.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/clip/__pycache__/clip.cpython-311.pyc -------------------------------------------------------------------------------- /clip/__pycache__/clip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/clip/__pycache__/clip.cpython-38.pyc -------------------------------------------------------------------------------- /clip/__pycache__/model.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/clip/__pycache__/model.cpython-311.pyc -------------------------------------------------------------------------------- /clip/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/clip/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /clip/__pycache__/simple_tokenizer.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/clip/__pycache__/simple_tokenizer.cpython-311.pyc -------------------------------------------------------------------------------- /clip/__pycache__/simple_tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/clip/__pycache__/simple_tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | try: 16 | from torchvision.transforms import InterpolationMode 17 | BICUBIC = InterpolationMode.BICUBIC 18 | except ImportError: 19 | BICUBIC = Image.BICUBIC 20 | 21 | 22 | if torch.__version__.split(".") < ["1", "7", "1"]: 23 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 24 | 25 | 26 | __all__ = ["available_models", "load", "tokenize"] 27 | _tokenizer = _Tokenizer() 28 | 29 | _MODELS = { 30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 34 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 35 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 36 | } 37 | 38 | 39 | def _download(url: str, root: str): 40 | os.makedirs(root, exist_ok=True) 41 | filename = os.path.basename(url) 42 | 43 | expected_sha256 = url.split("/")[-2] 44 | download_target = os.path.join(root, filename) 45 | 46 | if os.path.exists(download_target) and not os.path.isfile(download_target): 47 | raise RuntimeError(f"{download_target} exists and is not a regular file") 48 | 49 | if os.path.isfile(download_target): 50 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 51 | return download_target 52 | else: 53 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 54 | 55 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 56 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 57 | while True: 58 | buffer = source.read(8192) 59 | if not buffer: 60 | break 61 | 62 | output.write(buffer) 63 | loop.update(len(buffer)) 64 | 65 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 66 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 67 | 68 | return download_target 69 | 70 | 71 | def _convert_image_to_rgb(image): 72 | return image.convert("RGB") 73 | 74 | 75 | def _transform(n_px): 76 | return Compose([ 77 | Resize(n_px, interpolation=BICUBIC), 78 | CenterCrop(n_px), 79 | _convert_image_to_rgb, 80 | ToTensor(), 81 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 82 | ]) 83 | 84 | 85 | def available_models() -> List[str]: 86 | """Returns the names of available CLIP models""" 87 | return list(_MODELS.keys()) 88 | 89 | 90 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 91 | """Load a CLIP model 92 | 93 | Parameters 94 | ---------- 95 | name : str 96 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 97 | 98 | device : Union[str, torch.device] 99 | The device to put the loaded model 100 | 101 | jit : bool 102 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 103 | 104 | download_root: str 105 | path to download the model files; by default, it uses "~/.cache/clip" 106 | 107 | Returns 108 | ------- 109 | model : torch.nn.Module 110 | The CLIP model 111 | 112 | preprocess : Callable[[PIL.Image], torch.Tensor] 113 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 114 | """ 115 | if name in _MODELS: 116 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 117 | elif os.path.isfile(name): 118 | model_path = name 119 | else: 120 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 121 | 122 | try: 123 | # loading JIT archive 124 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 125 | state_dict = None 126 | except RuntimeError: 127 | # loading saved state dict 128 | if jit: 129 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 130 | jit = False 131 | state_dict = torch.load(model_path, map_location="cpu") 132 | 133 | if not jit: 134 | model = build_model(state_dict or model.state_dict()).to(device) 135 | if str(device) == "cpu": 136 | model.float() 137 | return model, _transform(model.visual.input_resolution) 138 | 139 | # patch the device names 140 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 141 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 142 | 143 | def patch_device(module): 144 | try: 145 | graphs = [module.graph] if hasattr(module, "graph") else [] 146 | except RuntimeError: 147 | graphs = [] 148 | 149 | if hasattr(module, "forward1"): 150 | graphs.append(module.forward1.graph) 151 | 152 | for graph in graphs: 153 | for node in graph.findAllNodes("prim::Constant"): 154 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 155 | node.copyAttributes(device_node) 156 | 157 | model.apply(patch_device) 158 | patch_device(model.encode_image) 159 | patch_device(model.encode_text) 160 | 161 | # patch dtype to float32 on CPU 162 | if str(device) == "cpu": 163 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 164 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 165 | float_node = float_input.node() 166 | 167 | def patch_float(module): 168 | try: 169 | graphs = [module.graph] if hasattr(module, "graph") else [] 170 | except RuntimeError: 171 | graphs = [] 172 | 173 | if hasattr(module, "forward1"): 174 | graphs.append(module.forward1.graph) 175 | 176 | for graph in graphs: 177 | for node in graph.findAllNodes("aten::to"): 178 | inputs = list(node.inputs()) 179 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 180 | if inputs[i].node()["value"] == 5: 181 | inputs[i].node().copyAttributes(float_node) 182 | 183 | model.apply(patch_float) 184 | patch_float(model.encode_image) 185 | patch_float(model.encode_text) 186 | 187 | model.float() 188 | 189 | return model, _transform(model.input_resolution.item()) 190 | 191 | 192 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 193 | """ 194 | Returns the tokenized representation of given input string(s) 195 | 196 | Parameters 197 | ---------- 198 | texts : Union[str, List[str]] 199 | An input string or a list of input strings to tokenize 200 | 201 | context_length : int 202 | The context length to use; all CLIP models use 77 as the context length 203 | 204 | truncate: bool 205 | Whether to truncate the text in case its encoding is longer than the context length 206 | 207 | Returns 208 | ------- 209 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 210 | """ 211 | if isinstance(texts, str): 212 | texts = [texts] 213 | 214 | sot_token = _tokenizer.encoder["<|startoftext|>"] 215 | eot_token = _tokenizer.encoder["<|endoftext|>"] 216 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 217 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 218 | 219 | for i, tokens in enumerate(all_tokens): 220 | if len(tokens) > context_length: 221 | if truncate: 222 | tokens = tokens[:context_length] 223 | tokens[-1] = eot_token 224 | else: 225 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 226 | result[i, :len(tokens)] = torch.tensor(tokens) 227 | 228 | return result 229 | -------------------------------------------------------------------------------- /clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | 20 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 24 | 25 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 26 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 27 | 28 | self.relu = nn.ReLU(inplace=True) 29 | self.downsample = None 30 | self.stride = stride 31 | 32 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 33 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 34 | self.downsample = nn.Sequential(OrderedDict([ 35 | ("-1", nn.AvgPool2d(stride)), 36 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 37 | ("1", nn.BatchNorm2d(planes * self.expansion)) 38 | ])) 39 | 40 | def forward(self, x: torch.Tensor): 41 | identity = x 42 | 43 | out = self.relu(self.bn1(self.conv1(x))) 44 | out = self.relu(self.bn2(self.conv2(out))) 45 | out = self.avgpool(out) 46 | out = self.bn3(self.conv3(out)) 47 | 48 | if self.downsample is not None: 49 | identity = self.downsample(x) 50 | 51 | out += identity 52 | out = self.relu(out) 53 | return out 54 | 55 | 56 | class AttentionPool2d(nn.Module): 57 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 58 | super().__init__() 59 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 60 | self.k_proj = nn.Linear(embed_dim, embed_dim) 61 | self.q_proj = nn.Linear(embed_dim, embed_dim) 62 | self.v_proj = nn.Linear(embed_dim, embed_dim) 63 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 64 | self.num_heads = num_heads 65 | 66 | def forward(self, x): 67 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 68 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 69 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 70 | x, _ = F.multi_head_attention_forward( 71 | query=x, key=x, value=x, 72 | embed_dim_to_check=x.shape[-1], 73 | num_heads=self.num_heads, 74 | q_proj_weight=self.q_proj.weight, 75 | k_proj_weight=self.k_proj.weight, 76 | v_proj_weight=self.v_proj.weight, 77 | in_proj_weight=None, 78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 79 | bias_k=None, 80 | bias_v=None, 81 | add_zero_attn=False, 82 | dropout_p=0, 83 | out_proj_weight=self.c_proj.weight, 84 | out_proj_bias=self.c_proj.bias, 85 | use_separate_proj_weight=True, 86 | training=self.training, 87 | need_weights=False 88 | ) 89 | 90 | return x[0] 91 | 92 | 93 | class ModifiedResNet(nn.Module): 94 | """ 95 | A ResNet class that is similar to torchvision's but contains the following changes: 96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 98 | - The final pooling layer is a QKV attention instead of an average pool 99 | """ 100 | 101 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 102 | super().__init__() 103 | self.output_dim = output_dim 104 | self.input_resolution = input_resolution 105 | 106 | # the 3-layer stem 107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 108 | self.bn1 = nn.BatchNorm2d(width // 2) 109 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 110 | self.bn2 = nn.BatchNorm2d(width // 2) 111 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 112 | self.bn3 = nn.BatchNorm2d(width) 113 | self.avgpool = nn.AvgPool2d(2) 114 | self.relu = nn.ReLU(inplace=True) 115 | 116 | # residual layers 117 | self._inplanes = width # this is a *mutable* variable used during construction 118 | self.layer1 = self._make_layer(width, layers[0]) 119 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 120 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 121 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 122 | 123 | embed_dim = width * 32 # the ResNet feature dimension 124 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 125 | 126 | def _make_layer(self, planes, blocks, stride=1): 127 | layers = [Bottleneck(self._inplanes, planes, stride)] 128 | 129 | self._inplanes = planes * Bottleneck.expansion 130 | for _ in range(1, blocks): 131 | layers.append(Bottleneck(self._inplanes, planes)) 132 | 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | def stem(x): 137 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 138 | x = self.relu(bn(conv(x))) 139 | x = self.avgpool(x) 140 | return x 141 | 142 | x = x.type(self.conv1.weight.dtype) 143 | x = stem(x) 144 | x = self.layer1(x) 145 | x = self.layer2(x) 146 | x = self.layer3(x) 147 | x = self.layer4(x) 148 | x = self.attnpool(x) 149 | 150 | return x 151 | 152 | 153 | class LayerNorm(nn.LayerNorm): 154 | """Subclass torch's LayerNorm to handle fp16.""" 155 | 156 | def forward(self, x: torch.Tensor): 157 | orig_type = x.dtype 158 | ret = super().forward(x.type(torch.float32)) 159 | return ret.type(orig_type) 160 | 161 | 162 | class QuickGELU(nn.Module): 163 | def forward(self, x: torch.Tensor): 164 | return x * torch.sigmoid(1.702 * x) 165 | 166 | 167 | class ResidualAttentionBlock(nn.Module): 168 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 169 | super().__init__() 170 | 171 | self.attn = nn.MultiheadAttention(d_model, n_head) 172 | self.ln_1 = LayerNorm(d_model) 173 | self.mlp = nn.Sequential(OrderedDict([ 174 | ("c_fc", nn.Linear(d_model, d_model * 4)), 175 | ("gelu", QuickGELU()), 176 | ("c_proj", nn.Linear(d_model * 4, d_model)) 177 | ])) 178 | self.ln_2 = LayerNorm(d_model) 179 | self.attn_mask = attn_mask 180 | 181 | def attention(self, x: torch.Tensor): 182 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 183 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 184 | 185 | def forward(self, x: torch.Tensor): 186 | x = x + self.attention(self.ln_1(x)) 187 | x = x + self.mlp(self.ln_2(x)) 188 | return x 189 | 190 | 191 | class Transformer(nn.Module): 192 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 193 | super().__init__() 194 | self.width = width 195 | self.layers = layers 196 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 197 | 198 | def forward(self, x: torch.Tensor): 199 | return self.resblocks(x) 200 | 201 | 202 | class VisionTransformer(nn.Module): 203 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 204 | super().__init__() 205 | self.input_resolution = input_resolution 206 | self.output_dim = output_dim 207 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 208 | 209 | scale = width ** -0.5 210 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 211 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 212 | self.ln_pre = LayerNorm(width) 213 | 214 | self.transformer = Transformer(width, layers, heads) 215 | 216 | self.ln_post = LayerNorm(width) 217 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 218 | 219 | def forward(self, x: torch.Tensor): 220 | x = self.conv1(x) # shape = [*, width, grid, grid] 221 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 222 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 223 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 224 | x = x + self.positional_embedding.to(x.dtype) 225 | x = self.ln_pre(x) 226 | 227 | x = x.permute(1, 0, 2) # NLD -> LND 228 | x = self.transformer(x) 229 | x = x.permute(1, 0, 2) # LND -> NLD 230 | 231 | x = self.ln_post(x[:, 0, :]) 232 | 233 | if self.proj is not None: 234 | x = x @ self.proj 235 | 236 | return x 237 | 238 | 239 | class CLIP(nn.Module): 240 | def __init__(self, 241 | embed_dim: int, 242 | # vision 243 | image_resolution: int, 244 | vision_layers: Union[Tuple[int, int, int, int], int], 245 | vision_width: int, 246 | vision_patch_size: int, 247 | # text 248 | context_length: int, 249 | vocab_size: int, 250 | transformer_width: int, 251 | transformer_heads: int, 252 | transformer_layers: int 253 | ): 254 | super().__init__() 255 | 256 | self.context_length = context_length 257 | 258 | if isinstance(vision_layers, (tuple, list)): 259 | vision_heads = vision_width * 32 // 64 260 | self.visual = ModifiedResNet( 261 | layers=vision_layers, 262 | output_dim=embed_dim, 263 | heads=vision_heads, 264 | input_resolution=image_resolution, 265 | width=vision_width 266 | ) 267 | else: 268 | vision_heads = vision_width // 64 269 | self.visual = VisionTransformer( 270 | input_resolution=image_resolution, 271 | patch_size=vision_patch_size, 272 | width=vision_width, 273 | layers=vision_layers, 274 | heads=vision_heads, 275 | output_dim=embed_dim 276 | ) 277 | 278 | self.transformer = Transformer( 279 | width=transformer_width, 280 | layers=transformer_layers, 281 | heads=transformer_heads, 282 | attn_mask=self.build_attention_mask() 283 | ) 284 | 285 | self.vocab_size = vocab_size 286 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 287 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 288 | self.ln_final = LayerNorm(transformer_width) 289 | 290 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 291 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 292 | 293 | self.initialize_parameters() 294 | 295 | def initialize_parameters(self): 296 | nn.init.normal_(self.token_embedding.weight, std=0.02) 297 | nn.init.normal_(self.positional_embedding, std=0.01) 298 | 299 | if isinstance(self.visual, ModifiedResNet): 300 | if self.visual.attnpool is not None: 301 | std = self.visual.attnpool.c_proj.in_features ** -0.5 302 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 303 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 304 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 305 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 306 | 307 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 308 | for name, param in resnet_block.named_parameters(): 309 | if name.endswith("bn3.weight"): 310 | nn.init.zeros_(param) 311 | 312 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 313 | attn_std = self.transformer.width ** -0.5 314 | fc_std = (2 * self.transformer.width) ** -0.5 315 | for block in self.transformer.resblocks: 316 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 317 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 318 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 319 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 320 | 321 | if self.text_projection is not None: 322 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 323 | 324 | def build_attention_mask(self): 325 | # lazily create causal attention mask, with full attention between the vision tokens 326 | # pytorch uses additive attention mask; fill with -inf 327 | mask = torch.empty(self.context_length, self.context_length) 328 | mask.fill_(float("-inf")) 329 | mask.triu_(1) # zero out the lower diagonal 330 | return mask 331 | 332 | @property 333 | def dtype(self): 334 | return self.visual.conv1.weight.dtype 335 | 336 | def encode_image(self, image): 337 | return self.visual(image.type(self.dtype)) 338 | 339 | def encode_text(self, text): 340 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 341 | 342 | x = x + self.positional_embedding.type(self.dtype) 343 | x = x.permute(1, 0, 2) # NLD -> LND 344 | x = self.transformer(x) 345 | x = x.permute(1, 0, 2) # LND -> NLD 346 | x = self.ln_final(x).type(self.dtype) 347 | 348 | # x.shape = [batch_size, n_ctx, transformer.width] 349 | # take features from the eot embedding (eot_token is the highest number in each sequence) 350 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 351 | 352 | return x 353 | 354 | def forward(self, image, text): 355 | image_features = self.encode_image(image) 356 | text_features = self.encode_text(text) 357 | 358 | # normalized features 359 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 360 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 361 | 362 | # cosine similarity as logits 363 | logit_scale = self.logit_scale.exp() 364 | logits_per_image = logit_scale * image_features @ text_features.t() 365 | logits_per_text = logits_per_image.t() 366 | 367 | # shape = [global_batch_size, global_batch_size] 368 | return logits_per_image, logits_per_text 369 | 370 | 371 | def convert_weights(model: nn.Module): 372 | """Convert applicable model parameters to fp16""" 373 | 374 | def _convert_weights_to_fp16(l): 375 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 376 | l.weight.data = l.weight.data.half() 377 | if l.bias is not None: 378 | l.bias.data = l.bias.data.half() 379 | 380 | if isinstance(l, nn.MultiheadAttention): 381 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 382 | tensor = getattr(l, attr) 383 | if tensor is not None: 384 | tensor.data = tensor.data.half() 385 | 386 | for name in ["text_projection", "proj"]: 387 | if hasattr(l, name): 388 | attr = getattr(l, name) 389 | if attr is not None: 390 | attr.data = attr.data.half() 391 | 392 | model.apply(_convert_weights_to_fp16) 393 | 394 | 395 | def build_model(state_dict: dict): 396 | vit = "visual.proj" in state_dict 397 | 398 | if vit: 399 | vision_width = state_dict["visual.conv1.weight"].shape[0] 400 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 401 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 402 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 403 | image_resolution = vision_patch_size * grid_size 404 | else: 405 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 406 | vision_layers = tuple(counts) 407 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 408 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 409 | vision_patch_size = None 410 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 411 | image_resolution = output_width * 32 412 | 413 | embed_dim = state_dict["text_projection"].shape[1] 414 | context_length = state_dict["positional_embedding"].shape[0] 415 | vocab_size = state_dict["token_embedding.weight"].shape[0] 416 | transformer_width = state_dict["ln_final.weight"].shape[0] 417 | transformer_heads = transformer_width // 64 418 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 419 | 420 | model = CLIP( 421 | embed_dim, 422 | image_resolution, vision_layers, vision_width, vision_patch_size, 423 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 424 | ) 425 | 426 | for key in ["input_resolution", "context_length", "vocab_size"]: 427 | if key in state_dict: 428 | del state_dict[key] 429 | 430 | convert_weights(model) 431 | model.load_state_dict(state_dict) 432 | return model.eval() 433 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /configs/caltech101.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for DPE Hyperparameters 2 | 3 | # --- Positive Cache Configuration --- 4 | positive: 5 | enabled: True 6 | shot_capacity: 3 7 | alpha: 15.0 8 | beta: 5.0 9 | 10 | learning_rate: 11 | text: 0.0006 12 | image: 0.0006 13 | align: 0.2 -------------------------------------------------------------------------------- /configs/dtd.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for DPE Hyperparameters 2 | 3 | # --- Positive Cache Configuration --- 4 | positive: 5 | enabled: True 6 | shot_capacity: 3 7 | alpha: 6.0 8 | beta: 3.0 9 | 10 | learning_rate: 11 | text: 0.0006 12 | image: 0.0006 13 | align: 0.2 -------------------------------------------------------------------------------- /configs/eurosat.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for DPE Hyperparameters 2 | 3 | # --- Positive Cache Configuration --- 4 | positive: 5 | enabled: True 6 | shot_capacity: 3 7 | alpha: 3.0 8 | beta: 8.0 9 | 10 | learning_rate: 11 | text: 0.00005 12 | image: 0.00005 13 | align: 0.2 -------------------------------------------------------------------------------- /configs/fgvc.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for DPE Hyperparameters 2 | 3 | # --- Positive Cache Configuration --- 4 | positive: 5 | enabled: True 6 | shot_capacity: 3 7 | alpha: 6.0 8 | beta: 2.0 9 | 10 | learning_rate: 11 | text: 0.0006 12 | image: 0.0006 13 | align: 0.2 -------------------------------------------------------------------------------- /configs/food101.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for DPE Hyperparameters 2 | 3 | # --- Positive Cache Configuration --- 4 | positive: 5 | enabled: True 6 | shot_capacity: 3 7 | alpha: 3.0 8 | beta: 1.0 9 | 10 | learning_rate: 11 | text: 0.0002 12 | image: 0.0002 13 | align: 0.2 -------------------------------------------------------------------------------- /configs/imagenet.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for DPE Hyperparameters 2 | 3 | # --- Positive Cache Configuration --- 4 | positive: 5 | enabled: True 6 | shot_capacity: 3 7 | alpha: 6.0 8 | beta: 5.0 9 | 10 | learning_rate: 11 | text: 0.0006 12 | image: 0.0006 13 | align: 0.5 -------------------------------------------------------------------------------- /configs/imagenet_a.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for DPE Hyperparameters 2 | 3 | # --- Positive Cache Configuration --- 4 | positive: 5 | enabled: True 6 | shot_capacity: 3 7 | alpha: 6.0 8 | beta: 5.0 9 | 10 | learning_rate: 11 | text: 0.0003 12 | image: 0.0003 13 | align: 2.5 -------------------------------------------------------------------------------- /configs/imagenet_r.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for DPE Hyperparameters 2 | 3 | # --- Positive Cache Configuration --- 4 | positive: 5 | enabled: True 6 | shot_capacity: 3 7 | alpha: 3.0 8 | beta: 8.0 9 | 10 | learning_rate: 11 | text: 0.0006 12 | image: 0.0006 13 | align: 0.0 -------------------------------------------------------------------------------- /configs/imagenet_s.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for DPE Hyperparameters 2 | 3 | # --- Positive Cache Configuration --- 4 | positive: 5 | enabled: True 6 | shot_capacity: 3 7 | alpha: 7 8 | beta: 7.45 9 | 10 | learning_rate: 11 | text: 0.0006 12 | image: 0.0006 13 | align: 0.2 -------------------------------------------------------------------------------- /configs/imagenet_v.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for DPE Hyperparameters 2 | 3 | # --- Positive Cache Configuration --- 4 | positive: 5 | enabled: True 6 | shot_capacity: 3 7 | alpha: 3.0 8 | beta: 8.0 9 | 10 | learning_rate: 11 | text: 0.0005 12 | image: 0.0005 13 | align: 0.5 -------------------------------------------------------------------------------- /configs/oxford_flowers.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for DPE Hyperparameters 2 | 3 | # --- Positive Cache Configuration --- 4 | positive: 5 | enabled: True 6 | shot_capacity: 3 7 | alpha: 2.5 8 | beta: 5.0 9 | 10 | learning_rate: 11 | text: 0.0001 12 | image: 0.0001 13 | align: 0.2 -------------------------------------------------------------------------------- /configs/oxford_pets.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for DPE Hyperparameters 2 | 3 | # --- Positive Cache Configuration --- 4 | positive: 5 | enabled: True 6 | shot_capacity: 3 7 | alpha: 4.0 8 | beta: 7.0 9 | 10 | learning_rate: 11 | text: 0.00005 12 | image: 0.00005 13 | align: 0.2 -------------------------------------------------------------------------------- /configs/stanford_cars.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for DPE Hyperparameters 2 | 3 | # --- Positive Cache Configuration --- 4 | positive: 5 | enabled: True 6 | shot_capacity: 3 7 | alpha: 3.0 8 | beta: 7.0 9 | 10 | learning_rate: 11 | text: 0.0001 12 | image: 0.0001 13 | align: 1.5 -------------------------------------------------------------------------------- /configs/sun397.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for DPE Hyperparameters 2 | 3 | # --- Positive Cache Configuration --- 4 | positive: 5 | enabled: True 6 | shot_capacity: 3 7 | alpha: 6.0 8 | beta: 3.0 9 | 10 | learning_rate: 11 | text: 0.0002 12 | image: 0.0002 13 | align: 0.2 -------------------------------------------------------------------------------- /configs/ucf101.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for DPE Hyperparameters 2 | 3 | # --- Positive Cache Configuration --- 4 | positive: 5 | enabled: True 6 | shot_capacity: 3 7 | alpha: 9.0 8 | beta: 8.0 9 | 10 | learning_rate: 11 | text: 0.0004 12 | image: 0.0004 13 | align: 0.2 -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .oxford_pets import OxfordPets 2 | from .eurosat import EuroSAT 3 | from .ucf101 import UCF101 4 | from .sun397 import SUN397 5 | from .caltech101 import Caltech101 6 | from .dtd import DescribableTextures 7 | from .fgvc import FGVCAircraft 8 | from .food101 import Food101 9 | from .oxford_flowers import OxfordFlowers 10 | from .stanford_cars import StanfordCars 11 | from .imagenetv2 import ImageNetV2 12 | from .imagenet_a import ImageNetA 13 | from .imagenet_r import ImageNetR 14 | from .imagenet_sketch import ImageNetSketch 15 | 16 | 17 | dataset_list = { 18 | "oxford_pets": OxfordPets, 19 | "eurosat": EuroSAT, 20 | "ucf101": UCF101, 21 | "sun397": SUN397, 22 | "caltech101": Caltech101, 23 | "dtd": DescribableTextures, 24 | "fgvc": FGVCAircraft, 25 | "food101": Food101, 26 | "oxford_flowers": OxfordFlowers, 27 | "stanford_cars": StanfordCars, 28 | "imagenet-a": ImageNetA, 29 | "imagenet-v": ImageNetV2, 30 | "imagenet-r": ImageNetR, 31 | "imagenet-s": ImageNetSketch, 32 | } 33 | 34 | 35 | def build_dataset(dataset, root_path): 36 | return dataset_list[dataset](root_path) -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/datasets/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/augmix_ops.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/datasets/__pycache__/augmix_ops.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/caltech101.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/datasets/__pycache__/caltech101.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dtd.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/datasets/__pycache__/dtd.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/eurosat.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/datasets/__pycache__/eurosat.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/fgvc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/datasets/__pycache__/fgvc.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/food101.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/datasets/__pycache__/food101.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/datasets/__pycache__/imagenet.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet_a.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/datasets/__pycache__/imagenet_a.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet_r.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/datasets/__pycache__/imagenet_r.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet_sketch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/datasets/__pycache__/imagenet_sketch.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenetv2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/datasets/__pycache__/imagenetv2.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/oxford_flowers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/datasets/__pycache__/oxford_flowers.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/oxford_pets.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/datasets/__pycache__/oxford_pets.cpython-311.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/oxford_pets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/datasets/__pycache__/oxford_pets.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/stanford_cars.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/datasets/__pycache__/stanford_cars.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sun397.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/datasets/__pycache__/sun397.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/ucf101.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/datasets/__pycache__/ucf101.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/datasets/__pycache__/utils.cpython-311.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/datasets/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/augmix_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Base augmentations operators.""" 16 | 17 | import numpy as np 18 | from PIL import Image, ImageOps, ImageEnhance 19 | 20 | # ImageNet code should change this value 21 | IMAGE_SIZE = 224 22 | 23 | 24 | def int_parameter(level, maxval): 25 | """Helper function to scale `val` between 0 and maxval . 26 | Args: 27 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 28 | maxval: Maximum value that the operation can have. This will be scaled to 29 | level/PARAMETER_MAX. 30 | Returns: 31 | An int that results from scaling `maxval` according to `level`. 32 | """ 33 | return int(level * maxval / 10) 34 | 35 | 36 | def float_parameter(level, maxval): 37 | """Helper function to scale `val` between 0 and maxval. 38 | Args: 39 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 40 | maxval: Maximum value that the operation can have. This will be scaled to 41 | level/PARAMETER_MAX. 42 | Returns: 43 | A float that results from scaling `maxval` according to `level`. 44 | """ 45 | return float(level) * maxval / 10. 46 | 47 | 48 | def sample_level(n): 49 | return np.random.uniform(low=0.1, high=n) 50 | 51 | 52 | def autocontrast(pil_img, _): 53 | return ImageOps.autocontrast(pil_img) 54 | 55 | 56 | def equalize(pil_img, _): 57 | return ImageOps.equalize(pil_img) 58 | 59 | 60 | def posterize(pil_img, level): 61 | level = int_parameter(sample_level(level), 4) 62 | return ImageOps.posterize(pil_img, 4 - level) 63 | 64 | 65 | def rotate(pil_img, level): 66 | degrees = int_parameter(sample_level(level), 30) 67 | if np.random.uniform() > 0.5: 68 | degrees = -degrees 69 | return pil_img.rotate(degrees, resample=Image.BILINEAR) 70 | 71 | 72 | def solarize(pil_img, level): 73 | level = int_parameter(sample_level(level), 256) 74 | return ImageOps.solarize(pil_img, 256 - level) 75 | 76 | 77 | def shear_x(pil_img, level): 78 | level = float_parameter(sample_level(level), 0.3) 79 | if np.random.uniform() > 0.5: 80 | level = -level 81 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), 82 | Image.AFFINE, (1, level, 0, 0, 1, 0), 83 | resample=Image.BILINEAR) 84 | 85 | 86 | def shear_y(pil_img, level): 87 | level = float_parameter(sample_level(level), 0.3) 88 | if np.random.uniform() > 0.5: 89 | level = -level 90 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), 91 | Image.AFFINE, (1, 0, 0, level, 1, 0), 92 | resample=Image.BILINEAR) 93 | 94 | 95 | def translate_x(pil_img, level): 96 | level = int_parameter(sample_level(level), IMAGE_SIZE / 3) 97 | if np.random.random() > 0.5: 98 | level = -level 99 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), 100 | Image.AFFINE, (1, 0, level, 0, 1, 0), 101 | resample=Image.BILINEAR) 102 | 103 | 104 | def translate_y(pil_img, level): 105 | level = int_parameter(sample_level(level), IMAGE_SIZE / 3) 106 | if np.random.random() > 0.5: 107 | level = -level 108 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), 109 | Image.AFFINE, (1, 0, 0, 0, 1, level), 110 | resample=Image.BILINEAR) 111 | 112 | 113 | # operation that overlaps with ImageNet-C's test set 114 | def color(pil_img, level): 115 | level = float_parameter(sample_level(level), 1.8) + 0.1 116 | return ImageEnhance.Color(pil_img).enhance(level) 117 | 118 | 119 | # operation that overlaps with ImageNet-C's test set 120 | def contrast(pil_img, level): 121 | level = float_parameter(sample_level(level), 1.8) + 0.1 122 | return ImageEnhance.Contrast(pil_img).enhance(level) 123 | 124 | 125 | # operation that overlaps with ImageNet-C's test set 126 | def brightness(pil_img, level): 127 | level = float_parameter(sample_level(level), 1.8) + 0.1 128 | return ImageEnhance.Brightness(pil_img).enhance(level) 129 | 130 | 131 | # operation that overlaps with ImageNet-C's test set 132 | def sharpness(pil_img, level): 133 | level = float_parameter(sample_level(level), 1.8) + 0.1 134 | return ImageEnhance.Sharpness(pil_img).enhance(level) 135 | 136 | 137 | augmentations = [ 138 | autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, 139 | translate_x, translate_y 140 | ] 141 | 142 | augmentations_all = [ 143 | autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, 144 | translate_x, translate_y, color, contrast, brightness, sharpness 145 | ] -------------------------------------------------------------------------------- /datasets/caltech101.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import DatasetBase 4 | from .oxford_pets import OxfordPets 5 | 6 | 7 | template = ["itap of a {}.", 8 | "a bad photo of the {}.", 9 | "a origami {}.", 10 | "a photo of the large {}.", 11 | "a {} in a video game.", 12 | "art of the {}.", 13 | "a photo of the small {}."] 14 | 15 | 16 | class Caltech101(DatasetBase): 17 | 18 | dataset_dir = 'caltech-101' 19 | 20 | def __init__(self, root): 21 | self.dataset_dir = os.path.join(root, self.dataset_dir) 22 | self.image_dir = os.path.join(self.dataset_dir, '101_ObjectCategories') 23 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_Caltech101.json') 24 | self.cupl_path = './gpt3_prompts/CuPL_prompts_caltech101.json' 25 | 26 | self.template = template 27 | 28 | test = OxfordPets.read_split(self.split_path, self.image_dir) 29 | 30 | super().__init__(test=test) -------------------------------------------------------------------------------- /datasets/dtd.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import DatasetBase 4 | from .oxford_pets import OxfordPets 5 | 6 | 7 | template = ['{} texture.'] 8 | 9 | class DescribableTextures(DatasetBase): 10 | 11 | dataset_dir = 'dtd' 12 | 13 | def __init__(self, root): 14 | self.dataset_dir = os.path.join(root, self.dataset_dir) 15 | self.image_dir = os.path.join(self.dataset_dir, 'images') 16 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_DescribableTextures.json') 17 | self.cupl_path = './gpt3_prompts/CuPL_prompts_dtd.json' 18 | 19 | self.template = template 20 | 21 | test = OxfordPets.read_split(self.split_path, self.image_dir) 22 | 23 | super().__init__(test=test) 24 | -------------------------------------------------------------------------------- /datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import DatasetBase 4 | from .oxford_pets import OxfordPets 5 | 6 | 7 | template = ['a centered satellite photo of {}.'] 8 | 9 | NEW_CLASSNAMES = { 10 | 'AnnualCrop': 'Annual Crop Land', 11 | 'Forest': 'Forest', 12 | 'HerbaceousVegetation': 'Herbaceous Vegetation Land', 13 | 'Highway': 'Highway or Road', 14 | 'Industrial': 'Industrial Buildings', 15 | 'Pasture': 'Pasture Land', 16 | 'PermanentCrop': 'Permanent Crop Land', 17 | 'Residential': 'Residential Buildings', 18 | 'River': 'River', 19 | 'SeaLake': 'Sea or Lake' 20 | } 21 | 22 | 23 | class EuroSAT(DatasetBase): 24 | 25 | dataset_dir = 'eurosat' 26 | 27 | def __init__(self, root): 28 | self.dataset_dir = os.path.join(root, self.dataset_dir) 29 | self.image_dir = os.path.join(self.dataset_dir, '2750') 30 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_EuroSAT.json') 31 | self.cupl_path = './gpt3_prompts/CuPL_prompts_eurosat.json' 32 | 33 | self.template = template 34 | 35 | test = OxfordPets.read_split(self.split_path, self.image_dir) 36 | 37 | super().__init__(test=test) 38 | -------------------------------------------------------------------------------- /datasets/fgvc.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase 4 | 5 | 6 | template = ['a photo of a {}, a type of aircraft.'] 7 | 8 | class FGVCAircraft(DatasetBase): 9 | 10 | dataset_dir = 'fgvc_aircraft' 11 | 12 | def __init__(self, root): 13 | 14 | self.dataset_dir = os.path.join(root, self.dataset_dir) 15 | self.image_dir = os.path.join(self.dataset_dir, 'images') 16 | 17 | self.template = template 18 | self.cupl_path = './gpt3_prompts/CuPL_prompts_fgvcaircraft.json' 19 | 20 | classnames = [] 21 | with open(os.path.join(self.dataset_dir, 'variants.txt'), 'r') as f: 22 | lines = f.readlines() 23 | for line in lines: 24 | classnames.append(line.strip()) 25 | cname2lab = {c: i for i, c in enumerate(classnames)} 26 | 27 | test = self.read_data(cname2lab, 'images_variant_test.txt') 28 | 29 | 30 | super().__init__(test=test) 31 | 32 | def read_data(self, cname2lab, split_file): 33 | filepath = os.path.join(self.dataset_dir, split_file) 34 | items = [] 35 | 36 | with open(filepath, 'r') as f: 37 | lines = f.readlines() 38 | for line in lines: 39 | line = line.strip().split(' ') 40 | imname = line[0] + '.jpg' 41 | classname = ' '.join(line[1:]) 42 | impath = os.path.join(self.image_dir, imname) 43 | label = cname2lab[classname] 44 | item = Datum( 45 | impath=impath, 46 | label=label, 47 | classname=classname 48 | ) 49 | items.append(item) 50 | 51 | return items -------------------------------------------------------------------------------- /datasets/food101.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import DatasetBase 4 | from .oxford_pets import OxfordPets 5 | 6 | template = ['a photo of {}, a type of food.'] 7 | 8 | class Food101(DatasetBase): 9 | 10 | dataset_dir = 'food-101' 11 | 12 | def __init__(self, root): 13 | self.dataset_dir = os.path.join(root, self.dataset_dir) 14 | self.image_dir = os.path.join(self.dataset_dir, 'images') 15 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_Food101.json') 16 | self.cupl_path = './gpt3_prompts/CuPL_prompts_food101.json' 17 | 18 | self.template = template 19 | 20 | test = OxfordPets.read_split(self.split_path, self.image_dir) 21 | 22 | super().__init__(test=test) -------------------------------------------------------------------------------- /datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torchvision 3 | from collections import OrderedDict 4 | 5 | 6 | 7 | imagenet_classes = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", 8 | "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", 9 | "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", 10 | "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", 11 | "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", 12 | "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", 13 | "box turtle", "banded gecko", "green iguana", "Carolina anole", 14 | "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", 15 | "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", 16 | "American alligator", "triceratops", "worm snake", "ring-necked snake", 17 | "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", 18 | "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", 19 | "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", 20 | "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", 21 | "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", 22 | "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", 23 | "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", 24 | "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", 25 | "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", 26 | "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", 27 | "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", 28 | "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", 29 | "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", 30 | "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", 31 | "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", 32 | "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", 33 | "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", 34 | "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", 35 | "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", 36 | "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", 37 | "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", 38 | "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", 39 | "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", 40 | "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", 41 | "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", 42 | "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", 43 | "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", 44 | "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", 45 | "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", 46 | "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", 47 | "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", 48 | "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", 49 | "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", 50 | "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", 51 | "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", 52 | "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", 53 | "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", 54 | "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", 55 | "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", 56 | "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", 57 | "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", 58 | "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", 59 | "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", 60 | "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", 61 | "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", 62 | "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", 63 | "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", 64 | "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", 65 | "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", 66 | "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", 67 | "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", 68 | "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", 69 | "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", 70 | "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", 71 | "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", 72 | "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", 73 | "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", 74 | "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", 75 | "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", 76 | "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", 77 | "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", 78 | "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", 79 | "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", 80 | "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", 81 | "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", 82 | "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", 83 | "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", 84 | "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", 85 | "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", 86 | "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", 87 | "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", 88 | "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", 89 | "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", 90 | "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", 91 | "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", 92 | "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", 93 | "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", 94 | "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", 95 | "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", 96 | "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", 97 | "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", 98 | "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", 99 | "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", 100 | "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", 101 | "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", 102 | "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", 103 | "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", 104 | "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", 105 | "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", 106 | "freight car", "French horn", "frying pan", "fur coat", "garbage truck", 107 | "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", 108 | "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", 109 | "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", 110 | "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", 111 | "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", 112 | "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", 113 | "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", 114 | "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", 115 | "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", 116 | "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", 117 | "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", 118 | "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", 119 | "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", 120 | "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", 121 | "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", 122 | "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", 123 | "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", 124 | "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", 125 | "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", 126 | "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", 127 | "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", 128 | "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", 129 | "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", 130 | "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", 131 | "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", 132 | "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", 133 | "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", 134 | "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", 135 | "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", 136 | "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", 137 | "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", 138 | "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", 139 | "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", 140 | "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", 141 | "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", 142 | "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", 143 | "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", 144 | "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", 145 | "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", 146 | "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", 147 | "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", 148 | "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", 149 | "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", 150 | "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", 151 | "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", 152 | "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", 153 | "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", 154 | "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", 155 | "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", 156 | "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", 157 | "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", 158 | "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", 159 | "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", 160 | "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", 161 | "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", 162 | "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", 163 | "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", 164 | "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", 165 | "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", 166 | "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", 167 | "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", 168 | "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", 169 | "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", 170 | "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", 171 | "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"] 172 | 173 | imagenet_templates = ["itap of a {}.", 174 | "a bad photo of the {}.", 175 | "a origami {}.", 176 | "a photo of the large {}.", 177 | "a {} in a video game.", 178 | "art of the {}.", 179 | "a photo of the small {}."] 180 | 181 | class ImageNet(): 182 | 183 | dataset_dir = 'imagenet' 184 | 185 | def __init__(self, root, preprocess): 186 | 187 | self.dataset_dir = os.path.join(root, self.dataset_dir) 188 | self.image_dir = os.path.join(self.dataset_dir, 'images') 189 | 190 | test_preprocess = preprocess 191 | 192 | self.test = torchvision.datasets.ImageNet(self.image_dir, split='val', transform=test_preprocess) 193 | 194 | self.template = imagenet_templates 195 | self.classnames = imagenet_classes 196 | self.cupl_path = './gpt3_prompts/CuPL_prompts_imagenet.json' 197 | 198 | def read_classnames(text_file): 199 | """Return a dictionary containing 200 | key-value pairs of : . 201 | """ 202 | classnames = OrderedDict() 203 | with open(text_file, "r") as f: 204 | lines = f.readlines() 205 | for line in lines: 206 | line = line.strip().split(" ") 207 | folder = line[0] 208 | classname = " ".join(line[1:]) 209 | classnames[folder] = classname 210 | return classnames 211 | -------------------------------------------------------------------------------- /datasets/imagenet_a.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .utils import Datum, DatasetBase, listdir_nohidden 3 | 4 | from .imagenet import ImageNet 5 | 6 | TO_BE_IGNORED = ["README.txt"] 7 | template = [ 8 | "itap of a {}.", 9 | "a bad photo of the {}.", 10 | "a origami {}.", 11 | "a photo of the large {}.", 12 | "a {} in a video game.", 13 | "art of the {}.", 14 | "a photo of the small {}.", 15 | ] 16 | 17 | class ImageNetA(DatasetBase): 18 | """ImageNet-A(dversarial). 19 | 20 | This dataset is used for testing only. 21 | """ 22 | 23 | dataset_dir = "imagenet-adversarial" 24 | 25 | def __init__(self, root): 26 | root = os.path.abspath(os.path.expanduser(root)) 27 | self.dataset_dir = os.path.join(root, self.dataset_dir) 28 | self.image_dir = os.path.join(self.dataset_dir, "imagenet-a") 29 | self.template = template 30 | self.cupl_path = './gpt3_prompts/CuPL_prompts_imagenet.json' 31 | 32 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 33 | classnames = ImageNet.read_classnames(text_file) 34 | 35 | data = self.read_data(classnames) 36 | 37 | super().__init__(test=data) 38 | 39 | def read_data(self, classnames): 40 | image_dir = self.image_dir 41 | folders = listdir_nohidden(image_dir, sort=True) 42 | folders = [f for f in folders if f not in TO_BE_IGNORED] 43 | items = [] 44 | 45 | for label, folder in enumerate(folders): 46 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 47 | classname = classnames[folder] 48 | for imname in imnames: 49 | impath = os.path.join(image_dir, folder, imname) 50 | item = Datum(impath=impath, label=label, classname=classname) 51 | items.append(item) 52 | 53 | return items -------------------------------------------------------------------------------- /datasets/imagenet_r.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .utils import Datum, DatasetBase, listdir_nohidden 3 | 4 | from .imagenet import ImageNet 5 | 6 | TO_BE_IGNORED = ["README.txt"] 7 | template = ['a bad photo of a {}.', 8 | 'a photo of many {}.', 9 | 'a sculpture of a {}.', 10 | 'a photo of the hard to see {}.', 11 | 'a low resolution photo of the {}.', 12 | 'a rendering of a {}.', 13 | 'graffiti of a {}.', 14 | 'a bad photo of the {}.', 15 | 'a cropped photo of the {}.', 16 | 'a tattoo of a {}.', 17 | 'the embroidered {}.', 18 | 'a photo of a hard to see {}.', 19 | 'a bright photo of a {}.', 20 | 'a photo of a clean {}.', 21 | 'a photo of a dirty {}.', 22 | 'a dark photo of the {}.', 23 | 'a drawing of a {}.', 24 | 'a photo of my {}.', 25 | 'the plastic {}.', 26 | 'a photo of the cool {}.', 27 | 'a close-up photo of a {}.', 28 | 'a black and white photo of the {}.', 29 | 'a painting of the {}.', 30 | 'a painting of a {}.', 31 | 'a pixelated photo of the {}.', 32 | 'a sculpture of the {}.', 33 | 'a bright photo of the {}.', 34 | 'a cropped photo of a {}.', 35 | 'a plastic {}.', 36 | 'a photo of the dirty {}.', 37 | 'a jpeg corrupted photo of a {}.', 38 | 'a blurry photo of the {}.', 39 | 'a photo of the {}.', 40 | 'a good photo of the {}.', 41 | 'a rendering of the {}.', 42 | 'a {} in a video game.', 43 | 'a photo of one {}.', 44 | 'a doodle of a {}.', 45 | 'a close-up photo of the {}.', 46 | 'a photo of a {}.', 47 | 'the origami {}.', 48 | 'the {} in a video game.', 49 | 'a sketch of a {}.', 50 | 'a doodle of the {}.', 51 | 'a origami {}.', 52 | 'a low resolution photo of a {}.', 53 | 'the toy {}.', 54 | 'a rendition of the {}.', 55 | 'a photo of the clean {}.', 56 | 'a photo of a large {}.', 57 | 'a rendition of a {}.', 58 | 'a photo of a nice {}.', 59 | 'a photo of a weird {}.', 60 | 'a blurry photo of a {}.', 61 | 'a cartoon {}.', 62 | 'art of a {}.', 63 | 'a sketch of the {}.', 64 | 'a embroidered {}.', 65 | 'a pixelated photo of a {}.', 66 | 'itap of the {}.', 67 | 'a jpeg corrupted photo of the {}.', 68 | 'a good photo of a {}.', 69 | 'a plushie {}.', 70 | 'a photo of the nice {}.', 71 | 'a photo of the small {}.', 72 | 'a photo of the weird {}.', 73 | 'the cartoon {}.', 74 | 'art of the {}.', 75 | 'a drawing of the {}.', 76 | 'a photo of the large {}.', 77 | 'a black and white photo of a {}.', 78 | 'the plushie {}.', 79 | 'a dark photo of a {}.', 80 | 'itap of a {}.', 81 | 'graffiti of the {}.', 82 | 'a toy {}.', 83 | 'itap of my {}.', 84 | 'a photo of a cool {}.', 85 | 'a photo of a small {}.', 86 | 'a tattoo of the {}.'] 87 | 88 | class ImageNetR(DatasetBase): 89 | """ImageNet-R(endition). 90 | 91 | This dataset is used for testing only. 92 | """ 93 | 94 | dataset_dir = "imagenet-rendition" 95 | 96 | def __init__(self, root): 97 | root = os.path.abspath(os.path.expanduser(root)) 98 | self.dataset_dir = os.path.join(root, self.dataset_dir) 99 | self.image_dir = os.path.join(self.dataset_dir, "imagenet-r") 100 | self.template = template 101 | self.cupl_path = './gpt3_prompts/CuPL_prompts_imagenet.json' 102 | 103 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 104 | classnames = ImageNet.read_classnames(text_file) 105 | 106 | data = self.read_data(classnames) 107 | 108 | super().__init__(test=data) 109 | 110 | 111 | def read_data(self, classnames): 112 | image_dir = self.image_dir 113 | folders = listdir_nohidden(image_dir, sort=True) 114 | folders = [f for f in folders if f not in TO_BE_IGNORED] 115 | items = [] 116 | 117 | for label, folder in enumerate(folders): 118 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 119 | classname = classnames[folder] 120 | for imname in imnames: 121 | impath = os.path.join(image_dir, folder, imname) 122 | item = Datum(impath=impath, label=label, classname=classname) 123 | items.append(item) 124 | 125 | return items -------------------------------------------------------------------------------- /datasets/imagenet_sketch.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .utils import Datum, DatasetBase, listdir_nohidden 3 | 4 | from .imagenet import ImageNet 5 | 6 | template = ['a bad photo of a {}.', 7 | 'a photo of many {}.', 8 | 'a sculpture of a {}.', 9 | 'a photo of the hard to see {}.', 10 | 'a low resolution photo of the {}.', 11 | 'a rendering of a {}.', 12 | 'graffiti of a {}.', 13 | 'a bad photo of the {}.', 14 | 'a cropped photo of the {}.', 15 | 'a tattoo of a {}.', 16 | 'the embroidered {}.', 17 | 'a photo of a hard to see {}.', 18 | 'a bright photo of a {}.', 19 | 'a photo of a clean {}.', 20 | 'a photo of a dirty {}.', 21 | 'a dark photo of the {}.', 22 | 'a drawing of a {}.', 23 | 'a photo of my {}.', 24 | 'the plastic {}.', 25 | 'a photo of the cool {}.', 26 | 'a close-up photo of a {}.', 27 | 'a black and white photo of the {}.', 28 | 'a painting of the {}.', 29 | 'a painting of a {}.', 30 | 'a pixelated photo of the {}.', 31 | 'a sculpture of the {}.', 32 | 'a bright photo of the {}.', 33 | 'a cropped photo of a {}.', 34 | 'a plastic {}.', 35 | 'a photo of the dirty {}.', 36 | 'a jpeg corrupted photo of a {}.', 37 | 'a blurry photo of the {}.', 38 | 'a photo of the {}.', 39 | 'a good photo of the {}.', 40 | 'a rendering of the {}.', 41 | 'a {} in a video game.', 42 | 'a photo of one {}.', 43 | 'a doodle of a {}.', 44 | 'a close-up photo of the {}.', 45 | 'a photo of a {}.', 46 | 'the origami {}.', 47 | 'the {} in a video game.', 48 | 'a sketch of a {}.', 49 | 'a doodle of the {}.', 50 | 'a origami {}.', 51 | 'a low resolution photo of a {}.', 52 | 'the toy {}.', 53 | 'a rendition of the {}.', 54 | 'a photo of the clean {}.', 55 | 'a photo of a large {}.', 56 | 'a rendition of a {}.', 57 | 'a photo of a nice {}.', 58 | 'a photo of a weird {}.', 59 | 'a blurry photo of a {}.', 60 | 'a cartoon {}.', 61 | 'art of a {}.', 62 | 'a sketch of the {}.', 63 | 'a embroidered {}.', 64 | 'a pixelated photo of a {}.', 65 | 'itap of the {}.', 66 | 'a jpeg corrupted photo of the {}.', 67 | 'a good photo of a {}.', 68 | 'a plushie {}.', 69 | 'a photo of the nice {}.', 70 | 'a photo of the small {}.', 71 | 'a photo of the weird {}.', 72 | 'the cartoon {}.', 73 | 'art of the {}.', 74 | 'a drawing of the {}.', 75 | 'a photo of the large {}.', 76 | 'a black and white photo of a {}.', 77 | 'the plushie {}.', 78 | 'a dark photo of a {}.', 79 | 'itap of a {}.', 80 | 'graffiti of the {}.', 81 | 'a toy {}.', 82 | 'itap of my {}.', 83 | 'a photo of a cool {}.', 84 | 'a photo of a small {}.', 85 | 'a tattoo of the {}.'] 86 | 87 | 88 | class ImageNetSketch(DatasetBase): 89 | """ImageNet-Sketch. 90 | 91 | This dataset is used for testing only. 92 | """ 93 | 94 | dataset_dir = "imagenet-sketch" 95 | 96 | def __init__(self, root): 97 | root = os.path.abspath(os.path.expanduser(root)) 98 | self.dataset_dir = os.path.join(root, self.dataset_dir) 99 | self.image_dir = os.path.join(self.dataset_dir, "images") 100 | self.template = template 101 | self.cupl_path = './gpt3_prompts/CuPL_prompts_imagenet.json' 102 | 103 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 104 | classnames = ImageNet.read_classnames(text_file) 105 | 106 | data = self.read_data(classnames) 107 | 108 | super().__init__(test=data) 109 | 110 | def read_data(self, classnames): 111 | image_dir = self.image_dir 112 | folders = listdir_nohidden(image_dir, sort=True) 113 | items = [] 114 | 115 | for label, folder in enumerate(folders): 116 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 117 | classname = classnames[folder] 118 | for imname in imnames: 119 | impath = os.path.join(image_dir, folder, imname) 120 | item = Datum(impath=impath, label=label, classname=classname) 121 | items.append(item) 122 | 123 | return items 124 | -------------------------------------------------------------------------------- /datasets/imagenetv2.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .utils import Datum, DatasetBase, listdir_nohidden 3 | 4 | from .imagenet import ImageNet 5 | 6 | template = ["itap of a {}.", 7 | "a bad photo of the {}.", 8 | "a origami {}.", 9 | "a photo of the large {}.", 10 | "a {} in a video game.", 11 | "art of the {}.", 12 | "a photo of the small {}."] 13 | 14 | class ImageNetV2(DatasetBase): 15 | """ImageNetV2. 16 | 17 | This dataset is used for testing only. 18 | """ 19 | 20 | dataset_dir = "imagenetv2" 21 | 22 | def __init__(self, root): 23 | root = os.path.abspath(os.path.expanduser(root)) 24 | self.dataset_dir = os.path.join(root, self.dataset_dir) 25 | image_dir = "imagenetv2-matched-frequency-format-val" 26 | self.image_dir = os.path.join(self.dataset_dir, image_dir) 27 | self.template = template 28 | self.cupl_path = './gpt3_prompts/CuPL_prompts_imagenet.json' 29 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 30 | classnames = ImageNet.read_classnames(text_file) 31 | 32 | data = self.read_data(classnames) 33 | 34 | super().__init__(test=data) 35 | 36 | def read_data(self, classnames): 37 | image_dir = self.image_dir 38 | folders = list(classnames.keys()) 39 | items = [] 40 | 41 | for label in range(1000): 42 | class_dir = os.path.join(image_dir, str(label)) 43 | imnames = listdir_nohidden(class_dir) 44 | folder = folders[label] 45 | classname = classnames[folder] 46 | for imname in imnames: 47 | impath = os.path.join(class_dir, imname) 48 | item = Datum(impath=impath, label=label, classname=classname) 49 | items.append(item) 50 | 51 | return items -------------------------------------------------------------------------------- /datasets/oxford_flowers.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .oxford_pets import OxfordPets 4 | from .utils import DatasetBase 5 | 6 | 7 | template = ['a photo of a {}, a type of flower.'] 8 | 9 | class OxfordFlowers(DatasetBase): 10 | 11 | dataset_dir = 'flowers102' 12 | 13 | def __init__(self, root): 14 | self.dataset_dir = os.path.join(root, self.dataset_dir) 15 | self.image_dir = os.path.join(self.dataset_dir, 'jpg') 16 | self.label_file = os.path.join(self.dataset_dir, 'imagelabels.mat') 17 | self.lab2cname_file = os.path.join(self.dataset_dir, 'cat_to_name.json') 18 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_OxfordFlowers.json') 19 | self.cupl_path = './gpt3_prompts/CuPL_prompts_flowers102.json' 20 | 21 | self.template = template 22 | 23 | test = OxfordPets.read_split(self.split_path, self.image_dir) 24 | 25 | super().__init__(test=test) -------------------------------------------------------------------------------- /datasets/oxford_pets.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase, read_json 4 | 5 | 6 | template = ['a photo of a {}, a type of pet.'] 7 | 8 | class OxfordPets(DatasetBase): 9 | 10 | dataset_dir = 'oxfordpets' 11 | 12 | def __init__(self, root): 13 | self.dataset_dir = os.path.join(root, self.dataset_dir) 14 | self.image_dir = os.path.join(self.dataset_dir, 'images') 15 | self.anno_dir = os.path.join(self.dataset_dir, 'annotations') 16 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_OxfordPets.json') 17 | self.cupl_path = './gpt3_prompts/CuPL_prompts_oxfordpets.json' 18 | 19 | self.template = template 20 | 21 | test = self.read_split(self.split_path, self.image_dir) 22 | 23 | super().__init__(test=test) 24 | 25 | @staticmethod 26 | def read_split(filepath, path_prefix): 27 | def _convert(items): 28 | out = [] 29 | for impath, label, classname in items: 30 | impath = os.path.join(path_prefix, impath) 31 | item = Datum( 32 | impath=impath, 33 | label=int(label), 34 | classname=classname 35 | ) 36 | out.append(item) 37 | return out 38 | 39 | print(f'Reading split from {filepath}') 40 | split = read_json(filepath) 41 | test = _convert(split['test']) 42 | 43 | return test -------------------------------------------------------------------------------- /datasets/stanford_cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .oxford_pets import OxfordPets 4 | from .utils import DatasetBase 5 | 6 | 7 | template = ['a photo of a {}.', 8 | 'A {} featuring a wide range of color options for easy selection.'] 9 | 10 | class StanfordCars(DatasetBase): 11 | 12 | dataset_dir = 'stanfordcars' 13 | 14 | def __init__(self, root): 15 | self.dataset_dir = os.path.join(root, self.dataset_dir) 16 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_StanfordCars.json') 17 | self.cupl_path = './gpt3_prompts/CuPL_prompts_stanfordcars.json' 18 | 19 | self.template = template 20 | 21 | test = OxfordPets.read_split(self.split_path, self.dataset_dir) 22 | 23 | super().__init__(test=test) -------------------------------------------------------------------------------- /datasets/sun397.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import DatasetBase 4 | 5 | from .oxford_pets import OxfordPets 6 | 7 | 8 | template = ["itap of a {}.", 9 | "a bad photo of the {}.", 10 | "a origami {}.", 11 | "a photo of the large {}.", 12 | "a {} in a video game.", 13 | "art of the {}.", 14 | "a photo of the small {}."] 15 | 16 | class SUN397(DatasetBase): 17 | 18 | dataset_dir = 'sun397' 19 | 20 | def __init__(self, root): 21 | self.dataset_dir = os.path.join(root, self.dataset_dir) 22 | self.image_dir = os.path.join(self.dataset_dir, 'SUN397') 23 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_SUN397.json') 24 | self.cupl_path = './gpt3_prompts/CuPL_prompts_sun397.json' 25 | 26 | self.template = template 27 | 28 | test = OxfordPets.read_split(self.split_path, self.image_dir) 29 | 30 | super().__init__(test=test) 31 | -------------------------------------------------------------------------------- /datasets/ucf101.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import DatasetBase 4 | 5 | from .oxford_pets import OxfordPets 6 | 7 | 8 | template = ['a photo of a person doing {}.'] 9 | 10 | class UCF101(DatasetBase): 11 | 12 | dataset_dir = 'ucf101' 13 | 14 | def __init__(self, root): 15 | self.dataset_dir = os.path.join(root, self.dataset_dir) 16 | self.image_dir = os.path.join(self.dataset_dir, 'UCF-101-midframes') 17 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_UCF101.json') 18 | self.cupl_path = './gpt3_prompts/CuPL_prompts_ucf101.json' 19 | 20 | self.template = template 21 | 22 | test = OxfordPets.read_split(self.split_path, self.image_dir) 23 | 24 | super().__init__(test=test) 25 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import tarfile 4 | import zipfile 5 | from collections import defaultdict 6 | import gdown 7 | import json 8 | import torch 9 | from torch.utils.data import Dataset as TorchDataset 10 | import torchvision.transforms as T 11 | from PIL import Image 12 | 13 | import numpy as np 14 | import torchvision.transforms as transforms 15 | from datasets.augmix_ops import augmentations 16 | 17 | 18 | def listdir_nohidden(path, sort=False): 19 | """List non-hidden items in a directory. 20 | Args: 21 | path (str): directory path. 22 | sort (bool): sort the items. 23 | """ 24 | items = [f for f in os.listdir(path) if not f.startswith(".")] 25 | if sort: 26 | items.sort() 27 | return items 28 | 29 | def read_json(fpath): 30 | """Read json file from a path.""" 31 | with open(fpath, 'r') as f: 32 | obj = json.load(f) 33 | return obj 34 | 35 | 36 | def write_json(obj, fpath): 37 | """Writes to a json file.""" 38 | if not osp.exists(osp.dirname(fpath)): 39 | os.makedirs(osp.dirname(fpath)) 40 | with open(fpath, 'w') as f: 41 | json.dump(obj, f, indent=4, separators=(',', ': ')) 42 | 43 | 44 | def read_image(path): 45 | """Read image from path using ``PIL.Image``. 46 | 47 | Args: 48 | path (str): path to an image. 49 | 50 | Returns: 51 | PIL image 52 | """ 53 | if not osp.exists(path): 54 | raise IOError('No file exists at {}'.format(path)) 55 | 56 | while True: 57 | try: 58 | img = Image.open(path).convert('RGB') 59 | return img 60 | except IOError: 61 | print( 62 | 'Cannot read image from {}, ' 63 | 'probably due to heavy IO. Will re-try'.format(path) 64 | ) 65 | 66 | 67 | def listdir_nohidden(path, sort=False): 68 | """List non-hidden items in a directory. 69 | 70 | Args: 71 | path (str): directory path. 72 | sort (bool): sort the items. 73 | """ 74 | items = [f for f in os.listdir(path) if not f.startswith('.') and 'sh' not in f] 75 | if sort: 76 | items.sort() 77 | return items 78 | 79 | 80 | class Datum: 81 | """Data instance which defines the basic attributes. 82 | 83 | Args: 84 | impath (str): image path. 85 | label (int): class label. 86 | domain (int): domain label. 87 | classname (str): class name. 88 | """ 89 | 90 | def __init__(self, impath='', label=0, domain=-1, classname=''): 91 | assert isinstance(impath, str) 92 | assert isinstance(label, int) 93 | assert isinstance(domain, int) 94 | assert isinstance(classname, str) 95 | 96 | self._impath = impath 97 | self._label = label 98 | self._domain = domain 99 | self._classname = classname 100 | 101 | @property 102 | def impath(self): 103 | return self._impath 104 | 105 | @property 106 | def label(self): 107 | return self._label 108 | 109 | @property 110 | def domain(self): 111 | return self._domain 112 | 113 | @property 114 | def classname(self): 115 | return self._classname 116 | 117 | 118 | class DatasetBase: 119 | """A unified dataset class for 120 | 1) domain adaptation 121 | 2) domain generalization 122 | 3) semi-supervised learning 123 | """ 124 | dataset_dir = '' # the directory where the dataset is stored 125 | domains = [] # string names of all domains 126 | 127 | def __init__(self, train_x=None, train_u=None, val=None, test=None): 128 | self._train_x = train_x # labeled training data 129 | self._train_u = train_u # unlabeled training data (optional) 130 | self._val = val # validation data (optional) 131 | self._test = test # test data 132 | 133 | self._num_classes = self.get_num_classes(test) 134 | self._lab2cname, self._classnames = self.get_lab2cname(test) 135 | 136 | @property 137 | def train_x(self): 138 | return self._train_x 139 | 140 | @property 141 | def train_u(self): 142 | return self._train_u 143 | 144 | @property 145 | def val(self): 146 | return self._val 147 | 148 | @property 149 | def test(self): 150 | return self._test 151 | 152 | @property 153 | def lab2cname(self): 154 | return self._lab2cname 155 | 156 | @property 157 | def classnames(self): 158 | return self._classnames 159 | 160 | @property 161 | def num_classes(self): 162 | return self._num_classes 163 | 164 | def get_num_classes(self, data_source): 165 | """Count number of classes. 166 | 167 | Args: 168 | data_source (list): a list of Datum objects. 169 | """ 170 | label_set = set() 171 | for item in data_source: 172 | label_set.add(item.label) 173 | return max(label_set) + 1 174 | 175 | def get_lab2cname(self, data_source): 176 | """Get a label-to-classname mapping (dict). 177 | 178 | Args: 179 | data_source (list): a list of Datum objects. 180 | """ 181 | container = set() 182 | for item in data_source: 183 | container.add((item.label, item.classname)) 184 | mapping = {label: classname for label, classname in container} 185 | labels = list(mapping.keys()) 186 | labels.sort() 187 | classnames = [mapping[label] for label in labels] 188 | return mapping, classnames 189 | 190 | def check_input_domains(self, source_domains, target_domains): 191 | self.is_input_domain_valid(source_domains) 192 | self.is_input_domain_valid(target_domains) 193 | 194 | def is_input_domain_valid(self, input_domains): 195 | for domain in input_domains: 196 | if domain not in self.domains: 197 | raise ValueError( 198 | 'Input domain must belong to {}, ' 199 | 'but got [{}]'.format(self.domains, domain) 200 | ) 201 | 202 | def download_data(self, url, dst, from_gdrive=True): 203 | if not osp.exists(osp.dirname(dst)): 204 | os.makedirs(osp.dirname(dst)) 205 | 206 | if from_gdrive: 207 | gdown.download(url, dst, quiet=False) 208 | else: 209 | raise NotImplementedError 210 | 211 | print('Extracting file ...') 212 | 213 | try: 214 | tar = tarfile.open(dst) 215 | tar.extractall(path=osp.dirname(dst)) 216 | tar.close() 217 | except: 218 | zip_ref = zipfile.ZipFile(dst, 'r') 219 | zip_ref.extractall(osp.dirname(dst)) 220 | zip_ref.close() 221 | 222 | print('File extracted to {}'.format(osp.dirname(dst))) 223 | 224 | 225 | def split_dataset_by_label(self, data_source): 226 | """Split a dataset, i.e. a list of Datum objects, 227 | into class-specific groups stored in a dictionary. 228 | 229 | Args: 230 | data_source (list): a list of Datum objects. 231 | """ 232 | output = defaultdict(list) 233 | 234 | for item in data_source: 235 | output[item.label].append(item) 236 | 237 | return output 238 | 239 | def split_dataset_by_domain(self, data_source): 240 | """Split a dataset, i.e. a list of Datum objects, 241 | into domain-specific groups stored in a dictionary. 242 | 243 | Args: 244 | data_source (list): a list of Datum objects. 245 | """ 246 | output = defaultdict(list) 247 | 248 | for item in data_source: 249 | output[item.domain].append(item) 250 | 251 | return output 252 | 253 | 254 | class DatasetWrapper(TorchDataset): 255 | def __init__(self, data_source, input_size, transform=None, is_train=False, 256 | return_img0=False, k_tfm=1): 257 | self.data_source = data_source 258 | self.transform = transform # accept list (tuple) as input 259 | self.is_train = is_train 260 | # Augmenting an image K>1 times is only allowed during training 261 | self.k_tfm = k_tfm if is_train else 1 262 | self.return_img0 = return_img0 263 | 264 | if self.k_tfm > 1 and transform is None: 265 | raise ValueError( 266 | 'Cannot augment the image {} times ' 267 | 'because transform is None'.format(self.k_tfm) 268 | ) 269 | 270 | # Build transform that doesn't apply any data augmentation 271 | interp_mode = T.InterpolationMode.BICUBIC 272 | to_tensor = [] 273 | to_tensor += [T.Resize(input_size, interpolation=interp_mode)] 274 | to_tensor += [T.ToTensor()] 275 | normalize = T.Normalize( 276 | mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) 277 | ) 278 | to_tensor += [normalize] 279 | self.to_tensor = T.Compose(to_tensor) 280 | 281 | def __len__(self): 282 | return len(self.data_source) 283 | 284 | def __getitem__(self, idx): 285 | item = self.data_source[idx] 286 | 287 | output = { 288 | 'label': item.label, 289 | 'domain': item.domain, 290 | 'impath': item.impath 291 | } 292 | 293 | img0 = read_image(item.impath) 294 | 295 | if self.transform is not None: 296 | if isinstance(self.transform, (list, tuple)): 297 | for i, tfm in enumerate(self.transform): 298 | img = self._transform_image(tfm, img0) 299 | keyname = 'img' 300 | if (i + 1) > 1: 301 | keyname += str(i + 1) 302 | output[keyname] = img 303 | else: 304 | img = self._transform_image(self.transform, img0) 305 | output['img'] = img 306 | 307 | if self.return_img0: 308 | output['img0'] = self.to_tensor(img0) 309 | 310 | return output['img'], output['label'] 311 | 312 | def _transform_image(self, tfm, img0): 313 | img_list = [] 314 | 315 | for k in range(self.k_tfm): 316 | img_list.append(tfm(img0)) 317 | 318 | img = img_list 319 | if len(img) == 1: 320 | img = img[0] 321 | 322 | return img 323 | 324 | 325 | def build_data_loader( 326 | data_source=None, 327 | batch_size=64, 328 | input_size=224, 329 | tfm=None, 330 | is_train=True, 331 | shuffle=False, 332 | dataset_wrapper=None 333 | ): 334 | 335 | if dataset_wrapper is None: 336 | dataset_wrapper = DatasetWrapper 337 | 338 | # Build data loader 339 | data_loader = torch.utils.data.DataLoader( 340 | dataset_wrapper(data_source, input_size=input_size, transform=tfm, is_train=is_train), 341 | batch_size=batch_size, 342 | num_workers=8, 343 | shuffle=shuffle, 344 | drop_last=False, 345 | pin_memory=True 346 | ) 347 | assert len(data_loader) > 0 348 | 349 | return data_loader 350 | 351 | 352 | def get_preaugment(): 353 | return transforms.Compose([ 354 | transforms.RandomResizedCrop(224), 355 | transforms.RandomHorizontalFlip(), 356 | ]) 357 | 358 | 359 | def augmix(image, preprocess, aug_list, severity=1): 360 | preaugment = get_preaugment() 361 | x_orig = preaugment(image) 362 | x_processed = preprocess(x_orig) 363 | if len(aug_list) == 0: 364 | return x_processed 365 | w = np.float32(np.random.dirichlet([1.0, 1.0, 1.0])) 366 | m = np.float32(np.random.beta(1.0, 1.0)) 367 | 368 | mix = torch.zeros_like(x_processed) 369 | for i in range(3): 370 | x_aug = x_orig.copy() 371 | for _ in range(np.random.randint(1, 4)): 372 | x_aug = np.random.choice(aug_list)(x_aug, severity) 373 | mix += w[i] * preprocess(x_aug) 374 | mix = m * x_processed + (1 - m) * mix 375 | return mix 376 | 377 | 378 | class AugMixAugmenter(object): 379 | def __init__(self, base_transform, preprocess, n_views=2, augmix=False, 380 | severity=1): 381 | self.base_transform = base_transform 382 | self.preprocess = preprocess 383 | self.n_views = n_views 384 | if augmix: 385 | self.aug_list = augmentations 386 | else: 387 | self.aug_list = [] 388 | self.severity = severity 389 | 390 | def __call__(self, x): 391 | image = self.preprocess(self.base_transform(x)) 392 | views = [augmix(x, self.preprocess, self.aug_list, self.severity) for _ in range(self.n_views)] 393 | 394 | return [image] + views -------------------------------------------------------------------------------- /docs/DATASETS.md: -------------------------------------------------------------------------------- 1 | # How to install datasets 2 | ### Acknowledgement: We extend our gratitude to the authors of [CoOp/CoCoOp](https://github.com/KaiyangZhou/CoOp/blob/main/DATASETS.md#how-to-install-datasets) for their foundational work on data preparation instructions. In this project, we have adopted their data preparation format, implementing certain modifications to better suit our needs. 3 | 4 | We suggest putting all datasets under the same folder (say `$DATA`) to ease management and following the instructions below to organize datasets to avoid modifying the source code. The file structure looks like 5 | 6 | ``` 7 | $DATA/ 8 | |–– imagenet/ 9 | |–– caltech-101/ 10 | |–– oxford_pets/ 11 | |–– stanford_cars/ 12 | ``` 13 | 14 | If you have some datasets already installed somewhere else, you can create symbolic links in `$DATA/dataset_name` that point to the original data to avoid duplicate download. 15 | 16 | Datasets list: 17 | - [ImageNet](#imagenet) 18 | - [Caltech101](#caltech101) 19 | - [OxfordPets](#oxfordpets) 20 | - [StanfordCars](#stanfordcars) 21 | - [Flowers102](#flowers102) 22 | - [Food101](#food101) 23 | - [FGVCAircraft](#fgvcaircraft) 24 | - [SUN397](#sun397) 25 | - [DTD](#dtd) 26 | - [EuroSAT](#eurosat) 27 | - [UCF101](#ucf101) 28 | - [ImageNetV2](#imagenetv2) 29 | - [ImageNet-Sketch](#imagenet-sketch) 30 | - [ImageNet-A](#imagenet-a) 31 | - [ImageNet-R](#imagenet-r) 32 | 33 | The instructions to prepare each dataset are detailed below. To ensure reproducibility and fair comparison for future work, we provide fixed train/val/test splits for all datasets except ImageNet where the validation set is used as test set. The fixed splits are either from the original datasets (if available) or created by us. 34 | 35 | ### ImageNet 36 | - Create a folder named `imagenet/` under `$DATA`. 37 | - Create `images/` under `imagenet/`. 38 | - Download the dataset from the [official website](https://image-net.org/index.php) and extract the validation set to `$DATA/imagenet/images`. The directory structure should look like 39 | ``` 40 | imagenet/ 41 | |–– images/ 42 | | |–– val/ # contains 1,000 folders like n01440764, n01443537, etc. 43 | ``` 44 | - If you had downloaded the ImageNet dataset before, you can create symbolic links to map the validation set to `$DATA/imagenet/images`. 45 | - Download the `classnames.txt` to `$DATA/imagenet/` from this [link](https://drive.google.com/file/d/1-61f_ol79pViBFDG_IDlUQSwoLcn2XXF/view?usp=sharing). The class names are copied from [CLIP](https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb). 46 | 47 | ### Caltech101 48 | - Create a folder named `caltech-101/` under `$DATA`. 49 | - Download `101_ObjectCategories.tar.gz` from http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz and extract the file under `$DATA/caltech-101`. 50 | - Download `split_zhou_Caltech101.json` from this [link](https://drive.google.com/file/d/1hyarUivQE36mY6jSomru6Fjd-JzwcCzN/view?usp=sharing) and put it under `$DATA/caltech-101`. 51 | 52 | The directory structure should look like 53 | ``` 54 | caltech-101/ 55 | |–– 101_ObjectCategories/ 56 | |–– split_zhou_Caltech101.json 57 | ``` 58 | 59 | ### OxfordPets 60 | - Create a folder named `oxford_pets/` under `$DATA`. 61 | - Download the images from https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz. 62 | - Download the annotations from https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz. 63 | - Download `split_zhou_OxfordPets.json` from this [link](https://drive.google.com/file/d/1501r8Ber4nNKvmlFVQZ8SeUHTcdTTEqs/view?usp=sharing). 64 | 65 | The directory structure should look like 66 | ``` 67 | oxford_pets/ 68 | |–– images/ 69 | |–– annotations/ 70 | |–– split_zhou_OxfordPets.json 71 | ``` 72 | 73 | ### StanfordCars 74 | - Create a folder named `stanford_cars/` under `$DATA`. 75 | - Download the test images http://ai.stanford.edu/~jkrause/car196/cars_test.tgz. 76 | - Download the test labels http://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat. 77 | - Download `split_zhou_StanfordCars.json` from this [link](https://drive.google.com/file/d/1ObCFbaAgVu0I-k_Au-gIUcefirdAuizT/view?usp=sharing). 78 | 79 | The directory structure should look like 80 | ``` 81 | stanford_cars/ 82 | |–– cars_test\ 83 | |–– cars_test_annos_withlabels.mat 84 | |–– devkit\ 85 | |–– split_zhou_StanfordCars.json 86 | ``` 87 | 88 | ### Flowers102 89 | - Create a folder named `oxford_flowers/` under `$DATA`. 90 | - Download the images and labels from https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz and https://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat respectively. 91 | - Download `cat_to_name.json` from [here](https://drive.google.com/file/d/1AkcxCXeK_RCGCEC_GvmWxjcjaNhu-at0/view?usp=sharing). 92 | - Download `split_zhou_OxfordFlowers.json` from [here](https://drive.google.com/file/d/1Pp0sRXzZFZq15zVOzKjKBu4A9i01nozT/view?usp=sharing). 93 | 94 | The directory structure should look like 95 | ``` 96 | oxford_flowers/ 97 | |–– cat_to_name.json 98 | |–– imagelabels.mat 99 | |–– jpg/ 100 | |–– split_zhou_OxfordFlowers.json 101 | ``` 102 | 103 | ### Food101 104 | - Download the dataset from https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/ and extract the file `food-101.tar.gz` under `$DATA`, resulting in a folder named `$DATA/food-101/`. 105 | - Download `split_zhou_Food101.json` from [here](https://drive.google.com/file/d/1QK0tGi096I0Ba6kggatX1ee6dJFIcEJl/view?usp=sharing). 106 | 107 | The directory structure should look like 108 | ``` 109 | food-101/ 110 | |–– images/ 111 | |–– license_agreement.txt 112 | |–– meta/ 113 | |–– README.txt 114 | |–– split_zhou_Food101.json 115 | ``` 116 | 117 | ### FGVCAircraft 118 | - Download the data from https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz. 119 | - Extract `fgvc-aircraft-2013b.tar.gz` and keep only `data/`. 120 | - Move `data/` to `$DATA` and rename the folder to `fgvc_aircraft/`. 121 | 122 | The directory structure should look like 123 | ``` 124 | fgvc_aircraft/ 125 | |–– images/ 126 | |–– ... # a bunch of .txt files 127 | ``` 128 | 129 | ### SUN397 130 | - Create a folder named `sun397/` under `$DATA`. 131 | - Download the images http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz. 132 | - Download the partitions https://vision.princeton.edu/projects/2010/SUN/download/Partitions.zip. 133 | - Extract these files under `$DATA/sun397/`. 134 | - Download `split_zhou_SUN397.json` from this [link](https://drive.google.com/file/d/1y2RD81BYuiyvebdN-JymPfyWYcd8_MUq/view?usp=sharing). 135 | 136 | The directory structure should look like 137 | ``` 138 | sun397/ 139 | |–– SUN397/ 140 | |–– split_zhou_SUN397.json 141 | |–– ... # a bunch of .txt files 142 | ``` 143 | 144 | ### DTD 145 | - Download the dataset from https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz and extract it to `$DATA`. This should lead to `$DATA/dtd/`. 146 | - Download `split_zhou_DescribableTextures.json` from this [link](https://drive.google.com/file/d/1u3_QfB467jqHgNXC00UIzbLZRQCg2S7x/view?usp=sharing). 147 | 148 | The directory structure should look like 149 | ``` 150 | dtd/ 151 | |–– images/ 152 | |–– imdb/ 153 | |–– labels/ 154 | |–– split_zhou_DescribableTextures.json 155 | ``` 156 | 157 | ### EuroSAT 158 | - Create a folder named `eurosat/` under `$DATA`. 159 | - Download the dataset from http://madm.dfki.de/files/sentinel/EuroSAT.zip and extract it to `$DATA/eurosat/`. 160 | - Download `split_zhou_EuroSAT.json` from [here](https://drive.google.com/file/d/1Ip7yaCWFi0eaOFUGga0lUdVi_DDQth1o/view?usp=sharing). 161 | 162 | The directory structure should look like 163 | ``` 164 | eurosat/ 165 | |–– 2750/ 166 | |–– split_zhou_EuroSAT.json 167 | ``` 168 | 169 | ### UCF101 170 | - Create a folder named `ucf101/` under `$DATA`. 171 | - Download the zip file `UCF-101-midframes.zip` from [here](https://drive.google.com/file/d/10Jqome3vtUA2keJkNanAiFpgbyC9Hc2O/view?usp=sharing) and extract it to `$DATA/ucf101/`. This zip file contains the extracted middle video frames. 172 | - Download `split_zhou_UCF101.json` from this [link](https://drive.google.com/file/d/1I0S0q91hJfsV9Gf4xDIjgDq4AqBNJb1y/view?usp=sharing). 173 | 174 | The directory structure should look like 175 | ``` 176 | ucf101/ 177 | |–– UCF-101-midframes/ 178 | |–– split_zhou_UCF101.json 179 | ``` 180 | 181 | ### ImageNetV2 182 | - Create a folder named `imagenetv2/` under `$DATA`. 183 | - Go to this github repo https://github.com/modestyachts/ImageNetV2. 184 | - Download the matched-frequency dataset from https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-matched-frequency.tar.gz and extract it to `$DATA/imagenetv2/`. 185 | - Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenetv2/`. 186 | 187 | The directory structure should look like 188 | ``` 189 | imagenetv2/ 190 | |–– imagenetv2-matched-frequency-format-val/ 191 | |–– classnames.txt 192 | ``` 193 | 194 | ### ImageNet-Sketch 195 | - Download the dataset from https://github.com/HaohanWang/ImageNet-Sketch. 196 | - Extract the dataset to `$DATA/imagenet-sketch`. 197 | - Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenet-sketch/`. 198 | 199 | The directory structure should look like 200 | ``` 201 | imagenet-sketch/ 202 | |–– images/ # contains 1,000 folders whose names have the format of n* 203 | |–– classnames.txt 204 | ``` 205 | 206 | ### ImageNet-A 207 | - Create a folder named `imagenet-adversarial/` under `$DATA`. 208 | - Download the dataset from https://github.com/hendrycks/natural-adv-examples and extract it to `$DATA/imagenet-adversarial/`. 209 | - Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenet-adversarial/`. 210 | 211 | The directory structure should look like 212 | ``` 213 | imagenet-adversarial/ 214 | |–– imagenet-a/ # contains 200 folders whose names have the format of n* 215 | |–– classnames.txt 216 | ``` 217 | 218 | ### ImageNet-R 219 | - Create a folder named `imagenet-rendition/` under `$DATA`. 220 | - Download the dataset from https://github.com/hendrycks/imagenet-r and extract it to `$DATA/imagenet-rendition/`. 221 | - Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenet-rendition/`. 222 | 223 | The directory structure should look like 224 | ``` 225 | imagenet-rendition/ 226 | |–– imagenet-r/ # contains 200 folders whose names have the format of n* 227 | |–– classnames.txt 228 | -------------------------------------------------------------------------------- /docs/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/docs/comparison.png -------------------------------------------------------------------------------- /docs/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangce01/DPE-CLIP/fecae5a50d7dc39fb944221be390201a56948416/docs/overview.png -------------------------------------------------------------------------------- /gpt3_prompts/CuPL_prompts_eurosat.json: -------------------------------------------------------------------------------- 1 | { 2 | "Annual Crop Land": [ 3 | "In an aerial satellite view of an annual crop land, one would see neatly arranged rows of crops, possibly of different types, growing in the fields.", 4 | "Aerial satellite view of an annual crop land would show various shades of green, depending on the type of crop planted.", 5 | "An annual crop land typically appears as a mosaic of different colors, depending on the type of crop that is growing.", 6 | "An aerial satellite view of an annual crop land would show fields of crops that are planted and harvested on a yearly basis.", 7 | "In an aerial satellite view of an annual crop land, one would see large expanses of land with crops growing in them.", 8 | "The aerial satellite view of an annual crop land would show the land being used to grow crops.", 9 | "An aerial satellite view of an Annual Crop Land would show a large expanse of land with crops growing in neat rows.", 10 | "The satellite view would show the vast amount of land that is used to grow crops on a yearly basis.", 11 | "The aerial satellite view of an annual crop land would show the farmers working in the field, the crops planted, and the irrigation system in place.", 12 | "Aerial satellite view of an annual crop land would show extensive green cover with well-defined patterns.", 13 | "A satellite photo of an annual crop land looks like a field that has been plowed and is ready for planting.", 14 | "A satellite photo of annual crop land looks like an aerial view of farmland with rows of crops.", 15 | "A satellite view of an annual crop land would show a large, open field with crops planted in evenly spaced rows.", 16 | "A satellite view of an annual crop land would show large fields with precise geometric patterns.", 17 | "An annual crop land would appear as a large, green space in the center of the image with fields delineated by lines or different shades of green.", 18 | "In a centered satellite view of an annual crop land, the land would appear as a large, green field with crops planted in straight rows.", 19 | "The image would show a flat, green field with rows of plants growing.", 20 | " CoverIn a centered satellite view of an annual crop land cover, the land would appear to be covered in a green blanket of crops.", 21 | "scapefields of green and brown separated by thin lines of trees; buildings in the distance.", 22 | "scapeAn annual crop landscape is typically a large, open field with crops planted in rows." 23 | ], 24 | "Forest": [ 25 | "In an aerial satellite view of a forest, the forest appears as a sea of trees with varying shades of green.", 26 | "A satellite view of a forest would show a large area of land covered in trees.", 27 | "An aerial satellite view of a forest would show a large expanse of trees, with varying shades of green.", 28 | "The forest would appear as a large green area with trees dotting the landscape.", 29 | "An aerial satellite view of a forest would show a large area of land with trees and vegetation.", 30 | "In a satellite view of a forest, the trees appear as a mass of green in the landscape.", 31 | "A satellite view of a forest would show a large area of trees, with various shades of green.", 32 | "The forest would appear as a large mass of green with individual trees distinguishable.", 33 | "Aerial satellite views of forests typically show a large expanse of trees, with varying degrees of green depending on the season.", 34 | "An aerial satellite view of a forest would show trees of various heights, densities, and colors spread out across the land.", 35 | "A satellite photo of a forest can look like a green and brown patchwork.", 36 | "In a centered satellite view of a forest, the trees would be evenly spaced apart and would form a dense canopy.", 37 | " fireIn the center of the image, there is a large forest fire.", 38 | "A satellite view of a forest would show a large area of trees, shrubs, and other plants.", 39 | " fireYou would see large, billowing clouds of smoke rising from the ground.", 40 | " fire at nightThe forest is ablaze with flames, and the smoke is billowing up into the night sky.", 41 | " fire.", 42 | "In the center of the image is a lush, green forest.", 43 | "In this image, a satellite is hovering above a forest, providing a clear view from above.", 44 | "The view would show a large mass of green in the center, with various shades of green throughout.", 45 | "In a centered satellite view of a forest, the trees would appear as a sea of green, with the occasional brown or tan patch where the trees are thinner or have died." 46 | ], 47 | "Herbaceous Vegetation Land": [ 48 | " coverIn an aerial satellite view of a Herbaceous Vegetation Land cover, the vegetation appears as a green, continuous cover with little to no bare ground visible.", 49 | " CoverHerbaceous vegetation land cover would appear as a green, continuous carpet from an aerial satellite view.", 50 | "An aerial satellite view of a Herbaceous Vegetation Land would look like a large green field with patches of different colors.", 51 | " CoverThe aerial satellite view would show a large expanse of green vegetation with few trees present.", 52 | "scapeIn an aerial satellite view of a herbaceous vegetation landscape, the land would appear to be covered in a sea of green.", 53 | "scapeHerbaceous Vegetation Landscape would appear as a green sea with small islands of trees and bushes.", 54 | " coverIn an aerial satellite view of a Herbaceous Vegetation Land cover, the land would appear as a green and brown expanse with patches of bare ground.", 55 | " coverIn an aerial satellite view of a Herbaceous Vegetation Land cover, one would expect to see a mix of grasses, herbs, and other plants adapted to living in areas with low rainfall and nutrient levels.", 56 | " CoverHerbaceous vegetation land cover is characterized by dense, soft, leafy plants with short stems.", 57 | "scapeAn aerial satellite view of a herbaceous vegetation landscape would show a large expanse of green plants interspersed with areas of bare ground.", 58 | "scapeIn a herbaceous landscape, there is a mix of different plant life including trees, shrubs, and vines.", 59 | " Cover.", 60 | " CoverIn the center of the image there is a large green expanse with small patches of brown interspersed throughout.", 61 | " CoverI see a large, green field with some trees dotting the landscape.", 62 | " CoverIn this image, you would see a large amount of green, covering most of the land.", 63 | " CoverIn this view, herbs and other plants cover the ground evenly, with no large patches of bare ground visible.", 64 | " Cover typeIn a satellite view, herbaceous vegetation would appear as a green, continuous cover with no tree canopy.", 65 | " CoverIn a centered satellite view of a Herbaceous Vegetation Land Cover, the vegetation appears as a green and brown hued patchwork.", 66 | " Cover.", 67 | " CoverIn a centered satellite view of herbaceous vegetation land cover, the ground is covered in green plants with thin stems and leaves." 68 | ], 69 | "Highway or Road": [ 70 | "An aerial satellite view of a highway or road would show a long, straight path with vehicles traveling along it.", 71 | "A satellite view of a highway or road would show a long, straight path with lines on either side.", 72 | "The view would show a long, straight road with cars or trucks traveling on it.", 73 | "This would be a view from space of a stretch of highway or road.", 74 | "On an aerial satellite view of a highway or road, you would see a long, straight path with two yellow lines running down the middle.", 75 | "You would see a long, straight ribbon of pavement cutting through a landscape of green trees and brown earth.", 76 | "An aerial satellite view of a highway or road would show a long, straight path with cars or other vehicles traveling along it.", 77 | "An aerial satellite view of a Highway or Road typically includes a long, straight stretch of pavement with two yellow lines down the middle, flanked by trees, green grass, and possibly some buildings or other structures in the distance.", 78 | "The aerial satellite view of a highway or road would show a long, straight, and typically flat stretch of pavement with two yellow lines running down the middle.", 79 | "A satellite view of a highway or road typically shows a long, straight path with two lane markings.", 80 | "In a centered satellite view of a highway or road, the viewer would see a long, straight path cutting through a landscape.", 81 | "A Highway or Road would appear as a long straight or winding line cutting through a landscape.", 82 | "In the center of the image is a long, straight highway flanked by trees on either side.", 83 | "A highway or road is typically a long and thin strip of asphalt that runs through a landscape.", 84 | "In the center of the satellite view is a highway or road.", 85 | ":A highway or road is typically a long, straight stretch of pavement with two lines down the middle demarcating lanes of traffic traveling in opposite directions.", 86 | "What you would see is a long, straight road stretching out ahead of you.", 87 | "In the center of the image is a long, straight highway.", 88 | "A highway or road is typically a long, narrow strip of pavement with lines down the middle indicating lanes of traffic." 89 | ], 90 | "Industrial Buildings": [ 91 | "An aerial satellite view of an industrial building would show a large, rectangular building with a flat roof.", 92 | "An aerial satellite view of an industrial building would show a large rectangular building with a flat roof.", 93 | "The aerial satellite view of an industrial buildings would show a series of large rectangular buildings with flat roofs.", 94 | "An aerial satellite view of an industrial building would show a large rectangular or square building with a flat roof.", 95 | "An aerial satellite view of an Industrial building would show the building(s) surrounded by a parking lot, with a few trees or other landscaping.", 96 | "The view would show large buildings with parking lots and shipping containers.", 97 | "An aerial satellite view of an industrial building might show a large factory with smokestacks, towers, and other buildings nearby.", 98 | "Industrial buildings are typically large, single-story structures with small windows and few exterior embellishments.", 99 | "The satellite view of an industrial building would show a large structure with a flat roof and numerous exhaust vents.", 100 | "An aerial satellite view of an industrial building would show a structure with a large number of floors, typically made of concrete or metal.", 101 | "In the center of the satellite view is a large, square industrial building.", 102 | "A satellite view of an industrial building would show a cluster of large buildings with various chimneys and smokestacks.", 103 | " and Patterns:In this image, we see a satellite view of an industrial area with a number of buildings and patterns.", 104 | "In the center of the image is a cluster of large industrial buildings, surrounded by a parking lot.", 105 | " ( factories , warehouses , grain silos , etc.", 106 | ":An industrial building is typically a large structure that houses machinery, equipment, or inventory.", 107 | "/ComplexAn industrial buildings/complex would appear as a group of buildings with large parking areas and open spaces between the structures.", 108 | "An industrial building is a structure where workers manufacturing goods or performing services.", 109 | "\"In the center of the satellite view is a cluster of large industrial buildings surrounded by a parking lot." 110 | ], 111 | "Pasture Land": [ 112 | "formFrom an aerial perspective, a pasture landform would likely appear as a large, flat expanse of land with some scattered trees or bushes.", 113 | "Pasture land is typically composed of areas of grassland maintained for the purpose of grazing animals.", 114 | "The pasture land would appear as a large, flat green area with some trees dotting the landscape.", 115 | "A pasture land is an area of land used primarily for grazing purposes.", 116 | "An aerial satellite view of a pasture land would be a flat area of land with green grass and some trees.", 117 | "The pasture land would be a large, open field with grassy.", 118 | "scapeAn aerial satellite view of a pasture landscape would show a large open area of grassland with grazing animals dotting the landscape.", 119 | "The aerial satellite view of a pasture land would show a large, flat piece of land with small patches of trees or bushes.", 120 | "In an aerial satellite view of a pasture land, one would see vast expanses of open green fields broken up by the occasional fence line.", 121 | "The view would look like a big green field with some trees.", 122 | "scapeThe view would be of a large, green field with patches of brown and white.", 123 | "In the center of the image is a large, green pasture.", 124 | "scapeIn the center of the image is a large, green pasture with trees and shrubs scattered throughout.", 125 | " Rolling hills with patches of trees and grazing animals.", 126 | " with some grazing cattle on itA pastureland is a flat, green expanse of land where grass grows and cattle graze.", 127 | "scapeIn the center of the image is a large, green pasture.", 128 | "scapeA pasture landscape is a flat, grassy area where livestock can graze.", 129 | "scaIn the center of the pasture is a clover-shaped field with a barn in the middle.", 130 | "The pasture land is a large, flat area of land with green grass.", 131 | "scapeA pasture landscape would appear as a large, open field with grassy vegetation." 132 | ], 133 | "Permanent Crop Land": [ 134 | "Aerial satellite view of a Permanent Crop Land would show an area of land that is used for growing crops on a permanent basis.", 135 | "A satellite view of a Permanent Crop Land would show fields of crops with well-defined boundaries, irrigation systems, and roads or paths between the fields.", 136 | "scapeIn a permanent crop landscape, crops are planted in rows and spaced evenly apart.", 137 | "scapeAerial satellite view of a permanent crop landscape would show a fields with various crops growing in them.", 138 | "A permanent crop land is an area of land that is used for growing crops on a long-term basis.", 139 | "The Google Maps image of Permanent Crop Land is a satellite view of a large expanse of farmland with crops planted in rows.", 140 | "Aerial satellite view of a Permanent Crop Land would show an area of land that is used for growing crops on a permanent basis.", 141 | "The permanent crop land would appear as a green oasis in the midst of a desert.", 142 | "In an aerial satellite view of a Permanent Crop Land, you would see an area of land that is used for growing crops.", 143 | "A satellite view of permanent crop land would show neatly arranged rows of plants or trees.", 144 | "A satellite photo of a Permanent Crop Land usually looks like a large green area with some small dark spots.", 145 | "A satellite photo of permanent crop land looks like a field with crops growing in it.", 146 | " use typeA satellite view of permanent crop land would show an even distribution of crops throughout the field.", 147 | "scapeI see a huge, open field with crops that go on for as far as the eye can see.", 148 | "In the center of the frame is a large, green field with crops growing in neat rows.", 149 | "In a centered satellite view of a Permanent Crop Land, the land would appear to be a large, continuous stretch of green.", 150 | "scape\nA satellite view of a permanent crop landscape would show an orderly arrangement of crops, with fields of different crops spaced evenly apart.", 151 | "lord Tenant LawIn the center of the image is a large farm with fields of crops.", 152 | " within the collection areaIn the center of the image is a large green field with a crop planted in it.", 153 | " area with some roadsIn the center of the image is a large green area with criss-crossing roads.", 154 | " Cover in the SummerIn the center of the image is a lush, green field with trees lining the borders." 155 | ], 156 | "Residential Buildings": [ 157 | "Aerial satellite view of a Residential Buildings would typically show a series of single family homes or multi-family dwellings arranged in a grid pattern.", 158 | "An aerial satellite view of a residential building would show the building from above, with the surrounding area also visible.", 159 | "The aerial satellite view of a Residential Building would show the different types of buildings that people live in.", 160 | "Aerial satellite view of a Residential Buildings would show a group of buildings that are close together and are typically used for housing.", 161 | "An aerial satellite view of a residential building would show a bird's eye view of the building and surrounding area.", 162 | "Satellite view of a neighborhood with houses and trees.", 163 | "An aerial satellite view of residential buildings would show a grid-like pattern of streets and houses.", 164 | "Aerial satellites provide high-resolution views of residential buildings and their surrounding areas.", 165 | "Residential buildings are typically arranged in a grid pattern, with streets running in straight lines between them.", 166 | "In an aerial satellite view of a residential area, one would see many houses with their surrounding yards.", 167 | "In a centered satellite view of a Residential Buildings, one would see a cluster of buildings with perhaps a few streets in between them.", 168 | " projectA group of medium to high-rise buildings are arranged in a symmetrical formation around a large, central park.", 169 | " surrounded by trees.", 170 | " in a large cityYou would see a lot of tall buildings that are close together.", 171 | "A satellite view of a centered Residential Buildings would show a cluster of buildings surrounded by a green space.", 172 | "Residential Buildings would be arranged in a neat grid pattern with evenly spaced roads running through them.", 173 | "The image would be a bird's eye view of a group of houses with neatly manicured lawns.", 174 | " in a small cityIn a small city, there are typically many residential buildings that are spread out and not centrally located." 175 | ], 176 | "River": [ 177 | "An aerial view of a river would show the meandering path of the water as it flows through the landscape.", 178 | "The satellite view of a river would show the winding path of the river through the landscape.", 179 | "From an aerial satellite view, a river would look like a thin line of water winding through the landscape.", 180 | "The images below show an aerial satellite view of a river.", 181 | "The river is winding through the landscape, with trees and other vegetation lining its banks.", 182 | "A satellite view of a river would show a winding line of water cutting through a landscape.", 183 | "An aerial satellite view of a river might show the winding path of the river through a green landscape.", 184 | "An aerial satellite view of a river would show the river snaking through the landscape.", 185 | "A satellite view of a river would show a long, thin line of water winding through a landscape.", 186 | "The satellite view of a river would show the long, winding path of the river as it flows through the landscape.", 187 | "A satellite photo of a river looks like a thin line of water winding through a landscape.", 188 | " DeltaIn the center of the image is a large river delta, where the river has split into many smaller branches that flow into the ocean.", 189 | " with an island in the middleA blue river flows from the top of the frame to the bottom, passing through a small green island in the center.", 190 | "A river is a body of water that flows from a high point, such as a mountain, to a lower point, typically a lake or the ocean.", 191 | " and its Valley.", 192 | "A long, thin blue line winding its way through a green landscape.", 193 | "The cells of a river are close together and packed tightly.", 194 | "A river is a natural flowing watercourse, usually freshwater, flowing towards an ocean, sea, lake or another river.", 195 | " delta in the AmazonA satellite view of aRiver delta in the Amazon would show a small body of water surrounded by land on all sides.", 196 | " DeltaA river delta is a landform created at the mouth of a river where it meets a body of water, typically an ocean, sea, or lake.", 197 | "If one were to take a satellite view of a river from directly above, they would see a long, winding snake-like object made up of different shades of blue." 198 | ], 199 | "Sea or Lake": [ 200 | "The colors in a satellite view of a sea or lake will depend on the type of sensor used and the time of day.", 201 | "An aerial satellite view of a sea or lake would show the large body of water with its surrounding landmass.", 202 | "The view would show a large body of water with shades of blue.", 203 | "This could describe the view from an aerial satellite of any sea or lake, but let's imagine a view of the Puget Sound in the Pacific Northwest.", 204 | "An aerial satellite view of a sea or lake would show a large body of water with ripples or waves on the surface.", 205 | "The satellite view of a sea or lake would be a bird's eye view of the water.", 206 | "An aerial satellite view of a sea or lake might show the different shades of blue in the water, as well as any islands that might be in the area.", 207 | "The aerial satellite view of a sea or lake would show the body of water with its surrounding landmass.", 208 | "An aerial satellite view of a Sea or Lake would show a large body of water with different shades of blue.", 209 | "A view from an aerial satellite of a sea or lake would show a large body of water with various shades of blue.", 210 | "A satellite photo of a Sea or Lake looks like a large body of water with small waves on the surface.", 211 | "A blue body of water with ripples.", 212 | "Centered in the middle of the frame is a large body of water.", 213 | "In a centered satellite view of a sea or lake, the water appears as a large, dark blue expanse with small waves rippling across its surface.", 214 | "In the center of the image would be a large body of water, possibly with some land masses visible along the edges.", 215 | "A flat, round body of water with rippling waves would be visible in the center of the scene, while the surrounding land would be visible in varying shades of green and brown.", 216 | "You would see a large body of water with ripples on the surface.", 217 | "In a centered satellite view of a sea or lake, the water appears as a deep blue expanse with rippling waves.", 218 | "You would see a large body of water with ripples or waves on the surface." 219 | ] 220 | } -------------------------------------------------------------------------------- /main_dpe.py: -------------------------------------------------------------------------------- 1 | import random 2 | import argparse 3 | import wandb 4 | from tqdm import tqdm 5 | from datetime import datetime 6 | from copy import deepcopy 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | import operator 11 | import torch.nn as nn 12 | from info_nce import InfoNCE 13 | from sklearn.manifold import TSNE 14 | import matplotlib.pyplot as plt 15 | 16 | import clip 17 | from utils import * 18 | 19 | import open_clip 20 | 21 | 22 | def get_arguments(): 23 | """Get arguments of the test-time adaptation.""" 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--config', dest='config', required=True, help='settings of DPE on specific dataset in yaml format.') 26 | parser.add_argument('--wandb-log', dest='wandb', action='store_true', help='Whether you want to log to wandb. Include this flag to enable logging.') 27 | parser.add_argument('--datasets', dest='datasets', type=str, required=True, help="Datasets to process, separated by a slash (/). Example: I/A/V/R/S") 28 | parser.add_argument('--data-root', dest='data_root', type=str, default='../data/', help='Path to the datasets directory. Default is ../data/') 29 | parser.add_argument('--backbone', dest='backbone', type=str, choices=['RN50', 'ViT-B/16', 'SigLIP', 'OpenCLIP'], required=True, help='CLIP model backbone to use: RN50 or ViT-B/16.') 30 | parser.add_argument('--coop', dest='coop', action='store_true', help='Whether you want to use CoOp weights for initialization.') 31 | 32 | args = parser.parse_args() 33 | 34 | return args 35 | 36 | def InfoNCELoss(A, B): 37 | loss = InfoNCE(temperature=0.01, reduction='mean') 38 | return loss(A, B) 39 | 40 | def update_cache(cache, pred, features_loss, shot_capacity, include_prob_map=False): 41 | """Update cache with new features and loss, maintaining the maximum shot capacity.""" 42 | with torch.no_grad(): 43 | item = features_loss if not include_prob_map else features_loss[:2] + [features_loss[2]] 44 | if pred in cache: 45 | if len(cache[pred]) < shot_capacity: 46 | cache[pred].append(item) 47 | elif features_loss[1] < cache[pred][-1][1]: 48 | cache[pred][-1] = item 49 | cache[pred] = sorted(cache[pred], key=operator.itemgetter(1)) 50 | else: 51 | cache[pred] = [item] 52 | return 53 | 54 | def visualize_cache(cache, iter): 55 | # t-SNE visualization of cache features 56 | with torch.no_grad(): 57 | cache_features = [] 58 | cache_labels = [] 59 | for class_index in sorted(cache.keys()): 60 | for item in cache[class_index]: 61 | cache_features.append(item[0].reshape(-1)) 62 | cache_labels.append(class_index) 63 | cache_features = torch.stack(cache_features, dim=0) 64 | cache_labels = torch.Tensor(cache_labels).to(torch.int64) 65 | cache_features = F.normalize(cache_features, dim=1) 66 | cache_features = cache_features.cpu().numpy() 67 | cache_labels = cache_labels.cpu().numpy() 68 | tsne = TSNE(n_components=2) 69 | print(cache_features.shape) 70 | cache_features_fit = tsne.fit_transform(cache_features) 71 | 72 | # Assign different colors to different cache_labels 73 | colors = [ 74 | '#00429d', # Strong Blue 75 | '#93003a', # Deep Red 76 | '#007d34', # Vivid Green 77 | '#ff6800', # Vivid Orange 78 | '#e30022', # Bright Red 79 | '#a6bdd7', # Light Periwinkle 80 | '#ffcc00', # Vivid Yellow 81 | '#540d6e', # Dark Violet 82 | '#7f180d', # Dark Red 83 | '#00939c', # Cyan Process 84 | '#5f3c99', # Purplish Blue 85 | '#ff4a46', # Bright Red-Orange 86 | '#8f0075', # Strong Purple 87 | '#ff3c38', # Bright Red 88 | '#83a697', # Muted Cyan 89 | '#1e96be', # Strong Cyan 90 | '#d9e021', # Vivid Lime Green 91 | '#f18d05', # Rich Orange 92 | '#f6e120', # Bright Yellow 93 | '#8f2d56', # Strong Rose 94 | '#006837', # Dark Green 95 | '#e7298a', # Bright Pink 96 | '#ce1256', # Dark Pink 97 | '#01665e', # Dark Teal 98 | '#dfc27d', # Pale Gold 99 | '#35978f', # Muted Teal 100 | '#bf812d', # Mustard Brown 101 | '#543005', # Dark Brown 102 | '#8c510a', # Light Brown 103 | '#80cdc1', # Soft Turquoise 104 | ] 105 | colors_others = 'gray' 106 | figure, ax = plt.subplots(1, 1, dpi=600, figsize=(5, 5)) 107 | patch = ax.patch 108 | patch.set_color("#f5f5f5") 109 | ax.tick_params(axis='both', # Changes apply to both x and y axes 110 | which='both', # Apply changes to both major and minor ticks 111 | bottom=False, # No ticks along the bottom edge 112 | top=False, # No ticks along the top edge 113 | left=False, # No ticks along the left edge 114 | right=False, # No ticks along the right edge 115 | labelbottom=False, # No labels along the bottom edge 116 | labelleft=False) # No labels along the left edge 117 | plt.grid(color='w', zorder=0, linewidth=2) 118 | plt.gca().spines['bottom'].set_color('gray') 119 | plt.gca().spines['left'].set_color('gray') 120 | plt.gca().spines['top'].set_color('gray') 121 | plt.gca().spines['right'].set_color('gray') 122 | # In Food-101, we have 101 classes 123 | for i in range(101): 124 | if i < 30: 125 | plt.scatter(cache_features_fit[cache_labels == i, 0], cache_features_fit[cache_labels == i, 1], c=colors[i], s=15, marker='x', zorder=5) 126 | else: 127 | plt.scatter(cache_features_fit[cache_labels == i, 0], cache_features_fit[cache_labels == i, 1], c=colors_others, s=5, zorder=5) 128 | save_path = 'fig/cache_features_iter_{}.png'.format(iter) 129 | plt.savefig(save_path) 130 | plt.close() 131 | 132 | 133 | def cache_key_value(image_features, cache, alpha, beta, clip_weights): 134 | """Compute logits using positive/negative cache.""" 135 | with torch.no_grad(): 136 | cache_keys = [] 137 | cache_values = [] 138 | all_classes = [] 139 | for class_index in sorted(cache.keys()): 140 | num_items = len(cache[class_index]) 141 | # Compute the prototype of the class 142 | image_prototype = torch.zeros_like(image_features) 143 | for item in cache[class_index]: 144 | image_prototype += item[0] / num_items 145 | cache_keys.append(image_prototype) 146 | cache_values.append(class_index) 147 | all_classes.append(class_index) 148 | 149 | cache_keys = torch.cat(cache_keys, dim=0).permute(1, 0) 150 | cache_values = (F.one_hot(torch.Tensor(cache_values).to(torch.int64), num_classes=clip_weights.size(1))).cuda().half() 151 | 152 | return cache_keys, cache_values, all_classes 153 | 154 | def compute_cache_logits(image_features, cache_keys, cache_values, alpha, beta, clip_weights): 155 | affinity = image_features @ cache_keys 156 | cache_logits = ((-1) * (beta - beta * affinity)).exp() @ cache_values 157 | return alpha * cache_logits 158 | 159 | class TextResidue(nn.Module): 160 | def __init__(self, clip_weights): 161 | super(TextResidue, self).__init__() 162 | self.feat_dim, self.cate_num = clip_weights.shape 163 | self.residual = nn.Parameter(torch.zeros([self.feat_dim, self.cate_num]).half().cuda(), requires_grad=True) 164 | 165 | def forward(self, x): 166 | new_clip_weights = x.clone() + self.residual 167 | new_clip_weights = F.normalize(new_clip_weights, dim=0) 168 | return new_clip_weights 169 | 170 | def reset(self): 171 | self.residual = nn.Parameter(torch.zeros([self.feat_dim, self.cate_num]).half().cuda(), requires_grad=True) 172 | 173 | class PositiveCacheResidue(nn.Module): 174 | def __init__(self, pos_cache_keys): 175 | super(PositiveCacheResidue, self).__init__() 176 | self.feat_dim, self.cache_size = pos_cache_keys.shape 177 | self.residual = nn.Parameter(torch.zeros([self.feat_dim, self.cache_size]).half().cuda(), requires_grad=True) 178 | 179 | def forward(self, x): 180 | new_pos_cache_keys = x.clone() + self.residual 181 | new_pos_cache_keys = F.normalize(new_pos_cache_keys, dim=0) 182 | return new_pos_cache_keys 183 | 184 | class SmoothCrossEntropy(nn.Module): 185 | def __init__(self, alpha=0.0): 186 | super(SmoothCrossEntropy, self).__init__() 187 | self.alpha = alpha 188 | 189 | def forward(self, logits, labels): 190 | num_classes = logits.shape[-1] 191 | alpha_div_k = self.alpha / num_classes 192 | target_probs = F.one_hot(labels, num_classes=num_classes).float() * \ 193 | (1. - self.alpha) + alpha_div_k 194 | loss = -(target_probs * torch.log_softmax(logits, dim=-1)).sum(dim=-1) 195 | return loss.mean() 196 | 197 | def run_test_dpe(pos_cfg, lr_cfg, loader, clip_model, clip_weights, dataset_name): 198 | with torch.cuda.amp.autocast(): 199 | pos_cache, accuracies = {}, [] 200 | 201 | # Unpack all hyperparameters 202 | pos_enabled = pos_cfg['enabled'] 203 | 204 | if pos_enabled: 205 | pos_params = {k: pos_cfg[k] for k in ['shot_capacity', 'alpha', 'beta']} 206 | 207 | clip_weights_global = clip_weights.clone() 208 | num_avg = 0 209 | total = len(loader) 210 | 211 | losses = [] 212 | all_clip_weights = [] 213 | distances = [] 214 | 215 | # Test-time adaptation 216 | for i, (images, target) in enumerate(tqdm(loader, desc='Processed test images: ')): 217 | clip_weights_local = clip_weights_global.clone().detach() 218 | text_residue = TextResidue(clip_weights_local) 219 | new_clip_weights = text_residue(clip_weights_local) 220 | 221 | image_features_x, clip_logits, entropy, prob_map, pred = get_clip_logits(images, clip_model, new_clip_weights) 222 | target = target.cuda() 223 | 224 | if pos_enabled: 225 | entropy = get_entropy(entropy, clip_weights) 226 | update_cache(pos_cache, pred, [image_features_x, entropy], pos_params['shot_capacity']) 227 | pos_cache_keys, pos_cache_values, all_classes = cache_key_value(image_features_x, pos_cache, pos_params['alpha'], pos_params['beta'], clip_weights) 228 | pos_cache_residue = PositiveCacheResidue(pos_cache_keys) 229 | # if i != 0 and i % 1000 == 0: 230 | # visualize_cache(pos_cache, i) 231 | steps = 1 # Update step, set to 1 in default 232 | for j in range(steps): 233 | new_clip_weights = text_residue(clip_weights_local) 234 | final_logits = clip_logits.clone() 235 | if pos_enabled and pos_cache: 236 | new_pos_cache_keys = pos_cache_residue(pos_cache_keys) 237 | final_logits += compute_cache_logits(image_features_x, new_pos_cache_keys, pos_cache_values, pos_params['alpha'], pos_params['beta'], clip_weights) 238 | loss = avg_entropy(final_logits) 239 | # alignment loss 240 | image2text_loss = InfoNCELoss(new_pos_cache_keys.T, new_clip_weights[:, all_classes].T) 241 | loss += image2text_loss * lr_cfg['align'] 242 | else: 243 | loss = avg_entropy(final_logits) 244 | 245 | lr_text = lr_cfg['text'] 246 | lr_image = lr_cfg['image'] 247 | if pos_enabled and pos_cache: 248 | optimizer = torch.optim.AdamW([ 249 | {'params': text_residue.parameters(), 'lr': lr_text, 'eps': 1e-3, 'weight_decay': 1e-1}, 250 | {'params': pos_cache_residue.parameters(), 'lr': lr_image, 'eps': 1e-3, 'weight_decay': 1e-1} 251 | ]) 252 | else: 253 | optimizer = torch.optim.AdamW([ 254 | {'params': text_residue.parameters(), 'lr': lr_text, 'eps': 1e-3, 'weight_decay': 1e-1} 255 | ]) 256 | 257 | optimizer.zero_grad() 258 | if j == steps - 1: 259 | loss.backward() 260 | else: 261 | loss.backward(retain_graph=True) 262 | optimizer.step() 263 | 264 | # Actual inference 265 | text_residue.eval() 266 | if pos_enabled and pos_cache: 267 | pos_cache_residue.eval() 268 | with torch.no_grad(): 269 | new_clip_weights = text_residue(clip_weights_local) 270 | if dataset_name == 'A': 271 | image_features, clip_logits, _, _, _ = get_clip_logits(images, clip_model, new_clip_weights) 272 | else: 273 | image_features, clip_logits, _, _, _ = get_clip_logits(images[0], clip_model, new_clip_weights) 274 | final_logits = clip_logits.clone() 275 | if pos_enabled and pos_cache: 276 | new_pos_cache_keys = pos_cache_residue(pos_cache_keys) 277 | final_logits += compute_cache_logits(image_features, new_pos_cache_keys, pos_cache_values, pos_params['alpha'], pos_params['beta'], clip_weights) 278 | 279 | acc = cls_acc(final_logits, target.cuda()) 280 | accuracies.append(acc) 281 | wandb.log({"Averaged test accuracy": sum(accuracies)/len(accuracies)}, commit=True) 282 | 283 | loss = avg_entropy(final_logits) 284 | 285 | # Global update step: textual prototype evolution 286 | # lam = 0.99 287 | # clip_weights_global = sum([w * clip for w, clip in zip(weights, all_clip_weights)]) 288 | if get_entropy(loss, clip_weights) < 0.1: 289 | # Full Update 290 | # clip_weights_global = new_clip_weights 291 | # Cumalative Avg 292 | num_avg += 1 293 | clip_weights_global = clip_weights_global * (num_avg / (num_avg + 1)) + new_clip_weights * (1 / (num_avg + 1)) 294 | # clip_weights_global = clip_weights_global / clip_weights_global.norm(dim=0) 295 | # Exponential Avg 296 | # clip_weights_global = clip_weights_global * lam + new_clip_weights * (1 - lam) 297 | 298 | if i % 1000 == 0: 299 | print("---- DPE's test accuracy: {:.2f}. ----\n".format(sum(accuracies)/len(accuracies))) 300 | print("---- DPE's test accuracy: {:.2f}. ----\n".format(sum(accuracies)/len(accuracies))) 301 | 302 | 303 | return sum(accuracies)/len(accuracies) 304 | 305 | def main(): 306 | args = get_arguments() 307 | config_path = args.config 308 | 309 | # Initialize CLIP model 310 | if args.backbone == 'RN50' or args.backbone == 'ViT-B/16': 311 | clip_model, preprocess = clip.load(args.backbone) 312 | elif args.backbone == 'SigLIP': 313 | clip_model, preprocess = create_model_from_pretrained('hf-hub:timm/ViT-SO400M-14-SigLIP') 314 | elif args.backbone == 'OpenCLIP': 315 | clip_model, _, preprocess = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-L-14-laion2B-s32B-b82K') 316 | clip_model = clip_model.to('cuda') 317 | 318 | if args.wandb: 319 | date = datetime.now().strftime("%b%d_%H-%M-%S") 320 | group_name = f"{args.backbone}_{args.datasets}_{date}" 321 | 322 | # Run DPE on each dataset 323 | datasets = args.datasets.split('/') 324 | for dataset_name in datasets: 325 | # Set random seed 326 | random.seed(1) 327 | torch.manual_seed(1) 328 | print(f"Processing {dataset_name} dataset.") 329 | 330 | cfg = get_config_file(config_path, dataset_name) 331 | print("\nRunning dataset configurations:") 332 | print(cfg, "\n") 333 | print(args.coop) 334 | print(args.backbone) 335 | 336 | test_loader, classnames, template, cupl_path = build_test_data_loader(dataset_name, args.data_root, preprocess) 337 | clip_weights = clip_classifier(classnames, template, cupl_path, clip_model, args.coop, args.backbone) 338 | 339 | if args.wandb: 340 | run_name = f"{dataset_name}" 341 | run = wandb.init(project="ETTA-CLIP", config=cfg, group=group_name, name=run_name) 342 | 343 | acc = run_test_dpe(cfg['positive'], cfg['learning_rate'], test_loader, clip_model, clip_weights, dataset_name) 344 | 345 | if args.wandb: 346 | wandb.log({f"{dataset_name}": acc}) 347 | run.finish() 348 | 349 | if __name__ == "__main__": 350 | main() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | flake8==3.7.9 5 | yapf==0.29.0 6 | isort==4.3.21 7 | yacs 8 | gdown 9 | tb-nightly 10 | future 11 | scipy 12 | scikit-learn 13 | tqdm 14 | wandb==0.15.12 15 | chardet 16 | ftfy 17 | regex 18 | wilds==1.2.2 19 | tabulate -------------------------------------------------------------------------------- /scripts/run_cd_benchmark_rn50.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | CUDA_VISIBLE_DEVICES=0 python main_dpe.py --config configs \ 3 | --wandb-log \ 4 | --datasets caltech101/dtd/eurosat/fgvc/food101/oxford_flowers/oxford_pets/stanford_cars/sun397/ucf101 \ 5 | --backbone RN50 \ 6 | # --coop -------------------------------------------------------------------------------- /scripts/run_cd_benchmark_vit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | CUDA_VISIBLE_DEVICES=0 python main_dpe.py --config configs \ 3 | --wandb-log \ 4 | --datasets caltech101/dtd/eurosat/fgvc/food101/oxford_flowers/oxford_pets/stanford_cars/sun397/ucf101 \ 5 | --backbone ViT-B/16 \ 6 | # --coop -------------------------------------------------------------------------------- /scripts/run_ood_benchmark_rn50.sh: -------------------------------------------------------------------------------- 1 | # #!/bin/bash 2 | CUDA_VISIBLE_DEVICES=0 python main_dpe.py --config configs \ 3 | --wandb-log \ 4 | --datasets I/A/V/R/S \ 5 | --backbone RN50 \ 6 | # --coop -------------------------------------------------------------------------------- /scripts/run_ood_benchmark_vit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | CUDA_VISIBLE_DEVICES=0 python main_dpe.py --config configs \ 3 | --wandb-log \ 4 | --datasets I/A/V/R/S \ 5 | --backbone ViT-B/16 \ 6 | # --coop -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import torch 4 | import math 5 | import numpy as np 6 | import clip 7 | from datasets.imagenet import ImageNet 8 | from datasets import build_dataset 9 | from datasets.utils import build_data_loader, AugMixAugmenter 10 | import torchvision.transforms as transforms 11 | from PIL import Image 12 | import json 13 | 14 | import open_clip 15 | try: 16 | from torchvision.transforms import InterpolationMode 17 | BICUBIC = InterpolationMode.BICUBIC 18 | except ImportError: 19 | BICUBIC = Image.BICUBIC 20 | 21 | import torch 22 | import torch.nn as nn 23 | 24 | class TextEncoderWithPrompt(nn.Module): 25 | def __init__(self, clip_model): 26 | super().__init__() 27 | self.transformer = clip_model.transformer 28 | self.positional_embedding = clip_model.positional_embedding 29 | self.ln_final = clip_model.ln_final 30 | self.text_projection = clip_model.text_projection 31 | self.dtype = clip_model.dtype 32 | 33 | def forward(self, prompts, tokenized_prompts): 34 | x = prompts + self.positional_embedding.type(self.dtype) 35 | x = x.permute(1, 0, 2) # NLD -> LND 36 | x = self.transformer(x) 37 | x = x.permute(1, 0, 2) # LND -> NLD 38 | x = self.ln_final(x).type(self.dtype) 39 | 40 | # x.shape = [batch_size, n_ctx, transformer.width] 41 | # take features from the eot embedding (eot_token is the highest number in each sequence) 42 | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection 43 | 44 | return x 45 | 46 | def get_entropy(loss, clip_weights): 47 | max_entropy = math.log2(clip_weights.size(1)) 48 | return float(loss / max_entropy) 49 | 50 | 51 | def softmax_entropy(x): 52 | return -(x.softmax(1) * x.log_softmax(1)).sum(1) 53 | 54 | 55 | def avg_entropy(outputs): 56 | logits = outputs - outputs.logsumexp(dim=-1, keepdim=True) 57 | avg_logits = logits.logsumexp(dim=0) - np.log(logits.shape[0]) 58 | min_real = torch.finfo(avg_logits.dtype).min 59 | avg_logits = torch.clamp(avg_logits, min=min_real) 60 | return -(avg_logits * torch.exp(avg_logits)).sum(dim=-1) 61 | 62 | 63 | def cls_acc(output, target, topk=1): 64 | pred = output.topk(topk, 1, True, True)[1].t() 65 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 66 | acc = float(correct[: topk].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) 67 | acc = 100 * acc / target.shape[0] 68 | return acc 69 | 70 | 71 | def clip_classifier(classnames, template, cupl_path, clip_model, coop=False, backbone='RN50'): 72 | if coop: 73 | n_ctx = 4 74 | if backbone == 'RN50': 75 | print('Using CoOp weights (RN50) for initialization.') 76 | coop_path = '/home/ce/DiffTPT/coop_weights/rn50_ep50_16shots/nctx4_cscFalse_ctpend/seed1/prompt_learner/model.pth.tar-50' 77 | elif backbone == 'ViT-B/16': 78 | print('Using CoOp weights (ViT-B/16) for initialization.') 79 | coop_path = '/home/ce/DiffTPT/coop_weights/vit_b16_ep50_16shots/nctx4_cscFalse_ctpend/seed2/prompt_learner/model.pth.tar-50' 80 | ctx = torch.load(coop_path)['state_dict']['ctx'].unsqueeze(0).cuda() 81 | f = open(cupl_path) 82 | cupl = json.load(f) 83 | 84 | if backbone == 'OpenCLIP': 85 | tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-ViT-L-14-laion2B-s32B-b82K') 86 | with torch.no_grad(): 87 | clip_weights = [] 88 | 89 | for classname in classnames: 90 | # Tokenize the prompts 91 | classname = classname.replace('_', ' ') 92 | texts = [t.format(classname) for t in template] 93 | texts += cupl[classname] 94 | 95 | if coop: 96 | prompts = [f'a photo of a {classname}.'] 97 | tokenized_prompts = clip.tokenize(prompts).cuda() 98 | embedding = clip_model.token_embedding(tokenized_prompts).type(clip_model.visual.conv1.weight.dtype) 99 | 100 | prefix = embedding[:, :1, :] 101 | suffix = embedding[:, 1 + n_ctx :, :] # CLS, EOS 102 | 103 | # print(prefix.shape, ctx.shape, suffix.shape) 104 | 105 | prompts = torch.cat( 106 | [ 107 | prefix, # (n_cls, 1, dim) 108 | ctx, # (n_cls, n_ctx, dim) 109 | suffix, # (n_cls, *, dim) 110 | ], 111 | dim=-2, 112 | ) 113 | text_encoder_w_prompt = TextEncoderWithPrompt(clip_model) 114 | class_embedding = text_encoder_w_prompt(prompts, tokenized_prompts) 115 | class_embedding = class_embedding.squeeze() 116 | else: 117 | if backbone == 'RN50' or backbone == 'ViT-B/16': 118 | texts = clip.tokenize(texts).cuda() 119 | elif backbone == 'OpenCLIP': 120 | texts = tokenizer(texts).cuda() 121 | class_embeddings = clip_model.encode_text(texts) 122 | # prompt ensemble for ImageNet 123 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) 124 | class_embedding = class_embeddings.mean(dim=0) 125 | class_embedding /= class_embedding.norm() 126 | clip_weights.append(class_embedding) 127 | 128 | clip_weights = torch.stack(clip_weights, dim=1).cuda() 129 | return clip_weights 130 | 131 | 132 | def get_clip_logits(images, clip_model, clip_weights): 133 | # with torch.no_grad(): 134 | if isinstance(images, list): 135 | images = torch.cat(images, dim=0).cuda() 136 | else: 137 | images = images.cuda() 138 | 139 | # Change 3D tensor to 4D tensor 140 | if len(images.shape) == 3: 141 | images = images.unsqueeze(0) 142 | 143 | image_features = clip_model.encode_image(images) 144 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 145 | 146 | clip_logits = 100. * image_features @ clip_weights 147 | 148 | if image_features.size(0) > 1: 149 | batch_entropy = softmax_entropy(clip_logits) 150 | selected_idx = torch.argsort(batch_entropy, descending=False)[:int(batch_entropy.size()[0] * 0.1)] 151 | output = clip_logits[selected_idx] 152 | image_features = image_features[selected_idx].mean(0).unsqueeze(0) 153 | clip_logits = output.mean(0).unsqueeze(0) 154 | 155 | loss = avg_entropy(output) 156 | prob_map = output.softmax(1).mean(0).unsqueeze(0) 157 | pred = int(output.mean(0).unsqueeze(0).topk(1, 1, True, True)[1].t()) 158 | else: 159 | loss = softmax_entropy(clip_logits) 160 | prob_map = clip_logits.softmax(1) 161 | pred = int(clip_logits.topk(1, 1, True, True)[1].t()[0]) 162 | 163 | return image_features, clip_logits, loss, prob_map, pred 164 | 165 | 166 | def get_preprocess(): 167 | normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], 168 | std=[0.26862954, 0.26130258, 0.27577711]) 169 | # normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], 170 | # std=[0.5, 0.5, 0.5]) # For OpenCLIP 171 | base_transform = transforms.Compose([ 172 | transforms.Resize(224, interpolation=BICUBIC), 173 | transforms.CenterCrop(224)]) 174 | preprocess = transforms.Compose([ 175 | transforms.ToTensor(), 176 | normalize]) 177 | aug_preprocess = AugMixAugmenter(base_transform, preprocess, n_views=63, augmix=True) 178 | 179 | return aug_preprocess 180 | 181 | 182 | def get_config_file(config_path, dataset_name): 183 | if dataset_name == "I": 184 | config_name = "imagenet.yaml" 185 | elif dataset_name in ["A", "V", "R", "S"]: 186 | config_name = f"imagenet_{dataset_name.lower()}.yaml" 187 | else: 188 | config_name = f"{dataset_name}.yaml" 189 | 190 | config_file = os.path.join(config_path, config_name) 191 | 192 | with open(config_file, 'r') as file: 193 | cfg = yaml.load(file, Loader=yaml.SafeLoader) 194 | 195 | if not os.path.exists(config_file): 196 | raise FileNotFoundError(f"The configuration file {config_file} was not found.") 197 | 198 | return cfg 199 | 200 | 201 | def build_test_data_loader(dataset_name, root_path, preprocess): 202 | if dataset_name == 'I': 203 | preprocess = get_preprocess() 204 | dataset = ImageNet(root_path, preprocess) 205 | test_loader = torch.utils.data.DataLoader(dataset.test, batch_size=1, num_workers=8, shuffle=True, pin_memory=True) 206 | 207 | elif dataset_name in ['A','V','R','S']: 208 | preprocess = get_preprocess() 209 | dataset = build_dataset(f"imagenet-{dataset_name.lower()}", root_path) 210 | test_loader = build_data_loader(data_source=dataset.test, batch_size=1, is_train=False, tfm=preprocess, shuffle=True) 211 | 212 | elif dataset_name in ['caltech101','dtd','eurosat','fgvc','food101','oxford_flowers','oxford_pets','stanford_cars','sun397','ucf101']: 213 | # preprocess = get_preprocess() 214 | dataset = build_dataset(dataset_name, root_path) 215 | test_loader = build_data_loader(data_source=dataset.test, batch_size=1, is_train=False, tfm=preprocess, shuffle=True) 216 | 217 | else: 218 | raise "Dataset is not from the chosen list" 219 | 220 | return test_loader, dataset.classnames, dataset.template, dataset.cupl_path --------------------------------------------------------------------------------