├── CaFo.png ├── CaFo_arXiv.pdf ├── LICENSE ├── README.md ├── clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── clip.py ├── model.py └── simple_tokenizer.py ├── configs ├── caltech101 │ ├── 16shot.yaml │ ├── 1shot.yaml │ ├── 2shot.yaml │ ├── 4shot.yaml │ └── 8shot.yaml ├── cars │ ├── 16shot.yaml │ ├── 1shot.yaml │ ├── 2shot.yaml │ ├── 4shot.yaml │ └── 8shot.yaml ├── chat_caltech101 │ ├── 16shot.yaml │ ├── 1shot.yaml │ ├── 2shot.yaml │ ├── 4shot.yaml │ └── 8shot.yaml ├── dtd │ ├── 16shot.yaml │ ├── 1shot.yaml │ ├── 2shot.yaml │ ├── 4shot.yaml │ └── 8shot.yaml ├── eurosat │ ├── 16shot.yaml │ ├── 1shot.yaml │ ├── 2shot.yaml │ ├── 4shot.yaml │ └── 8shot.yaml ├── fgvc │ ├── 16shot.yaml │ ├── 1shot.yaml │ ├── 2shot.yaml │ ├── 4shot.yaml │ └── 8shot.yaml ├── food101 │ ├── 16shot.yaml │ ├── 1shot.yaml │ ├── 2shot.yaml │ ├── 4shot.yaml │ └── 8shot.yaml ├── imagenet │ ├── 16shot.yaml │ ├── 1shot.yaml │ ├── 2shot.yaml │ ├── 4shot.yaml │ └── 8shot.yaml ├── oxford_flowers │ ├── 16shot.yaml │ ├── 1shot.yaml │ ├── 2shot.yaml │ ├── 4shot.yaml │ └── 8shot.yaml ├── pets │ ├── 16shot.yaml │ ├── 1shot.yaml │ ├── 2shot.yaml │ ├── 4shot.yaml │ └── 8shot.yaml ├── sd_caltech101 │ ├── 16shot.yaml │ ├── 1shot.yaml │ ├── 2shot.yaml │ ├── 4shot.yaml │ └── 8shot.yaml ├── sun │ ├── 16shot.yaml │ ├── 1shot.yaml │ ├── 2shot.yaml │ ├── 4shot.yaml │ └── 8shot.yaml └── ucf │ ├── 16shot.yaml │ ├── 1shot.yaml │ ├── 2shot.yaml │ ├── 4shot.yaml │ └── 8shot.yaml ├── datasets ├── __init__.py ├── caltech101.py ├── dalle_caltech.py ├── dalle_cars.py ├── dalle_dtd.py ├── dalle_eurosat.py ├── dalle_fgvc.py ├── dalle_flowers.py ├── dalle_food.py ├── dalle_imagenet.py ├── dalle_pets.py ├── dalle_sun.py ├── dalle_ucf.py ├── dtd.py ├── eurosat.py ├── fgvc.py ├── food101.py ├── imagenet.py ├── oxford_flowers.py ├── oxford_pets.py ├── sd_caltech.py ├── stanford_cars.py ├── sun397.py ├── ucf101.py └── utils.py ├── dino ├── __pycache__ │ └── utils.cpython-36.pyc └── utils.py ├── exp.log ├── gpt_file ├── caltech_prompt.json ├── caltech_prompt_chat.json ├── dtd_prompt.json ├── eurosat_prompt.json ├── fgvc_prompt.json ├── food101_prompt.json ├── imagenet_prompt.json ├── oxford_flowers_prompt.json ├── oxford_pets_prompt.json ├── stanford_cars_prompt.json ├── sun397_prompt.json └── ucf101_prompt.json ├── main.py ├── main_imagenet.py ├── requirements.txt └── utils.py /CaFo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/CaFo/a805a2aefc6757fdbe10ac9a3165520ceb0e01cb/CaFo.png -------------------------------------------------------------------------------- /CaFo_arXiv.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/CaFo/a805a2aefc6757fdbe10ac9a3165520ceb0e01cb/CaFo_arXiv.pdf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Renrui Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Prompt, Generate, then Cache 2 | 3 | Official implementation of ['Prompt, Generate, then Cache: Cascade of Foundation Models makes Strong Few-shot Learners'](https://arxiv.org/pdf/2303.02151.pdf). 4 | 5 | The paper has been accepted by **CVPR 2023** 🔥. 6 | 7 | ## News 8 | * Please check our latest work ['Point-NN, Parameter is Not All You Need'](https://arxiv.org/pdf/2303.08134.pdf) with [code](https://github.com/ZrrSkywalker/Point-NN), accepted by **CVPR 2023** 🔥, which conducts 3D understanding without ant parameters or training. 9 | * CaFo cascaded with [ChatGPT](https://openai.com/blog/chatgpt) and [Stable Diffusion](https://github.com/CompVis/stable-diffusion) on Caltech-101 dataset has been released 📌. 10 | * The code of CaFo has been released. 11 | * The CaFo model is developed based on [Tip-Adapter](https://arxiv.org/pdf/2207.09519), accepted by **ECCV 2022** and [open-sourced](https://github.com/gaopengcuhk/Tip-Adapter). 12 | 13 | ## Introduction 14 | We propose **CaFo**, a **Ca**scade of **Fo**undation models that incorporates diverse prior knowledge of various pre-trianing paradigms for better few-shot learning, including CLIP, DINO, DALL-E, and GPT-3. Specifically, CaFo works by **`Prompt, Generate, then Cache'**. We leverage GPT-3 to prompt CLIP with rich linguistic semantics and generate synthetic images via DALL-E to expand the few-shot training data. Then, we introduce a learnable cache model to adaptively blend the predictions from CLIP and DINO. By such collaboration, CaFo can fully unleash the potential of different pre-training methods and unify them to perform *state-of-the-art* for few-shot classification. 15 | 16 |
17 | 18 |
19 | 20 | ## Requirements 21 | 22 | ### Installation 23 | Create a conda environment and install dependencies: 24 | ```bash 25 | git clone https://github.com/ZrrSkywalker/CaFo.git 26 | cd CaFo 27 | 28 | conda create -n cafo python=3.7 29 | conda activate cafo 30 | 31 | pip install -r requirements.txt 32 | 33 | # Install the according versions of torch and torchvision 34 | conda install pytorch torchvision cudatoolkit 35 | ``` 36 | 37 | ### Dataset 38 | Please follow [DATASET.md](https://github.com/gaopengcuhk/Tip-Adapter/blob/main/DATASET.md) to download official ImageNet and other 10 datasets. 39 | 40 | ### Foundation Models 41 | * The pre-tained weights of **CLIP** will be automatically downloaded by running. 42 | * The prompts produced by **GPT-3** have been stored at `gpt_file/`. 43 | * Please download **DINO's** pre-trained ResNet-50 from [here](https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth), and put it under `dino/`. 44 | * Please download **DALL-E's** generated images from [here](https://drive.google.com/drive/folders/1e249OgUFCmpfEDPsxCVR-nNb6Q1VaZVW?usp=sharing), and organize them with the official datasets like 45 | ``` 46 | $DATA/ 47 | |–– imagenet/ 48 | |–– caltech-101/ 49 | |–– oxford_pets/ 50 | |–– ... 51 | |–– dalle_imagenet/ 52 | |–– dalle_caltech-101/ 53 | |–– dalle_oxford_pets/ 54 | |–– ... 55 | |–– sd_caltech-101/ 56 | ``` 57 | * For Caltech-101 dataset, we also provide **Stable Diffusion's** images from [here](https://drive.google.com/drive/folders/1e249OgUFCmpfEDPsxCVR-nNb6Q1VaZVW?usp=sharing), and **ChatGPT's** prompts in `gpt_file/`. 58 | 59 | ## Get Started 60 | ### Configs 61 | The running configurations for different `[dataset]` with `[k]` shots can be modified in `configs/[dataset]/[k]shot.yaml`, including visual encoders and hyperparamters. We have provided the configurations for reproducing the results in the paper. You can edit the `search_scale`, `search_step`, `init_beta` and `init_alpha` for fine-grained tuning and better results. 62 | 63 | Note that the default `load_cache` and `load_pre_feat` are `False` for the first running, which will store the cache model and val/test features in `configs/dataset/`. For later running, they can be set as `True` for faster hyperparamters tuning. 64 | 65 | For Caltech101 dataset, the config of Stable Diffusion's images and ChatGPT's prompts is respectively in `configs/sd_caltech101` and `configs/chat_caltech101`. 66 | 67 | ### Running 68 | For 16-shot ImageNet dataset: 69 | ```bash 70 | CUDA_VISIBLE_DEVICES=0 python main_imagenet.py --config configs/imagenet/16shot.yaml 71 | ``` 72 | For other 10 datasets: 73 | ```bash 74 | CUDA_VISIBLE_DEVICES=0 python main.py --config configs/dataset/16shot.yaml 75 | ``` 76 | 77 | ### Numerical Results 78 | 79 | We provide CaFo's numerical results on 11 datasets from 1 to 16 shots at [exp_Cafo.log](https://github.com/ZrrSkywalker/CaFo/blob/main/exp.log). 80 | The results for Tip-Adapter and Tip-Adapter-F is at [exp_Tip.log](https://github.com/gaopengcuhk/Tip-Adapter/blob/main/exp.log). 81 | 82 | 83 | ## Acknowledgement 84 | This repo benefits from [Tip-Adapter](https://github.com/gaopengcuhk/Tip-Adapter), [CLIP](https://github.com/openai/CLIP), [DINO](https://github.com/facebookresearch/dino), [DALL-E](https://github.com/borisdayma/dalle-mini) and [CuPL](https://github.com/sarahpratt/CuPL). Thanks for their wonderful works. 85 | 86 | 87 | ## Citation 88 | ```bash 89 | @article{zhang2023prompt, 90 | title={Prompt, Generate, then Cache: Cascade of Foundation Models makes Strong Few-shot Learners}, 91 | author={Renrui Zhang and Xiangfei Hu and Bohao Li and Siyuan Huang and Hanqiu Deng and Hongsheng Li and Yu Qiao and Peng Gao}, 92 | journal={arXiv preprint arXiv:2303.02151}, 93 | year={2023} 94 | } 95 | ``` 96 | 97 | ## Contributors 98 | [Renrui Zhang](https://github.com/ZrrSkywalker), [Xiangfei Hu](https://github.com/hxf42), [Bohao Li](https://github.com/Bohao-Lee) 99 | 100 | ## Contact 101 | If you have any question about this project, please feel free to contact zhangrenrui@pjlab.org.cn and sjtuhxf@sjtu.edu.cn. 102 | -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/CaFo/a805a2aefc6757fdbe10ac9a3165520ceb0e01cb/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | try: 16 | from torchvision.transforms import InterpolationMode 17 | BICUBIC = InterpolationMode.BICUBIC 18 | except ImportError: 19 | BICUBIC = Image.BICUBIC 20 | 21 | 22 | if torch.__version__.split(".") < ["1", "7", "1"]: 23 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 24 | 25 | 26 | __all__ = ["available_models", "load", "tokenize"] 27 | _tokenizer = _Tokenizer() 28 | 29 | _MODELS = { 30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 34 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 35 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 36 | } 37 | 38 | 39 | def _download(url: str, root: str): 40 | os.makedirs(root, exist_ok=True) 41 | filename = os.path.basename(url) 42 | 43 | expected_sha256 = url.split("/")[-2] 44 | download_target = os.path.join(root, filename) 45 | 46 | if os.path.exists(download_target) and not os.path.isfile(download_target): 47 | raise RuntimeError(f"{download_target} exists and is not a regular file") 48 | 49 | if os.path.isfile(download_target): 50 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 51 | return download_target 52 | else: 53 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 54 | 55 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 56 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 57 | while True: 58 | buffer = source.read(8192) 59 | if not buffer: 60 | break 61 | 62 | output.write(buffer) 63 | loop.update(len(buffer)) 64 | 65 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 66 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 67 | 68 | return download_target 69 | 70 | 71 | def _convert_image_to_rgb(image): 72 | return image.convert("RGB") 73 | 74 | 75 | def _transform(n_px): 76 | return Compose([ 77 | Resize(n_px, interpolation=BICUBIC), 78 | CenterCrop(n_px), 79 | _convert_image_to_rgb, 80 | ToTensor(), 81 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 82 | ]) 83 | 84 | 85 | def available_models() -> List[str]: 86 | """Returns the names of available CLIP models""" 87 | return list(_MODELS.keys()) 88 | 89 | 90 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 91 | """Load a CLIP model 92 | 93 | Parameters 94 | ---------- 95 | name : str 96 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 97 | 98 | device : Union[str, torch.device] 99 | The device to put the loaded model 100 | 101 | jit : bool 102 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 103 | 104 | download_root: str 105 | path to download the model files; by default, it uses "~/.cache/clip" 106 | 107 | Returns 108 | ------- 109 | model : torch.nn.Module 110 | The CLIP model 111 | 112 | preprocess : Callable[[PIL.Image], torch.Tensor] 113 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 114 | """ 115 | if name in _MODELS: 116 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 117 | elif os.path.isfile(name): 118 | model_path = name 119 | else: 120 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 121 | 122 | try: 123 | # loading JIT archive 124 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 125 | state_dict = None 126 | except RuntimeError: 127 | # loading saved state dict 128 | if jit: 129 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 130 | jit = False 131 | state_dict = torch.load(model_path, map_location="cpu") 132 | 133 | if not jit: 134 | model = build_model(state_dict or model.state_dict()).to(device) 135 | if str(device) == "cpu": 136 | model.float() 137 | return model, _transform(model.visual.input_resolution) 138 | 139 | # patch the device names 140 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 141 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 142 | 143 | def patch_device(module): 144 | try: 145 | graphs = [module.graph] if hasattr(module, "graph") else [] 146 | except RuntimeError: 147 | graphs = [] 148 | 149 | if hasattr(module, "forward1"): 150 | graphs.append(module.forward1.graph) 151 | 152 | for graph in graphs: 153 | for node in graph.findAllNodes("prim::Constant"): 154 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 155 | node.copyAttributes(device_node) 156 | 157 | model.apply(patch_device) 158 | patch_device(model.encode_image) 159 | patch_device(model.encode_text) 160 | 161 | # patch dtype to float32 on CPU 162 | if str(device) == "cpu": 163 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 164 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 165 | float_node = float_input.node() 166 | 167 | def patch_float(module): 168 | try: 169 | graphs = [module.graph] if hasattr(module, "graph") else [] 170 | except RuntimeError: 171 | graphs = [] 172 | 173 | if hasattr(module, "forward1"): 174 | graphs.append(module.forward1.graph) 175 | 176 | for graph in graphs: 177 | for node in graph.findAllNodes("aten::to"): 178 | inputs = list(node.inputs()) 179 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 180 | if inputs[i].node()["value"] == 5: 181 | inputs[i].node().copyAttributes(float_node) 182 | 183 | model.apply(patch_float) 184 | patch_float(model.encode_image) 185 | patch_float(model.encode_text) 186 | 187 | model.float() 188 | 189 | return model, _transform(model.input_resolution.item()) 190 | 191 | 192 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 193 | """ 194 | Returns the tokenized representation of given input string(s) 195 | 196 | Parameters 197 | ---------- 198 | texts : Union[str, List[str]] 199 | An input string or a list of input strings to tokenize 200 | 201 | context_length : int 202 | The context length to use; all CLIP models use 77 as the context length 203 | 204 | truncate: bool 205 | Whether to truncate the text in case its encoding is longer than the context length 206 | 207 | Returns 208 | ------- 209 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 210 | """ 211 | if isinstance(texts, str): 212 | texts = [texts] 213 | 214 | sot_token = _tokenizer.encoder["<|startoftext|>"] 215 | eot_token = _tokenizer.encoder["<|endoftext|>"] 216 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 217 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 218 | 219 | for i, tokens in enumerate(all_tokens): 220 | if len(tokens) > context_length: 221 | if truncate: 222 | tokens = tokens[:context_length] 223 | tokens[-1] = eot_token 224 | else: 225 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 226 | result[i, :len(tokens)] = torch.tensor(tokens) 227 | 228 | return result 229 | -------------------------------------------------------------------------------- /clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | 20 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 24 | 25 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 26 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 27 | 28 | self.relu = nn.ReLU(inplace=True) 29 | self.downsample = None 30 | self.stride = stride 31 | 32 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 33 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 34 | self.downsample = nn.Sequential(OrderedDict([ 35 | ("-1", nn.AvgPool2d(stride)), 36 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 37 | ("1", nn.BatchNorm2d(planes * self.expansion)) 38 | ])) 39 | 40 | def forward(self, x: torch.Tensor): 41 | identity = x 42 | 43 | out = self.relu(self.bn1(self.conv1(x))) 44 | out = self.relu(self.bn2(self.conv2(out))) 45 | out = self.avgpool(out) 46 | out = self.bn3(self.conv3(out)) 47 | 48 | if self.downsample is not None: 49 | identity = self.downsample(x) 50 | 51 | out += identity 52 | out = self.relu(out) 53 | return out 54 | 55 | 56 | class AttentionPool2d(nn.Module): 57 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 58 | super().__init__() 59 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 60 | self.k_proj = nn.Linear(embed_dim, embed_dim) 61 | self.q_proj = nn.Linear(embed_dim, embed_dim) 62 | self.v_proj = nn.Linear(embed_dim, embed_dim) 63 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 64 | self.num_heads = num_heads 65 | 66 | def forward(self, x): 67 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 68 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 69 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 70 | x, _ = F.multi_head_attention_forward( 71 | query=x, key=x, value=x, 72 | embed_dim_to_check=x.shape[-1], 73 | num_heads=self.num_heads, 74 | q_proj_weight=self.q_proj.weight, 75 | k_proj_weight=self.k_proj.weight, 76 | v_proj_weight=self.v_proj.weight, 77 | in_proj_weight=None, 78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 79 | bias_k=None, 80 | bias_v=None, 81 | add_zero_attn=False, 82 | dropout_p=0, 83 | out_proj_weight=self.c_proj.weight, 84 | out_proj_bias=self.c_proj.bias, 85 | use_separate_proj_weight=True, 86 | training=self.training, 87 | need_weights=False 88 | ) 89 | 90 | return x[0] 91 | 92 | 93 | class ModifiedResNet(nn.Module): 94 | """ 95 | A ResNet class that is similar to torchvision's but contains the following changes: 96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 98 | - The final pooling layer is a QKV attention instead of an average pool 99 | """ 100 | 101 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 102 | super().__init__() 103 | self.output_dim = output_dim 104 | self.input_resolution = input_resolution 105 | 106 | # the 3-layer stem 107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 108 | self.bn1 = nn.BatchNorm2d(width // 2) 109 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 110 | self.bn2 = nn.BatchNorm2d(width // 2) 111 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 112 | self.bn3 = nn.BatchNorm2d(width) 113 | self.avgpool = nn.AvgPool2d(2) 114 | self.relu = nn.ReLU(inplace=True) 115 | 116 | # residual layers 117 | self._inplanes = width # this is a *mutable* variable used during construction 118 | self.layer1 = self._make_layer(width, layers[0]) 119 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 120 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 121 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 122 | 123 | embed_dim = width * 32 # the ResNet feature dimension 124 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 125 | 126 | def _make_layer(self, planes, blocks, stride=1): 127 | layers = [Bottleneck(self._inplanes, planes, stride)] 128 | 129 | self._inplanes = planes * Bottleneck.expansion 130 | for _ in range(1, blocks): 131 | layers.append(Bottleneck(self._inplanes, planes)) 132 | 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | def stem(x): 137 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 138 | x = self.relu(bn(conv(x))) 139 | x = self.avgpool(x) 140 | return x 141 | 142 | x = x.type(self.conv1.weight.dtype) 143 | x = stem(x) 144 | x = self.layer1(x) 145 | x = self.layer2(x) 146 | x = self.layer3(x) 147 | x = self.layer4(x) 148 | x = self.attnpool(x) 149 | 150 | return x 151 | 152 | 153 | class LayerNorm(nn.LayerNorm): 154 | """Subclass torch's LayerNorm to handle fp16.""" 155 | 156 | def forward(self, x: torch.Tensor): 157 | orig_type = x.dtype 158 | ret = super().forward(x.type(torch.float32)) 159 | return ret.type(orig_type) 160 | 161 | 162 | class QuickGELU(nn.Module): 163 | def forward(self, x: torch.Tensor): 164 | return x * torch.sigmoid(1.702 * x) 165 | 166 | 167 | class ResidualAttentionBlock(nn.Module): 168 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 169 | super().__init__() 170 | 171 | self.attn = nn.MultiheadAttention(d_model, n_head) 172 | self.ln_1 = LayerNorm(d_model) 173 | self.mlp = nn.Sequential(OrderedDict([ 174 | ("c_fc", nn.Linear(d_model, d_model * 4)), 175 | ("gelu", QuickGELU()), 176 | ("c_proj", nn.Linear(d_model * 4, d_model)) 177 | ])) 178 | self.ln_2 = LayerNorm(d_model) 179 | self.attn_mask = attn_mask 180 | 181 | def attention(self, x: torch.Tensor): 182 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 183 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 184 | 185 | def forward(self, x: torch.Tensor): 186 | x = x + self.attention(self.ln_1(x)) 187 | x = x + self.mlp(self.ln_2(x)) 188 | return x 189 | 190 | 191 | class Transformer(nn.Module): 192 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 193 | super().__init__() 194 | self.width = width 195 | self.layers = layers 196 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 197 | 198 | def forward(self, x: torch.Tensor): 199 | return self.resblocks(x) 200 | 201 | 202 | class VisionTransformer(nn.Module): 203 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 204 | super().__init__() 205 | self.input_resolution = input_resolution 206 | self.output_dim = output_dim 207 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 208 | 209 | scale = width ** -0.5 210 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 211 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 212 | self.ln_pre = LayerNorm(width) 213 | 214 | self.transformer = Transformer(width, layers, heads) 215 | 216 | self.ln_post = LayerNorm(width) 217 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 218 | 219 | def forward(self, x: torch.Tensor): 220 | x = self.conv1(x) # shape = [*, width, grid, grid] 221 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 222 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 223 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 224 | x = x + self.positional_embedding.to(x.dtype) 225 | x = self.ln_pre(x) 226 | 227 | x = x.permute(1, 0, 2) # NLD -> LND 228 | x = self.transformer(x) 229 | x = x.permute(1, 0, 2) # LND -> NLD 230 | 231 | x = self.ln_post(x[:, 0, :]) 232 | 233 | if self.proj is not None: 234 | x = x @ self.proj 235 | 236 | return x 237 | 238 | 239 | class CLIP(nn.Module): 240 | def __init__(self, 241 | embed_dim: int, 242 | # vision 243 | image_resolution: int, 244 | vision_layers: Union[Tuple[int, int, int, int], int], 245 | vision_width: int, 246 | vision_patch_size: int, 247 | # text 248 | context_length: int, 249 | vocab_size: int, 250 | transformer_width: int, 251 | transformer_heads: int, 252 | transformer_layers: int 253 | ): 254 | super().__init__() 255 | 256 | self.context_length = context_length 257 | 258 | if isinstance(vision_layers, (tuple, list)): 259 | vision_heads = vision_width * 32 // 64 260 | self.visual = ModifiedResNet( 261 | layers=vision_layers, 262 | output_dim=embed_dim, 263 | heads=vision_heads, 264 | input_resolution=image_resolution, 265 | width=vision_width 266 | ) 267 | else: 268 | vision_heads = vision_width // 64 269 | self.visual = VisionTransformer( 270 | input_resolution=image_resolution, 271 | patch_size=vision_patch_size, 272 | width=vision_width, 273 | layers=vision_layers, 274 | heads=vision_heads, 275 | output_dim=embed_dim 276 | ) 277 | 278 | self.transformer = Transformer( 279 | width=transformer_width, 280 | layers=transformer_layers, 281 | heads=transformer_heads, 282 | attn_mask=self.build_attention_mask() 283 | ) 284 | 285 | self.vocab_size = vocab_size 286 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 287 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 288 | self.ln_final = LayerNorm(transformer_width) 289 | 290 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 291 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 292 | 293 | self.initialize_parameters() 294 | 295 | def initialize_parameters(self): 296 | nn.init.normal_(self.token_embedding.weight, std=0.02) 297 | nn.init.normal_(self.positional_embedding, std=0.01) 298 | 299 | if isinstance(self.visual, ModifiedResNet): 300 | if self.visual.attnpool is not None: 301 | std = self.visual.attnpool.c_proj.in_features ** -0.5 302 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 303 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 304 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 305 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 306 | 307 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 308 | for name, param in resnet_block.named_parameters(): 309 | if name.endswith("bn3.weight"): 310 | nn.init.zeros_(param) 311 | 312 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 313 | attn_std = self.transformer.width ** -0.5 314 | fc_std = (2 * self.transformer.width) ** -0.5 315 | for block in self.transformer.resblocks: 316 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 317 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 318 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 319 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 320 | 321 | if self.text_projection is not None: 322 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 323 | 324 | def build_attention_mask(self): 325 | # lazily create causal attention mask, with full attention between the vision tokens 326 | # pytorch uses additive attention mask; fill with -inf 327 | mask = torch.empty(self.context_length, self.context_length) 328 | mask.fill_(float("-inf")) 329 | mask.triu_(1) # zero out the lower diagonal 330 | return mask 331 | 332 | @property 333 | def dtype(self): 334 | return self.visual.conv1.weight.dtype 335 | 336 | def encode_image(self, image): 337 | return self.visual(image.type(self.dtype)) 338 | 339 | def encode_text(self, text): 340 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 341 | 342 | x = x + self.positional_embedding.type(self.dtype) 343 | x = x.permute(1, 0, 2) # NLD -> LND 344 | x = self.transformer(x) 345 | x = x.permute(1, 0, 2) # LND -> NLD 346 | x = self.ln_final(x).type(self.dtype) 347 | 348 | # x.shape = [batch_size, n_ctx, transformer.width] 349 | # take features from the eot embedding (eot_token is the highest number in each sequence) 350 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 351 | 352 | return x 353 | 354 | def forward(self, image, text): 355 | image_features = self.encode_image(image) 356 | text_features = self.encode_text(text) 357 | 358 | # normalized features 359 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 360 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 361 | 362 | # cosine similarity as logits 363 | logit_scale = self.logit_scale.exp() 364 | logits_per_image = logit_scale * image_features @ text_features.t() 365 | logits_per_text = logits_per_image.t() 366 | 367 | # shape = [global_batch_size, global_batch_size] 368 | return logits_per_image, logits_per_text 369 | 370 | 371 | def convert_weights(model: nn.Module): 372 | """Convert applicable model parameters to fp16""" 373 | 374 | def _convert_weights_to_fp16(l): 375 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 376 | l.weight.data = l.weight.data.half() 377 | if l.bias is not None: 378 | l.bias.data = l.bias.data.half() 379 | 380 | if isinstance(l, nn.MultiheadAttention): 381 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 382 | tensor = getattr(l, attr) 383 | if tensor is not None: 384 | tensor.data = tensor.data.half() 385 | 386 | for name in ["text_projection", "proj"]: 387 | if hasattr(l, name): 388 | attr = getattr(l, name) 389 | if attr is not None: 390 | attr.data = attr.data.half() 391 | 392 | model.apply(_convert_weights_to_fp16) 393 | 394 | 395 | def build_model(state_dict: dict): 396 | vit = "visual.proj" in state_dict 397 | 398 | if vit: 399 | vision_width = state_dict["visual.conv1.weight"].shape[0] 400 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 401 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 402 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 403 | image_resolution = vision_patch_size * grid_size 404 | else: 405 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 406 | vision_layers = tuple(counts) 407 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 408 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 409 | vision_patch_size = None 410 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 411 | image_resolution = output_width * 32 412 | 413 | embed_dim = state_dict["text_projection"].shape[1] 414 | context_length = state_dict["positional_embedding"].shape[0] 415 | vocab_size = state_dict["token_embedding.weight"].shape[0] 416 | transformer_width = state_dict["ln_final.weight"].shape[0] 417 | transformer_heads = transformer_width // 64 418 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 419 | 420 | model = CLIP( 421 | embed_dim, 422 | image_resolution, vision_layers, vision_width, vision_patch_size, 423 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 424 | ) 425 | 426 | for key in ["input_resolution", "context_length", "vocab_size"]: 427 | if key in state_dict: 428 | del state_dict[key] 429 | 430 | convert_weights(model) 431 | model.load_state_dict(state_dict) 432 | return model.eval() 433 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /configs/caltech101/16shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [12, 5] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 1.3 22 | 23 | gpt3_prompt_file: './gpt_file/caltech_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'caltech101' 27 | shots: 16 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_caltech' 33 | dalle_shots: 1 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/caltech101/1shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [12, 5] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 0.5 22 | 23 | gpt3_prompt_file: './gpt_file/caltech_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'caltech101' 27 | shots: 1 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_caltech' 33 | dalle_shots: 8 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/caltech101/2shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | # load_cache: True 9 | # load_pre_feat: True 10 | 11 | 12 | # ------ Hyperparamters ------ 13 | search_hp: True 14 | # search_hp: False 15 | 16 | search_scale: [12, 5] 17 | search_step: [200, 20] 18 | 19 | init_beta: 1 20 | init_alpha: 0.8 21 | 22 | gpt3_prompt_file: './gpt_file/caltech_prompt.json' 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'caltech101' 26 | shots: 2 27 | clip_backbone: 'RN50' 28 | dino_backbone: 'resnet50' 29 | 30 | # ------ Dalle Dataset ----- 31 | dalle_dataset: 'dalle_caltech' 32 | dalle_shots: 16 33 | 34 | lr: 0.001 35 | augment_epoch: 10 36 | train_epoch: 20 37 | -------------------------------------------------------------------------------- /configs/caltech101/4shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [12, 5] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 1.1 22 | 23 | gpt3_prompt_file: './gpt_file/caltech_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'caltech101' 27 | shots: 4 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_caltech' 33 | dalle_shots: 2 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/caltech101/8shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [12, 5] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 1.1 22 | 23 | gpt3_prompt_file: './gpt_file/caltech_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'caltech101' 27 | shots: 8 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_caltech' 33 | dalle_shots: 4 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/cars/16shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | # load_cache: True 9 | # load_pre_feat: True 10 | 11 | 12 | # ------ Hyperparamters ------ 13 | search_hp: True 14 | # search_hp: False 15 | 16 | search_scale: [20, 10] 17 | search_step: [200, 20] 18 | 19 | init_beta: 1 20 | init_alpha: 0.6 21 | 22 | gpt3_prompt_file: './gpt_file/stanford_cars_prompt.json' 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'stanford_cars' 26 | shots: 16 27 | clip_backbone: 'RN50' 28 | dino_backbone: 'resnet50' 29 | 30 | # ------ Dalle Dataset ----- 31 | dalle_dataset: 'dalle_cars' 32 | dalle_shots: 1 33 | 34 | lr: 0.001 35 | augment_epoch: 10 36 | train_epoch: 200 37 | -------------------------------------------------------------------------------- /configs/cars/1shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [20, 10] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 0.4 22 | 23 | gpt3_prompt_file: './gpt_file/stanford_cars_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'stanford_cars' 27 | shots: 1 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_cars' 33 | dalle_shots: 16 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 100 38 | -------------------------------------------------------------------------------- /configs/cars/2shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [20, 10] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 0.4 22 | 23 | gpt3_prompt_file: './gpt_file/stanford_cars_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'stanford_cars' 27 | shots: 2 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_cars' 33 | dalle_shots: 16 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 200 38 | -------------------------------------------------------------------------------- /configs/cars/4shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [20, 10] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 0.8 22 | 23 | gpt3_prompt_file: './gpt_file/stanford_cars_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'stanford_cars' 27 | shots: 4 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_cars' 33 | dalle_shots: 16 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 400 38 | -------------------------------------------------------------------------------- /configs/cars/8shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [20, 10] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 0.5 22 | 23 | gpt3_prompt_file: './gpt_file/stanford_cars_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'stanford_cars' 27 | shots: 8 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_cars' 33 | dalle_shots: 1 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 200 38 | -------------------------------------------------------------------------------- /configs/chat_caltech101/16shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [12, 5] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 1.3 22 | 23 | gpt3_prompt_file: './gpt_file/caltech_prompt_chat.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'caltech101' 27 | shots: 16 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_caltech' 33 | dalle_shots: 4 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/chat_caltech101/1shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [12, 5] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 0.5 22 | 23 | gpt3_prompt_file: './gpt_file/caltech_prompt_chat.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'caltech101' 27 | shots: 1 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_caltech' 33 | dalle_shots: 8 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/chat_caltech101/2shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | # load_cache: True 9 | # load_pre_feat: True 10 | 11 | 12 | # ------ Hyperparamters ------ 13 | search_hp: True 14 | # search_hp: False 15 | 16 | search_scale: [12, 5] 17 | search_step: [200, 20] 18 | 19 | init_beta: 1 20 | init_alpha: 0.8 21 | 22 | gpt3_prompt_file: './gpt_file/caltech_prompt_chat.json' 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'caltech101' 26 | shots: 2 27 | clip_backbone: 'RN50' 28 | dino_backbone: 'resnet50' 29 | 30 | # ------ Dalle Dataset ----- 31 | dalle_dataset: 'dalle_caltech' 32 | dalle_shots: 16 33 | 34 | lr: 0.001 35 | augment_epoch: 10 36 | train_epoch: 20 37 | -------------------------------------------------------------------------------- /configs/chat_caltech101/4shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [12, 5] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 1.1 22 | 23 | gpt3_prompt_file: './gpt_file/caltech_prompt_chat.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'caltech101' 27 | shots: 4 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_caltech' 33 | dalle_shots: 4 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/chat_caltech101/8shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [12, 5] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 1.1 22 | 23 | gpt3_prompt_file: './gpt_file/caltech_prompt_chat.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'caltech101' 27 | shots: 8 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_caltech' 33 | dalle_shots: 1 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/dtd/16shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: True 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [13, 13] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 2 22 | 23 | gpt3_prompt_file: './gpt_file/dtd_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'dtd' 27 | shots: 16 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_dtd' 33 | dalle_shots: 1 34 | 35 | 36 | lr: 0.001 37 | augment_epoch: 10 38 | train_epoch: 20 39 | -------------------------------------------------------------------------------- /configs/dtd/1shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [13, 13] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 2 22 | 23 | gpt3_prompt_file: './gpt_file/dtd_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'dtd' 27 | shots: 1 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_dtd' 33 | dalle_shots: 1 34 | 35 | 36 | lr: 0.001 37 | augment_epoch: 10 38 | train_epoch: 20 39 | -------------------------------------------------------------------------------- /configs/dtd/2shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [13, 13] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 2 22 | 23 | gpt3_prompt_file: './gpt_file/dtd_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'dtd' 27 | shots: 2 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_dtd' 33 | dalle_shots: 1 34 | 35 | 36 | lr: 0.001 37 | augment_epoch: 10 38 | train_epoch: 20 39 | -------------------------------------------------------------------------------- /configs/dtd/4shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [13, 13] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 2 22 | 23 | gpt3_prompt_file: './gpt_file/dtd_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'dtd' 27 | shots: 4 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_dtd' 33 | dalle_shots: 1 34 | 35 | 36 | lr: 0.001 37 | augment_epoch: 10 38 | train_epoch: 20 39 | -------------------------------------------------------------------------------- /configs/dtd/8shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [13, 13] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 2 22 | 23 | gpt3_prompt_file: './gpt_file/dtd_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'dtd' 27 | shots: 8 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_dtd' 33 | dalle_shots: 1 34 | 35 | 36 | lr: 0.001 37 | augment_epoch: 10 38 | train_epoch: 20 39 | -------------------------------------------------------------------------------- /configs/eurosat/16shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [12, 10] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 2 22 | 23 | gpt3_prompt_file: './gpt_file/eurosat_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'eurosat' 27 | shots: 16 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_eurosat' 33 | dalle_shots: 8 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 100 38 | -------------------------------------------------------------------------------- /configs/eurosat/1shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | # load_cache: True 9 | # load_pre_feat: True 10 | 11 | 12 | # ------ Hyperparamters ------ 13 | search_hp: True 14 | # search_hp: False 15 | 16 | search_scale: [12, 10] 17 | search_step: [200, 20] 18 | 19 | init_beta: 1 20 | init_alpha: 3 21 | 22 | gpt3_prompt_file: './gpt_file/eurosat_prompt.json' 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'eurosat' 26 | shots: 1 27 | clip_backbone: 'RN50' 28 | dino_backbone: 'resnet50' 29 | 30 | # ------ Dalle Dataset ----- 31 | dalle_dataset: 'dalle_eurosat' 32 | dalle_shots: 4 33 | 34 | lr: 0.001 35 | augment_epoch: 10 36 | train_epoch: 20 37 | -------------------------------------------------------------------------------- /configs/eurosat/2shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [12, 10] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 1.5 22 | 23 | gpt3_prompt_file: './gpt_file/eurosat_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'eurosat' 27 | shots: 2 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_eurosat' 33 | dalle_shots: 8 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 100 38 | -------------------------------------------------------------------------------- /configs/eurosat/4shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [12, 10] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 2 22 | 23 | gpt3_prompt_file: './gpt_file/eurosat_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'eurosat' 27 | shots: 4 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_eurosat' 33 | dalle_shots: 4 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 100 38 | -------------------------------------------------------------------------------- /configs/eurosat/8shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [12, 10] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 2 22 | 23 | gpt3_prompt_file: './gpt_file/eurosat_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'eurosat' 27 | shots: 8 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_eurosat' 33 | dalle_shots: 8 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 100 38 | -------------------------------------------------------------------------------- /configs/fgvc/16shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | # load_cache: True 9 | # load_pre_feat: True 10 | 11 | 12 | # ------ Hyperparamters ------ 13 | search_hp: True 14 | # search_hp: False 15 | 16 | search_scale: [30, 30] 17 | search_step: [200, 20] 18 | 19 | init_beta: 1 20 | init_alpha: 1 21 | 22 | gpt3_prompt_file: './gpt_file/fgvc_prompt.json' 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'fgvc' 26 | shots: 16 27 | clip_backbone: 'RN50' 28 | dino_backbone: 'resnet50' 29 | 30 | # ------ Dalle Dataset ----- 31 | dalle_dataset: 'dalle_fgvc' 32 | dalle_shots: 1 33 | 34 | lr: 0.001 35 | augment_epoch: 10 36 | train_epoch: 100 37 | -------------------------------------------------------------------------------- /configs/fgvc/1shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [30, 30] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 1 22 | 23 | gpt3_prompt_file: './gpt_file/fgvc_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'fgvc' 27 | shots: 1 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_fgvc' 33 | dalle_shots: 8 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 100 38 | -------------------------------------------------------------------------------- /configs/fgvc/2shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | # load_cache: True 9 | # load_pre_feat: True 10 | 11 | 12 | # ------ Hyperparamters ------ 13 | search_hp: True 14 | # search_hp: False 15 | 16 | search_scale: [30, 30] 17 | search_step: [200, 20] 18 | 19 | init_beta: 1 20 | init_alpha: 0.8 21 | 22 | gpt3_prompt_file: './gpt_file/fgvc_prompt.json' 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'fgvc' 26 | shots: 2 27 | clip_backbone: 'RN50' 28 | dino_backbone: 'resnet50' 29 | 30 | # ------ Dalle Dataset ----- 31 | dalle_dataset: 'dalle_fgvc' 32 | dalle_shots: 4 33 | 34 | lr: 0.001 35 | augment_epoch: 10 36 | train_epoch: 100 37 | -------------------------------------------------------------------------------- /configs/fgvc/4shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [30, 30] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 0.9 22 | 23 | gpt3_prompt_file: './gpt_file/fgvc_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'fgvc' 27 | shots: 4 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_fgvc' 33 | dalle_shots: 2 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 100 38 | -------------------------------------------------------------------------------- /configs/fgvc/8shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [30, 30] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 1 22 | 23 | gpt3_prompt_file: './gpt_file/fgvc_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'fgvc' 27 | shots: 8 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_fgvc' 33 | dalle_shots: 1 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 100 38 | -------------------------------------------------------------------------------- /configs/food101/16shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [10, 10] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 0.22 22 | 23 | gpt3_prompt_file: './gpt_file/food101_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'food101' 27 | shots: 16 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_food' 33 | dalle_shots: 16 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 200 38 | -------------------------------------------------------------------------------- /configs/food101/1shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [20, 10] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 0.2 22 | 23 | gpt3_prompt_file: './gpt_file/food101_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'food101' 27 | shots: 1 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_food' 33 | dalle_shots: 4 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 200 38 | -------------------------------------------------------------------------------- /configs/food101/2shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [10, 10] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 0.2 22 | 23 | gpt3_prompt_file: './gpt_file/food101_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'food101' 27 | shots: 2 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_food' 33 | dalle_shots: 4 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 200 38 | -------------------------------------------------------------------------------- /configs/food101/4shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [10, 10] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 0.22 22 | 23 | gpt3_prompt_file: './gpt_file/food101_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'food101' 27 | shots: 4 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_food' 33 | dalle_shots: 4 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 200 38 | -------------------------------------------------------------------------------- /configs/food101/8shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [10, 10] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 0.22 22 | 23 | gpt3_prompt_file: './gpt_file/food101_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'food101' 27 | shots: 8 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_food' 33 | dalle_shots: 4 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 200 38 | -------------------------------------------------------------------------------- /configs/imagenet/16shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | 10 | # ------ Hyperparamters ------ 11 | search_hp: True 12 | # search_hp: False 13 | 14 | search_scale: [7, 3] 15 | search_step: [200, 20] 16 | 17 | init_beta: 1 18 | init_alpha: 0.6 19 | 20 | gpt3_prompt_file: './gpt_file/imagenet_prompt.json' 21 | 22 | # ------ Basic Config ------ 23 | dataset: 'ImageNet' 24 | shots: 16 25 | clip_backbone: 'RN50' # ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16'] 26 | dino_backbone: 'resnet50' 27 | 28 | # ------ Dalle Dataset ----- 29 | dalle_dataset: 'dalle_imagenet' 30 | dalle_shots: 2 31 | 32 | lr: 0.001 33 | augment_epoch: 1 34 | train_epoch: 20 35 | -------------------------------------------------------------------------------- /configs/imagenet/1shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [7, 3] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 0.3 22 | 23 | gpt3_prompt_file: './gpt_file/imagenet_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'ImageNet' 27 | shots: 1 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_imagenet' 33 | dalle_shots: 8 34 | 35 | lr: 0.001 36 | augment_epoch: 1 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/imagenet/2shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | 10 | # ------ Hyperparamters ------ 11 | search_hp: True 12 | # search_hp: False 13 | 14 | search_scale: [7, 3] 15 | search_step: [200, 20] 16 | 17 | init_beta: 1 18 | init_alpha: 0.3 19 | 20 | gpt3_prompt_file: './gpt_file/imagenet_prompt.json' 21 | 22 | 23 | # ------ Basic Config ------ 24 | dataset: 'ImageNet' 25 | shots: 2 26 | clip_backbone: 'RN50' 27 | dino_backbone: 'resnet50' 28 | 29 | # ------ Dalle Dataset ----- 30 | dalle_dataset: 'dalle_imagenet' 31 | dalle_shots: 2 32 | 33 | lr: 0.001 34 | augment_epoch: 1 35 | train_epoch: 20 36 | -------------------------------------------------------------------------------- /configs/imagenet/4shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # ------ Hyperparamters ------ 10 | search_hp: True 11 | # search_hp: False 12 | 13 | search_scale: [7, 3] 14 | search_step: [200, 20] 15 | 16 | init_beta: 1 17 | init_alpha: 0.4 18 | 19 | gpt3_prompt_file: './gpt_file/imagenet_prompt.json' 20 | 21 | # ------ Basic Config ------ 22 | dataset: 'ImageNet' 23 | shots: 4 24 | clip_backbone: 'RN50' 25 | dino_backbone: 'resnet50' 26 | 27 | # ------ Dalle Dataset ----- 28 | dalle_dataset: 'dalle_imagenet' 29 | dalle_shots: 8 30 | 31 | lr: 0.001 32 | augment_epoch: 1 33 | train_epoch: 20 34 | -------------------------------------------------------------------------------- /configs/imagenet/8shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | 10 | # ------ Hyperparamters ------ 11 | search_hp: True 12 | # search_hp: False 13 | 14 | search_scale: [7, 3] 15 | search_step: [200, 20] 16 | 17 | init_beta: 1 18 | init_alpha: 0.5 19 | 20 | gpt3_prompt_file: './gpt_file/imagenet_prompt.json' 21 | 22 | # ------ Basic Config ------ 23 | dataset: 'ImageNet' 24 | shots: 8 25 | clip_backbone: 'RN50' 26 | dino_backbone: 'resnet50' 27 | 28 | # ------ Dalle Dataset ----- 29 | dalle_dataset: 'dalle_imagenet' 30 | dalle_shots: 2 31 | 32 | lr: 0.001 33 | augment_epoch: 1 34 | train_epoch: 20 35 | -------------------------------------------------------------------------------- /configs/oxford_flowers/16shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [50, 50] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 4 22 | 23 | gpt3_prompt_file: './gpt_file/oxford_flowers_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'oxford_flowers' 27 | shots: 16 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_flowers' 33 | dalle_shots: 1 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/oxford_flowers/1shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [50, 50] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 1.2 22 | 23 | gpt3_prompt_file: './gpt_file/oxford_flowers_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'oxford_flowers' 27 | shots: 1 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_flowers' 33 | dalle_shots: 1 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/oxford_flowers/2shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [50, 50] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 0.7 22 | 23 | gpt3_prompt_file: './gpt_file/oxford_flowers_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'oxford_flowers' 27 | shots: 2 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_flowers' 33 | dalle_shots: 1 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/oxford_flowers/4shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [50, 50] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 2.2 22 | 23 | gpt3_prompt_file: './gpt_file/oxford_flowers_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'oxford_flowers' 27 | shots: 4 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_flowers' 33 | dalle_shots: 1 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/oxford_flowers/8shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [50, 50] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 3.7 22 | 23 | gpt3_prompt_file: './gpt_file/oxford_flowers_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'oxford_flowers' 27 | shots: 8 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_flowers' 33 | dalle_shots: 1 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/pets/16shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [7, 3] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 1.5 22 | 23 | gpt3_prompt_file: './gpt_file/oxford_pets_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'oxford_pets' 27 | shots: 16 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_pets' 33 | dalle_shots: 8 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/pets/1shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [7, 3] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 0.4 22 | 23 | gpt3_prompt_file: './gpt_file/oxford_pets_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'oxford_pets' 27 | shots: 1 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_pets' 33 | dalle_shots: 4 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/pets/2shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [7, 3] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 0.4 22 | 23 | gpt3_prompt_file: './gpt_file/oxford_pets_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'oxford_pets' 27 | shots: 2 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_pets' 33 | dalle_shots: 2 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/pets/4shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [7, 3] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 0.4 22 | 23 | gpt3_prompt_file: './gpt_file/oxford_pets_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'oxford_pets' 27 | shots: 4 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_pets' 33 | dalle_shots: 4 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/pets/8shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [7, 3] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 0.6 22 | 23 | gpt3_prompt_file: './gpt_file/oxford_pets_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'oxford_pets' 27 | shots: 8 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_pets' 33 | dalle_shots: 8 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/sd_caltech101/16shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [12, 5] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 1.3 22 | 23 | gpt3_prompt_file: './gpt_file/caltech_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'caltech101' 27 | shots: 16 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Stable Diffusion Dataset ----- 32 | dalle_dataset: 'sd_caltech' 33 | dalle_shots: 2 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/sd_caltech101/1shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [12, 5] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 0.5 22 | 23 | gpt3_prompt_file: './gpt_file/caltech_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'caltech101' 27 | shots: 1 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Stable Diffusion Dataset ----- 32 | dalle_dataset: 'sd_caltech' 33 | dalle_shots: 4 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/sd_caltech101/2shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | # load_cache: True 9 | # load_pre_feat: True 10 | 11 | 12 | # ------ Hyperparamters ------ 13 | search_hp: True 14 | # search_hp: False 15 | 16 | search_scale: [12, 5] 17 | search_step: [200, 20] 18 | 19 | init_beta: 1 20 | init_alpha: 0.8 21 | 22 | gpt3_prompt_file: './gpt_file/caltech_prompt.json' 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'caltech101' 26 | shots: 2 27 | clip_backbone: 'RN50' 28 | dino_backbone: 'resnet50' 29 | 30 | # ------ Stable Diffusion Dataset ----- 31 | dalle_dataset: 'sd_caltech' 32 | dalle_shots: 8 33 | 34 | lr: 0.001 35 | augment_epoch: 10 36 | train_epoch: 20 37 | -------------------------------------------------------------------------------- /configs/sd_caltech101/4shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [12, 5] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 1.1 22 | 23 | gpt3_prompt_file: './gpt_file/caltech_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'caltech101' 27 | shots: 4 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Stable Diffusion Dataset ----- 32 | dalle_dataset: 'sd_caltech' 33 | dalle_shots: 2 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/sd_caltech101/8shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [12, 5] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 1.1 22 | 23 | gpt3_prompt_file: './gpt_file/caltech_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'caltech101' 27 | shots: 8 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Stable Diffusion Dataset ----- 32 | dalle_dataset: 'sd_caltech' 33 | dalle_shots: 1 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/sun/16shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [12, 10] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 0.8 22 | 23 | gpt3_prompt_file: './gpt_file/sun397_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'sun397' 27 | shots: 16 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_sun' 33 | dalle_shots: 1 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/sun/1shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | # load_cache: True 9 | # load_pre_feat: True 10 | 11 | 12 | # ------ Hyperparamters ------ 13 | search_hp: True 14 | # search_hp: False 15 | 16 | search_scale: [12, 10] 17 | search_step: [200, 20] 18 | 19 | init_beta: 1 20 | init_alpha: 0.5 21 | 22 | gpt3_prompt_file: './gpt_file/sun397_prompt.json' 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'sun397' 26 | shots: 1 27 | clip_backbone: 'RN50' 28 | dino_backbone: 'resnet50' 29 | 30 | # ------ Dalle Dataset ----- 31 | dalle_dataset: 'dalle_sun' 32 | dalle_shots: 1 33 | 34 | lr: 0.001 35 | augment_epoch: 10 36 | train_epoch: 20 37 | -------------------------------------------------------------------------------- /configs/sun/2shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [12, 10] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 0.5 22 | 23 | gpt3_prompt_file: './gpt_file/sun397_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'sun397' 27 | shots: 2 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_sun' 33 | dalle_shots: 1 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/sun/4shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [12, 10] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 0.6 22 | 23 | gpt3_prompt_file: './gpt_file/sun397_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'sun397' 27 | shots: 4 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_sun' 33 | dalle_shots: 1 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/sun/8shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [12, 10] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 0.7 22 | 23 | gpt3_prompt_file: './gpt_file/sun397_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'sun397' 27 | shots: 8 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_sun' 33 | dalle_shots: 1 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/ucf/16shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [7, 3] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 2 22 | 23 | gpt3_prompt_file: './gpt_file/ucf101_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'ucf101' 27 | shots: 16 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_ucf' 33 | dalle_shots: 2 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 40 38 | -------------------------------------------------------------------------------- /configs/ucf/1shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [7, 3] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 1 22 | 23 | gpt3_prompt_file: './gpt_file/ucf101_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'ucf101' 27 | shots: 1 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_ucf' 33 | dalle_shots: 8 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/ucf/2shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [7, 3] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 1 22 | 23 | gpt3_prompt_file: './gpt_file/ucf101_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'ucf101' 27 | shots: 2 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_ucf' 33 | dalle_shots: 1 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/ucf/4shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [7, 3] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 1 22 | 23 | gpt3_prompt_file: './gpt_file/ucf101_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'ucf101' 27 | shots: 4 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_ucf' 33 | dalle_shots: 4 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 20 38 | -------------------------------------------------------------------------------- /configs/ucf/8shot.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [7, 3] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 1.5 22 | 23 | gpt3_prompt_file: './gpt_file/ucf101_prompt.json' 24 | 25 | # ------ Basic Config ------ 26 | dataset: 'ucf101' 27 | shots: 8 28 | clip_backbone: 'RN50' 29 | dino_backbone: 'resnet50' 30 | 31 | # ------ Dalle Dataset ----- 32 | dalle_dataset: 'dalle_ucf' 33 | dalle_shots: 4 34 | 35 | lr: 0.001 36 | augment_epoch: 10 37 | train_epoch: 40 38 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .oxford_pets import OxfordPets 2 | from .eurosat import EuroSAT 3 | from .ucf101 import UCF101 4 | from .sun397 import SUN397 5 | from .caltech101 import Caltech101 6 | from .dtd import DescribableTextures 7 | from .fgvc import FGVCAircraft 8 | from .food101 import Food101 9 | from .oxford_flowers import OxfordFlowers 10 | from .stanford_cars import StanfordCars 11 | from .dalle_imagenet import Dalle_Imagenet 12 | from .dalle_caltech import Dalle_Caltech 13 | from .dalle_flowers import Dalle_Flowers 14 | from .dalle_food import Dalle_Food 15 | from .dalle_cars import Dalle_Cars 16 | from .dalle_dtd import Dalle_DTD 17 | from .dalle_eurosat import Dalle_Eurosat 18 | from .dalle_pets import Dalle_Pets 19 | from .dalle_sun import Dalle_Sun 20 | from .dalle_ucf import Dalle_UCF 21 | from .dalle_fgvc import Dalle_fgvc 22 | from .sd_caltech import SD_Caltech 23 | 24 | dataset_list = { 25 | "oxford_pets": OxfordPets, 26 | "eurosat": EuroSAT, 27 | "ucf101": UCF101, 28 | "sun397": SUN397, 29 | "caltech101": Caltech101, 30 | "dtd": DescribableTextures, 31 | "fgvc": FGVCAircraft, 32 | "food101": Food101, 33 | "oxford_flowers": OxfordFlowers, 34 | "stanford_cars": StanfordCars, 35 | "dalle_imagenet": Dalle_Imagenet, 36 | "dalle_caltech": Dalle_Caltech, 37 | "dalle_flowers": Dalle_Flowers, 38 | "dalle_food": Dalle_Food, 39 | "dalle_cars": Dalle_Cars, 40 | "dalle_dtd": Dalle_DTD, 41 | "dalle_eurosat": Dalle_Eurosat, 42 | "dalle_pets": Dalle_Pets, 43 | "dalle_sun": Dalle_Sun, 44 | "dalle_ucf": Dalle_UCF, 45 | "dalle_fgvc": Dalle_fgvc, 46 | "sd_caltech": SD_Caltech 47 | } 48 | 49 | 50 | def build_dataset(dataset, root_path, shots): 51 | return dataset_list[dataset](root_path, shots) -------------------------------------------------------------------------------- /datasets/caltech101.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase 4 | from .oxford_pets import OxfordPets 5 | 6 | 7 | template = ['a photo of a {}.'] 8 | 9 | 10 | class Caltech101(DatasetBase): 11 | 12 | dataset_dir = 'caltech-101' 13 | 14 | def __init__(self, root, num_shots): 15 | self.dataset_dir = os.path.join(root, self.dataset_dir) 16 | self.image_dir = os.path.join(self.dataset_dir, '101_ObjectCategories') 17 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_Caltech101.json') 18 | 19 | self.template = template 20 | 21 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 22 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 23 | 24 | super().__init__(train_x=train, val=val, test=test) -------------------------------------------------------------------------------- /datasets/dalle_caltech.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 3 | from .oxford_pets import OxfordPets 4 | 5 | class Dalle_Caltech(DatasetBase): 6 | 7 | dataset_dir = 'dalle_caltech-101' 8 | 9 | def __init__(self, root, num_shots): 10 | # root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 11 | self.dataset_dir = os.path.join(root, self.dataset_dir) 12 | self.image_dir = os.path.join(self.dataset_dir, '101_ObjectCategories') 13 | self.split_path = os.path.join(self.dataset_dir, 'dalle_caltech.json') 14 | 15 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 16 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 17 | 18 | super().__init__(train_x=train, val=val, test=test) -------------------------------------------------------------------------------- /datasets/dalle_cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 3 | from .oxford_pets import OxfordPets 4 | 5 | class Dalle_Cars(DatasetBase): 6 | 7 | dataset_dir = 'dalle_stanford_cars' 8 | 9 | def __init__(self, root, num_shots): 10 | # root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 11 | self.dataset_dir = os.path.join(root, self.dataset_dir) 12 | self.image_dir = os.path.join(self.dataset_dir, 'cars_train') 13 | self.split_path = os.path.join(self.dataset_dir, 'dalle_cars.json') 14 | 15 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 16 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 17 | 18 | super().__init__(train_x=train, val=val, test=test) -------------------------------------------------------------------------------- /datasets/dalle_dtd.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 3 | from .oxford_pets import OxfordPets 4 | 5 | class Dalle_DTD(DatasetBase): 6 | 7 | dataset_dir = 'dalle_dtd' 8 | 9 | def __init__(self, root, num_shots): 10 | # root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 11 | self.dataset_dir = os.path.join(root, self.dataset_dir) 12 | self.image_dir = os.path.join(self.dataset_dir, 'images') 13 | self.split_path = os.path.join(self.dataset_dir, 'dalle_dtd.json') 14 | 15 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 16 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 17 | 18 | super().__init__(train_x=train, val=val, test=test) -------------------------------------------------------------------------------- /datasets/dalle_eurosat.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 3 | from .oxford_pets import OxfordPets 4 | 5 | class Dalle_Eurosat(DatasetBase): 6 | 7 | dataset_dir = 'dalle_eurosat' 8 | 9 | def __init__(self, root, num_shots): 10 | # root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 11 | self.dataset_dir = os.path.join(root, self.dataset_dir) 12 | self.image_dir = os.path.join(self.dataset_dir, '2750') 13 | self.split_path = os.path.join(self.dataset_dir, 'dalle_eurosat.json') 14 | 15 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 16 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 17 | 18 | super().__init__(train_x=train, val=val, test=test) -------------------------------------------------------------------------------- /datasets/dalle_fgvc.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 3 | from .oxford_pets import OxfordPets 4 | 5 | class Dalle_fgvc(DatasetBase): 6 | 7 | dataset_dir = 'dalle_fgvc_aircraft' 8 | 9 | def __init__(self, root, num_shots): 10 | # root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 11 | self.dataset_dir = os.path.join(root, self.dataset_dir) 12 | self.image_dir = os.path.join(self.dataset_dir, 'images') 13 | self.split_path = os.path.join(self.dataset_dir, 'dalle_fgvc.json') 14 | 15 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 16 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 17 | 18 | super().__init__(train_x=train, val=val, test=test) -------------------------------------------------------------------------------- /datasets/dalle_flowers.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 3 | from .oxford_pets import OxfordPets 4 | 5 | class Dalle_Flowers(DatasetBase): 6 | 7 | dataset_dir = 'dalle_oxford_flowers' 8 | 9 | def __init__(self, root, num_shots): 10 | # root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 11 | self.dataset_dir = os.path.join(root, self.dataset_dir) 12 | self.image_dir = os.path.join(self.dataset_dir, 'jpg') 13 | self.split_path = os.path.join(self.dataset_dir, 'dalle_flower.json') 14 | 15 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 16 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 17 | 18 | super().__init__(train_x=train, val=val, test=test) -------------------------------------------------------------------------------- /datasets/dalle_food.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 3 | from .oxford_pets import OxfordPets 4 | 5 | class Dalle_Food(DatasetBase): 6 | 7 | dataset_dir = 'dalle_food-101' 8 | 9 | def __init__(self, root, num_shots): 10 | # root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 11 | self.dataset_dir = os.path.join(root, self.dataset_dir) 12 | self.image_dir = os.path.join(self.dataset_dir, 'images') 13 | self.split_path = os.path.join(self.dataset_dir, 'dalle_food.json') 14 | 15 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 16 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 17 | 18 | super().__init__(train_x=train, val=val, test=test) -------------------------------------------------------------------------------- /datasets/dalle_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 3 | from .oxford_pets import OxfordPets 4 | 5 | class Dalle_Imagenet(DatasetBase): 6 | 7 | dataset_dir = 'dalle_imagenet' 8 | 9 | def __init__(self, root, num_shots): 10 | # root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 11 | self.dataset_dir = os.path.join(root, self.dataset_dir) 12 | self.image_dir = os.path.join(self.dataset_dir, 'data') 13 | self.split_path = os.path.join(self.dataset_dir, 'dalle_imagenet.json') 14 | 15 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 16 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 17 | 18 | super().__init__(train_x=train, val=val, test=test) -------------------------------------------------------------------------------- /datasets/dalle_pets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 3 | from .oxford_pets import OxfordPets 4 | 5 | class Dalle_Pets(DatasetBase): 6 | 7 | dataset_dir = 'dalle_oxford_pets' 8 | 9 | def __init__(self, root, num_shots): 10 | # root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 11 | self.dataset_dir = os.path.join(root, self.dataset_dir) 12 | self.image_dir = os.path.join(self.dataset_dir, 'images') 13 | self.split_path = os.path.join(self.dataset_dir, 'dalle_pet.json') 14 | 15 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 16 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 17 | 18 | super().__init__(train_x=train, val=val, test=test) -------------------------------------------------------------------------------- /datasets/dalle_sun.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 3 | from .oxford_pets import OxfordPets 4 | 5 | class Dalle_Sun(DatasetBase): 6 | 7 | dataset_dir = 'dalle_sun397' 8 | 9 | def __init__(self, root, num_shots): 10 | # root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 11 | self.dataset_dir = os.path.join(root, self.dataset_dir) 12 | self.image_dir = os.path.join(self.dataset_dir, 'SUN397') 13 | self.split_path = os.path.join(self.dataset_dir, 'dalle_sun.json') 14 | 15 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 16 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 17 | 18 | super().__init__(train_x=train, val=val, test=test) -------------------------------------------------------------------------------- /datasets/dalle_ucf.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 3 | from .oxford_pets import OxfordPets 4 | 5 | class Dalle_UCF(DatasetBase): 6 | 7 | dataset_dir = 'dalle_ucf101' 8 | 9 | def __init__(self, root, num_shots): 10 | # root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 11 | self.dataset_dir = os.path.join(root, self.dataset_dir) 12 | self.image_dir = os.path.join(self.dataset_dir, 'ucf101_midframes') 13 | self.split_path = os.path.join(self.dataset_dir, 'dalle_ucf.json') 14 | 15 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 16 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 17 | 18 | super().__init__(train_x=train, val=val, test=test) -------------------------------------------------------------------------------- /datasets/dtd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | from .utils import Datum, DatasetBase, listdir_nohidden 5 | from .oxford_pets import OxfordPets 6 | 7 | 8 | template = ['{} texture.'] 9 | 10 | 11 | class DescribableTextures(DatasetBase): 12 | 13 | dataset_dir = 'dtd' 14 | 15 | def __init__(self, root, num_shots): 16 | self.dataset_dir = os.path.join(root, self.dataset_dir) 17 | self.image_dir = os.path.join(self.dataset_dir, 'images') 18 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_DescribableTextures.json') 19 | 20 | self.template = template 21 | 22 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 23 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 24 | 25 | super().__init__(train_x=train, val=val, test=test) 26 | 27 | @staticmethod 28 | def read_and_split_data( 29 | image_dir, 30 | p_trn=0.5, 31 | p_val=0.2, 32 | ignored=[], 33 | new_cnames=None 34 | ): 35 | # The data are supposed to be organized into the following structure 36 | # ============= 37 | # images/ 38 | # dog/ 39 | # cat/ 40 | # horse/ 41 | # ============= 42 | categories = listdir_nohidden(image_dir) 43 | categories = [c for c in categories if c not in ignored] 44 | categories.sort() 45 | 46 | p_tst = 1 - p_trn - p_val 47 | print(f'Splitting into {p_trn:.0%} train, {p_val:.0%} val, and {p_tst:.0%} test') 48 | 49 | def _collate(ims, y, c): 50 | items = [] 51 | for im in ims: 52 | item = Datum( 53 | impath=im, 54 | label=y, # is already 0-based 55 | classname=c 56 | ) 57 | items.append(item) 58 | return items 59 | 60 | train, val, test = [], [], [] 61 | for label, category in enumerate(categories): 62 | category_dir = os.path.join(image_dir, category) 63 | images = listdir_nohidden(category_dir) 64 | images = [os.path.join(category_dir, im) for im in images] 65 | random.shuffle(images) 66 | n_total = len(images) 67 | n_train = round(n_total * p_trn) 68 | n_val = round(n_total * p_val) 69 | n_test = n_total - n_train - n_val 70 | assert n_train > 0 and n_val > 0 and n_test > 0 71 | 72 | if new_cnames is not None and category in new_cnames: 73 | category = new_cnames[category] 74 | 75 | train.extend(_collate(images[:n_train], label, category)) 76 | val.extend(_collate(images[n_train:n_train+n_val], label, category)) 77 | test.extend(_collate(images[n_train+n_val:], label, category)) 78 | 79 | return train, val, test 80 | -------------------------------------------------------------------------------- /datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 4 | from .oxford_pets import OxfordPets 5 | 6 | 7 | template = ['a centered satellite photo of {}.'] 8 | 9 | 10 | NEW_CNAMES = { 11 | 'AnnualCrop': 'Annual Crop Land', 12 | 'Forest': 'Forest', 13 | 'HerbaceousVegetation': 'Herbaceous Vegetation Land', 14 | 'Highway': 'Highway or Road', 15 | 'Industrial': 'Industrial Buildings', 16 | 'Pasture': 'Pasture Land', 17 | 'PermanentCrop': 'Permanent Crop Land', 18 | 'Residential': 'Residential Buildings', 19 | 'River': 'River', 20 | 'SeaLake': 'Sea or Lake' 21 | } 22 | 23 | 24 | class EuroSAT(DatasetBase): 25 | 26 | dataset_dir = 'eurosat' 27 | 28 | def __init__(self, root, num_shots): 29 | self.dataset_dir = os.path.join(root, self.dataset_dir) 30 | self.image_dir = os.path.join(self.dataset_dir, '2750') 31 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_EuroSAT.json') 32 | 33 | self.template = template 34 | 35 | train_u, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 36 | train = self.generate_fewshot_dataset(train_u, num_shots=num_shots) 37 | 38 | super().__init__(train_x=train, val=val, test=test ,train_u= train_u) 39 | 40 | def update_classname(self, dataset_old): 41 | dataset_new = [] 42 | for item_old in dataset_old: 43 | cname_old = item_old.classname 44 | cname_new = NEW_CLASSNAMES[cname_old] 45 | item_new = Datum( 46 | impath=item_old.impath, 47 | label=item_old.label, 48 | classname=cname_new 49 | ) 50 | dataset_new.append(item_new) 51 | return dataset_new 52 | -------------------------------------------------------------------------------- /datasets/fgvc.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 4 | 5 | 6 | template = ['a photo of a {}, a type of aircraft.'] 7 | 8 | 9 | class FGVCAircraft(DatasetBase): 10 | 11 | dataset_dir = 'fgvc_aircraft' 12 | 13 | def __init__(self, root, num_shots): 14 | 15 | self.dataset_dir = os.path.join(root, self.dataset_dir) 16 | self.image_dir = os.path.join(self.dataset_dir, 'images') 17 | 18 | self.template = template 19 | 20 | classnames = [] 21 | with open(os.path.join(self.dataset_dir, 'variants.txt'), 'r') as f: 22 | lines = f.readlines() 23 | for line in lines: 24 | classnames.append(line.strip()) 25 | cname2lab = {c: i for i, c in enumerate(classnames)} 26 | 27 | train = self.read_data(cname2lab, 'images_variant_train.txt') 28 | val = self.read_data(cname2lab, 'images_variant_val.txt') 29 | test = self.read_data(cname2lab, 'images_variant_test.txt') 30 | 31 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 32 | 33 | super().__init__(train_x=train, val=val, test=test) 34 | 35 | def read_data(self, cname2lab, split_file): 36 | filepath = os.path.join(self.dataset_dir, split_file) 37 | items = [] 38 | 39 | with open(filepath, 'r') as f: 40 | lines = f.readlines() 41 | for line in lines: 42 | line = line.strip().split(' ') 43 | imname = line[0] + '.jpg' 44 | classname = ' '.join(line[1:]) 45 | impath = os.path.join(self.image_dir, imname) 46 | label = cname2lab[classname] 47 | item = Datum( 48 | impath=impath, 49 | label=label, 50 | classname=classname 51 | ) 52 | items.append(item) 53 | 54 | return items -------------------------------------------------------------------------------- /datasets/food101.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 4 | from .oxford_pets import OxfordPets 5 | 6 | 7 | template = ['a photo of {}, a type of food.'] 8 | 9 | 10 | class Food101(DatasetBase): 11 | 12 | dataset_dir = 'food-101' 13 | 14 | def __init__(self, root, num_shots): 15 | self.dataset_dir = os.path.join(root, self.dataset_dir) 16 | self.image_dir = os.path.join(self.dataset_dir, 'images') 17 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_Food101.json') 18 | 19 | self.template = template 20 | 21 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 22 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 23 | 24 | super().__init__(train_x=train, val=val, test=test) -------------------------------------------------------------------------------- /datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import random 4 | from collections import defaultdict 5 | 6 | import torch 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | 10 | 11 | imagenet_classes = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", 12 | "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", 13 | "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", 14 | "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", 15 | "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", 16 | "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", 17 | "box turtle", "banded gecko", "green iguana", "Carolina anole", 18 | "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", 19 | "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", 20 | "American alligator", "triceratops", "worm snake", "ring-necked snake", 21 | "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", 22 | "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", 23 | "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", 24 | "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", 25 | "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", 26 | "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", 27 | "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", 28 | "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", 29 | "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", 30 | "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", 31 | "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", 32 | "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", 33 | "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", 34 | "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", 35 | "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", 36 | "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", 37 | "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", 38 | "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", 39 | "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", 40 | "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", 41 | "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", 42 | "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", 43 | "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", 44 | "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", 45 | "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", 46 | "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", 47 | "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", 48 | "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", 49 | "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", 50 | "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", 51 | "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", 52 | "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", 53 | "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", 54 | "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", 55 | "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", 56 | "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", 57 | "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", 58 | "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", 59 | "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", 60 | "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", 61 | "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", 62 | "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", 63 | "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", 64 | "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", 65 | "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", 66 | "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", 67 | "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", 68 | "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", 69 | "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", 70 | "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", 71 | "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", 72 | "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", 73 | "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", 74 | "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", 75 | "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", 76 | "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", 77 | "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", 78 | "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", 79 | "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", 80 | "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", 81 | "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", 82 | "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", 83 | "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", 84 | "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", 85 | "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", 86 | "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", 87 | "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", 88 | "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", 89 | "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", 90 | "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", 91 | "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", 92 | "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", 93 | "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", 94 | "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", 95 | "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", 96 | "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", 97 | "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", 98 | "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", 99 | "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", 100 | "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", 101 | "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", 102 | "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", 103 | "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", 104 | "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", 105 | "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", 106 | "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", 107 | "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", 108 | "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", 109 | "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", 110 | "freight car", "French horn", "frying pan", "fur coat", "garbage truck", 111 | "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", 112 | "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", 113 | "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", 114 | "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", 115 | "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", 116 | "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", 117 | "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", 118 | "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", 119 | "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", 120 | "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", 121 | "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", 122 | "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", 123 | "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", 124 | "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", 125 | "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", 126 | "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", 127 | "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", 128 | "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", 129 | "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", 130 | "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", 131 | "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", 132 | "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", 133 | "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", 134 | "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", 135 | "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", 136 | "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", 137 | "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", 138 | "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", 139 | "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", 140 | "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", 141 | "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", 142 | "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", 143 | "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", 144 | "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", 145 | "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", 146 | "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", 147 | "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", 148 | "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", 149 | "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", 150 | "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", 151 | "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", 152 | "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", 153 | "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", 154 | "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", 155 | "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", 156 | "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", 157 | "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", 158 | "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", 159 | "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", 160 | "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", 161 | "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", 162 | "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", 163 | "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", 164 | "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", 165 | "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", 166 | "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", 167 | "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", 168 | "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", 169 | "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", 170 | "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", 171 | "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", 172 | "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", 173 | "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", 174 | "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", 175 | "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"] 176 | 177 | imagenet_templates = ["itap of a {}.", 178 | "a bad photo of the {}.", 179 | "a origami {}.", 180 | "a photo of the large {}.", 181 | "a {} in a video game.", 182 | "art of the {}.", 183 | "a photo of the small {}."] 184 | 185 | 186 | class ImageNet(): 187 | 188 | dataset_dir = 'imagenet' 189 | 190 | def __init__(self, root, num_shots, preprocess): 191 | 192 | self.dataset_dir = os.path.join(root, self.dataset_dir) 193 | self.image_dir = os.path.join(self.dataset_dir, 'images') 194 | 195 | train_preprocess = transforms.Compose([ 196 | transforms.RandomResizedCrop(size=224, scale=(0.5, 1), interpolation=transforms.InterpolationMode.BICUBIC), 197 | transforms.RandomHorizontalFlip(p=0.5), 198 | transforms.ToTensor(), 199 | transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) 200 | ]) 201 | test_preprocess = preprocess 202 | 203 | self.train = torchvision.datasets.ImageNet(self.image_dir, split='train', transform=train_preprocess) 204 | self.val = torchvision.datasets.ImageNet(self.image_dir, split='val', transform=test_preprocess) 205 | self.test = torchvision.datasets.ImageNet(self.image_dir, split='val', transform=test_preprocess) 206 | 207 | self.template = imagenet_templates 208 | self.classnames = imagenet_classes 209 | 210 | split_by_label_dict = defaultdict(list) 211 | for i in range(len(self.train.imgs)): 212 | split_by_label_dict[self.train.targets[i]].append(self.train.imgs[i]) 213 | imgs = [] 214 | targets = [] 215 | 216 | for label, items in split_by_label_dict.items(): 217 | imgs = imgs + random.sample(items, num_shots) 218 | targets = targets + [label for i in range(num_shots)] 219 | self.train.imgs = imgs 220 | self.train.targets = targets 221 | self.train.samples = imgs 222 | 223 | if __name__ == '__main__': 224 | print('screw' in imagenet_classes) -------------------------------------------------------------------------------- /datasets/oxford_flowers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from scipy.io import loadmat 4 | from collections import defaultdict 5 | 6 | from .oxford_pets import OxfordPets 7 | from .utils import Datum, DatasetBase, read_json 8 | 9 | 10 | template = ['a photo of a {}, a type of flower.'] 11 | 12 | 13 | class OxfordFlowers(DatasetBase): 14 | 15 | dataset_dir = 'oxford_flowers' 16 | 17 | def __init__(self, root, num_shots): 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, 'jpg') 20 | self.label_file = os.path.join(self.dataset_dir, 'imagelabels.mat') 21 | self.lab2cname_file = os.path.join(self.dataset_dir, 'cat_to_name.json') 22 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_OxfordFlowers.json') 23 | 24 | self.template = template 25 | 26 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 27 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 28 | 29 | super().__init__(train_x=train, val=val, test=test) 30 | 31 | def read_data(self): 32 | tracker = defaultdict(list) 33 | label_file = loadmat(self.label_file)['labels'][0] 34 | for i, label in enumerate(label_file): 35 | imname = f'image_{str(i + 1).zfill(5)}.jpg' 36 | impath = os.path.join(self.image_dir, imname) 37 | label = int(label) 38 | tracker[label].append(impath) 39 | 40 | print('Splitting data into 50% train, 20% val, and 30% test') 41 | 42 | def _collate(ims, y, c): 43 | items = [] 44 | for im in ims: 45 | item = Datum( 46 | impath=im, 47 | label=y-1, # convert to 0-based label 48 | classname=c 49 | ) 50 | items.append(item) 51 | return items 52 | 53 | lab2cname = read_json(self.lab2cname_file) 54 | train, val, test = [], [], [] 55 | for label, impaths in tracker.items(): 56 | random.shuffle(impaths) 57 | n_total = len(impaths) 58 | n_train = round(n_total * 0.5) 59 | n_val = round(n_total * 0.2) 60 | n_test = n_total - n_train - n_val 61 | assert n_train > 0 and n_val > 0 and n_test > 0 62 | cname = lab2cname[str(label)] 63 | train.extend(_collate(impaths[:n_train], label, cname)) 64 | val.extend(_collate(impaths[n_train:n_train+n_val], label, cname)) 65 | test.extend(_collate(impaths[n_train+n_val:], label, cname)) 66 | 67 | return train, val, test -------------------------------------------------------------------------------- /datasets/oxford_pets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import random 4 | from collections import defaultdict 5 | 6 | import torchvision.transforms as transforms 7 | 8 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 9 | 10 | 11 | template = ['a photo of a {}, a type of pet.'] 12 | 13 | 14 | class OxfordPets(DatasetBase): 15 | 16 | dataset_dir = 'oxford_pets' 17 | 18 | def __init__(self, root, num_shots): 19 | self.dataset_dir = os.path.join(root, self.dataset_dir) 20 | self.image_dir = os.path.join(self.dataset_dir, 'images') 21 | self.anno_dir = os.path.join(self.dataset_dir, 'annotations') 22 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_OxfordPets.json') 23 | 24 | self.template = template 25 | 26 | train, val, test = self.read_split(self.split_path, self.image_dir) 27 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 28 | 29 | super().__init__(train_x=train, val=val, test=test) 30 | 31 | def read_data(self, split_file): 32 | filepath = os.path.join(self.anno_dir, split_file) 33 | items = [] 34 | 35 | with open(filepath, 'r') as f: 36 | lines = f.readlines() 37 | for line in lines: 38 | line = line.strip() 39 | imname, label, species, _ = line.split(' ') 40 | breed = imname.split('_')[:-1] 41 | breed = '_'.join(breed) 42 | breed = breed.lower() 43 | imname += '.jpg' 44 | impath = os.path.join(self.image_dir, imname) 45 | label = int(label) - 1 # convert to 0-based index 46 | item = Datum( 47 | impath=impath, 48 | label=label, 49 | classname=breed 50 | ) 51 | items.append(item) 52 | 53 | return items 54 | 55 | @staticmethod 56 | def split_trainval(trainval, p_val=0.2): 57 | p_trn = 1 - p_val 58 | print(f'Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val') 59 | tracker = defaultdict(list) 60 | for idx, item in enumerate(trainval): 61 | label = item.label 62 | tracker[label].append(idx) 63 | 64 | train, val = [], [] 65 | for label, idxs in tracker.items(): 66 | n_val = round(len(idxs) * p_val) 67 | assert n_val > 0 68 | random.shuffle(idxs) 69 | for n, idx in enumerate(idxs): 70 | item = trainval[idx] 71 | if n < n_val: 72 | val.append(item) 73 | else: 74 | train.append(item) 75 | 76 | return train, val 77 | 78 | @staticmethod 79 | def save_split(train, val, test, filepath, path_prefix): 80 | def _extract(items): 81 | out = [] 82 | for item in items: 83 | impath = item.impath 84 | label = item.label 85 | classname = item.classname 86 | impath = impath.replace(path_prefix, '') 87 | if impath.startswith('/'): 88 | impath = impath[1:] 89 | out.append((impath, label, classname)) 90 | return out 91 | 92 | train = _extract(train) 93 | val = _extract(val) 94 | test = _extract(test) 95 | 96 | split = { 97 | 'train': train, 98 | 'val': val, 99 | 'test': test 100 | } 101 | 102 | write_json(split, filepath) 103 | print(f'Saved split to {filepath}') 104 | 105 | @staticmethod 106 | def read_split(filepath, path_prefix): 107 | def _convert(items): 108 | out = [] 109 | for impath, label, classname in items: 110 | impath = os.path.join(path_prefix, impath) 111 | item = Datum( 112 | impath=impath, 113 | label=int(label), 114 | classname=classname 115 | ) 116 | out.append(item) 117 | return out 118 | 119 | print(f'Reading split from {filepath}') 120 | split = read_json(filepath) 121 | train = _convert(split['train']) 122 | val = _convert(split['val']) 123 | test = _convert(split['test']) 124 | 125 | return train, val, test -------------------------------------------------------------------------------- /datasets/sd_caltech.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 3 | from .oxford_pets import OxfordPets 4 | 5 | class SD_Caltech(DatasetBase): 6 | 7 | dataset_dir = 'sd_caltech_101' 8 | 9 | def __init__(self, root, num_shots): 10 | # root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 11 | self.dataset_dir = os.path.join(root, self.dataset_dir) 12 | self.image_dir = os.path.join(self.dataset_dir, '101_ObjectCategories') 13 | self.split_path = os.path.join(self.dataset_dir, 'sd_caltech.json') 14 | 15 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 16 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 17 | 18 | super().__init__(train_x=train, val=val, test=test) -------------------------------------------------------------------------------- /datasets/stanford_cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | from scipy.io import loadmat 3 | 4 | from .oxford_pets import OxfordPets 5 | from .utils import Datum, DatasetBase 6 | 7 | 8 | template = ['a photo of a {}.'] 9 | 10 | 11 | class StanfordCars(DatasetBase): 12 | 13 | dataset_dir = 'stanford_cars' 14 | 15 | def __init__(self, root, num_shots): 16 | self.dataset_dir = os.path.join(root, self.dataset_dir) 17 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_StanfordCars.json') 18 | 19 | self.template = template 20 | 21 | train, val, test = OxfordPets.read_split(self.split_path, self.dataset_dir) 22 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 23 | 24 | super().__init__(train_x=train, val=val, test=test) 25 | 26 | def read_data(self, image_dir, anno_file, meta_file): 27 | anno_file = loadmat(anno_file)['annotations'][0] 28 | meta_file = loadmat(meta_file)['class_names'][0] 29 | items = [] 30 | 31 | for i in range(len(anno_file)): 32 | imname = anno_file[i]['fname'][0] 33 | impath = os.path.join(self.dataset_dir, image_dir, imname) 34 | label = anno_file[i]['class'][0, 0] 35 | label = int(label) - 1 # convert to 0-based index 36 | classname = meta_file[label][0] 37 | names = classname.split(' ') 38 | year = names.pop(-1) 39 | names.insert(0, year) 40 | classname = ' '.join(names) 41 | item = Datum( 42 | impath=impath, 43 | label=label, 44 | classname=classname 45 | ) 46 | items.append(item) 47 | 48 | return items -------------------------------------------------------------------------------- /datasets/sun397.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 4 | 5 | from .oxford_pets import OxfordPets 6 | 7 | 8 | template = ['a photo of a {}.'] 9 | 10 | 11 | class SUN397(DatasetBase): 12 | 13 | dataset_dir = 'sun397' 14 | 15 | def __init__(self, root, num_shots): 16 | self.dataset_dir = os.path.join(root, self.dataset_dir) 17 | self.image_dir = os.path.join(self.dataset_dir, 'SUN397') 18 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_SUN397.json') 19 | 20 | self.template = template 21 | 22 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 23 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 24 | 25 | super().__init__(train_x=train, val=val, test=test) 26 | 27 | def read_data(self, cname2lab, text_file): 28 | text_file = os.path.join(self.dataset_dir, text_file) 29 | items = [] 30 | 31 | with open(text_file, 'r') as f: 32 | lines = f.readlines() 33 | for line in lines: 34 | imname = line.strip()[1:] # remove / 35 | classname = os.path.dirname(imname) 36 | label = cname2lab[classname] 37 | impath = os.path.join(self.image_dir, imname) 38 | 39 | names = classname.split('/')[1:] # remove 1st letter 40 | names = names[::-1] # put words like indoor/outdoor at first 41 | classname = ' '.join(names) 42 | 43 | item = Datum( 44 | impath=impath, 45 | label=label, 46 | classname=classname 47 | ) 48 | items.append(item) 49 | 50 | return items 51 | -------------------------------------------------------------------------------- /datasets/ucf101.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 4 | 5 | from .oxford_pets import OxfordPets 6 | 7 | 8 | template = ['a photo of a person doing {}.'] 9 | 10 | 11 | class UCF101(DatasetBase): 12 | 13 | dataset_dir = 'ucf101' 14 | 15 | def __init__(self, root, num_shots): 16 | self.dataset_dir = os.path.join(root, self.dataset_dir) 17 | self.image_dir = os.path.join(self.dataset_dir, 'UCF-101-midframes') 18 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_UCF101.json') 19 | 20 | self.template = template 21 | 22 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 23 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 24 | 25 | super().__init__(train_x=train, val=val, test=test) 26 | 27 | def read_data(self, cname2lab, text_file): 28 | text_file = os.path.join(self.dataset_dir, text_file) 29 | items = [] 30 | 31 | with open(text_file, 'r') as f: 32 | lines = f.readlines() 33 | for line in lines: 34 | line = line.strip().split(' ')[0] # trainlist: filename, label 35 | action, filename = line.split('/') 36 | label = cname2lab[action] 37 | 38 | elements = re.findall('[A-Z][^A-Z]*', action) 39 | renamed_action = '_'.join(elements) 40 | 41 | filename = filename.replace('.avi', '.jpg') 42 | impath = os.path.join(self.image_dir, renamed_action, filename) 43 | 44 | item = Datum( 45 | impath=impath, 46 | label=label, 47 | classname=renamed_action 48 | ) 49 | items.append(item) 50 | 51 | return items 52 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import os.path as osp 4 | import tarfile 5 | import zipfile 6 | from collections import defaultdict 7 | import gdown 8 | import json 9 | import torch 10 | from torch.utils.data import Dataset as TorchDataset 11 | import torchvision.transforms as T 12 | from PIL import Image 13 | 14 | 15 | def read_json(fpath): 16 | """Read json file from a path.""" 17 | with open(fpath, 'r') as f: 18 | obj = json.load(f) 19 | return obj 20 | 21 | 22 | def write_json(obj, fpath): 23 | """Writes to a json file.""" 24 | if not osp.exists(osp.dirname(fpath)): 25 | os.makedirs(osp.dirname(fpath)) 26 | with open(fpath, 'w') as f: 27 | json.dump(obj, f, indent=4, separators=(',', ': ')) 28 | 29 | 30 | def read_image(path): 31 | """Read image from path using ``PIL.Image``. 32 | 33 | Args: 34 | path (str): path to an image. 35 | 36 | Returns: 37 | PIL image 38 | """ 39 | if not osp.exists(path): 40 | raise IOError('No file exists at {}'.format(path)) 41 | 42 | while True: 43 | try: 44 | img = Image.open(path).convert('RGB') 45 | return img 46 | except IOError: 47 | print( 48 | 'Cannot read image from {}, ' 49 | 'probably due to heavy IO. Will re-try'.format(path) 50 | ) 51 | 52 | 53 | def listdir_nohidden(path, sort=False): 54 | """List non-hidden items in a directory. 55 | 56 | Args: 57 | path (str): directory path. 58 | sort (bool): sort the items. 59 | """ 60 | items = [f for f in os.listdir(path) if not f.startswith('.') and 'sh' not in f] 61 | if sort: 62 | items.sort() 63 | return items 64 | 65 | 66 | class Datum: 67 | """Data instance which defines the basic attributes. 68 | 69 | Args: 70 | impath (str): image path. 71 | label (int): class label. 72 | domain (int): domain label. 73 | classname (str): class name. 74 | """ 75 | 76 | def __init__(self, impath='', label=0, domain=-1, classname=''): 77 | assert isinstance(impath, str) 78 | assert isinstance(label, int) 79 | assert isinstance(domain, int) 80 | assert isinstance(classname, str) 81 | 82 | self._impath = impath 83 | self._label = label 84 | self._domain = domain 85 | self._classname = classname 86 | 87 | @property 88 | def impath(self): 89 | return self._impath 90 | 91 | @property 92 | def label(self): 93 | return self._label 94 | 95 | @property 96 | def domain(self): 97 | return self._domain 98 | 99 | @property 100 | def classname(self): 101 | return self._classname 102 | 103 | 104 | class DatasetBase: 105 | """A unified dataset class for 106 | 1) domain adaptation 107 | 2) domain generalization 108 | 3) semi-supervised learning 109 | """ 110 | dataset_dir = '' # the directory where the dataset is stored 111 | domains = [] # string names of all domains 112 | 113 | def __init__(self, train_x=None, train_u=None, val=None, test=None): 114 | self._train_x = train_x # labeled training data 115 | self._train_u = train_u # unlabeled training data (optional) 116 | self._val = val # validation data (optional) 117 | self._test = test # test data 118 | 119 | self._num_classes = self.get_num_classes(train_x) 120 | self._lab2cname, self._classnames = self.get_lab2cname(train_x) 121 | 122 | @property 123 | def train_x(self): 124 | return self._train_x 125 | 126 | @property 127 | def train_u(self): 128 | return self._train_u 129 | 130 | @property 131 | def val(self): 132 | return self._val 133 | 134 | @property 135 | def test(self): 136 | return self._test 137 | 138 | @property 139 | def lab2cname(self): 140 | return self._lab2cname 141 | 142 | @property 143 | def classnames(self): 144 | return self._classnames 145 | 146 | @property 147 | def num_classes(self): 148 | return self._num_classes 149 | 150 | def get_num_classes(self, data_source): 151 | """Count number of classes. 152 | 153 | Args: 154 | data_source (list): a list of Datum objects. 155 | """ 156 | label_set = set() 157 | for item in data_source: 158 | label_set.add(item.label) 159 | return max(label_set) + 1 160 | 161 | def get_lab2cname(self, data_source): 162 | """Get a label-to-classname mapping (dict). 163 | 164 | Args: 165 | data_source (list): a list of Datum objects. 166 | """ 167 | container = set() 168 | for item in data_source: 169 | container.add((item.label, item.classname)) 170 | mapping = {label: classname for label, classname in container} 171 | labels = list(mapping.keys()) 172 | labels.sort() 173 | classnames = [mapping[label] for label in labels] 174 | return mapping, classnames 175 | 176 | def check_input_domains(self, source_domains, target_domains): 177 | self.is_input_domain_valid(source_domains) 178 | self.is_input_domain_valid(target_domains) 179 | 180 | def is_input_domain_valid(self, input_domains): 181 | for domain in input_domains: 182 | if domain not in self.domains: 183 | raise ValueError( 184 | 'Input domain must belong to {}, ' 185 | 'but got [{}]'.format(self.domains, domain) 186 | ) 187 | 188 | def download_data(self, url, dst, from_gdrive=True): 189 | if not osp.exists(osp.dirname(dst)): 190 | os.makedirs(osp.dirname(dst)) 191 | 192 | if from_gdrive: 193 | gdown.download(url, dst, quiet=False) 194 | else: 195 | raise NotImplementedError 196 | 197 | print('Extracting file ...') 198 | 199 | try: 200 | tar = tarfile.open(dst) 201 | tar.extractall(path=osp.dirname(dst)) 202 | tar.close() 203 | except: 204 | zip_ref = zipfile.ZipFile(dst, 'r') 205 | zip_ref.extractall(osp.dirname(dst)) 206 | zip_ref.close() 207 | 208 | print('File extracted to {}'.format(osp.dirname(dst))) 209 | 210 | def generate_fewshot_dataset( 211 | self, *data_sources, num_shots=-1, repeat=True 212 | ): 213 | """Generate a few-shot dataset (typically for the training set). 214 | 215 | This function is useful when one wants to evaluate a model 216 | in a few-shot learning setting where each class only contains 217 | a few number of images. 218 | 219 | Args: 220 | data_sources: each individual is a list containing Datum objects. 221 | num_shots (int): number of instances per class to sample. 222 | repeat (bool): repeat images if needed. 223 | """ 224 | if num_shots < 1: 225 | if len(data_sources) == 1: 226 | return data_sources[0] 227 | return data_sources 228 | 229 | print(f'Creating a {num_shots}-shot dataset') 230 | 231 | output = [] 232 | 233 | for data_source in data_sources: 234 | tracker = self.split_dataset_by_label(data_source) 235 | dataset = [] 236 | 237 | for label, items in tracker.items(): 238 | if len(items) >= num_shots: 239 | sampled_items = random.sample(items, num_shots) 240 | else: 241 | if repeat: 242 | sampled_items = random.choices(items, k=num_shots) 243 | else: 244 | sampled_items = items 245 | dataset.extend(sampled_items) 246 | 247 | output.append(dataset) 248 | 249 | if len(output) == 1: 250 | return output[0] 251 | 252 | return output 253 | 254 | def split_dataset_by_label(self, data_source): 255 | """Split a dataset, i.e. a list of Datum objects, 256 | into class-specific groups stored in a dictionary. 257 | 258 | Args: 259 | data_source (list): a list of Datum objects. 260 | """ 261 | output = defaultdict(list) 262 | 263 | for item in data_source: 264 | output[item.label].append(item) 265 | 266 | return output 267 | 268 | def split_dataset_by_domain(self, data_source): 269 | """Split a dataset, i.e. a list of Datum objects, 270 | into domain-specific groups stored in a dictionary. 271 | 272 | Args: 273 | data_source (list): a list of Datum objects. 274 | """ 275 | output = defaultdict(list) 276 | 277 | for item in data_source: 278 | output[item.domain].append(item) 279 | 280 | return output 281 | 282 | 283 | class DatasetWrapper(TorchDataset): 284 | def __init__(self, data_source, input_size, transform=None, is_train=False, 285 | return_img0=False, k_tfm=1): 286 | self.data_source = data_source 287 | self.transform = transform # accept list (tuple) as input 288 | self.is_train = is_train 289 | # Augmenting an image K>1 times is only allowed during training 290 | self.k_tfm = k_tfm if is_train else 1 291 | self.return_img0 = return_img0 292 | 293 | if self.k_tfm > 1 and transform is None: 294 | raise ValueError( 295 | 'Cannot augment the image {} times ' 296 | 'because transform is None'.format(self.k_tfm) 297 | ) 298 | 299 | # Build transform that doesn't apply any data augmentation 300 | interp_mode = T.InterpolationMode.BICUBIC 301 | to_tensor = [] 302 | to_tensor += [T.Resize(input_size, interpolation=interp_mode)] 303 | to_tensor += [T.ToTensor()] 304 | normalize = T.Normalize( 305 | mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) 306 | ) 307 | to_tensor += [normalize] 308 | self.to_tensor = T.Compose(to_tensor) 309 | 310 | def __len__(self): 311 | return len(self.data_source) 312 | 313 | def __getitem__(self, idx): 314 | item = self.data_source[idx] 315 | 316 | output = { 317 | 'label': item.label, 318 | 'domain': item.domain, 319 | 'impath': item.impath 320 | } 321 | 322 | img0 = read_image(item.impath) 323 | 324 | if self.transform is not None: 325 | if isinstance(self.transform, (list, tuple)): 326 | for i, tfm in enumerate(self.transform): 327 | img = self._transform_image(tfm, img0) 328 | keyname = 'img' 329 | if (i + 1) > 1: 330 | keyname += str(i + 1) 331 | output[keyname] = img 332 | else: 333 | img = self._transform_image(self.transform, img0) 334 | output['img'] = img 335 | 336 | if self.return_img0: 337 | output['img0'] = self.to_tensor(img0) 338 | 339 | return output['img'], output['label'] 340 | 341 | def _transform_image(self, tfm, img0): 342 | img_list = [] 343 | 344 | for k in range(self.k_tfm): 345 | img_list.append(tfm(img0)) 346 | 347 | img = img_list 348 | if len(img) == 1: 349 | img = img[0] 350 | 351 | return img 352 | 353 | 354 | def build_data_loader( 355 | data_source=None, 356 | batch_size=64, 357 | input_size=224, 358 | tfm=None, 359 | is_train=True, 360 | shuffle=False, 361 | dataset_wrapper=None 362 | ): 363 | 364 | if dataset_wrapper is None: 365 | dataset_wrapper = DatasetWrapper 366 | 367 | # Build data loader 368 | data_loader = torch.utils.data.DataLoader( 369 | dataset_wrapper(data_source, input_size=input_size, transform=tfm, is_train=is_train), 370 | batch_size=batch_size, 371 | num_workers=8, 372 | shuffle=shuffle, 373 | drop_last=False, 374 | pin_memory=False 375 | ) 376 | assert len(data_loader) > 0 377 | 378 | return data_loader 379 | -------------------------------------------------------------------------------- /dino/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/CaFo/a805a2aefc6757fdbe10ac9a3165520ceb0e01cb/dino/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /exp.log: -------------------------------------------------------------------------------- 1 | 1 2 4 8 16 2 | ImageNet 63.80 64.34 65.64 66.86 68.79 3 | StanfordCars 61.98 63.36 65.69 70.31 76.73 4 | UCF101 68.60 70.45 72.96 78.06 79.94 5 | Caltech101 91.85 92.37 93.14 93.83 94.60 6 | Flowers102 80.88 84.94 90.95 92.98 95.86 7 | SUN397 64.89 66.81 69.17 70.34 72.60 8 | DTD 53.43 56.32 60.99 66.19 69.62 9 | Eurosat 69.00 72.86 83.90 86.48 88.68 10 | FGVCAircraft 24.96 26.04 32.94 40.38 49.05 11 | OxfordPets 89.21 89.10 90.11 90.52 91.55 12 | Food101 77.99 78.10 78.32 78.84 79.30 -------------------------------------------------------------------------------- /gpt_file/eurosat_prompt.json: -------------------------------------------------------------------------------- 1 | { 2 | "Annual Crop Land": [ 3 | "A centered satellite photo of Annual Crop Land would look like a large green field with small patches of brown or bare earth in between.", 4 | "A centered satellite photo of Annual Crop Land would look like a large green field with small patches of brown or bare earth in between.", 5 | "A centered satellite photo of Annual Crop Land would look like a large green field with small patches of brown or bare earth in between.", 6 | "A centered satellite photo of Annual Crop Land would also include any buildings or roads that are near the field.", 7 | "A centered satellite photo of Annual Crop Land would also show any roads or paths that lead to the crop land.", 8 | "A centered satellite photo of Annual Crop Land would look like a large green field with small patches of brown or bare earth in between.", 9 | "A centered satellite photo of Annual Crop Land would look like a large green field with small patches of brown or bare earth in between.", 10 | "A centered satellite photo of Annual Crop Land would look like a large green field with small patches of brown or bare earth in between.", 11 | "A centered satellite photo of Annual Crop Land would look like one large, continuous field of green.", 12 | "A centered satellite photo of Annual Crop Land may also show irrigation systems or other farming infrastructure." 13 | ], 14 | "Forest": [ 15 | "A centered satellite photo of Forest Land would look like a large green field with small patches of brown or bare earth in between.", 16 | "A centered satellite photo of Forest Land would look like a large green area with small patches of brown or bare earth in between.", 17 | "A centered satellite photo of Forest Land would look like a dense, green area with few or no bare patches of earth.", 18 | "A centered satellite photo of Forest Land would look like a large green field with small patches of brown or bare earth in between.", 19 | "A centered satellite photo of Forest Land would look like a large green field with small patches of brown or bare earth in between. A centered satellite photo of Grassland would look like a large green field with small patches of brown or bare earth in between.", 20 | "A centered satellite photo of Forest would look like a large green field with small patches of brown or bare earth in between.", 21 | "A centered satellite photo of Forest Land would look like a large playing field with lots of trees.", 22 | "A centered satellite photo of Forest Land would look like a large green field with small patches of brown or bare earth in between.", 23 | "A centered satellite photo of Forest would look like a green or dark green field with patches of brown or bare earth in between.", 24 | "A centered satellite photo of Forest Land would look like a large green area with small patches of brown or bare earth in between." 25 | ], 26 | "Herbaceous Vegetation Land": [ 27 | "A centered satellite photo of Herbaceous Vegetation Land would look like a large green or brown field with small patches of green or brown in between.", 28 | "A centered satellite photo of Herbaceous Vegetation Land would look like a large green field with small patches of brown or bare earth in between.", 29 | "A centered satellite photo of Herbaceous Vegetation Land would look like a green field with small patches of brown of bare earth in between. A centered satellite photo of Tree Cover would look like a green field with small patches of brown or bare earth in between, and a few trees scattered throughout.", 30 | "A centered satellite photo of Herbaceous Vegetation Land would look like a field of green with a few brown spots in between.", 31 | "A centered satellite photo of Herbaceous Vegetation Land would look like a green field with a few trees or bushes mixed in.", 32 | "A centered satellite photo of Herbaceous Vegetation Land would look like a large green field with small patches of brown or bare earth in between.", 33 | "A centered satellite photo of Herbaceous Vegetation Land would look like a green field with small patches of brown or bare earth in between.", 34 | "A centered satellite photo of Herbaceous Vegetation Land would look like a green field with very small patches of brown or bare earth in between.", 35 | "A centered satellite photo of Herbaceous Vegetation Land would look like a large green field with small patches of brown or bare earth in between. A centered satellite photo of Perennial Crop Land would look like a large green field with small patches of brown or bare earth in between.", 36 | "A centered satellite photo of Herbaceous Vegetation Land would look like a large, green field with small patches of brown or bare earth in between." 37 | ], 38 | "Highway or Road": [ 39 | "A centered satellite photo of Highway or Road Land would look like a long, thin, dark strip with small patches of green or brown on either side.", 40 | "A centered satellite photo of Highway or Road Land would look like a large paved road with small patches of green or brown on either side.", 41 | "A centered satellite photo of Highway or Road would look like a thin, dark line winding through a lighter-colored background.", 42 | "A centered satellite photo of Highway or Road Infrastructure would look like a large number of dark lines running across the landscape.", 43 | "A centered satellite photo of Highway or Road Infrastructure would look like a thin line of asphalt with a small patch of gravel or dirt on each side.", 44 | "A centered satellite photo of Highway or Road Land would look like a long, straight, grey line with small patches of green or brown on either side.", 45 | "A centered satellite photo of Highway or Road Infrastructure would look like a spider web of grey or white lines with small patches of green or brown in between.", 46 | "A centered satellite photo of Highway or Road Land would look like a large number of thin, dark lines criss-crossing each other.", 47 | "A centered satellite photo of Highway or Road Land would look like a large brown or gray road with green fields on either side.", 48 | "A centered satellite photo of Highway or Road Land would look like a spider web of thin, black lines." 49 | ], 50 | "Industrial Buildings": [ 51 | "A centered satellite photo of Industrial Buildings would look like a cluster of buildings, usually gray or white, surrounded by a parking lot.", 52 | "A centered satellite photo of Industrial Buildings would look like a group of large structures with small parking lots around them.", 53 | "A centered satellite photo of Industrial Buildings would look like a series of low, rectangular buildings with roofs of different colors.", 54 | "A centered satellite photo of Industrial Buildings would look like large, dark buildings amid a matrix of smaller, lighter buildings.", 55 | "A centered satellite photo of Industrial Buildings would look like a city with large buildings and smokestacks.", 56 | "A centered satellite photo of Industrial Buildings would look like a city with a few buildings that are taller than the others.", 57 | "A centered satellite photo of Industrial Buildings would look like a densely populated area with many buildings and roads.", 58 | "A centered satellite photo of Industrial Buildings would look like large connected buildings surrounded by asphalt parking lots.", 59 | "A centered satellite photo of Industrial Buildings would look like a bunch of large angular buildings with small streets in between them.", 60 | "A centered satellite photo of Industrial Buildings would look like a series of large Modern highrises in an urban area." 61 | ], 62 | "Pasture Land": [ 63 | "A centered satellite photo of Pasture Land would look like large green fields with animals grazing on them.", 64 | "A centered satellite photo of Pasture Land would look like a large green field with some areas of brown or bare earth in between.", 65 | "A centered satellite photo of Pasture Land would look like a large green field broken up by areas of trees, bushes, or other foliage.", 66 | "A centered satellite photo of Pasture Land would look like large green fields with small areas of brown or bare earth in between.", 67 | "A centered satellite photo of Pasture Land would look like a large green field with small patches of brown or bare earth in between.", 68 | "A centered satellite photo of Pasture Land would look like a large green or tan field with small patches of brown or bare earth in between.", 69 | "A centered satellite photo of Pasture Land would look like a large green or brown field with small patches of different colors in between.", 70 | "A centered satellite photo of Pasture Land would look like a large field of green with small brown or black spots (cows).", 71 | "A centered satellite photo of Pasture Land would look like large green fields with some areas of brown or bare earth in between.", 72 | "A centered satellite photo of Pasture Land would look like large green fields with small patches of brown or bare earth in between." 73 | ], 74 | "Permanent Crop Land": [ 75 | "A centered satellite photo of Permanent Crop Land would look like a large field with different colors depending on what crop is being grown.", 76 | "A centered satellite photo of Permanent Crop Land would look like a large green field with small patches of brown or bare earth in between.", 77 | "A centered satellite photo of Permanent Crop Land would look like a large green field with smaller, more uniform green patches in between.", 78 | "A centered satellite photo of Permanent Crop Land would look like a green field with small patches of brown earth or water in between.", 79 | "A centered satellite photo of Permanent Crop Land would look like a large green field with a few smaller green or brown fields in between.", 80 | "A centered satellite photo of Permanent Crop Land would look like a large green field with small patches of brown or bare earth in between, and there would also be small patches of different colors representing different types of permanent crops.", 81 | "A centered satellite photo of Permanent Crop Land would look like a large green field with small patches of brown or bare earth in between.", 82 | "A centered satellite photo of Permanent Crop Land would look like a mosaic of different colors, depending on the type of crop being grown.", 83 | "A centered satellite photo of Permanent Crop Land would look like a similar green field, however the patches of brown or bare earth would be much smaller, as there is less open land in between crops.", 84 | "A centered satellite photo of Permanent Crop Land would look like a large green or brown field with small patches of bare earth in between." 85 | ], 86 | "Residential Buildings": [ 87 | "A centered satellite photo of Residential Buildings would look like a city with tall buildings in the center and smaller buildings on the outskirts.", 88 | "A centered satellite photo of Residential Buildings would look like a city with large buildings and concrete roads. A centered satellite photo of a Commercial Harbor would look like a harbor with many boats and a few warehouses.", 89 | "A centered satellite photo of Residential Buildings would look like many small rectangular buildings that are close together with some green space in between them.", 90 | "A centered satellite photo of Residential Buildings would look like a lot of small buildings close together with some green space in between them.", 91 | "A centered satellite photo of Residential Buildings would look like a city with areas of green trees and parks throughout.", 92 | "A centered satellite photo of Residential Buildings would look like a city with tall buildings in the center and lower buildings or houses on the outskirts.", 93 | "A centered satellite photo of Residential Buildings would look like a bunch of small squares with a variety of colors.", 94 | "A centered satellite photo of Residential Buildings would look like a large number of small, square or rectangular shaped buildings with large open spaces in between.", 95 | "A centered satellite photo of Residential Buildings would look like a large number of small, square or rectangular buildings with small patches of green or bare earth in between.", 96 | "A centered satellite photo of Residential Buildings would look like a small city with many houses and buildings." 97 | ], 98 | "River": [ 99 | "A centered satellite photo of River Delta would look like a large mass of water with small islands or patches of land in between.", 100 | "A centered satellite photo of River would look like many small streams or rivers flowing through a larger body of water.", 101 | "A centered satellite photo of River would look like a long, thin blue line with small tributaries branching off of it.", 102 | "A centered satellite photo of River would look like a thin blue line winding through a larger green area.", 103 | "A centered satellite photo of River would look like a long, thin blue or green line winding its way through a landscape.", 104 | "A centered satellite photo of River would look like a large blue or green body of water with smaller tributaries feeding into it.", 105 | "A centered satellite photo of River would look like a large blue body of water with small patches of green or brown land on either side.", 106 | "A centered satellite photo of River Delta would look like a series of branching streams or rivers flowing into a larger body of water.", 107 | "A centered satellite photo of River would look like a long, thin body of water with trees or other landforms surrounding it.", 108 | "A centered satellite photo of River Delta would look like a large body of water with many small waterways flowing into it." 109 | ], 110 | "Sea or Lake": [ 111 | "A centered satellite photo of Sea or Lake would look like a large blue circle with small patches of green, white, or brown around the edge.", 112 | "A centered satellite photo of Sea or Lake Ice would look like a large white or blue field with small patches of ocean water in between.", 113 | "A centered satellite photo of Sea or Lake would look like a large dark blue body of water with small white or light-colored areas around the edge.", 114 | "A centered satellite photo of Sea or Lake Ice would look like a large body of white with small patches of blue in between.", 115 | "A centered satellite photo of Sea or Lake ice would look like a large white or light blue field with small patches of dark blue or black in between.", 116 | "A centered satellite photo of Sea or Lake would look like a large blue or green body of water with small islands in it.", 117 | "A centered satellite photo of Sea or Lake would look like a large dark blue body with small areas of whitecaps where the waves are crashing.", 118 | "A centered satellite photo of Sea or Lake ice would look like large white fields with small patches of dark water in between.", 119 | "A centered satellite photo of Sea or Lake ice would look like large white areas with smaller areas of dark water in between.", 120 | "A centered satellite photo of Sea or Lake Ice would look like a large white or light blue area with bits of dark blue in the middle." 121 | ] 122 | } -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import argparse 4 | import yaml 5 | from tqdm import tqdm 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | import torchvision.transforms as transforms 11 | from torchvision import models as torchvision_models 12 | 13 | from datasets import build_dataset 14 | from datasets.utils import build_data_loader 15 | import clip 16 | from utils import * 17 | import dino.utils as utils 18 | import itertools 19 | import json 20 | 21 | def get_arguments(): 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--config', dest='config', help='settings of Tip-Adapter in yaml format') 25 | args = parser.parse_args() 26 | 27 | return args 28 | 29 | def run_ensemble_tip_dalle_adapter_F(cfg, 30 | clip_cache_keys, 31 | clip_cache_values, 32 | clip_val_features, 33 | clip_test_features, 34 | dino_cache_keys, 35 | dino_cache_values, 36 | dino_val_features, 37 | dino_test_features, 38 | val_labels, 39 | test_labels, 40 | clip_weights, 41 | clip_model, 42 | dino_model, 43 | train_loader_F, 44 | dalle_train_loader_F): 45 | 46 | # Enable the cached keys to be learnable 47 | clip_adapter = nn.Linear(clip_cache_keys.shape[0], clip_cache_keys.shape[1], bias=False).to(clip_model.dtype).cuda() 48 | clip_adapter.weight = nn.Parameter(clip_cache_keys.t()) 49 | dino_adapter = nn.Linear(dino_cache_keys.shape[0], dino_cache_keys.shape[1], bias=False).to(clip_model.dtype).cuda() 50 | dino_adapter.weight = nn.Parameter(dino_cache_keys.t()) 51 | 52 | optimizer = torch.optim.AdamW( 53 | itertools.chain(dino_adapter.parameters(), clip_adapter.parameters()), 54 | lr=cfg['lr'], 55 | eps=1e-4) 56 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, cfg['train_epoch'] * len(train_loader_F)) 57 | 58 | beta, alpha = cfg['init_beta'], cfg['init_alpha'] 59 | best_acc, best_epoch = 0.0, 0 60 | 61 | for train_idx in range(cfg['train_epoch']): 62 | # Train 63 | clip_adapter.train() 64 | dino_adapter.train() 65 | correct_samples, all_samples = 0, 0 66 | loss_list = [] 67 | print('Train Epoch: {:} / {:}'.format(train_idx, cfg['train_epoch'])) 68 | 69 | # origin image 70 | for i, (images, target) in enumerate(tqdm(train_loader_F)): 71 | images, target = images.cuda(), target.cuda() 72 | with torch.no_grad(): 73 | clip_image_features = clip_model.encode_image(images) 74 | clip_image_features /= clip_image_features.norm(dim=-1, keepdim=True) 75 | dino_image_features = dino_model(images) 76 | dino_image_features /= dino_image_features.norm(dim=-1, keepdim=True) 77 | 78 | clip_affinity = clip_adapter(clip_image_features) 79 | clip_cache_logits = ((-1) * (beta - beta * clip_affinity)).exp() @ clip_cache_values 80 | dino_affinity = dino_adapter(dino_image_features).to(dino_cache_values.dtype) 81 | dino_cache_logits = ((-1) * (beta - beta * dino_affinity)).exp() @ dino_cache_values 82 | clip_logits = 100. * clip_image_features @ clip_weights 83 | 84 | cache_logits = logits_fuse(clip_logits, [clip_cache_logits, dino_cache_logits]) 85 | tip_logits = clip_logits + cache_logits * alpha 86 | loss = F.cross_entropy(tip_logits, target) 87 | 88 | acc = cls_acc(tip_logits, target) 89 | correct_samples += acc / 100 * len(tip_logits) 90 | all_samples += len(tip_logits) 91 | loss_list.append(loss.item()) 92 | 93 | optimizer.zero_grad() 94 | loss.backward() 95 | optimizer.step() 96 | scheduler.step() 97 | 98 | # dalle image 99 | for i, (images, target) in enumerate(tqdm(dalle_train_loader_F)): 100 | images, target = images.cuda(), target.cuda() 101 | with torch.no_grad(): 102 | clip_image_features = clip_model.encode_image(images) 103 | clip_image_features /= clip_image_features.norm(dim=-1, keepdim=True) 104 | dino_image_features = dino_model(images) 105 | dino_image_features /= dino_image_features.norm(dim=-1, keepdim=True) 106 | 107 | clip_affinity = clip_adapter(clip_image_features) 108 | clip_cache_logits = ((-1) * (beta - beta * clip_affinity)).exp() @ clip_cache_values 109 | dino_affinity = dino_adapter(dino_image_features).to(dino_cache_values.dtype) 110 | dino_cache_logits = ((-1) * (beta - beta * dino_affinity)).exp() @ dino_cache_values 111 | clip_logits = 100. * clip_image_features @ clip_weights 112 | 113 | cache_logits = logits_fuse(clip_logits, [clip_cache_logits, dino_cache_logits]) 114 | tip_logits = clip_logits + cache_logits * alpha 115 | loss = F.cross_entropy(tip_logits, target) 116 | 117 | acc = cls_acc(tip_logits, target) 118 | correct_samples += acc / 100 * len(tip_logits) 119 | all_samples += len(tip_logits) 120 | loss_list.append(loss.item()) 121 | 122 | optimizer.zero_grad() 123 | loss.backward() 124 | optimizer.step() 125 | scheduler.step() 126 | 127 | current_lr = scheduler.get_last_lr()[0] 128 | print('LR: {:.6f}, Acc: {:.4f} ({:}/{:}), Loss: {:.4f}'.format(current_lr, correct_samples / all_samples, correct_samples, all_samples, sum(loss_list)/len(loss_list))) 129 | 130 | # Eval 131 | clip_adapter.eval() 132 | dino_adapter.eval() 133 | 134 | clip_affinity = clip_adapter(clip_test_features) 135 | dino_affinity = dino_adapter(dino_test_features).to(dino_cache_values.dtype) 136 | clip_cache_logits = ((-1) * (beta - beta * clip_affinity)).exp() @ clip_cache_values 137 | dino_cache_logits = ((-1) * (beta - beta * dino_affinity)).exp() @ dino_cache_values 138 | clip_logits = 100. * clip_test_features @ clip_weights 139 | cache_logits = logits_fuse(clip_logits, [clip_cache_logits, dino_cache_logits]) 140 | tip_logits = clip_logits + cache_logits * alpha 141 | acc = cls_acc(tip_logits, test_labels) 142 | 143 | print("**** CaFo's test accuracy: {:.2f}. ****\n".format(acc)) 144 | if acc > best_acc: 145 | best_acc = acc 146 | best_epoch = train_idx 147 | torch.save(clip_adapter.weight, cfg['cache_dir'] + "/best_F_clip_adapter_" + str(cfg['shots']) + "shots.pt") 148 | torch.save(dino_adapter.weight, cfg['cache_dir'] + "/best_F_dino_adapter_" + str(cfg['shots']) + "shots.pt") 149 | 150 | clip_adapter.weight = torch.load(cfg['cache_dir'] + "/best_F_clip_adapter_" + str(cfg['shots']) + "shots.pt") 151 | dino_adapter.weight = torch.load(cfg['cache_dir'] + "/best_F_dino_adapter_" + str(cfg['shots']) + "shots.pt") 152 | print(f"**** After fine-tuning, CaFo's best test accuracy: {best_acc:.2f}, at epoch: {best_epoch}. ****\n") 153 | 154 | print("\n-------- Searching hyperparameters on the val set. --------") 155 | 156 | # Search Hyperparameters 157 | best_beta, best_alpha = best_beta, best_alpha = search_ensemble_hp(cfg, clip_cache_keys, clip_cache_values, clip_val_features, dino_cache_keys, dino_cache_values, dino_val_features, val_labels, clip_weights) 158 | 159 | print("\n-------- Evaluating on the test set. --------") 160 | 161 | clip_affinity = clip_adapter(clip_test_features) 162 | dino_affinity = dino_adapter(dino_test_features).to(dino_cache_values.dtype) 163 | clip_cache_logits = ((-1) * (best_beta - best_beta * clip_affinity)).exp() @ clip_cache_values 164 | dino_cache_logits = ((-1) * (best_beta - best_beta * dino_affinity)).exp() @ dino_cache_values 165 | 166 | cache_logits = logits_fuse(clip_logits, [clip_cache_logits, dino_cache_logits]) 167 | tip_logits = clip_logits + cache_logits * best_alpha 168 | acc = cls_acc(tip_logits, test_labels) 169 | print("**** CaFo's test accuracy: {:.2f}. ****\n".format(max(best_acc, acc))) 170 | 171 | def main(): 172 | 173 | # Load config file 174 | args = get_arguments() 175 | assert (os.path.exists(args.config)) 176 | 177 | cfg = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 178 | 179 | cache_dir = os.path.join('./caches', cfg['dataset']) 180 | os.makedirs(cache_dir, exist_ok=True) 181 | cfg['cache_dir'] = cache_dir 182 | 183 | print("\nRunning configs.") 184 | print(cfg, "\n") 185 | 186 | # CLIP 187 | clip_model, preprocess = clip.load(cfg['clip_backbone']) 188 | clip_model.eval() 189 | 190 | # DINO 191 | dino_model = torchvision_models.__dict__[cfg['dino_backbone']](num_classes=0) 192 | dino_model.fc = nn.Identity() 193 | dino_model.cuda() 194 | utils.load_pretrained_weights(dino_model, "dino/dino_resnet50_pretrain.pth", "teacher", "vit_small'", 16) 195 | dino_model.eval() 196 | 197 | # Prepare dataset 198 | random.seed(1) 199 | torch.manual_seed(1) 200 | 201 | print("Preparing dataset.") 202 | dataset = build_dataset(cfg['dataset'], cfg['root_path'], cfg['shots']) 203 | 204 | val_loader = build_data_loader(data_source=dataset.val, batch_size=64, is_train=False, tfm=preprocess, shuffle=False) 205 | test_loader = build_data_loader(data_source=dataset.test, batch_size=64, is_train=False, tfm=preprocess, shuffle=False) 206 | 207 | train_tranform = transforms.Compose([ 208 | transforms.RandomResizedCrop(size=224, scale=(0.5, 1), interpolation=transforms.InterpolationMode.BICUBIC), 209 | transforms.RandomHorizontalFlip(p=0.5), 210 | transforms.ToTensor(), 211 | transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) 212 | ]) 213 | 214 | train_loader_cache = build_data_loader(data_source=dataset.train_x, batch_size=256, tfm=train_tranform, is_train=True, shuffle=False) 215 | train_loader_F = build_data_loader(data_source=dataset.train_x, batch_size=256, tfm=train_tranform, is_train=True, shuffle=True) 216 | 217 | dalle_dataset = build_dataset(cfg['dalle_dataset'], cfg['root_path'], cfg['dalle_shots']) 218 | dalle_train_loader_cache = build_data_loader(data_source=dalle_dataset.train_x, batch_size=256, tfm=train_tranform, is_train=True, shuffle=False) 219 | dalle_train_loader_F = build_data_loader(data_source=dalle_dataset.train_x, batch_size=256, tfm=train_tranform, is_train=True, shuffle=True) 220 | 221 | with open(cfg['gpt3_prompt_file']) as f: 222 | gpt3_prompt = json.load(f) 223 | 224 | # Textual features 225 | print("\nGetting textual features as CLIP's classifier.") 226 | #clip_weights = clip_classifier(dataset.classnames, dataset.template, clip_model) 227 | clip_weights = gpt_clip_classifier(dataset.classnames, gpt3_prompt, clip_model, dataset.template) 228 | 229 | # Construct the cache model by few-shot training set 230 | print("\nConstructing cache model by few-shot visual features and labels.") 231 | print("\nConstructing CLIP cache model.") 232 | clip_cache_keys, clip_cache_values = build_clip_cache_model(cfg, clip_model, train_loader_cache) 233 | print("\nConstructing DINO cache model.") 234 | dino_cache_keys, dino_cache_values = build_dino_cache_model(cfg, dino_model, train_loader_cache) 235 | 236 | print("\nConstructing cache model by dalle image.") 237 | print("\nConstructing CLIP cache model.") 238 | clip_dalle_cache_keys, clip_dalle_cache_values = build_clip_dalle_cache_model(cfg, clip_model, dalle_train_loader_cache) 239 | print("\nConstructing DINO cache model.") 240 | dino_dalle_cache_keys, dino_dalle_cache_values = build_dino_dalle_cache_model(cfg, dino_model, dalle_train_loader_cache) 241 | 242 | # Pre-load val features 243 | print("\nLoading visual features and labels from val set.") 244 | print("\nLoading CLIP feature.") 245 | val_clip_features, val_labels = pre_CLIP_load_features(cfg, "val", clip_model, val_loader) 246 | print("\nLoading DINO feature.") 247 | val_dino_features, val_labels = pre_DINO_load_features(cfg, "val", dino_model, val_loader) 248 | 249 | # Pre-load test features 250 | print("\nLoading visual features and labels from test set.") 251 | print("\nLoading CLIP feature.") 252 | test_clip_features, test_labels = pre_CLIP_load_features(cfg, "test", clip_model, test_loader) 253 | print("\nLoading DINO feature.") 254 | test_dino_features, test_labels = pre_DINO_load_features(cfg, "test", dino_model, test_loader) 255 | 256 | # ------------------------------------------ Tip-Adapter-F ------------------------------------------ 257 | 258 | run_ensemble_tip_dalle_adapter_F(cfg, 259 | torch.cat((clip_cache_keys, clip_dalle_cache_keys), dim=1), 260 | torch.cat((clip_cache_values, clip_dalle_cache_values), dim=0), 261 | val_clip_features, 262 | test_clip_features, 263 | torch.cat((dino_cache_keys, dino_dalle_cache_keys), dim=1), 264 | torch.cat((dino_cache_values, dino_dalle_cache_values), dim=0), 265 | val_dino_features, 266 | test_dino_features, 267 | val_labels, 268 | test_labels, 269 | clip_weights, 270 | clip_model, 271 | dino_model, 272 | train_loader_F, 273 | dalle_train_loader_F) 274 | 275 | if __name__ == '__main__': 276 | main() -------------------------------------------------------------------------------- /main_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import argparse 4 | import yaml 5 | from tqdm import tqdm 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | import torchvision.transforms as transforms 11 | from torchvision import models as torchvision_models 12 | 13 | from datasets.imagenet import ImageNet 14 | from datasets import build_dataset 15 | from datasets.utils import build_data_loader 16 | import clip 17 | from utils import * 18 | import dino.utils as utils 19 | import itertools 20 | import json 21 | 22 | 23 | def get_arguments(): 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--config', dest='config', help='settings of Tip-Adapter in yaml format') 27 | args = parser.parse_args() 28 | 29 | return args 30 | 31 | def run_ensemble_tip_dalle_adapter_F(cfg, 32 | clip_cache_keys, 33 | clip_cache_values, 34 | clip_test_features, 35 | dino_cache_keys, 36 | dino_cache_values, 37 | dino_test_features, 38 | test_labels, 39 | clip_weights, 40 | clip_model, 41 | dino_model, 42 | train_loader_F, 43 | dalle_train_loader_F): 44 | 45 | # Enable the cached keys to be learnable 46 | clip_adapter = nn.Linear(clip_cache_keys.shape[0], clip_cache_keys.shape[1], bias=False).to(clip_model.dtype).cuda() 47 | clip_adapter.weight = nn.Parameter(clip_cache_keys.t()) 48 | dino_adapter = nn.Linear(dino_cache_keys.shape[0], dino_cache_keys.shape[1], bias=False).to(clip_model.dtype).cuda() 49 | dino_adapter.weight = nn.Parameter(dino_cache_keys.t()) 50 | 51 | optimizer = torch.optim.AdamW( 52 | itertools.chain(dino_adapter.parameters(), clip_adapter.parameters()), 53 | lr=cfg['lr'], 54 | eps=1e-4) 55 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, cfg['train_epoch'] * len(train_loader_F)) 56 | 57 | beta, alpha = cfg['init_beta'], cfg['init_alpha'] 58 | best_acc, best_epoch = 0.0, 0 59 | 60 | for train_idx in range(cfg['train_epoch']): 61 | # Train 62 | clip_adapter.train() 63 | dino_adapter.train() 64 | correct_samples, all_samples = 0, 0 65 | loss_list = [] 66 | print('Train Epoch: {:} / {:}'.format(train_idx, cfg['train_epoch'])) 67 | 68 | # origin image 69 | for i, (images, target) in enumerate(tqdm(train_loader_F)): 70 | images, target = images.cuda(), target.cuda() 71 | with torch.no_grad(): 72 | clip_image_features = clip_model.encode_image(images) 73 | clip_image_features /= clip_image_features.norm(dim=-1, keepdim=True) 74 | dino_image_features = dino_model(images) 75 | dino_image_features /= dino_image_features.norm(dim=-1, keepdim=True) 76 | 77 | clip_affinity = clip_adapter(clip_image_features) 78 | clip_cache_logits = ((-1) * (beta - beta * clip_affinity)).exp() @ clip_cache_values 79 | dino_affinity = dino_adapter(dino_image_features).to(dino_cache_values.dtype) 80 | dino_cache_logits = ((-1) * (beta - beta * dino_affinity)).exp() @ dino_cache_values 81 | clip_logits = 100. * clip_image_features @ clip_weights 82 | 83 | cache_logits = logits_fuse(clip_logits, [clip_cache_logits, dino_cache_logits]) 84 | tip_logits = clip_logits + cache_logits * alpha 85 | loss = F.cross_entropy(tip_logits, target) 86 | 87 | acc = cls_acc(tip_logits, target) 88 | correct_samples += acc / 100 * len(tip_logits) 89 | all_samples += len(tip_logits) 90 | loss_list.append(loss.item()) 91 | 92 | optimizer.zero_grad() 93 | loss.backward() 94 | optimizer.step() 95 | scheduler.step() 96 | 97 | # dalle image 98 | for i, (images, target) in enumerate(tqdm(dalle_train_loader_F)): 99 | images, target = images.cuda(), target.cuda() 100 | with torch.no_grad(): 101 | clip_image_features = clip_model.encode_image(images) 102 | clip_image_features /= clip_image_features.norm(dim=-1, keepdim=True) 103 | dino_image_features = dino_model(images) 104 | dino_image_features /= dino_image_features.norm(dim=-1, keepdim=True) 105 | 106 | clip_affinity = clip_adapter(clip_image_features) 107 | clip_cache_logits = ((-1) * (beta - beta * clip_affinity)).exp() @ clip_cache_values 108 | dino_affinity = dino_adapter(dino_image_features).to(dino_cache_values.dtype) 109 | dino_cache_logits = ((-1) * (beta - beta * dino_affinity)).exp() @ dino_cache_values 110 | clip_logits = 100. * clip_image_features @ clip_weights 111 | 112 | cache_logits = logits_fuse(clip_logits, [clip_cache_logits, dino_cache_logits]) 113 | tip_logits = clip_logits + cache_logits * alpha 114 | loss = F.cross_entropy(tip_logits, target) 115 | 116 | acc = cls_acc(tip_logits, target) 117 | correct_samples += acc / 100 * len(tip_logits) 118 | all_samples += len(tip_logits) 119 | loss_list.append(loss.item()) 120 | 121 | optimizer.zero_grad() 122 | loss.backward() 123 | optimizer.step() 124 | scheduler.step() 125 | 126 | current_lr = scheduler.get_last_lr()[0] 127 | print('LR: {:.6f}, Acc: {:.4f} ({:}/{:}), Loss: {:.4f}'.format(current_lr, correct_samples / all_samples, correct_samples, all_samples, sum(loss_list)/len(loss_list))) 128 | 129 | # Eval 130 | clip_adapter.eval() 131 | dino_adapter.eval() 132 | 133 | clip_affinity = clip_adapter(clip_test_features) 134 | dino_affinity = dino_adapter(dino_test_features).to(dino_cache_values.dtype) 135 | clip_cache_logits = ((-1) * (beta - beta * clip_affinity)).exp() @ clip_cache_values 136 | dino_cache_logits = ((-1) * (beta - beta * dino_affinity)).exp() @ dino_cache_values 137 | clip_logits = 100. * clip_test_features @ clip_weights 138 | cache_logits = logits_fuse(clip_logits, [clip_cache_logits, dino_cache_logits]) 139 | tip_logits = clip_logits + cache_logits * alpha 140 | acc = cls_acc(tip_logits, test_labels) 141 | 142 | print("**** CaFo's test accuracy: {:.2f}. ****\n".format(acc)) 143 | if acc > best_acc: 144 | best_acc = acc 145 | best_epoch = train_idx 146 | torch.save(clip_adapter.weight, cfg['cache_dir'] + "/best_F_clip_adapter_" + str(cfg['shots']) + "shots.pt") 147 | torch.save(dino_adapter.weight, cfg['cache_dir'] + "/best_F_dino_adapter_" + str(cfg['shots']) + "shots.pt") 148 | 149 | clip_adapter.weight = torch.load(cfg['cache_dir'] + "/best_F_clip_adapter_" + str(cfg['shots']) + "shots.pt") 150 | dino_adapter.weight = torch.load(cfg['cache_dir'] + "/best_F_dino_adapter_" + str(cfg['shots']) + "shots.pt") 151 | print(f"**** After fine-tuning, CaFo's best test accuracy: {best_acc:.2f}, at epoch: {best_epoch}. ****\n") 152 | 153 | del clip_logits, tip_logits, cache_logits, clip_cache_logits, dino_cache_logits, clip_affinity, dino_affinity 154 | # Search Hyperparameters 155 | # _ = search_hp(cfg, affinity, clip_cache_values, clip_test_features, test_labels, clip_weights, clip_adapter=adapter) 156 | best_beta, best_alpha = search_ensemble_hp(cfg, clip_cache_keys, clip_cache_values, clip_test_features, dino_cache_keys, dino_cache_values, dino_test_features, test_labels, clip_weights, clip_adapter=clip_adapter, dino_adapter=dino_adapter) 157 | clip_affinity = clip_adapter(clip_test_features) 158 | dino_affinity = dino_adapter(dino_test_features).to(dino_cache_values.dtype) 159 | clip_cache_logits = ((-1) * (best_beta - best_beta * clip_affinity)).exp() @ clip_cache_values 160 | dino_cache_logits = ((-1) * (best_beta - best_beta * dino_affinity)).exp() @ dino_cache_values 161 | clip_logits = 100. * clip_test_features @ clip_weights 162 | cache_logits = logits_fuse(clip_logits, [clip_cache_logits, dino_cache_logits]) 163 | tip_logits = clip_logits + cache_logits * best_alpha 164 | print("save logits!!!!!!!!!!!!!") 165 | torch.save(tip_logits, cfg['cache_dir'] + "/best_tip_dino_dalle_logits_" + str(cfg['shots']) + "shots.pt") 166 | 167 | def main(): 168 | 169 | # Load config file 170 | args = get_arguments() 171 | assert (os.path.exists(args.config)) 172 | 173 | cfg = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 174 | 175 | cache_dir = os.path.join('./caches', cfg['dataset']) 176 | os.makedirs(cache_dir, exist_ok=True) 177 | cfg['cache_dir'] = cache_dir 178 | 179 | print("\nRunning configs.") 180 | print(cfg, "\n") 181 | 182 | # CLIP 183 | clip_model, preprocess = clip.load(cfg['clip_backbone']) 184 | clip_model.eval() 185 | 186 | # DINO 187 | dino_model = torchvision_models.__dict__[cfg['dino_backbone']](num_classes=0) 188 | dino_model.fc = nn.Identity() 189 | dino_model.cuda() 190 | utils.load_pretrained_weights(dino_model, "dino/dino_resnet50_pretrain.pth", "teacher", "vit_small'", 16) 191 | dino_model.eval() 192 | 193 | # ImageNet dataset 194 | random.seed(2) 195 | torch.manual_seed(1) 196 | 197 | print("Preparing ImageNet dataset.") 198 | imagenet = ImageNet(cfg['root_path'], cfg['shots'], preprocess) 199 | 200 | test_loader = torch.utils.data.DataLoader(imagenet.test, batch_size=64, num_workers=8, shuffle=False) 201 | 202 | train_loader_cache = torch.utils.data.DataLoader(imagenet.train, batch_size=256, num_workers=8, shuffle=False) 203 | train_loader_F = torch.utils.data.DataLoader(imagenet.train, batch_size=256, num_workers=8, shuffle=True) 204 | 205 | dalle_dataset = build_dataset(cfg['dalle_dataset'], cfg['root_path'], cfg['dalle_shots']) 206 | train_tranform = transforms.Compose([ 207 | transforms.RandomResizedCrop(size=224, scale=(0.5, 1), interpolation=transforms.InterpolationMode.BICUBIC), 208 | transforms.RandomHorizontalFlip(p=0.5), 209 | transforms.ToTensor(), 210 | transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) 211 | ]) 212 | dalle_train_loader_cache = build_data_loader(data_source=dalle_dataset.train_x, batch_size=256, tfm=train_tranform, is_train=True, shuffle=False) 213 | dalle_train_loader_F = build_data_loader(data_source=dalle_dataset.train_x, batch_size=256, tfm=train_tranform, is_train=True, shuffle=True) 214 | 215 | with open(cfg['gpt3_prompt_file']) as f: 216 | gpt3_prompt = json.load(f) 217 | 218 | # Textual features 219 | print("Getting textual features as CLIP's classifier.") 220 | clip_weights = gpt_clip_classifier(imagenet.classnames, gpt3_prompt, clip_model, imagenet.template) 221 | 222 | 223 | # Construct the cache model by few-shot training set 224 | print("\nConstructing cache model by few-shot visual features and labels.") 225 | print("\nConstructing CLIP cache model.") 226 | clip_cache_keys, clip_cache_values = build_clip_cache_model(cfg, clip_model, train_loader_cache) 227 | print("\nConstructing DINO cache model.") 228 | dino_cache_keys, dino_cache_values = build_dino_cache_model(cfg, dino_model, train_loader_cache) 229 | 230 | print("\nConstructing cache model by dalle image.") 231 | print("\nConstructing CLIP cache model.") 232 | clip_dalle_cache_keys, clip_dalle_cache_values = build_clip_dalle_cache_model(cfg, clip_model, dalle_train_loader_cache) 233 | print("\nConstructing DINO cache model.") 234 | dino_dalle_cache_keys, dino_dalle_cache_values = build_dino_dalle_cache_model(cfg, dino_model, dalle_train_loader_cache) 235 | 236 | # Pre-load test features 237 | print("\nLoading visual features and labels from test set.") 238 | print("\nLoading CLIP feature.") 239 | test_clip_features, test_labels = pre_CLIP_load_features(cfg, "test", clip_model, test_loader) 240 | print("\nLoading DINO feature.") 241 | test_dino_features, test_labels = pre_DINO_load_features(cfg, "test", dino_model, test_loader) 242 | 243 | # ------------------------------------------ Tip-Adapter-F ------------------------------------------ 244 | 245 | run_ensemble_tip_dalle_adapter_F(cfg, 246 | torch.cat((clip_cache_keys, clip_dalle_cache_keys), dim=1), 247 | torch.cat((clip_cache_values, clip_dalle_cache_values), dim=0), 248 | test_clip_features, 249 | torch.cat((dino_cache_keys, dino_dalle_cache_keys), dim=1), 250 | torch.cat((dino_cache_values, dino_dalle_cache_values), dim=0), 251 | test_dino_features, 252 | test_labels, 253 | clip_weights, 254 | clip_model, 255 | dino_model, 256 | train_loader_F, 257 | dalle_train_loader_F) 258 | 259 | if __name__ == '__main__': 260 | main() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | flake8==3.7.9 2 | yapf==0.29.0 3 | isort==4.3.21 4 | yacs 5 | gdown 6 | tb-nightly 7 | future 8 | scipy 9 | scikit-learn 10 | tqdm 11 | ftfy 12 | regex 13 | wilds==1.2.2 14 | tabulate -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | 7 | import clip 8 | 9 | 10 | def cls_acc(output, target, topk=1): 11 | pred = output.topk(topk, 1, True, True)[1].t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | acc = float(correct[: topk].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) 14 | acc = 100 * acc / target.shape[0] 15 | return acc 16 | 17 | def gpt_clip_classifier(classnames, gpt_prompts, clip_model, template): 18 | with torch.no_grad(): 19 | clip_weights = [] 20 | for classname in classnames: 21 | # Tokenize the prompts 22 | classname = classname.replace('_', ' ') 23 | texts = [] 24 | for t in gpt_prompts[classname]: 25 | texts.append(t) 26 | texts = clip.tokenize(texts).cuda() 27 | # prompt ensemble for ImageNet 28 | class_embeddings = clip_model.encode_text(texts) 29 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) 30 | class_embedding = class_embeddings.mean(dim=0) 31 | class_embedding /= class_embedding.norm() 32 | clip_weights.append(class_embedding) 33 | 34 | clip_weights = torch.stack(clip_weights, dim=1).cuda() 35 | return clip_weights 36 | 37 | 38 | def build_clip_cache_model(cfg, clip_model, train_loader_cache): 39 | 40 | if cfg['load_cache'] == False: 41 | cache_keys = [] 42 | cache_values = [] 43 | 44 | with torch.no_grad(): 45 | # Data augmentation for the cache model 46 | for augment_idx in range(cfg['augment_epoch']): 47 | train_features = [] 48 | 49 | print('Augment Epoch: {:} / {:}'.format(augment_idx, cfg['augment_epoch'])) 50 | for i, (images, target) in enumerate(tqdm(train_loader_cache)): 51 | images = images.cuda() 52 | image_features = clip_model.encode_image(images) 53 | train_features.append(image_features) 54 | if augment_idx == 0: 55 | target = target.cuda() 56 | cache_values.append(target) 57 | cache_keys.append(torch.cat(train_features, dim=0).unsqueeze(0)) 58 | 59 | cache_keys = torch.cat(cache_keys, dim=0).mean(dim=0) 60 | cache_keys /= cache_keys.norm(dim=-1, keepdim=True) 61 | cache_keys = cache_keys.permute(1, 0) 62 | cache_values = F.one_hot(torch.cat(cache_values, dim=0)).half() 63 | 64 | torch.save(cache_keys, cfg['cache_dir'] + '/clip_keys_' + str(cfg['shots']) + "shots.pt") 65 | torch.save(cache_values, cfg['cache_dir'] + '/clip_values_' + str(cfg['shots']) + "shots.pt") 66 | 67 | else: 68 | cache_keys = torch.load(cfg['cache_dir'] + '/clip_keys_' + str(cfg['shots']) + "shots.pt") 69 | cache_values = torch.load(cfg['cache_dir'] + '/clip_values_' + str(cfg['shots']) + "shots.pt") 70 | 71 | return cache_keys, cache_values 72 | 73 | def build_dino_cache_model(cfg, dino_model, train_loader_cache): 74 | 75 | if cfg['load_cache'] == False: 76 | cache_keys = [] 77 | cache_values = [] 78 | 79 | with torch.no_grad(): 80 | # Data augmentation for the cache model 81 | for augment_idx in range(cfg['augment_epoch']): 82 | train_features = [] 83 | 84 | print('Augment Epoch: {:} / {:}'.format(augment_idx, cfg['augment_epoch'])) 85 | for i, (images, target) in enumerate(tqdm(train_loader_cache)): 86 | images = images.cuda() 87 | image_features = dino_model(images) 88 | train_features.append(image_features) 89 | if augment_idx == 0: 90 | target = target.cuda() 91 | cache_values.append(target) 92 | cache_keys.append(torch.cat(train_features, dim=0).unsqueeze(0)) 93 | 94 | cache_keys = torch.cat(cache_keys, dim=0).mean(dim=0) 95 | cache_keys /= cache_keys.norm(dim=-1, keepdim=True) 96 | cache_keys = cache_keys.permute(1, 0) 97 | cache_values = F.one_hot(torch.cat(cache_values, dim=0)).half() 98 | 99 | torch.save(cache_keys, cfg['cache_dir'] + '/dino_keys_' + str(cfg['shots']) + "shots.pt") 100 | torch.save(cache_values, cfg['cache_dir'] + '/dino_values_' + str(cfg['shots']) + "shots.pt") 101 | 102 | else: 103 | cache_keys = torch.load(cfg['cache_dir'] + '/dino_keys_' + str(cfg['shots']) + "shots.pt") 104 | cache_values = torch.load(cfg['cache_dir'] + '/dino_values_' + str(cfg['shots']) + "shots.pt") 105 | 106 | return cache_keys, cache_values 107 | 108 | def build_clip_dalle_cache_model(cfg, clip_model, train_loader_cache): 109 | 110 | if cfg['load_cache'] == False: 111 | cache_keys = [] 112 | cache_values = [] 113 | 114 | with torch.no_grad(): 115 | # Data augmentation for the cache model 116 | for augment_idx in range(cfg['augment_epoch']): 117 | train_features = [] 118 | 119 | print('Augment Epoch: {:} / {:}'.format(augment_idx, cfg['augment_epoch'])) 120 | for i, (images, target) in enumerate(tqdm(train_loader_cache)): 121 | images = images.cuda() 122 | image_features = clip_model.encode_image(images) 123 | train_features.append(image_features) 124 | if augment_idx == 0: 125 | target = target.cuda() 126 | cache_values.append(target) 127 | cache_keys.append(torch.cat(train_features, dim=0).unsqueeze(0)) 128 | 129 | cache_keys = torch.cat(cache_keys, dim=0).mean(dim=0) 130 | cache_keys /= cache_keys.norm(dim=-1, keepdim=True) 131 | cache_keys = cache_keys.permute(1, 0) 132 | cache_values = F.one_hot(torch.cat(cache_values, dim=0)).half() 133 | 134 | torch.save(cache_keys, cfg['cache_dir'] + '/clip_dalle_keys_' + str(cfg['dalle_shots']) + "shots.pt") 135 | torch.save(cache_values, cfg['cache_dir'] + '/clip_dalle_values_' + str(cfg['dalle_shots']) + "shots.pt") 136 | 137 | else: 138 | cache_keys = torch.load(cfg['cache_dir'] + '/clip_dalle_keys_' + str(cfg['dalle_shots']) + "shots.pt") 139 | cache_values = torch.load(cfg['cache_dir'] + '/clip_dalle_values_' + str(cfg['dalle_shots']) + "shots.pt") 140 | 141 | return cache_keys, cache_values 142 | 143 | def build_dino_dalle_cache_model(cfg, dino_model, train_loader_cache): 144 | 145 | if cfg['load_cache'] == False: 146 | cache_keys = [] 147 | cache_values = [] 148 | 149 | with torch.no_grad(): 150 | # Data augmentation for the cache model 151 | for augment_idx in range(cfg['augment_epoch']): 152 | train_features = [] 153 | 154 | print('Augment Epoch: {:} / {:}'.format(augment_idx, cfg['augment_epoch'])) 155 | for i, (images, target) in enumerate(tqdm(train_loader_cache)): 156 | images = images.cuda() 157 | image_features = dino_model(images) 158 | train_features.append(image_features) 159 | if augment_idx == 0: 160 | target = target.cuda() 161 | cache_values.append(target) 162 | cache_keys.append(torch.cat(train_features, dim=0).unsqueeze(0)) 163 | 164 | cache_keys = torch.cat(cache_keys, dim=0).mean(dim=0) 165 | cache_keys /= cache_keys.norm(dim=-1, keepdim=True) 166 | cache_keys = cache_keys.permute(1, 0) 167 | cache_values = F.one_hot(torch.cat(cache_values, dim=0)).half() 168 | 169 | torch.save(cache_keys, cfg['cache_dir'] + '/dino_dalle_keys_' + str(cfg['dalle_shots']) + "shots.pt") 170 | torch.save(cache_values, cfg['cache_dir'] + '/dino_dalle_values_' + str(cfg['dalle_shots']) + "shots.pt") 171 | 172 | else: 173 | cache_keys = torch.load(cfg['cache_dir'] + '/dino_dalle_keys_' + str(cfg['dalle_shots']) + "shots.pt") 174 | cache_values = torch.load(cfg['cache_dir'] + '/dino_dalle_values_' + str(cfg['dalle_shots']) + "shots.pt") 175 | 176 | return cache_keys, cache_values 177 | 178 | 179 | def pre_CLIP_load_features(cfg, split, clip_model, loader): 180 | 181 | if cfg['load_pre_feat'] == False: 182 | features, labels = [], [] 183 | 184 | with torch.no_grad(): 185 | for i, (images, target) in enumerate(tqdm(loader)): 186 | images, target = images.cuda(), target.cuda() 187 | image_features = clip_model.encode_image(images) 188 | image_features /= image_features.norm(dim=-1, keepdim=True) 189 | features.append(image_features) 190 | labels.append(target) 191 | 192 | features, labels = torch.cat(features), torch.cat(labels) 193 | 194 | torch.save(features, cfg['cache_dir'] + "/" + split + "_clip_f.pt") 195 | torch.save(labels, cfg['cache_dir'] + "/" + split + "_clip_l.pt") 196 | 197 | else: 198 | features = torch.load(cfg['cache_dir'] + "/" + split + "_clip_f.pt") 199 | labels = torch.load(cfg['cache_dir'] + "/" + split + "_clip_l.pt") 200 | 201 | return features, labels 202 | 203 | 204 | def pre_DINO_load_features(cfg, split, dino_model, loader): 205 | 206 | if cfg['load_pre_feat'] == False: 207 | features, labels = [], [] 208 | 209 | with torch.no_grad(): 210 | for i, (images, target) in enumerate(tqdm(loader)): 211 | images, target = images.cuda(), target.cuda() 212 | image_features = dino_model(images) 213 | image_features /= image_features.norm(dim=-1, keepdim=True) 214 | features.append(image_features) 215 | labels.append(target) 216 | 217 | features, labels = torch.cat(features), torch.cat(labels) 218 | 219 | torch.save(features, cfg['cache_dir'] + "/" + split + "_dino_f.pt") 220 | torch.save(labels, cfg['cache_dir'] + "/" + split + "_dino_l.pt") 221 | 222 | else: 223 | features = torch.load(cfg['cache_dir'] + "/" + split + "_dino_f.pt") 224 | labels = torch.load(cfg['cache_dir'] + "/" + split + "_dino_l.pt") 225 | 226 | return features, labels 227 | 228 | 229 | def search_hp(cfg, cache_keys, cache_values, features, labels, clip_weights, adapter=None): 230 | 231 | if cfg['search_hp'] == True: 232 | 233 | beta_list = [i * (cfg['search_scale'][0] - 0.1) / cfg['search_step'][0] + 0.1 for i in range(cfg['search_step'][0])] 234 | alpha_list = [i * (cfg['search_scale'][1] - 0.1) / cfg['search_step'][1] + 0.1 for i in range(cfg['search_step'][1])] 235 | 236 | best_acc = 0 237 | best_beta, best_alpha = 0, 0 238 | 239 | for beta in beta_list: 240 | for alpha in alpha_list: 241 | if adapter: 242 | affinity = adapter(features) 243 | else: 244 | affinity = features @ cache_keys 245 | 246 | cache_logits = ((-1) * (beta - beta * affinity)).exp() @ cache_values 247 | clip_logits = 100. * features @ clip_weights 248 | tip_logits = clip_logits + cache_logits * alpha 249 | acc = cls_acc(tip_logits, labels) 250 | 251 | if acc > best_acc: 252 | print("New best setting, beta: {:.2f}, alpha: {:.2f}; accuracy: {:.2f}".format(beta, alpha, acc)) 253 | best_acc = acc 254 | best_beta = beta 255 | best_alpha = alpha 256 | 257 | print("\nAfter searching, the best accuarcy: {:.2f}.\n".format(best_acc)) 258 | 259 | return best_beta, best_alpha 260 | 261 | def search_no_clip_hp(cfg, cache_keys, cache_values, features, labels, adapter=None): 262 | 263 | if cfg['search_hp'] == True: 264 | 265 | beta_list = [i * (cfg['search_scale'][0] - 0.1) / cfg['search_step'][0] + 0.1 for i in range(cfg['search_step'][0])] 266 | alpha_list = [i * (cfg['search_scale'][1] - 0.1) / cfg['search_step'][1] + 0.1 for i in range(cfg['search_step'][1])] 267 | 268 | best_acc = 0 269 | best_beta, best_alpha = 0, 0 270 | 271 | for beta in beta_list: 272 | for alpha in alpha_list: 273 | if adapter: 274 | affinity = adapter(features).to(torch.float16) 275 | else: 276 | affinity = features @ cache_keys 277 | 278 | cache_logits = ((-1) * (beta - beta * affinity)).exp() @ cache_values 279 | # clip_logits = 100. * features @ clip_weights 280 | # tip_logits = clip_logits + cache_logits * alpha 281 | tip_logits = cache_logits 282 | acc = cls_acc(tip_logits, labels) 283 | 284 | if acc > best_acc: 285 | print("New best setting, beta: {:.2f}, alpha: {:.2f}; accuracy: {:.2f}".format(beta, alpha, acc)) 286 | best_acc = acc 287 | best_beta = beta 288 | best_alpha = alpha 289 | 290 | print("\nAfter searching, the best accuarcy: {:.2f}.\n".format(best_acc)) 291 | 292 | return best_beta, best_alpha 293 | 294 | 295 | def search_ensemble_hp(cfg, 296 | clip_cache_keys, 297 | clip_cache_values, 298 | clip_features, 299 | dino_cache_keys, 300 | dino_cache_values, 301 | dino_features, 302 | labels, 303 | clip_weights, 304 | clip_adapter=None, 305 | dino_adapter=None): 306 | 307 | if cfg['search_hp'] == True: 308 | 309 | beta_list = [i * (cfg['search_scale'][0] - 0.1) / cfg['search_step'][0] + 0.1 for i in range(cfg['search_step'][0])] 310 | alpha_list = [i * (cfg['search_scale'][1] - 0.1) / cfg['search_step'][1] + 0.1 for i in range(cfg['search_step'][1])] 311 | 312 | best_acc = 0 313 | best_beta, best_alpha = 0, 0 314 | 315 | for beta in beta_list: 316 | for alpha in alpha_list: 317 | if clip_adapter: 318 | clip_affinity = clip_adapter(clip_features) 319 | dino_affinity = dino_adapter(dino_features).to(dino_cache_values) 320 | else: 321 | clip_affinity = clip_features @ clip_cache_keys 322 | dino_affinity = (dino_features @ dino_cache_keys).to(dino_cache_values) 323 | 324 | clip_cache_logits = ((-1) * (beta - beta * clip_affinity)).exp() @ clip_cache_values 325 | dino_cache_logits = ((-1) * (beta - beta * dino_affinity)).exp() @ dino_cache_values 326 | clip_logits = 100. * clip_features @ clip_weights 327 | cache_logits = logits_fuse(clip_logits, [clip_cache_logits, dino_cache_logits]) 328 | tip_logits = clip_logits + cache_logits * alpha 329 | acc = cls_acc(tip_logits, labels) 330 | 331 | if acc > best_acc: 332 | print("New best setting, beta: {:.2f}, alpha: {:.2f}; accuracy: {:.2f}".format(beta, alpha, acc)) 333 | best_acc = acc 334 | best_beta = beta 335 | best_alpha = alpha 336 | 337 | print("\nAfter searching, the best accuarcy: {:.2f}.\n".format(best_acc)) 338 | with open("best.txt","w") as f: 339 | f.write("After searching, the best accuarcy: {:.2f}.\n".format(best_acc)) 340 | return best_beta, best_alpha 341 | 342 | 343 | # clip zero_shot as baseline 344 | def logits_fuse(zero_logtis, logits, normalize='mean'): 345 | # normalize logits 346 | softmax_fun = nn.Softmax(dim=1) 347 | if normalize == 'softmax': 348 | zero_logtis = softmax_fun(zero_logtis) 349 | elif normalize =='linear': 350 | zero_logtis /= torch.norm(zero_logtis, p=2, dim=1, keepdim=True) 351 | elif normalize == 'mean': 352 | logits_std = torch.std(zero_logtis, dim=1, keepdim=True) 353 | logits_mean = torch.mean(zero_logtis, dim=1, keepdim=True) 354 | zero_logtis = (zero_logtis - logits_mean) / logits_std 355 | else: 356 | raise("error normalize!") 357 | similarity_matrix = [] 358 | normalize_logits = [] 359 | for logit in logits: 360 | if normalize == 'softmax': 361 | current_normalize_logits = softmax_fun(logit) 362 | elif normalize =='linear': 363 | current_normalize_logits = logit / torch.norm(logit, p=2, dim=1, keepdim=True) 364 | elif normalize == 'mean': 365 | logits_std = torch.std(logit, dim=1, keepdim=True) 366 | logits_mean = torch.mean(logit, dim=1, keepdim=True) 367 | current_normalize_logits = (logit - logits_mean) / logits_std 368 | else: 369 | raise("error normalize!") 370 | current_similarity = current_normalize_logits * zero_logtis 371 | current_similarity = torch.sum(current_similarity, dim=1, keepdim=True) 372 | similarity_matrix.append(current_similarity) 373 | normalize_logits.append(current_normalize_logits) 374 | similarity_matrix = torch.stack(similarity_matrix, dim=-2) 375 | similarity_matrix = softmax_fun(similarity_matrix) 376 | normalize_logits = torch.stack(normalize_logits, dim=-2) 377 | result_logits = torch.sum(normalize_logits * similarity_matrix, dim=1) 378 | 379 | return result_logits 380 | def logits_fuse_s(zero_logtis, logits, normalize='mean'): 381 | # normalize logits 382 | softmax_fun = nn.Softmax(dim=1) 383 | if normalize == 'softmax': 384 | zero_logtis = softmax_fun(zero_logtis) 385 | elif normalize =='linear': 386 | zero_logtis /= torch.norm(zero_logtis, p=2, dim=1, keepdim=True) 387 | elif normalize == 'mean': 388 | logits_std = torch.std(zero_logtis, dim=1, keepdim=True) 389 | logits_mean = torch.mean(zero_logtis, dim=1, keepdim=True) 390 | zero_logtis = (zero_logtis - logits_mean) / logits_std 391 | else: 392 | raise("error normalize!") 393 | similarity_matrix = [] 394 | normalize_logits = [] 395 | for logit in logits: 396 | if normalize == 'softmax': 397 | current_normalize_logits = softmax_fun(logit) 398 | elif normalize =='linear': 399 | current_normalize_logits = logit / torch.norm(logit, p=2, dim=1, keepdim=True) 400 | elif normalize == 'mean': 401 | logits_std = torch.std(logit, dim=1, keepdim=True) 402 | logits_mean = torch.mean(logit, dim=1, keepdim=True) 403 | current_normalize_logits = (logit - logits_mean) / logits_std 404 | else: 405 | raise("error normalize!") 406 | current_similarity = current_normalize_logits * zero_logtis 407 | current_similarity = torch.sum(current_similarity, dim=1, keepdim=True) 408 | similarity_matrix.append(current_similarity) 409 | normalize_logits.append(current_normalize_logits) 410 | similarity_matrix = torch.stack(similarity_matrix, dim=-2) 411 | similarity_matrix = softmax_fun(similarity_matrix) 412 | count = 0 413 | for i in similarity_matrix: 414 | if i[0]>0.4 and i[0]<0.6: 415 | count += 1 416 | normalize_logits = torch.stack(normalize_logits, dim=-2) 417 | result_logits = torch.sum(normalize_logits * similarity_matrix, dim=1) 418 | 419 | return result_logits, count 420 | --------------------------------------------------------------------------------