├── CaFo.png
├── CaFo_arXiv.pdf
├── LICENSE
├── README.md
├── clip
├── __init__.py
├── bpe_simple_vocab_16e6.txt.gz
├── clip.py
├── model.py
└── simple_tokenizer.py
├── configs
├── caltech101
│ ├── 16shot.yaml
│ ├── 1shot.yaml
│ ├── 2shot.yaml
│ ├── 4shot.yaml
│ └── 8shot.yaml
├── cars
│ ├── 16shot.yaml
│ ├── 1shot.yaml
│ ├── 2shot.yaml
│ ├── 4shot.yaml
│ └── 8shot.yaml
├── chat_caltech101
│ ├── 16shot.yaml
│ ├── 1shot.yaml
│ ├── 2shot.yaml
│ ├── 4shot.yaml
│ └── 8shot.yaml
├── dtd
│ ├── 16shot.yaml
│ ├── 1shot.yaml
│ ├── 2shot.yaml
│ ├── 4shot.yaml
│ └── 8shot.yaml
├── eurosat
│ ├── 16shot.yaml
│ ├── 1shot.yaml
│ ├── 2shot.yaml
│ ├── 4shot.yaml
│ └── 8shot.yaml
├── fgvc
│ ├── 16shot.yaml
│ ├── 1shot.yaml
│ ├── 2shot.yaml
│ ├── 4shot.yaml
│ └── 8shot.yaml
├── food101
│ ├── 16shot.yaml
│ ├── 1shot.yaml
│ ├── 2shot.yaml
│ ├── 4shot.yaml
│ └── 8shot.yaml
├── imagenet
│ ├── 16shot.yaml
│ ├── 1shot.yaml
│ ├── 2shot.yaml
│ ├── 4shot.yaml
│ └── 8shot.yaml
├── oxford_flowers
│ ├── 16shot.yaml
│ ├── 1shot.yaml
│ ├── 2shot.yaml
│ ├── 4shot.yaml
│ └── 8shot.yaml
├── pets
│ ├── 16shot.yaml
│ ├── 1shot.yaml
│ ├── 2shot.yaml
│ ├── 4shot.yaml
│ └── 8shot.yaml
├── sd_caltech101
│ ├── 16shot.yaml
│ ├── 1shot.yaml
│ ├── 2shot.yaml
│ ├── 4shot.yaml
│ └── 8shot.yaml
├── sun
│ ├── 16shot.yaml
│ ├── 1shot.yaml
│ ├── 2shot.yaml
│ ├── 4shot.yaml
│ └── 8shot.yaml
└── ucf
│ ├── 16shot.yaml
│ ├── 1shot.yaml
│ ├── 2shot.yaml
│ ├── 4shot.yaml
│ └── 8shot.yaml
├── datasets
├── __init__.py
├── caltech101.py
├── dalle_caltech.py
├── dalle_cars.py
├── dalle_dtd.py
├── dalle_eurosat.py
├── dalle_fgvc.py
├── dalle_flowers.py
├── dalle_food.py
├── dalle_imagenet.py
├── dalle_pets.py
├── dalle_sun.py
├── dalle_ucf.py
├── dtd.py
├── eurosat.py
├── fgvc.py
├── food101.py
├── imagenet.py
├── oxford_flowers.py
├── oxford_pets.py
├── sd_caltech.py
├── stanford_cars.py
├── sun397.py
├── ucf101.py
└── utils.py
├── dino
├── __pycache__
│ └── utils.cpython-36.pyc
└── utils.py
├── exp.log
├── gpt_file
├── caltech_prompt.json
├── caltech_prompt_chat.json
├── dtd_prompt.json
├── eurosat_prompt.json
├── fgvc_prompt.json
├── food101_prompt.json
├── imagenet_prompt.json
├── oxford_flowers_prompt.json
├── oxford_pets_prompt.json
├── stanford_cars_prompt.json
├── sun397_prompt.json
└── ucf101_prompt.json
├── main.py
├── main_imagenet.py
├── requirements.txt
└── utils.py
/CaFo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/CaFo/a805a2aefc6757fdbe10ac9a3165520ceb0e01cb/CaFo.png
--------------------------------------------------------------------------------
/CaFo_arXiv.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/CaFo/a805a2aefc6757fdbe10ac9a3165520ceb0e01cb/CaFo_arXiv.pdf
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Renrui Zhang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Prompt, Generate, then Cache
2 |
3 | Official implementation of ['Prompt, Generate, then Cache: Cascade of Foundation Models makes Strong Few-shot Learners'](https://arxiv.org/pdf/2303.02151.pdf).
4 |
5 | The paper has been accepted by **CVPR 2023** 🔥.
6 |
7 | ## News
8 | * Please check our latest work ['Point-NN, Parameter is Not All You Need'](https://arxiv.org/pdf/2303.08134.pdf) with [code](https://github.com/ZrrSkywalker/Point-NN), accepted by **CVPR 2023** 🔥, which conducts 3D understanding without ant parameters or training.
9 | * CaFo cascaded with [ChatGPT](https://openai.com/blog/chatgpt) and [Stable Diffusion](https://github.com/CompVis/stable-diffusion) on Caltech-101 dataset has been released 📌.
10 | * The code of CaFo has been released.
11 | * The CaFo model is developed based on [Tip-Adapter](https://arxiv.org/pdf/2207.09519), accepted by **ECCV 2022** and [open-sourced](https://github.com/gaopengcuhk/Tip-Adapter).
12 |
13 | ## Introduction
14 | We propose **CaFo**, a **Ca**scade of **Fo**undation models that incorporates diverse prior knowledge of various pre-trianing paradigms for better few-shot learning, including CLIP, DINO, DALL-E, and GPT-3. Specifically, CaFo works by **`Prompt, Generate, then Cache'**. We leverage GPT-3 to prompt CLIP with rich linguistic semantics and generate synthetic images via DALL-E to expand the few-shot training data. Then, we introduce a learnable cache model to adaptively blend the predictions from CLIP and DINO. By such collaboration, CaFo can fully unleash the potential of different pre-training methods and unify them to perform *state-of-the-art* for few-shot classification.
15 |
16 |
17 |

18 |
19 |
20 | ## Requirements
21 |
22 | ### Installation
23 | Create a conda environment and install dependencies:
24 | ```bash
25 | git clone https://github.com/ZrrSkywalker/CaFo.git
26 | cd CaFo
27 |
28 | conda create -n cafo python=3.7
29 | conda activate cafo
30 |
31 | pip install -r requirements.txt
32 |
33 | # Install the according versions of torch and torchvision
34 | conda install pytorch torchvision cudatoolkit
35 | ```
36 |
37 | ### Dataset
38 | Please follow [DATASET.md](https://github.com/gaopengcuhk/Tip-Adapter/blob/main/DATASET.md) to download official ImageNet and other 10 datasets.
39 |
40 | ### Foundation Models
41 | * The pre-tained weights of **CLIP** will be automatically downloaded by running.
42 | * The prompts produced by **GPT-3** have been stored at `gpt_file/`.
43 | * Please download **DINO's** pre-trained ResNet-50 from [here](https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth), and put it under `dino/`.
44 | * Please download **DALL-E's** generated images from [here](https://drive.google.com/drive/folders/1e249OgUFCmpfEDPsxCVR-nNb6Q1VaZVW?usp=sharing), and organize them with the official datasets like
45 | ```
46 | $DATA/
47 | |–– imagenet/
48 | |–– caltech-101/
49 | |–– oxford_pets/
50 | |–– ...
51 | |–– dalle_imagenet/
52 | |–– dalle_caltech-101/
53 | |–– dalle_oxford_pets/
54 | |–– ...
55 | |–– sd_caltech-101/
56 | ```
57 | * For Caltech-101 dataset, we also provide **Stable Diffusion's** images from [here](https://drive.google.com/drive/folders/1e249OgUFCmpfEDPsxCVR-nNb6Q1VaZVW?usp=sharing), and **ChatGPT's** prompts in `gpt_file/`.
58 |
59 | ## Get Started
60 | ### Configs
61 | The running configurations for different `[dataset]` with `[k]` shots can be modified in `configs/[dataset]/[k]shot.yaml`, including visual encoders and hyperparamters. We have provided the configurations for reproducing the results in the paper. You can edit the `search_scale`, `search_step`, `init_beta` and `init_alpha` for fine-grained tuning and better results.
62 |
63 | Note that the default `load_cache` and `load_pre_feat` are `False` for the first running, which will store the cache model and val/test features in `configs/dataset/`. For later running, they can be set as `True` for faster hyperparamters tuning.
64 |
65 | For Caltech101 dataset, the config of Stable Diffusion's images and ChatGPT's prompts is respectively in `configs/sd_caltech101` and `configs/chat_caltech101`.
66 |
67 | ### Running
68 | For 16-shot ImageNet dataset:
69 | ```bash
70 | CUDA_VISIBLE_DEVICES=0 python main_imagenet.py --config configs/imagenet/16shot.yaml
71 | ```
72 | For other 10 datasets:
73 | ```bash
74 | CUDA_VISIBLE_DEVICES=0 python main.py --config configs/dataset/16shot.yaml
75 | ```
76 |
77 | ### Numerical Results
78 |
79 | We provide CaFo's numerical results on 11 datasets from 1 to 16 shots at [exp_Cafo.log](https://github.com/ZrrSkywalker/CaFo/blob/main/exp.log).
80 | The results for Tip-Adapter and Tip-Adapter-F is at [exp_Tip.log](https://github.com/gaopengcuhk/Tip-Adapter/blob/main/exp.log).
81 |
82 |
83 | ## Acknowledgement
84 | This repo benefits from [Tip-Adapter](https://github.com/gaopengcuhk/Tip-Adapter), [CLIP](https://github.com/openai/CLIP), [DINO](https://github.com/facebookresearch/dino), [DALL-E](https://github.com/borisdayma/dalle-mini) and [CuPL](https://github.com/sarahpratt/CuPL). Thanks for their wonderful works.
85 |
86 |
87 | ## Citation
88 | ```bash
89 | @article{zhang2023prompt,
90 | title={Prompt, Generate, then Cache: Cascade of Foundation Models makes Strong Few-shot Learners},
91 | author={Renrui Zhang and Xiangfei Hu and Bohao Li and Siyuan Huang and Hanqiu Deng and Hongsheng Li and Yu Qiao and Peng Gao},
92 | journal={arXiv preprint arXiv:2303.02151},
93 | year={2023}
94 | }
95 | ```
96 |
97 | ## Contributors
98 | [Renrui Zhang](https://github.com/ZrrSkywalker), [Xiangfei Hu](https://github.com/hxf42), [Bohao Li](https://github.com/Bohao-Lee)
99 |
100 | ## Contact
101 | If you have any question about this project, please feel free to contact zhangrenrui@pjlab.org.cn and sjtuhxf@sjtu.edu.cn.
102 |
--------------------------------------------------------------------------------
/clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import *
2 |
--------------------------------------------------------------------------------
/clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/CaFo/a805a2aefc6757fdbe10ac9a3165520ceb0e01cb/clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/clip/clip.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 | from typing import Any, Union, List
6 |
7 | import torch
8 | from PIL import Image
9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
10 | from tqdm import tqdm
11 |
12 | from .model import build_model
13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
14 |
15 | try:
16 | from torchvision.transforms import InterpolationMode
17 | BICUBIC = InterpolationMode.BICUBIC
18 | except ImportError:
19 | BICUBIC = Image.BICUBIC
20 |
21 |
22 | if torch.__version__.split(".") < ["1", "7", "1"]:
23 | warnings.warn("PyTorch version 1.7.1 or higher is recommended")
24 |
25 |
26 | __all__ = ["available_models", "load", "tokenize"]
27 | _tokenizer = _Tokenizer()
28 |
29 | _MODELS = {
30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
34 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
35 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
36 | }
37 |
38 |
39 | def _download(url: str, root: str):
40 | os.makedirs(root, exist_ok=True)
41 | filename = os.path.basename(url)
42 |
43 | expected_sha256 = url.split("/")[-2]
44 | download_target = os.path.join(root, filename)
45 |
46 | if os.path.exists(download_target) and not os.path.isfile(download_target):
47 | raise RuntimeError(f"{download_target} exists and is not a regular file")
48 |
49 | if os.path.isfile(download_target):
50 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
51 | return download_target
52 | else:
53 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
54 |
55 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
56 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
57 | while True:
58 | buffer = source.read(8192)
59 | if not buffer:
60 | break
61 |
62 | output.write(buffer)
63 | loop.update(len(buffer))
64 |
65 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
66 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
67 |
68 | return download_target
69 |
70 |
71 | def _convert_image_to_rgb(image):
72 | return image.convert("RGB")
73 |
74 |
75 | def _transform(n_px):
76 | return Compose([
77 | Resize(n_px, interpolation=BICUBIC),
78 | CenterCrop(n_px),
79 | _convert_image_to_rgb,
80 | ToTensor(),
81 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
82 | ])
83 |
84 |
85 | def available_models() -> List[str]:
86 | """Returns the names of available CLIP models"""
87 | return list(_MODELS.keys())
88 |
89 |
90 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
91 | """Load a CLIP model
92 |
93 | Parameters
94 | ----------
95 | name : str
96 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
97 |
98 | device : Union[str, torch.device]
99 | The device to put the loaded model
100 |
101 | jit : bool
102 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
103 |
104 | download_root: str
105 | path to download the model files; by default, it uses "~/.cache/clip"
106 |
107 | Returns
108 | -------
109 | model : torch.nn.Module
110 | The CLIP model
111 |
112 | preprocess : Callable[[PIL.Image], torch.Tensor]
113 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
114 | """
115 | if name in _MODELS:
116 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
117 | elif os.path.isfile(name):
118 | model_path = name
119 | else:
120 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
121 |
122 | try:
123 | # loading JIT archive
124 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
125 | state_dict = None
126 | except RuntimeError:
127 | # loading saved state dict
128 | if jit:
129 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
130 | jit = False
131 | state_dict = torch.load(model_path, map_location="cpu")
132 |
133 | if not jit:
134 | model = build_model(state_dict or model.state_dict()).to(device)
135 | if str(device) == "cpu":
136 | model.float()
137 | return model, _transform(model.visual.input_resolution)
138 |
139 | # patch the device names
140 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
141 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
142 |
143 | def patch_device(module):
144 | try:
145 | graphs = [module.graph] if hasattr(module, "graph") else []
146 | except RuntimeError:
147 | graphs = []
148 |
149 | if hasattr(module, "forward1"):
150 | graphs.append(module.forward1.graph)
151 |
152 | for graph in graphs:
153 | for node in graph.findAllNodes("prim::Constant"):
154 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
155 | node.copyAttributes(device_node)
156 |
157 | model.apply(patch_device)
158 | patch_device(model.encode_image)
159 | patch_device(model.encode_text)
160 |
161 | # patch dtype to float32 on CPU
162 | if str(device) == "cpu":
163 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
164 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
165 | float_node = float_input.node()
166 |
167 | def patch_float(module):
168 | try:
169 | graphs = [module.graph] if hasattr(module, "graph") else []
170 | except RuntimeError:
171 | graphs = []
172 |
173 | if hasattr(module, "forward1"):
174 | graphs.append(module.forward1.graph)
175 |
176 | for graph in graphs:
177 | for node in graph.findAllNodes("aten::to"):
178 | inputs = list(node.inputs())
179 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
180 | if inputs[i].node()["value"] == 5:
181 | inputs[i].node().copyAttributes(float_node)
182 |
183 | model.apply(patch_float)
184 | patch_float(model.encode_image)
185 | patch_float(model.encode_text)
186 |
187 | model.float()
188 |
189 | return model, _transform(model.input_resolution.item())
190 |
191 |
192 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
193 | """
194 | Returns the tokenized representation of given input string(s)
195 |
196 | Parameters
197 | ----------
198 | texts : Union[str, List[str]]
199 | An input string or a list of input strings to tokenize
200 |
201 | context_length : int
202 | The context length to use; all CLIP models use 77 as the context length
203 |
204 | truncate: bool
205 | Whether to truncate the text in case its encoding is longer than the context length
206 |
207 | Returns
208 | -------
209 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
210 | """
211 | if isinstance(texts, str):
212 | texts = [texts]
213 |
214 | sot_token = _tokenizer.encoder["<|startoftext|>"]
215 | eot_token = _tokenizer.encoder["<|endoftext|>"]
216 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
217 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
218 |
219 | for i, tokens in enumerate(all_tokens):
220 | if len(tokens) > context_length:
221 | if truncate:
222 | tokens = tokens[:context_length]
223 | tokens[-1] = eot_token
224 | else:
225 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
226 | result[i, :len(tokens)] = torch.tensor(tokens)
227 |
228 | return result
229 |
--------------------------------------------------------------------------------
/clip/model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 |
9 |
10 | class Bottleneck(nn.Module):
11 | expansion = 4
12 |
13 | def __init__(self, inplanes, planes, stride=1):
14 | super().__init__()
15 |
16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 |
20 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
21 | self.bn2 = nn.BatchNorm2d(planes)
22 |
23 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
24 |
25 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
26 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
27 |
28 | self.relu = nn.ReLU(inplace=True)
29 | self.downsample = None
30 | self.stride = stride
31 |
32 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
33 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
34 | self.downsample = nn.Sequential(OrderedDict([
35 | ("-1", nn.AvgPool2d(stride)),
36 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
37 | ("1", nn.BatchNorm2d(planes * self.expansion))
38 | ]))
39 |
40 | def forward(self, x: torch.Tensor):
41 | identity = x
42 |
43 | out = self.relu(self.bn1(self.conv1(x)))
44 | out = self.relu(self.bn2(self.conv2(out)))
45 | out = self.avgpool(out)
46 | out = self.bn3(self.conv3(out))
47 |
48 | if self.downsample is not None:
49 | identity = self.downsample(x)
50 |
51 | out += identity
52 | out = self.relu(out)
53 | return out
54 |
55 |
56 | class AttentionPool2d(nn.Module):
57 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
58 | super().__init__()
59 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
60 | self.k_proj = nn.Linear(embed_dim, embed_dim)
61 | self.q_proj = nn.Linear(embed_dim, embed_dim)
62 | self.v_proj = nn.Linear(embed_dim, embed_dim)
63 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
64 | self.num_heads = num_heads
65 |
66 | def forward(self, x):
67 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
68 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
69 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
70 | x, _ = F.multi_head_attention_forward(
71 | query=x, key=x, value=x,
72 | embed_dim_to_check=x.shape[-1],
73 | num_heads=self.num_heads,
74 | q_proj_weight=self.q_proj.weight,
75 | k_proj_weight=self.k_proj.weight,
76 | v_proj_weight=self.v_proj.weight,
77 | in_proj_weight=None,
78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
79 | bias_k=None,
80 | bias_v=None,
81 | add_zero_attn=False,
82 | dropout_p=0,
83 | out_proj_weight=self.c_proj.weight,
84 | out_proj_bias=self.c_proj.bias,
85 | use_separate_proj_weight=True,
86 | training=self.training,
87 | need_weights=False
88 | )
89 |
90 | return x[0]
91 |
92 |
93 | class ModifiedResNet(nn.Module):
94 | """
95 | A ResNet class that is similar to torchvision's but contains the following changes:
96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
98 | - The final pooling layer is a QKV attention instead of an average pool
99 | """
100 |
101 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
102 | super().__init__()
103 | self.output_dim = output_dim
104 | self.input_resolution = input_resolution
105 |
106 | # the 3-layer stem
107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
108 | self.bn1 = nn.BatchNorm2d(width // 2)
109 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
110 | self.bn2 = nn.BatchNorm2d(width // 2)
111 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
112 | self.bn3 = nn.BatchNorm2d(width)
113 | self.avgpool = nn.AvgPool2d(2)
114 | self.relu = nn.ReLU(inplace=True)
115 |
116 | # residual layers
117 | self._inplanes = width # this is a *mutable* variable used during construction
118 | self.layer1 = self._make_layer(width, layers[0])
119 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
120 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
121 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
122 |
123 | embed_dim = width * 32 # the ResNet feature dimension
124 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
125 |
126 | def _make_layer(self, planes, blocks, stride=1):
127 | layers = [Bottleneck(self._inplanes, planes, stride)]
128 |
129 | self._inplanes = planes * Bottleneck.expansion
130 | for _ in range(1, blocks):
131 | layers.append(Bottleneck(self._inplanes, planes))
132 |
133 | return nn.Sequential(*layers)
134 |
135 | def forward(self, x):
136 | def stem(x):
137 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
138 | x = self.relu(bn(conv(x)))
139 | x = self.avgpool(x)
140 | return x
141 |
142 | x = x.type(self.conv1.weight.dtype)
143 | x = stem(x)
144 | x = self.layer1(x)
145 | x = self.layer2(x)
146 | x = self.layer3(x)
147 | x = self.layer4(x)
148 | x = self.attnpool(x)
149 |
150 | return x
151 |
152 |
153 | class LayerNorm(nn.LayerNorm):
154 | """Subclass torch's LayerNorm to handle fp16."""
155 |
156 | def forward(self, x: torch.Tensor):
157 | orig_type = x.dtype
158 | ret = super().forward(x.type(torch.float32))
159 | return ret.type(orig_type)
160 |
161 |
162 | class QuickGELU(nn.Module):
163 | def forward(self, x: torch.Tensor):
164 | return x * torch.sigmoid(1.702 * x)
165 |
166 |
167 | class ResidualAttentionBlock(nn.Module):
168 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
169 | super().__init__()
170 |
171 | self.attn = nn.MultiheadAttention(d_model, n_head)
172 | self.ln_1 = LayerNorm(d_model)
173 | self.mlp = nn.Sequential(OrderedDict([
174 | ("c_fc", nn.Linear(d_model, d_model * 4)),
175 | ("gelu", QuickGELU()),
176 | ("c_proj", nn.Linear(d_model * 4, d_model))
177 | ]))
178 | self.ln_2 = LayerNorm(d_model)
179 | self.attn_mask = attn_mask
180 |
181 | def attention(self, x: torch.Tensor):
182 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
183 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
184 |
185 | def forward(self, x: torch.Tensor):
186 | x = x + self.attention(self.ln_1(x))
187 | x = x + self.mlp(self.ln_2(x))
188 | return x
189 |
190 |
191 | class Transformer(nn.Module):
192 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
193 | super().__init__()
194 | self.width = width
195 | self.layers = layers
196 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
197 |
198 | def forward(self, x: torch.Tensor):
199 | return self.resblocks(x)
200 |
201 |
202 | class VisionTransformer(nn.Module):
203 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
204 | super().__init__()
205 | self.input_resolution = input_resolution
206 | self.output_dim = output_dim
207 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
208 |
209 | scale = width ** -0.5
210 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
211 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
212 | self.ln_pre = LayerNorm(width)
213 |
214 | self.transformer = Transformer(width, layers, heads)
215 |
216 | self.ln_post = LayerNorm(width)
217 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
218 |
219 | def forward(self, x: torch.Tensor):
220 | x = self.conv1(x) # shape = [*, width, grid, grid]
221 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
222 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
223 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
224 | x = x + self.positional_embedding.to(x.dtype)
225 | x = self.ln_pre(x)
226 |
227 | x = x.permute(1, 0, 2) # NLD -> LND
228 | x = self.transformer(x)
229 | x = x.permute(1, 0, 2) # LND -> NLD
230 |
231 | x = self.ln_post(x[:, 0, :])
232 |
233 | if self.proj is not None:
234 | x = x @ self.proj
235 |
236 | return x
237 |
238 |
239 | class CLIP(nn.Module):
240 | def __init__(self,
241 | embed_dim: int,
242 | # vision
243 | image_resolution: int,
244 | vision_layers: Union[Tuple[int, int, int, int], int],
245 | vision_width: int,
246 | vision_patch_size: int,
247 | # text
248 | context_length: int,
249 | vocab_size: int,
250 | transformer_width: int,
251 | transformer_heads: int,
252 | transformer_layers: int
253 | ):
254 | super().__init__()
255 |
256 | self.context_length = context_length
257 |
258 | if isinstance(vision_layers, (tuple, list)):
259 | vision_heads = vision_width * 32 // 64
260 | self.visual = ModifiedResNet(
261 | layers=vision_layers,
262 | output_dim=embed_dim,
263 | heads=vision_heads,
264 | input_resolution=image_resolution,
265 | width=vision_width
266 | )
267 | else:
268 | vision_heads = vision_width // 64
269 | self.visual = VisionTransformer(
270 | input_resolution=image_resolution,
271 | patch_size=vision_patch_size,
272 | width=vision_width,
273 | layers=vision_layers,
274 | heads=vision_heads,
275 | output_dim=embed_dim
276 | )
277 |
278 | self.transformer = Transformer(
279 | width=transformer_width,
280 | layers=transformer_layers,
281 | heads=transformer_heads,
282 | attn_mask=self.build_attention_mask()
283 | )
284 |
285 | self.vocab_size = vocab_size
286 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
287 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
288 | self.ln_final = LayerNorm(transformer_width)
289 |
290 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
291 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
292 |
293 | self.initialize_parameters()
294 |
295 | def initialize_parameters(self):
296 | nn.init.normal_(self.token_embedding.weight, std=0.02)
297 | nn.init.normal_(self.positional_embedding, std=0.01)
298 |
299 | if isinstance(self.visual, ModifiedResNet):
300 | if self.visual.attnpool is not None:
301 | std = self.visual.attnpool.c_proj.in_features ** -0.5
302 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
303 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
304 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
305 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
306 |
307 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
308 | for name, param in resnet_block.named_parameters():
309 | if name.endswith("bn3.weight"):
310 | nn.init.zeros_(param)
311 |
312 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
313 | attn_std = self.transformer.width ** -0.5
314 | fc_std = (2 * self.transformer.width) ** -0.5
315 | for block in self.transformer.resblocks:
316 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
317 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
318 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
319 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
320 |
321 | if self.text_projection is not None:
322 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
323 |
324 | def build_attention_mask(self):
325 | # lazily create causal attention mask, with full attention between the vision tokens
326 | # pytorch uses additive attention mask; fill with -inf
327 | mask = torch.empty(self.context_length, self.context_length)
328 | mask.fill_(float("-inf"))
329 | mask.triu_(1) # zero out the lower diagonal
330 | return mask
331 |
332 | @property
333 | def dtype(self):
334 | return self.visual.conv1.weight.dtype
335 |
336 | def encode_image(self, image):
337 | return self.visual(image.type(self.dtype))
338 |
339 | def encode_text(self, text):
340 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
341 |
342 | x = x + self.positional_embedding.type(self.dtype)
343 | x = x.permute(1, 0, 2) # NLD -> LND
344 | x = self.transformer(x)
345 | x = x.permute(1, 0, 2) # LND -> NLD
346 | x = self.ln_final(x).type(self.dtype)
347 |
348 | # x.shape = [batch_size, n_ctx, transformer.width]
349 | # take features from the eot embedding (eot_token is the highest number in each sequence)
350 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
351 |
352 | return x
353 |
354 | def forward(self, image, text):
355 | image_features = self.encode_image(image)
356 | text_features = self.encode_text(text)
357 |
358 | # normalized features
359 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
360 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
361 |
362 | # cosine similarity as logits
363 | logit_scale = self.logit_scale.exp()
364 | logits_per_image = logit_scale * image_features @ text_features.t()
365 | logits_per_text = logits_per_image.t()
366 |
367 | # shape = [global_batch_size, global_batch_size]
368 | return logits_per_image, logits_per_text
369 |
370 |
371 | def convert_weights(model: nn.Module):
372 | """Convert applicable model parameters to fp16"""
373 |
374 | def _convert_weights_to_fp16(l):
375 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
376 | l.weight.data = l.weight.data.half()
377 | if l.bias is not None:
378 | l.bias.data = l.bias.data.half()
379 |
380 | if isinstance(l, nn.MultiheadAttention):
381 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
382 | tensor = getattr(l, attr)
383 | if tensor is not None:
384 | tensor.data = tensor.data.half()
385 |
386 | for name in ["text_projection", "proj"]:
387 | if hasattr(l, name):
388 | attr = getattr(l, name)
389 | if attr is not None:
390 | attr.data = attr.data.half()
391 |
392 | model.apply(_convert_weights_to_fp16)
393 |
394 |
395 | def build_model(state_dict: dict):
396 | vit = "visual.proj" in state_dict
397 |
398 | if vit:
399 | vision_width = state_dict["visual.conv1.weight"].shape[0]
400 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
401 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
402 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
403 | image_resolution = vision_patch_size * grid_size
404 | else:
405 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
406 | vision_layers = tuple(counts)
407 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
408 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
409 | vision_patch_size = None
410 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
411 | image_resolution = output_width * 32
412 |
413 | embed_dim = state_dict["text_projection"].shape[1]
414 | context_length = state_dict["positional_embedding"].shape[0]
415 | vocab_size = state_dict["token_embedding.weight"].shape[0]
416 | transformer_width = state_dict["ln_final.weight"].shape[0]
417 | transformer_heads = transformer_width // 64
418 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
419 |
420 | model = CLIP(
421 | embed_dim,
422 | image_resolution, vision_layers, vision_width, vision_patch_size,
423 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
424 | )
425 |
426 | for key in ["input_resolution", "context_length", "vocab_size"]:
427 | if key in state_dict:
428 | del state_dict[key]
429 |
430 | convert_weights(model)
431 | model.load_state_dict(state_dict)
432 | return model.eval()
433 |
--------------------------------------------------------------------------------
/clip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
--------------------------------------------------------------------------------
/configs/caltech101/16shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [12, 5]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 1.3
22 |
23 | gpt3_prompt_file: './gpt_file/caltech_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'caltech101'
27 | shots: 16
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_caltech'
33 | dalle_shots: 1
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/caltech101/1shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [12, 5]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 0.5
22 |
23 | gpt3_prompt_file: './gpt_file/caltech_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'caltech101'
27 | shots: 1
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_caltech'
33 | dalle_shots: 8
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/caltech101/2shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 | # load_cache: True
9 | # load_pre_feat: True
10 |
11 |
12 | # ------ Hyperparamters ------
13 | search_hp: True
14 | # search_hp: False
15 |
16 | search_scale: [12, 5]
17 | search_step: [200, 20]
18 |
19 | init_beta: 1
20 | init_alpha: 0.8
21 |
22 | gpt3_prompt_file: './gpt_file/caltech_prompt.json'
23 |
24 | # ------ Basic Config ------
25 | dataset: 'caltech101'
26 | shots: 2
27 | clip_backbone: 'RN50'
28 | dino_backbone: 'resnet50'
29 |
30 | # ------ Dalle Dataset -----
31 | dalle_dataset: 'dalle_caltech'
32 | dalle_shots: 16
33 |
34 | lr: 0.001
35 | augment_epoch: 10
36 | train_epoch: 20
37 |
--------------------------------------------------------------------------------
/configs/caltech101/4shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [12, 5]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 1.1
22 |
23 | gpt3_prompt_file: './gpt_file/caltech_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'caltech101'
27 | shots: 4
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_caltech'
33 | dalle_shots: 2
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/caltech101/8shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [12, 5]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 1.1
22 |
23 | gpt3_prompt_file: './gpt_file/caltech_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'caltech101'
27 | shots: 8
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_caltech'
33 | dalle_shots: 4
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/cars/16shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 | # load_cache: True
9 | # load_pre_feat: True
10 |
11 |
12 | # ------ Hyperparamters ------
13 | search_hp: True
14 | # search_hp: False
15 |
16 | search_scale: [20, 10]
17 | search_step: [200, 20]
18 |
19 | init_beta: 1
20 | init_alpha: 0.6
21 |
22 | gpt3_prompt_file: './gpt_file/stanford_cars_prompt.json'
23 |
24 | # ------ Basic Config ------
25 | dataset: 'stanford_cars'
26 | shots: 16
27 | clip_backbone: 'RN50'
28 | dino_backbone: 'resnet50'
29 |
30 | # ------ Dalle Dataset -----
31 | dalle_dataset: 'dalle_cars'
32 | dalle_shots: 1
33 |
34 | lr: 0.001
35 | augment_epoch: 10
36 | train_epoch: 200
37 |
--------------------------------------------------------------------------------
/configs/cars/1shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [20, 10]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 0.4
22 |
23 | gpt3_prompt_file: './gpt_file/stanford_cars_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'stanford_cars'
27 | shots: 1
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_cars'
33 | dalle_shots: 16
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 100
38 |
--------------------------------------------------------------------------------
/configs/cars/2shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [20, 10]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 0.4
22 |
23 | gpt3_prompt_file: './gpt_file/stanford_cars_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'stanford_cars'
27 | shots: 2
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_cars'
33 | dalle_shots: 16
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 200
38 |
--------------------------------------------------------------------------------
/configs/cars/4shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [20, 10]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 0.8
22 |
23 | gpt3_prompt_file: './gpt_file/stanford_cars_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'stanford_cars'
27 | shots: 4
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_cars'
33 | dalle_shots: 16
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 400
38 |
--------------------------------------------------------------------------------
/configs/cars/8shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [20, 10]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 0.5
22 |
23 | gpt3_prompt_file: './gpt_file/stanford_cars_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'stanford_cars'
27 | shots: 8
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_cars'
33 | dalle_shots: 1
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 200
38 |
--------------------------------------------------------------------------------
/configs/chat_caltech101/16shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [12, 5]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 1.3
22 |
23 | gpt3_prompt_file: './gpt_file/caltech_prompt_chat.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'caltech101'
27 | shots: 16
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_caltech'
33 | dalle_shots: 4
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/chat_caltech101/1shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [12, 5]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 0.5
22 |
23 | gpt3_prompt_file: './gpt_file/caltech_prompt_chat.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'caltech101'
27 | shots: 1
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_caltech'
33 | dalle_shots: 8
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/chat_caltech101/2shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 | # load_cache: True
9 | # load_pre_feat: True
10 |
11 |
12 | # ------ Hyperparamters ------
13 | search_hp: True
14 | # search_hp: False
15 |
16 | search_scale: [12, 5]
17 | search_step: [200, 20]
18 |
19 | init_beta: 1
20 | init_alpha: 0.8
21 |
22 | gpt3_prompt_file: './gpt_file/caltech_prompt_chat.json'
23 |
24 | # ------ Basic Config ------
25 | dataset: 'caltech101'
26 | shots: 2
27 | clip_backbone: 'RN50'
28 | dino_backbone: 'resnet50'
29 |
30 | # ------ Dalle Dataset -----
31 | dalle_dataset: 'dalle_caltech'
32 | dalle_shots: 16
33 |
34 | lr: 0.001
35 | augment_epoch: 10
36 | train_epoch: 20
37 |
--------------------------------------------------------------------------------
/configs/chat_caltech101/4shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [12, 5]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 1.1
22 |
23 | gpt3_prompt_file: './gpt_file/caltech_prompt_chat.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'caltech101'
27 | shots: 4
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_caltech'
33 | dalle_shots: 4
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/chat_caltech101/8shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [12, 5]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 1.1
22 |
23 | gpt3_prompt_file: './gpt_file/caltech_prompt_chat.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'caltech101'
27 | shots: 8
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_caltech'
33 | dalle_shots: 1
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/dtd/16shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: True
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [13, 13]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 2
22 |
23 | gpt3_prompt_file: './gpt_file/dtd_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'dtd'
27 | shots: 16
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_dtd'
33 | dalle_shots: 1
34 |
35 |
36 | lr: 0.001
37 | augment_epoch: 10
38 | train_epoch: 20
39 |
--------------------------------------------------------------------------------
/configs/dtd/1shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [13, 13]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 2
22 |
23 | gpt3_prompt_file: './gpt_file/dtd_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'dtd'
27 | shots: 1
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_dtd'
33 | dalle_shots: 1
34 |
35 |
36 | lr: 0.001
37 | augment_epoch: 10
38 | train_epoch: 20
39 |
--------------------------------------------------------------------------------
/configs/dtd/2shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [13, 13]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 2
22 |
23 | gpt3_prompt_file: './gpt_file/dtd_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'dtd'
27 | shots: 2
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_dtd'
33 | dalle_shots: 1
34 |
35 |
36 | lr: 0.001
37 | augment_epoch: 10
38 | train_epoch: 20
39 |
--------------------------------------------------------------------------------
/configs/dtd/4shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [13, 13]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 2
22 |
23 | gpt3_prompt_file: './gpt_file/dtd_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'dtd'
27 | shots: 4
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_dtd'
33 | dalle_shots: 1
34 |
35 |
36 | lr: 0.001
37 | augment_epoch: 10
38 | train_epoch: 20
39 |
--------------------------------------------------------------------------------
/configs/dtd/8shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [13, 13]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 2
22 |
23 | gpt3_prompt_file: './gpt_file/dtd_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'dtd'
27 | shots: 8
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_dtd'
33 | dalle_shots: 1
34 |
35 |
36 | lr: 0.001
37 | augment_epoch: 10
38 | train_epoch: 20
39 |
--------------------------------------------------------------------------------
/configs/eurosat/16shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [12, 10]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 2
22 |
23 | gpt3_prompt_file: './gpt_file/eurosat_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'eurosat'
27 | shots: 16
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_eurosat'
33 | dalle_shots: 8
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 100
38 |
--------------------------------------------------------------------------------
/configs/eurosat/1shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 | # load_cache: True
9 | # load_pre_feat: True
10 |
11 |
12 | # ------ Hyperparamters ------
13 | search_hp: True
14 | # search_hp: False
15 |
16 | search_scale: [12, 10]
17 | search_step: [200, 20]
18 |
19 | init_beta: 1
20 | init_alpha: 3
21 |
22 | gpt3_prompt_file: './gpt_file/eurosat_prompt.json'
23 |
24 | # ------ Basic Config ------
25 | dataset: 'eurosat'
26 | shots: 1
27 | clip_backbone: 'RN50'
28 | dino_backbone: 'resnet50'
29 |
30 | # ------ Dalle Dataset -----
31 | dalle_dataset: 'dalle_eurosat'
32 | dalle_shots: 4
33 |
34 | lr: 0.001
35 | augment_epoch: 10
36 | train_epoch: 20
37 |
--------------------------------------------------------------------------------
/configs/eurosat/2shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [12, 10]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 1.5
22 |
23 | gpt3_prompt_file: './gpt_file/eurosat_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'eurosat'
27 | shots: 2
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_eurosat'
33 | dalle_shots: 8
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 100
38 |
--------------------------------------------------------------------------------
/configs/eurosat/4shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [12, 10]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 2
22 |
23 | gpt3_prompt_file: './gpt_file/eurosat_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'eurosat'
27 | shots: 4
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_eurosat'
33 | dalle_shots: 4
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 100
38 |
--------------------------------------------------------------------------------
/configs/eurosat/8shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [12, 10]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 2
22 |
23 | gpt3_prompt_file: './gpt_file/eurosat_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'eurosat'
27 | shots: 8
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_eurosat'
33 | dalle_shots: 8
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 100
38 |
--------------------------------------------------------------------------------
/configs/fgvc/16shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 | # load_cache: True
9 | # load_pre_feat: True
10 |
11 |
12 | # ------ Hyperparamters ------
13 | search_hp: True
14 | # search_hp: False
15 |
16 | search_scale: [30, 30]
17 | search_step: [200, 20]
18 |
19 | init_beta: 1
20 | init_alpha: 1
21 |
22 | gpt3_prompt_file: './gpt_file/fgvc_prompt.json'
23 |
24 | # ------ Basic Config ------
25 | dataset: 'fgvc'
26 | shots: 16
27 | clip_backbone: 'RN50'
28 | dino_backbone: 'resnet50'
29 |
30 | # ------ Dalle Dataset -----
31 | dalle_dataset: 'dalle_fgvc'
32 | dalle_shots: 1
33 |
34 | lr: 0.001
35 | augment_epoch: 10
36 | train_epoch: 100
37 |
--------------------------------------------------------------------------------
/configs/fgvc/1shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [30, 30]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 1
22 |
23 | gpt3_prompt_file: './gpt_file/fgvc_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'fgvc'
27 | shots: 1
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_fgvc'
33 | dalle_shots: 8
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 100
38 |
--------------------------------------------------------------------------------
/configs/fgvc/2shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 | # load_cache: True
9 | # load_pre_feat: True
10 |
11 |
12 | # ------ Hyperparamters ------
13 | search_hp: True
14 | # search_hp: False
15 |
16 | search_scale: [30, 30]
17 | search_step: [200, 20]
18 |
19 | init_beta: 1
20 | init_alpha: 0.8
21 |
22 | gpt3_prompt_file: './gpt_file/fgvc_prompt.json'
23 |
24 | # ------ Basic Config ------
25 | dataset: 'fgvc'
26 | shots: 2
27 | clip_backbone: 'RN50'
28 | dino_backbone: 'resnet50'
29 |
30 | # ------ Dalle Dataset -----
31 | dalle_dataset: 'dalle_fgvc'
32 | dalle_shots: 4
33 |
34 | lr: 0.001
35 | augment_epoch: 10
36 | train_epoch: 100
37 |
--------------------------------------------------------------------------------
/configs/fgvc/4shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [30, 30]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 0.9
22 |
23 | gpt3_prompt_file: './gpt_file/fgvc_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'fgvc'
27 | shots: 4
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_fgvc'
33 | dalle_shots: 2
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 100
38 |
--------------------------------------------------------------------------------
/configs/fgvc/8shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [30, 30]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 1
22 |
23 | gpt3_prompt_file: './gpt_file/fgvc_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'fgvc'
27 | shots: 8
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_fgvc'
33 | dalle_shots: 1
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 100
38 |
--------------------------------------------------------------------------------
/configs/food101/16shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [10, 10]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 0.22
22 |
23 | gpt3_prompt_file: './gpt_file/food101_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'food101'
27 | shots: 16
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_food'
33 | dalle_shots: 16
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 200
38 |
--------------------------------------------------------------------------------
/configs/food101/1shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [20, 10]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 0.2
22 |
23 | gpt3_prompt_file: './gpt_file/food101_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'food101'
27 | shots: 1
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_food'
33 | dalle_shots: 4
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 200
38 |
--------------------------------------------------------------------------------
/configs/food101/2shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [10, 10]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 0.2
22 |
23 | gpt3_prompt_file: './gpt_file/food101_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'food101'
27 | shots: 2
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_food'
33 | dalle_shots: 4
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 200
38 |
--------------------------------------------------------------------------------
/configs/food101/4shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [10, 10]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 0.22
22 |
23 | gpt3_prompt_file: './gpt_file/food101_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'food101'
27 | shots: 4
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_food'
33 | dalle_shots: 4
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 200
38 |
--------------------------------------------------------------------------------
/configs/food101/8shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [10, 10]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 0.22
22 |
23 | gpt3_prompt_file: './gpt_file/food101_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'food101'
27 | shots: 8
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_food'
33 | dalle_shots: 4
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 200
38 |
--------------------------------------------------------------------------------
/configs/imagenet/16shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 |
10 | # ------ Hyperparamters ------
11 | search_hp: True
12 | # search_hp: False
13 |
14 | search_scale: [7, 3]
15 | search_step: [200, 20]
16 |
17 | init_beta: 1
18 | init_alpha: 0.6
19 |
20 | gpt3_prompt_file: './gpt_file/imagenet_prompt.json'
21 |
22 | # ------ Basic Config ------
23 | dataset: 'ImageNet'
24 | shots: 16
25 | clip_backbone: 'RN50' # ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16']
26 | dino_backbone: 'resnet50'
27 |
28 | # ------ Dalle Dataset -----
29 | dalle_dataset: 'dalle_imagenet'
30 | dalle_shots: 2
31 |
32 | lr: 0.001
33 | augment_epoch: 1
34 | train_epoch: 20
35 |
--------------------------------------------------------------------------------
/configs/imagenet/1shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [7, 3]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 0.3
22 |
23 | gpt3_prompt_file: './gpt_file/imagenet_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'ImageNet'
27 | shots: 1
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_imagenet'
33 | dalle_shots: 8
34 |
35 | lr: 0.001
36 | augment_epoch: 1
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/imagenet/2shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 |
10 | # ------ Hyperparamters ------
11 | search_hp: True
12 | # search_hp: False
13 |
14 | search_scale: [7, 3]
15 | search_step: [200, 20]
16 |
17 | init_beta: 1
18 | init_alpha: 0.3
19 |
20 | gpt3_prompt_file: './gpt_file/imagenet_prompt.json'
21 |
22 |
23 | # ------ Basic Config ------
24 | dataset: 'ImageNet'
25 | shots: 2
26 | clip_backbone: 'RN50'
27 | dino_backbone: 'resnet50'
28 |
29 | # ------ Dalle Dataset -----
30 | dalle_dataset: 'dalle_imagenet'
31 | dalle_shots: 2
32 |
33 | lr: 0.001
34 | augment_epoch: 1
35 | train_epoch: 20
36 |
--------------------------------------------------------------------------------
/configs/imagenet/4shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # ------ Hyperparamters ------
10 | search_hp: True
11 | # search_hp: False
12 |
13 | search_scale: [7, 3]
14 | search_step: [200, 20]
15 |
16 | init_beta: 1
17 | init_alpha: 0.4
18 |
19 | gpt3_prompt_file: './gpt_file/imagenet_prompt.json'
20 |
21 | # ------ Basic Config ------
22 | dataset: 'ImageNet'
23 | shots: 4
24 | clip_backbone: 'RN50'
25 | dino_backbone: 'resnet50'
26 |
27 | # ------ Dalle Dataset -----
28 | dalle_dataset: 'dalle_imagenet'
29 | dalle_shots: 8
30 |
31 | lr: 0.001
32 | augment_epoch: 1
33 | train_epoch: 20
34 |
--------------------------------------------------------------------------------
/configs/imagenet/8shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 |
10 | # ------ Hyperparamters ------
11 | search_hp: True
12 | # search_hp: False
13 |
14 | search_scale: [7, 3]
15 | search_step: [200, 20]
16 |
17 | init_beta: 1
18 | init_alpha: 0.5
19 |
20 | gpt3_prompt_file: './gpt_file/imagenet_prompt.json'
21 |
22 | # ------ Basic Config ------
23 | dataset: 'ImageNet'
24 | shots: 8
25 | clip_backbone: 'RN50'
26 | dino_backbone: 'resnet50'
27 |
28 | # ------ Dalle Dataset -----
29 | dalle_dataset: 'dalle_imagenet'
30 | dalle_shots: 2
31 |
32 | lr: 0.001
33 | augment_epoch: 1
34 | train_epoch: 20
35 |
--------------------------------------------------------------------------------
/configs/oxford_flowers/16shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [50, 50]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 4
22 |
23 | gpt3_prompt_file: './gpt_file/oxford_flowers_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'oxford_flowers'
27 | shots: 16
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_flowers'
33 | dalle_shots: 1
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/oxford_flowers/1shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [50, 50]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 1.2
22 |
23 | gpt3_prompt_file: './gpt_file/oxford_flowers_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'oxford_flowers'
27 | shots: 1
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_flowers'
33 | dalle_shots: 1
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/oxford_flowers/2shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [50, 50]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 0.7
22 |
23 | gpt3_prompt_file: './gpt_file/oxford_flowers_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'oxford_flowers'
27 | shots: 2
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_flowers'
33 | dalle_shots: 1
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/oxford_flowers/4shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [50, 50]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 2.2
22 |
23 | gpt3_prompt_file: './gpt_file/oxford_flowers_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'oxford_flowers'
27 | shots: 4
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_flowers'
33 | dalle_shots: 1
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/oxford_flowers/8shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [50, 50]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 3.7
22 |
23 | gpt3_prompt_file: './gpt_file/oxford_flowers_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'oxford_flowers'
27 | shots: 8
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_flowers'
33 | dalle_shots: 1
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/pets/16shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [7, 3]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 1.5
22 |
23 | gpt3_prompt_file: './gpt_file/oxford_pets_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'oxford_pets'
27 | shots: 16
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_pets'
33 | dalle_shots: 8
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/pets/1shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [7, 3]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 0.4
22 |
23 | gpt3_prompt_file: './gpt_file/oxford_pets_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'oxford_pets'
27 | shots: 1
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_pets'
33 | dalle_shots: 4
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/pets/2shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [7, 3]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 0.4
22 |
23 | gpt3_prompt_file: './gpt_file/oxford_pets_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'oxford_pets'
27 | shots: 2
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_pets'
33 | dalle_shots: 2
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/pets/4shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [7, 3]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 0.4
22 |
23 | gpt3_prompt_file: './gpt_file/oxford_pets_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'oxford_pets'
27 | shots: 4
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_pets'
33 | dalle_shots: 4
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/pets/8shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [7, 3]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 0.6
22 |
23 | gpt3_prompt_file: './gpt_file/oxford_pets_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'oxford_pets'
27 | shots: 8
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_pets'
33 | dalle_shots: 8
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/sd_caltech101/16shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [12, 5]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 1.3
22 |
23 | gpt3_prompt_file: './gpt_file/caltech_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'caltech101'
27 | shots: 16
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Stable Diffusion Dataset -----
32 | dalle_dataset: 'sd_caltech'
33 | dalle_shots: 2
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/sd_caltech101/1shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [12, 5]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 0.5
22 |
23 | gpt3_prompt_file: './gpt_file/caltech_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'caltech101'
27 | shots: 1
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Stable Diffusion Dataset -----
32 | dalle_dataset: 'sd_caltech'
33 | dalle_shots: 4
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/sd_caltech101/2shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 | # load_cache: True
9 | # load_pre_feat: True
10 |
11 |
12 | # ------ Hyperparamters ------
13 | search_hp: True
14 | # search_hp: False
15 |
16 | search_scale: [12, 5]
17 | search_step: [200, 20]
18 |
19 | init_beta: 1
20 | init_alpha: 0.8
21 |
22 | gpt3_prompt_file: './gpt_file/caltech_prompt.json'
23 |
24 | # ------ Basic Config ------
25 | dataset: 'caltech101'
26 | shots: 2
27 | clip_backbone: 'RN50'
28 | dino_backbone: 'resnet50'
29 |
30 | # ------ Stable Diffusion Dataset -----
31 | dalle_dataset: 'sd_caltech'
32 | dalle_shots: 8
33 |
34 | lr: 0.001
35 | augment_epoch: 10
36 | train_epoch: 20
37 |
--------------------------------------------------------------------------------
/configs/sd_caltech101/4shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [12, 5]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 1.1
22 |
23 | gpt3_prompt_file: './gpt_file/caltech_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'caltech101'
27 | shots: 4
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Stable Diffusion Dataset -----
32 | dalle_dataset: 'sd_caltech'
33 | dalle_shots: 2
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/sd_caltech101/8shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [12, 5]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 1.1
22 |
23 | gpt3_prompt_file: './gpt_file/caltech_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'caltech101'
27 | shots: 8
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Stable Diffusion Dataset -----
32 | dalle_dataset: 'sd_caltech'
33 | dalle_shots: 1
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/sun/16shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [12, 10]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 0.8
22 |
23 | gpt3_prompt_file: './gpt_file/sun397_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'sun397'
27 | shots: 16
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_sun'
33 | dalle_shots: 1
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/sun/1shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 | # load_cache: True
9 | # load_pre_feat: True
10 |
11 |
12 | # ------ Hyperparamters ------
13 | search_hp: True
14 | # search_hp: False
15 |
16 | search_scale: [12, 10]
17 | search_step: [200, 20]
18 |
19 | init_beta: 1
20 | init_alpha: 0.5
21 |
22 | gpt3_prompt_file: './gpt_file/sun397_prompt.json'
23 |
24 | # ------ Basic Config ------
25 | dataset: 'sun397'
26 | shots: 1
27 | clip_backbone: 'RN50'
28 | dino_backbone: 'resnet50'
29 |
30 | # ------ Dalle Dataset -----
31 | dalle_dataset: 'dalle_sun'
32 | dalle_shots: 1
33 |
34 | lr: 0.001
35 | augment_epoch: 10
36 | train_epoch: 20
37 |
--------------------------------------------------------------------------------
/configs/sun/2shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [12, 10]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 0.5
22 |
23 | gpt3_prompt_file: './gpt_file/sun397_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'sun397'
27 | shots: 2
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_sun'
33 | dalle_shots: 1
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/sun/4shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [12, 10]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 0.6
22 |
23 | gpt3_prompt_file: './gpt_file/sun397_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'sun397'
27 | shots: 4
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_sun'
33 | dalle_shots: 1
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/sun/8shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [12, 10]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 0.7
22 |
23 | gpt3_prompt_file: './gpt_file/sun397_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'sun397'
27 | shots: 8
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_sun'
33 | dalle_shots: 1
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/ucf/16shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [7, 3]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 2
22 |
23 | gpt3_prompt_file: './gpt_file/ucf101_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'ucf101'
27 | shots: 16
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_ucf'
33 | dalle_shots: 2
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 40
38 |
--------------------------------------------------------------------------------
/configs/ucf/1shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [7, 3]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 1
22 |
23 | gpt3_prompt_file: './gpt_file/ucf101_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'ucf101'
27 | shots: 1
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_ucf'
33 | dalle_shots: 8
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/ucf/2shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [7, 3]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 1
22 |
23 | gpt3_prompt_file: './gpt_file/ucf101_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'ucf101'
27 | shots: 2
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_ucf'
33 | dalle_shots: 1
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/ucf/4shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [7, 3]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 1
22 |
23 | gpt3_prompt_file: './gpt_file/ucf101_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'ucf101'
27 | shots: 4
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_ucf'
33 | dalle_shots: 4
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 20
38 |
--------------------------------------------------------------------------------
/configs/ucf/8shot.yaml:
--------------------------------------------------------------------------------
1 | # ------ root_path ------
2 | root_path: ''
3 |
4 |
5 | # ------ Load Cache and Features ------
6 | load_cache: False
7 | load_pre_feat: False
8 |
9 | # load_cache: True
10 | # load_pre_feat: True
11 |
12 |
13 | # ------ Hyperparamters ------
14 | search_hp: True
15 | # search_hp: False
16 |
17 | search_scale: [7, 3]
18 | search_step: [200, 20]
19 |
20 | init_beta: 1
21 | init_alpha: 1.5
22 |
23 | gpt3_prompt_file: './gpt_file/ucf101_prompt.json'
24 |
25 | # ------ Basic Config ------
26 | dataset: 'ucf101'
27 | shots: 8
28 | clip_backbone: 'RN50'
29 | dino_backbone: 'resnet50'
30 |
31 | # ------ Dalle Dataset -----
32 | dalle_dataset: 'dalle_ucf'
33 | dalle_shots: 4
34 |
35 | lr: 0.001
36 | augment_epoch: 10
37 | train_epoch: 40
38 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .oxford_pets import OxfordPets
2 | from .eurosat import EuroSAT
3 | from .ucf101 import UCF101
4 | from .sun397 import SUN397
5 | from .caltech101 import Caltech101
6 | from .dtd import DescribableTextures
7 | from .fgvc import FGVCAircraft
8 | from .food101 import Food101
9 | from .oxford_flowers import OxfordFlowers
10 | from .stanford_cars import StanfordCars
11 | from .dalle_imagenet import Dalle_Imagenet
12 | from .dalle_caltech import Dalle_Caltech
13 | from .dalle_flowers import Dalle_Flowers
14 | from .dalle_food import Dalle_Food
15 | from .dalle_cars import Dalle_Cars
16 | from .dalle_dtd import Dalle_DTD
17 | from .dalle_eurosat import Dalle_Eurosat
18 | from .dalle_pets import Dalle_Pets
19 | from .dalle_sun import Dalle_Sun
20 | from .dalle_ucf import Dalle_UCF
21 | from .dalle_fgvc import Dalle_fgvc
22 | from .sd_caltech import SD_Caltech
23 |
24 | dataset_list = {
25 | "oxford_pets": OxfordPets,
26 | "eurosat": EuroSAT,
27 | "ucf101": UCF101,
28 | "sun397": SUN397,
29 | "caltech101": Caltech101,
30 | "dtd": DescribableTextures,
31 | "fgvc": FGVCAircraft,
32 | "food101": Food101,
33 | "oxford_flowers": OxfordFlowers,
34 | "stanford_cars": StanfordCars,
35 | "dalle_imagenet": Dalle_Imagenet,
36 | "dalle_caltech": Dalle_Caltech,
37 | "dalle_flowers": Dalle_Flowers,
38 | "dalle_food": Dalle_Food,
39 | "dalle_cars": Dalle_Cars,
40 | "dalle_dtd": Dalle_DTD,
41 | "dalle_eurosat": Dalle_Eurosat,
42 | "dalle_pets": Dalle_Pets,
43 | "dalle_sun": Dalle_Sun,
44 | "dalle_ucf": Dalle_UCF,
45 | "dalle_fgvc": Dalle_fgvc,
46 | "sd_caltech": SD_Caltech
47 | }
48 |
49 |
50 | def build_dataset(dataset, root_path, shots):
51 | return dataset_list[dataset](root_path, shots)
--------------------------------------------------------------------------------
/datasets/caltech101.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from .utils import Datum, DatasetBase
4 | from .oxford_pets import OxfordPets
5 |
6 |
7 | template = ['a photo of a {}.']
8 |
9 |
10 | class Caltech101(DatasetBase):
11 |
12 | dataset_dir = 'caltech-101'
13 |
14 | def __init__(self, root, num_shots):
15 | self.dataset_dir = os.path.join(root, self.dataset_dir)
16 | self.image_dir = os.path.join(self.dataset_dir, '101_ObjectCategories')
17 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_Caltech101.json')
18 |
19 | self.template = template
20 |
21 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
22 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
23 |
24 | super().__init__(train_x=train, val=val, test=test)
--------------------------------------------------------------------------------
/datasets/dalle_caltech.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
3 | from .oxford_pets import OxfordPets
4 |
5 | class Dalle_Caltech(DatasetBase):
6 |
7 | dataset_dir = 'dalle_caltech-101'
8 |
9 | def __init__(self, root, num_shots):
10 | # root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
11 | self.dataset_dir = os.path.join(root, self.dataset_dir)
12 | self.image_dir = os.path.join(self.dataset_dir, '101_ObjectCategories')
13 | self.split_path = os.path.join(self.dataset_dir, 'dalle_caltech.json')
14 |
15 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
16 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
17 |
18 | super().__init__(train_x=train, val=val, test=test)
--------------------------------------------------------------------------------
/datasets/dalle_cars.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
3 | from .oxford_pets import OxfordPets
4 |
5 | class Dalle_Cars(DatasetBase):
6 |
7 | dataset_dir = 'dalle_stanford_cars'
8 |
9 | def __init__(self, root, num_shots):
10 | # root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
11 | self.dataset_dir = os.path.join(root, self.dataset_dir)
12 | self.image_dir = os.path.join(self.dataset_dir, 'cars_train')
13 | self.split_path = os.path.join(self.dataset_dir, 'dalle_cars.json')
14 |
15 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
16 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
17 |
18 | super().__init__(train_x=train, val=val, test=test)
--------------------------------------------------------------------------------
/datasets/dalle_dtd.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
3 | from .oxford_pets import OxfordPets
4 |
5 | class Dalle_DTD(DatasetBase):
6 |
7 | dataset_dir = 'dalle_dtd'
8 |
9 | def __init__(self, root, num_shots):
10 | # root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
11 | self.dataset_dir = os.path.join(root, self.dataset_dir)
12 | self.image_dir = os.path.join(self.dataset_dir, 'images')
13 | self.split_path = os.path.join(self.dataset_dir, 'dalle_dtd.json')
14 |
15 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
16 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
17 |
18 | super().__init__(train_x=train, val=val, test=test)
--------------------------------------------------------------------------------
/datasets/dalle_eurosat.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
3 | from .oxford_pets import OxfordPets
4 |
5 | class Dalle_Eurosat(DatasetBase):
6 |
7 | dataset_dir = 'dalle_eurosat'
8 |
9 | def __init__(self, root, num_shots):
10 | # root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
11 | self.dataset_dir = os.path.join(root, self.dataset_dir)
12 | self.image_dir = os.path.join(self.dataset_dir, '2750')
13 | self.split_path = os.path.join(self.dataset_dir, 'dalle_eurosat.json')
14 |
15 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
16 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
17 |
18 | super().__init__(train_x=train, val=val, test=test)
--------------------------------------------------------------------------------
/datasets/dalle_fgvc.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
3 | from .oxford_pets import OxfordPets
4 |
5 | class Dalle_fgvc(DatasetBase):
6 |
7 | dataset_dir = 'dalle_fgvc_aircraft'
8 |
9 | def __init__(self, root, num_shots):
10 | # root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
11 | self.dataset_dir = os.path.join(root, self.dataset_dir)
12 | self.image_dir = os.path.join(self.dataset_dir, 'images')
13 | self.split_path = os.path.join(self.dataset_dir, 'dalle_fgvc.json')
14 |
15 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
16 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
17 |
18 | super().__init__(train_x=train, val=val, test=test)
--------------------------------------------------------------------------------
/datasets/dalle_flowers.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
3 | from .oxford_pets import OxfordPets
4 |
5 | class Dalle_Flowers(DatasetBase):
6 |
7 | dataset_dir = 'dalle_oxford_flowers'
8 |
9 | def __init__(self, root, num_shots):
10 | # root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
11 | self.dataset_dir = os.path.join(root, self.dataset_dir)
12 | self.image_dir = os.path.join(self.dataset_dir, 'jpg')
13 | self.split_path = os.path.join(self.dataset_dir, 'dalle_flower.json')
14 |
15 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
16 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
17 |
18 | super().__init__(train_x=train, val=val, test=test)
--------------------------------------------------------------------------------
/datasets/dalle_food.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
3 | from .oxford_pets import OxfordPets
4 |
5 | class Dalle_Food(DatasetBase):
6 |
7 | dataset_dir = 'dalle_food-101'
8 |
9 | def __init__(self, root, num_shots):
10 | # root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
11 | self.dataset_dir = os.path.join(root, self.dataset_dir)
12 | self.image_dir = os.path.join(self.dataset_dir, 'images')
13 | self.split_path = os.path.join(self.dataset_dir, 'dalle_food.json')
14 |
15 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
16 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
17 |
18 | super().__init__(train_x=train, val=val, test=test)
--------------------------------------------------------------------------------
/datasets/dalle_imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
3 | from .oxford_pets import OxfordPets
4 |
5 | class Dalle_Imagenet(DatasetBase):
6 |
7 | dataset_dir = 'dalle_imagenet'
8 |
9 | def __init__(self, root, num_shots):
10 | # root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
11 | self.dataset_dir = os.path.join(root, self.dataset_dir)
12 | self.image_dir = os.path.join(self.dataset_dir, 'data')
13 | self.split_path = os.path.join(self.dataset_dir, 'dalle_imagenet.json')
14 |
15 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
16 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
17 |
18 | super().__init__(train_x=train, val=val, test=test)
--------------------------------------------------------------------------------
/datasets/dalle_pets.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
3 | from .oxford_pets import OxfordPets
4 |
5 | class Dalle_Pets(DatasetBase):
6 |
7 | dataset_dir = 'dalle_oxford_pets'
8 |
9 | def __init__(self, root, num_shots):
10 | # root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
11 | self.dataset_dir = os.path.join(root, self.dataset_dir)
12 | self.image_dir = os.path.join(self.dataset_dir, 'images')
13 | self.split_path = os.path.join(self.dataset_dir, 'dalle_pet.json')
14 |
15 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
16 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
17 |
18 | super().__init__(train_x=train, val=val, test=test)
--------------------------------------------------------------------------------
/datasets/dalle_sun.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
3 | from .oxford_pets import OxfordPets
4 |
5 | class Dalle_Sun(DatasetBase):
6 |
7 | dataset_dir = 'dalle_sun397'
8 |
9 | def __init__(self, root, num_shots):
10 | # root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
11 | self.dataset_dir = os.path.join(root, self.dataset_dir)
12 | self.image_dir = os.path.join(self.dataset_dir, 'SUN397')
13 | self.split_path = os.path.join(self.dataset_dir, 'dalle_sun.json')
14 |
15 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
16 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
17 |
18 | super().__init__(train_x=train, val=val, test=test)
--------------------------------------------------------------------------------
/datasets/dalle_ucf.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
3 | from .oxford_pets import OxfordPets
4 |
5 | class Dalle_UCF(DatasetBase):
6 |
7 | dataset_dir = 'dalle_ucf101'
8 |
9 | def __init__(self, root, num_shots):
10 | # root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
11 | self.dataset_dir = os.path.join(root, self.dataset_dir)
12 | self.image_dir = os.path.join(self.dataset_dir, 'ucf101_midframes')
13 | self.split_path = os.path.join(self.dataset_dir, 'dalle_ucf.json')
14 |
15 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
16 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
17 |
18 | super().__init__(train_x=train, val=val, test=test)
--------------------------------------------------------------------------------
/datasets/dtd.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | from .utils import Datum, DatasetBase, listdir_nohidden
5 | from .oxford_pets import OxfordPets
6 |
7 |
8 | template = ['{} texture.']
9 |
10 |
11 | class DescribableTextures(DatasetBase):
12 |
13 | dataset_dir = 'dtd'
14 |
15 | def __init__(self, root, num_shots):
16 | self.dataset_dir = os.path.join(root, self.dataset_dir)
17 | self.image_dir = os.path.join(self.dataset_dir, 'images')
18 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_DescribableTextures.json')
19 |
20 | self.template = template
21 |
22 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
23 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
24 |
25 | super().__init__(train_x=train, val=val, test=test)
26 |
27 | @staticmethod
28 | def read_and_split_data(
29 | image_dir,
30 | p_trn=0.5,
31 | p_val=0.2,
32 | ignored=[],
33 | new_cnames=None
34 | ):
35 | # The data are supposed to be organized into the following structure
36 | # =============
37 | # images/
38 | # dog/
39 | # cat/
40 | # horse/
41 | # =============
42 | categories = listdir_nohidden(image_dir)
43 | categories = [c for c in categories if c not in ignored]
44 | categories.sort()
45 |
46 | p_tst = 1 - p_trn - p_val
47 | print(f'Splitting into {p_trn:.0%} train, {p_val:.0%} val, and {p_tst:.0%} test')
48 |
49 | def _collate(ims, y, c):
50 | items = []
51 | for im in ims:
52 | item = Datum(
53 | impath=im,
54 | label=y, # is already 0-based
55 | classname=c
56 | )
57 | items.append(item)
58 | return items
59 |
60 | train, val, test = [], [], []
61 | for label, category in enumerate(categories):
62 | category_dir = os.path.join(image_dir, category)
63 | images = listdir_nohidden(category_dir)
64 | images = [os.path.join(category_dir, im) for im in images]
65 | random.shuffle(images)
66 | n_total = len(images)
67 | n_train = round(n_total * p_trn)
68 | n_val = round(n_total * p_val)
69 | n_test = n_total - n_train - n_val
70 | assert n_train > 0 and n_val > 0 and n_test > 0
71 |
72 | if new_cnames is not None and category in new_cnames:
73 | category = new_cnames[category]
74 |
75 | train.extend(_collate(images[:n_train], label, category))
76 | val.extend(_collate(images[n_train:n_train+n_val], label, category))
77 | test.extend(_collate(images[n_train+n_val:], label, category))
78 |
79 | return train, val, test
80 |
--------------------------------------------------------------------------------
/datasets/eurosat.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
4 | from .oxford_pets import OxfordPets
5 |
6 |
7 | template = ['a centered satellite photo of {}.']
8 |
9 |
10 | NEW_CNAMES = {
11 | 'AnnualCrop': 'Annual Crop Land',
12 | 'Forest': 'Forest',
13 | 'HerbaceousVegetation': 'Herbaceous Vegetation Land',
14 | 'Highway': 'Highway or Road',
15 | 'Industrial': 'Industrial Buildings',
16 | 'Pasture': 'Pasture Land',
17 | 'PermanentCrop': 'Permanent Crop Land',
18 | 'Residential': 'Residential Buildings',
19 | 'River': 'River',
20 | 'SeaLake': 'Sea or Lake'
21 | }
22 |
23 |
24 | class EuroSAT(DatasetBase):
25 |
26 | dataset_dir = 'eurosat'
27 |
28 | def __init__(self, root, num_shots):
29 | self.dataset_dir = os.path.join(root, self.dataset_dir)
30 | self.image_dir = os.path.join(self.dataset_dir, '2750')
31 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_EuroSAT.json')
32 |
33 | self.template = template
34 |
35 | train_u, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
36 | train = self.generate_fewshot_dataset(train_u, num_shots=num_shots)
37 |
38 | super().__init__(train_x=train, val=val, test=test ,train_u= train_u)
39 |
40 | def update_classname(self, dataset_old):
41 | dataset_new = []
42 | for item_old in dataset_old:
43 | cname_old = item_old.classname
44 | cname_new = NEW_CLASSNAMES[cname_old]
45 | item_new = Datum(
46 | impath=item_old.impath,
47 | label=item_old.label,
48 | classname=cname_new
49 | )
50 | dataset_new.append(item_new)
51 | return dataset_new
52 |
--------------------------------------------------------------------------------
/datasets/fgvc.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
4 |
5 |
6 | template = ['a photo of a {}, a type of aircraft.']
7 |
8 |
9 | class FGVCAircraft(DatasetBase):
10 |
11 | dataset_dir = 'fgvc_aircraft'
12 |
13 | def __init__(self, root, num_shots):
14 |
15 | self.dataset_dir = os.path.join(root, self.dataset_dir)
16 | self.image_dir = os.path.join(self.dataset_dir, 'images')
17 |
18 | self.template = template
19 |
20 | classnames = []
21 | with open(os.path.join(self.dataset_dir, 'variants.txt'), 'r') as f:
22 | lines = f.readlines()
23 | for line in lines:
24 | classnames.append(line.strip())
25 | cname2lab = {c: i for i, c in enumerate(classnames)}
26 |
27 | train = self.read_data(cname2lab, 'images_variant_train.txt')
28 | val = self.read_data(cname2lab, 'images_variant_val.txt')
29 | test = self.read_data(cname2lab, 'images_variant_test.txt')
30 |
31 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
32 |
33 | super().__init__(train_x=train, val=val, test=test)
34 |
35 | def read_data(self, cname2lab, split_file):
36 | filepath = os.path.join(self.dataset_dir, split_file)
37 | items = []
38 |
39 | with open(filepath, 'r') as f:
40 | lines = f.readlines()
41 | for line in lines:
42 | line = line.strip().split(' ')
43 | imname = line[0] + '.jpg'
44 | classname = ' '.join(line[1:])
45 | impath = os.path.join(self.image_dir, imname)
46 | label = cname2lab[classname]
47 | item = Datum(
48 | impath=impath,
49 | label=label,
50 | classname=classname
51 | )
52 | items.append(item)
53 |
54 | return items
--------------------------------------------------------------------------------
/datasets/food101.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
4 | from .oxford_pets import OxfordPets
5 |
6 |
7 | template = ['a photo of {}, a type of food.']
8 |
9 |
10 | class Food101(DatasetBase):
11 |
12 | dataset_dir = 'food-101'
13 |
14 | def __init__(self, root, num_shots):
15 | self.dataset_dir = os.path.join(root, self.dataset_dir)
16 | self.image_dir = os.path.join(self.dataset_dir, 'images')
17 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_Food101.json')
18 |
19 | self.template = template
20 |
21 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
22 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
23 |
24 | super().__init__(train_x=train, val=val, test=test)
--------------------------------------------------------------------------------
/datasets/imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import random
4 | from collections import defaultdict
5 |
6 | import torch
7 | import torchvision
8 | import torchvision.transforms as transforms
9 |
10 |
11 | imagenet_classes = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray",
12 | "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco",
13 | "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper",
14 | "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander",
15 | "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog",
16 | "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin",
17 | "box turtle", "banded gecko", "green iguana", "Carolina anole",
18 | "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard",
19 | "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile",
20 | "American alligator", "triceratops", "worm snake", "ring-necked snake",
21 | "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake",
22 | "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra",
23 | "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake",
24 | "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider",
25 | "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider",
26 | "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl",
27 | "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet",
28 | "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck",
29 | "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby",
30 | "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch",
31 | "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab",
32 | "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab",
33 | "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron",
34 | "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot",
35 | "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher",
36 | "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion",
37 | "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel",
38 | "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle",
39 | "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound",
40 | "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound",
41 | "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound",
42 | "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier",
43 | "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier",
44 | "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier",
45 | "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier",
46 | "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer",
47 | "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier",
48 | "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier",
49 | "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever",
50 | "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla",
51 | "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel",
52 | "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel",
53 | "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard",
54 | "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie",
55 | "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann",
56 | "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog",
57 | "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff",
58 | "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky",
59 | "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog",
60 | "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon",
61 | "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle",
62 | "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf",
63 | "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox",
64 | "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat",
65 | "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger",
66 | "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose",
67 | "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle",
68 | "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper",
69 | "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper",
70 | "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly",
71 | "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly",
72 | "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit",
73 | "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse",
74 | "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison",
75 | "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)",
76 | "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat",
77 | "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan",
78 | "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque",
79 | "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin",
80 | "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey",
81 | "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda",
82 | "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish",
83 | "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown",
84 | "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance",
85 | "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle",
86 | "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo",
87 | "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel",
88 | "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel",
89 | "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)",
90 | "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini",
91 | "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet",
92 | "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra",
93 | "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest",
94 | "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe",
95 | "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton",
96 | "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran",
97 | "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw",
98 | "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking",
99 | "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker",
100 | "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard",
101 | "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot",
102 | "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed",
103 | "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer",
104 | "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table",
105 | "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig",
106 | "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar",
107 | "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder",
108 | "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute",
109 | "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed",
110 | "freight car", "French horn", "frying pan", "fur coat", "garbage truck",
111 | "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola",
112 | "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine",
113 | "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer",
114 | "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet",
115 | "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar",
116 | "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep",
117 | "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat",
118 | "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library",
119 | "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion",
120 | "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag",
121 | "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask",
122 | "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone",
123 | "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile",
124 | "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor",
125 | "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa",
126 | "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail",
127 | "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina",
128 | "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart",
129 | "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush",
130 | "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench",
131 | "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case",
132 | "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube",
133 | "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball",
134 | "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag",
135 | "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho",
136 | "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug",
137 | "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill",
138 | "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel",
139 | "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator",
140 | "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser",
141 | "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal",
142 | "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard",
143 | "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store",
144 | "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap",
145 | "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door",
146 | "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock",
147 | "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater",
148 | "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight",
149 | "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf",
150 | "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa",
151 | "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge",
152 | "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe",
153 | "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball",
154 | "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof",
155 | "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store",
156 | "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod",
157 | "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard",
158 | "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling",
159 | "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball",
160 | "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink",
161 | "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle",
162 | "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing",
163 | "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website",
164 | "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu",
165 | "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette",
166 | "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli",
167 | "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber",
168 | "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange",
169 | "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate",
170 | "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito",
171 | "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef",
172 | "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player",
173 | "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn",
174 | "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom",
175 | "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"]
176 |
177 | imagenet_templates = ["itap of a {}.",
178 | "a bad photo of the {}.",
179 | "a origami {}.",
180 | "a photo of the large {}.",
181 | "a {} in a video game.",
182 | "art of the {}.",
183 | "a photo of the small {}."]
184 |
185 |
186 | class ImageNet():
187 |
188 | dataset_dir = 'imagenet'
189 |
190 | def __init__(self, root, num_shots, preprocess):
191 |
192 | self.dataset_dir = os.path.join(root, self.dataset_dir)
193 | self.image_dir = os.path.join(self.dataset_dir, 'images')
194 |
195 | train_preprocess = transforms.Compose([
196 | transforms.RandomResizedCrop(size=224, scale=(0.5, 1), interpolation=transforms.InterpolationMode.BICUBIC),
197 | transforms.RandomHorizontalFlip(p=0.5),
198 | transforms.ToTensor(),
199 | transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
200 | ])
201 | test_preprocess = preprocess
202 |
203 | self.train = torchvision.datasets.ImageNet(self.image_dir, split='train', transform=train_preprocess)
204 | self.val = torchvision.datasets.ImageNet(self.image_dir, split='val', transform=test_preprocess)
205 | self.test = torchvision.datasets.ImageNet(self.image_dir, split='val', transform=test_preprocess)
206 |
207 | self.template = imagenet_templates
208 | self.classnames = imagenet_classes
209 |
210 | split_by_label_dict = defaultdict(list)
211 | for i in range(len(self.train.imgs)):
212 | split_by_label_dict[self.train.targets[i]].append(self.train.imgs[i])
213 | imgs = []
214 | targets = []
215 |
216 | for label, items in split_by_label_dict.items():
217 | imgs = imgs + random.sample(items, num_shots)
218 | targets = targets + [label for i in range(num_shots)]
219 | self.train.imgs = imgs
220 | self.train.targets = targets
221 | self.train.samples = imgs
222 |
223 | if __name__ == '__main__':
224 | print('screw' in imagenet_classes)
--------------------------------------------------------------------------------
/datasets/oxford_flowers.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | from scipy.io import loadmat
4 | from collections import defaultdict
5 |
6 | from .oxford_pets import OxfordPets
7 | from .utils import Datum, DatasetBase, read_json
8 |
9 |
10 | template = ['a photo of a {}, a type of flower.']
11 |
12 |
13 | class OxfordFlowers(DatasetBase):
14 |
15 | dataset_dir = 'oxford_flowers'
16 |
17 | def __init__(self, root, num_shots):
18 | self.dataset_dir = os.path.join(root, self.dataset_dir)
19 | self.image_dir = os.path.join(self.dataset_dir, 'jpg')
20 | self.label_file = os.path.join(self.dataset_dir, 'imagelabels.mat')
21 | self.lab2cname_file = os.path.join(self.dataset_dir, 'cat_to_name.json')
22 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_OxfordFlowers.json')
23 |
24 | self.template = template
25 |
26 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
27 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
28 |
29 | super().__init__(train_x=train, val=val, test=test)
30 |
31 | def read_data(self):
32 | tracker = defaultdict(list)
33 | label_file = loadmat(self.label_file)['labels'][0]
34 | for i, label in enumerate(label_file):
35 | imname = f'image_{str(i + 1).zfill(5)}.jpg'
36 | impath = os.path.join(self.image_dir, imname)
37 | label = int(label)
38 | tracker[label].append(impath)
39 |
40 | print('Splitting data into 50% train, 20% val, and 30% test')
41 |
42 | def _collate(ims, y, c):
43 | items = []
44 | for im in ims:
45 | item = Datum(
46 | impath=im,
47 | label=y-1, # convert to 0-based label
48 | classname=c
49 | )
50 | items.append(item)
51 | return items
52 |
53 | lab2cname = read_json(self.lab2cname_file)
54 | train, val, test = [], [], []
55 | for label, impaths in tracker.items():
56 | random.shuffle(impaths)
57 | n_total = len(impaths)
58 | n_train = round(n_total * 0.5)
59 | n_val = round(n_total * 0.2)
60 | n_test = n_total - n_train - n_val
61 | assert n_train > 0 and n_val > 0 and n_test > 0
62 | cname = lab2cname[str(label)]
63 | train.extend(_collate(impaths[:n_train], label, cname))
64 | val.extend(_collate(impaths[n_train:n_train+n_val], label, cname))
65 | test.extend(_collate(impaths[n_train+n_val:], label, cname))
66 |
67 | return train, val, test
--------------------------------------------------------------------------------
/datasets/oxford_pets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import random
4 | from collections import defaultdict
5 |
6 | import torchvision.transforms as transforms
7 |
8 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
9 |
10 |
11 | template = ['a photo of a {}, a type of pet.']
12 |
13 |
14 | class OxfordPets(DatasetBase):
15 |
16 | dataset_dir = 'oxford_pets'
17 |
18 | def __init__(self, root, num_shots):
19 | self.dataset_dir = os.path.join(root, self.dataset_dir)
20 | self.image_dir = os.path.join(self.dataset_dir, 'images')
21 | self.anno_dir = os.path.join(self.dataset_dir, 'annotations')
22 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_OxfordPets.json')
23 |
24 | self.template = template
25 |
26 | train, val, test = self.read_split(self.split_path, self.image_dir)
27 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
28 |
29 | super().__init__(train_x=train, val=val, test=test)
30 |
31 | def read_data(self, split_file):
32 | filepath = os.path.join(self.anno_dir, split_file)
33 | items = []
34 |
35 | with open(filepath, 'r') as f:
36 | lines = f.readlines()
37 | for line in lines:
38 | line = line.strip()
39 | imname, label, species, _ = line.split(' ')
40 | breed = imname.split('_')[:-1]
41 | breed = '_'.join(breed)
42 | breed = breed.lower()
43 | imname += '.jpg'
44 | impath = os.path.join(self.image_dir, imname)
45 | label = int(label) - 1 # convert to 0-based index
46 | item = Datum(
47 | impath=impath,
48 | label=label,
49 | classname=breed
50 | )
51 | items.append(item)
52 |
53 | return items
54 |
55 | @staticmethod
56 | def split_trainval(trainval, p_val=0.2):
57 | p_trn = 1 - p_val
58 | print(f'Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val')
59 | tracker = defaultdict(list)
60 | for idx, item in enumerate(trainval):
61 | label = item.label
62 | tracker[label].append(idx)
63 |
64 | train, val = [], []
65 | for label, idxs in tracker.items():
66 | n_val = round(len(idxs) * p_val)
67 | assert n_val > 0
68 | random.shuffle(idxs)
69 | for n, idx in enumerate(idxs):
70 | item = trainval[idx]
71 | if n < n_val:
72 | val.append(item)
73 | else:
74 | train.append(item)
75 |
76 | return train, val
77 |
78 | @staticmethod
79 | def save_split(train, val, test, filepath, path_prefix):
80 | def _extract(items):
81 | out = []
82 | for item in items:
83 | impath = item.impath
84 | label = item.label
85 | classname = item.classname
86 | impath = impath.replace(path_prefix, '')
87 | if impath.startswith('/'):
88 | impath = impath[1:]
89 | out.append((impath, label, classname))
90 | return out
91 |
92 | train = _extract(train)
93 | val = _extract(val)
94 | test = _extract(test)
95 |
96 | split = {
97 | 'train': train,
98 | 'val': val,
99 | 'test': test
100 | }
101 |
102 | write_json(split, filepath)
103 | print(f'Saved split to {filepath}')
104 |
105 | @staticmethod
106 | def read_split(filepath, path_prefix):
107 | def _convert(items):
108 | out = []
109 | for impath, label, classname in items:
110 | impath = os.path.join(path_prefix, impath)
111 | item = Datum(
112 | impath=impath,
113 | label=int(label),
114 | classname=classname
115 | )
116 | out.append(item)
117 | return out
118 |
119 | print(f'Reading split from {filepath}')
120 | split = read_json(filepath)
121 | train = _convert(split['train'])
122 | val = _convert(split['val'])
123 | test = _convert(split['test'])
124 |
125 | return train, val, test
--------------------------------------------------------------------------------
/datasets/sd_caltech.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
3 | from .oxford_pets import OxfordPets
4 |
5 | class SD_Caltech(DatasetBase):
6 |
7 | dataset_dir = 'sd_caltech_101'
8 |
9 | def __init__(self, root, num_shots):
10 | # root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
11 | self.dataset_dir = os.path.join(root, self.dataset_dir)
12 | self.image_dir = os.path.join(self.dataset_dir, '101_ObjectCategories')
13 | self.split_path = os.path.join(self.dataset_dir, 'sd_caltech.json')
14 |
15 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
16 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
17 |
18 | super().__init__(train_x=train, val=val, test=test)
--------------------------------------------------------------------------------
/datasets/stanford_cars.py:
--------------------------------------------------------------------------------
1 | import os
2 | from scipy.io import loadmat
3 |
4 | from .oxford_pets import OxfordPets
5 | from .utils import Datum, DatasetBase
6 |
7 |
8 | template = ['a photo of a {}.']
9 |
10 |
11 | class StanfordCars(DatasetBase):
12 |
13 | dataset_dir = 'stanford_cars'
14 |
15 | def __init__(self, root, num_shots):
16 | self.dataset_dir = os.path.join(root, self.dataset_dir)
17 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_StanfordCars.json')
18 |
19 | self.template = template
20 |
21 | train, val, test = OxfordPets.read_split(self.split_path, self.dataset_dir)
22 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
23 |
24 | super().__init__(train_x=train, val=val, test=test)
25 |
26 | def read_data(self, image_dir, anno_file, meta_file):
27 | anno_file = loadmat(anno_file)['annotations'][0]
28 | meta_file = loadmat(meta_file)['class_names'][0]
29 | items = []
30 |
31 | for i in range(len(anno_file)):
32 | imname = anno_file[i]['fname'][0]
33 | impath = os.path.join(self.dataset_dir, image_dir, imname)
34 | label = anno_file[i]['class'][0, 0]
35 | label = int(label) - 1 # convert to 0-based index
36 | classname = meta_file[label][0]
37 | names = classname.split(' ')
38 | year = names.pop(-1)
39 | names.insert(0, year)
40 | classname = ' '.join(names)
41 | item = Datum(
42 | impath=impath,
43 | label=label,
44 | classname=classname
45 | )
46 | items.append(item)
47 |
48 | return items
--------------------------------------------------------------------------------
/datasets/sun397.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
4 |
5 | from .oxford_pets import OxfordPets
6 |
7 |
8 | template = ['a photo of a {}.']
9 |
10 |
11 | class SUN397(DatasetBase):
12 |
13 | dataset_dir = 'sun397'
14 |
15 | def __init__(self, root, num_shots):
16 | self.dataset_dir = os.path.join(root, self.dataset_dir)
17 | self.image_dir = os.path.join(self.dataset_dir, 'SUN397')
18 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_SUN397.json')
19 |
20 | self.template = template
21 |
22 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
23 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
24 |
25 | super().__init__(train_x=train, val=val, test=test)
26 |
27 | def read_data(self, cname2lab, text_file):
28 | text_file = os.path.join(self.dataset_dir, text_file)
29 | items = []
30 |
31 | with open(text_file, 'r') as f:
32 | lines = f.readlines()
33 | for line in lines:
34 | imname = line.strip()[1:] # remove /
35 | classname = os.path.dirname(imname)
36 | label = cname2lab[classname]
37 | impath = os.path.join(self.image_dir, imname)
38 |
39 | names = classname.split('/')[1:] # remove 1st letter
40 | names = names[::-1] # put words like indoor/outdoor at first
41 | classname = ' '.join(names)
42 |
43 | item = Datum(
44 | impath=impath,
45 | label=label,
46 | classname=classname
47 | )
48 | items.append(item)
49 |
50 | return items
51 |
--------------------------------------------------------------------------------
/datasets/ucf101.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
4 |
5 | from .oxford_pets import OxfordPets
6 |
7 |
8 | template = ['a photo of a person doing {}.']
9 |
10 |
11 | class UCF101(DatasetBase):
12 |
13 | dataset_dir = 'ucf101'
14 |
15 | def __init__(self, root, num_shots):
16 | self.dataset_dir = os.path.join(root, self.dataset_dir)
17 | self.image_dir = os.path.join(self.dataset_dir, 'UCF-101-midframes')
18 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_UCF101.json')
19 |
20 | self.template = template
21 |
22 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
23 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
24 |
25 | super().__init__(train_x=train, val=val, test=test)
26 |
27 | def read_data(self, cname2lab, text_file):
28 | text_file = os.path.join(self.dataset_dir, text_file)
29 | items = []
30 |
31 | with open(text_file, 'r') as f:
32 | lines = f.readlines()
33 | for line in lines:
34 | line = line.strip().split(' ')[0] # trainlist: filename, label
35 | action, filename = line.split('/')
36 | label = cname2lab[action]
37 |
38 | elements = re.findall('[A-Z][^A-Z]*', action)
39 | renamed_action = '_'.join(elements)
40 |
41 | filename = filename.replace('.avi', '.jpg')
42 | impath = os.path.join(self.image_dir, renamed_action, filename)
43 |
44 | item = Datum(
45 | impath=impath,
46 | label=label,
47 | classname=renamed_action
48 | )
49 | items.append(item)
50 |
51 | return items
52 |
--------------------------------------------------------------------------------
/datasets/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import os.path as osp
4 | import tarfile
5 | import zipfile
6 | from collections import defaultdict
7 | import gdown
8 | import json
9 | import torch
10 | from torch.utils.data import Dataset as TorchDataset
11 | import torchvision.transforms as T
12 | from PIL import Image
13 |
14 |
15 | def read_json(fpath):
16 | """Read json file from a path."""
17 | with open(fpath, 'r') as f:
18 | obj = json.load(f)
19 | return obj
20 |
21 |
22 | def write_json(obj, fpath):
23 | """Writes to a json file."""
24 | if not osp.exists(osp.dirname(fpath)):
25 | os.makedirs(osp.dirname(fpath))
26 | with open(fpath, 'w') as f:
27 | json.dump(obj, f, indent=4, separators=(',', ': '))
28 |
29 |
30 | def read_image(path):
31 | """Read image from path using ``PIL.Image``.
32 |
33 | Args:
34 | path (str): path to an image.
35 |
36 | Returns:
37 | PIL image
38 | """
39 | if not osp.exists(path):
40 | raise IOError('No file exists at {}'.format(path))
41 |
42 | while True:
43 | try:
44 | img = Image.open(path).convert('RGB')
45 | return img
46 | except IOError:
47 | print(
48 | 'Cannot read image from {}, '
49 | 'probably due to heavy IO. Will re-try'.format(path)
50 | )
51 |
52 |
53 | def listdir_nohidden(path, sort=False):
54 | """List non-hidden items in a directory.
55 |
56 | Args:
57 | path (str): directory path.
58 | sort (bool): sort the items.
59 | """
60 | items = [f for f in os.listdir(path) if not f.startswith('.') and 'sh' not in f]
61 | if sort:
62 | items.sort()
63 | return items
64 |
65 |
66 | class Datum:
67 | """Data instance which defines the basic attributes.
68 |
69 | Args:
70 | impath (str): image path.
71 | label (int): class label.
72 | domain (int): domain label.
73 | classname (str): class name.
74 | """
75 |
76 | def __init__(self, impath='', label=0, domain=-1, classname=''):
77 | assert isinstance(impath, str)
78 | assert isinstance(label, int)
79 | assert isinstance(domain, int)
80 | assert isinstance(classname, str)
81 |
82 | self._impath = impath
83 | self._label = label
84 | self._domain = domain
85 | self._classname = classname
86 |
87 | @property
88 | def impath(self):
89 | return self._impath
90 |
91 | @property
92 | def label(self):
93 | return self._label
94 |
95 | @property
96 | def domain(self):
97 | return self._domain
98 |
99 | @property
100 | def classname(self):
101 | return self._classname
102 |
103 |
104 | class DatasetBase:
105 | """A unified dataset class for
106 | 1) domain adaptation
107 | 2) domain generalization
108 | 3) semi-supervised learning
109 | """
110 | dataset_dir = '' # the directory where the dataset is stored
111 | domains = [] # string names of all domains
112 |
113 | def __init__(self, train_x=None, train_u=None, val=None, test=None):
114 | self._train_x = train_x # labeled training data
115 | self._train_u = train_u # unlabeled training data (optional)
116 | self._val = val # validation data (optional)
117 | self._test = test # test data
118 |
119 | self._num_classes = self.get_num_classes(train_x)
120 | self._lab2cname, self._classnames = self.get_lab2cname(train_x)
121 |
122 | @property
123 | def train_x(self):
124 | return self._train_x
125 |
126 | @property
127 | def train_u(self):
128 | return self._train_u
129 |
130 | @property
131 | def val(self):
132 | return self._val
133 |
134 | @property
135 | def test(self):
136 | return self._test
137 |
138 | @property
139 | def lab2cname(self):
140 | return self._lab2cname
141 |
142 | @property
143 | def classnames(self):
144 | return self._classnames
145 |
146 | @property
147 | def num_classes(self):
148 | return self._num_classes
149 |
150 | def get_num_classes(self, data_source):
151 | """Count number of classes.
152 |
153 | Args:
154 | data_source (list): a list of Datum objects.
155 | """
156 | label_set = set()
157 | for item in data_source:
158 | label_set.add(item.label)
159 | return max(label_set) + 1
160 |
161 | def get_lab2cname(self, data_source):
162 | """Get a label-to-classname mapping (dict).
163 |
164 | Args:
165 | data_source (list): a list of Datum objects.
166 | """
167 | container = set()
168 | for item in data_source:
169 | container.add((item.label, item.classname))
170 | mapping = {label: classname for label, classname in container}
171 | labels = list(mapping.keys())
172 | labels.sort()
173 | classnames = [mapping[label] for label in labels]
174 | return mapping, classnames
175 |
176 | def check_input_domains(self, source_domains, target_domains):
177 | self.is_input_domain_valid(source_domains)
178 | self.is_input_domain_valid(target_domains)
179 |
180 | def is_input_domain_valid(self, input_domains):
181 | for domain in input_domains:
182 | if domain not in self.domains:
183 | raise ValueError(
184 | 'Input domain must belong to {}, '
185 | 'but got [{}]'.format(self.domains, domain)
186 | )
187 |
188 | def download_data(self, url, dst, from_gdrive=True):
189 | if not osp.exists(osp.dirname(dst)):
190 | os.makedirs(osp.dirname(dst))
191 |
192 | if from_gdrive:
193 | gdown.download(url, dst, quiet=False)
194 | else:
195 | raise NotImplementedError
196 |
197 | print('Extracting file ...')
198 |
199 | try:
200 | tar = tarfile.open(dst)
201 | tar.extractall(path=osp.dirname(dst))
202 | tar.close()
203 | except:
204 | zip_ref = zipfile.ZipFile(dst, 'r')
205 | zip_ref.extractall(osp.dirname(dst))
206 | zip_ref.close()
207 |
208 | print('File extracted to {}'.format(osp.dirname(dst)))
209 |
210 | def generate_fewshot_dataset(
211 | self, *data_sources, num_shots=-1, repeat=True
212 | ):
213 | """Generate a few-shot dataset (typically for the training set).
214 |
215 | This function is useful when one wants to evaluate a model
216 | in a few-shot learning setting where each class only contains
217 | a few number of images.
218 |
219 | Args:
220 | data_sources: each individual is a list containing Datum objects.
221 | num_shots (int): number of instances per class to sample.
222 | repeat (bool): repeat images if needed.
223 | """
224 | if num_shots < 1:
225 | if len(data_sources) == 1:
226 | return data_sources[0]
227 | return data_sources
228 |
229 | print(f'Creating a {num_shots}-shot dataset')
230 |
231 | output = []
232 |
233 | for data_source in data_sources:
234 | tracker = self.split_dataset_by_label(data_source)
235 | dataset = []
236 |
237 | for label, items in tracker.items():
238 | if len(items) >= num_shots:
239 | sampled_items = random.sample(items, num_shots)
240 | else:
241 | if repeat:
242 | sampled_items = random.choices(items, k=num_shots)
243 | else:
244 | sampled_items = items
245 | dataset.extend(sampled_items)
246 |
247 | output.append(dataset)
248 |
249 | if len(output) == 1:
250 | return output[0]
251 |
252 | return output
253 |
254 | def split_dataset_by_label(self, data_source):
255 | """Split a dataset, i.e. a list of Datum objects,
256 | into class-specific groups stored in a dictionary.
257 |
258 | Args:
259 | data_source (list): a list of Datum objects.
260 | """
261 | output = defaultdict(list)
262 |
263 | for item in data_source:
264 | output[item.label].append(item)
265 |
266 | return output
267 |
268 | def split_dataset_by_domain(self, data_source):
269 | """Split a dataset, i.e. a list of Datum objects,
270 | into domain-specific groups stored in a dictionary.
271 |
272 | Args:
273 | data_source (list): a list of Datum objects.
274 | """
275 | output = defaultdict(list)
276 |
277 | for item in data_source:
278 | output[item.domain].append(item)
279 |
280 | return output
281 |
282 |
283 | class DatasetWrapper(TorchDataset):
284 | def __init__(self, data_source, input_size, transform=None, is_train=False,
285 | return_img0=False, k_tfm=1):
286 | self.data_source = data_source
287 | self.transform = transform # accept list (tuple) as input
288 | self.is_train = is_train
289 | # Augmenting an image K>1 times is only allowed during training
290 | self.k_tfm = k_tfm if is_train else 1
291 | self.return_img0 = return_img0
292 |
293 | if self.k_tfm > 1 and transform is None:
294 | raise ValueError(
295 | 'Cannot augment the image {} times '
296 | 'because transform is None'.format(self.k_tfm)
297 | )
298 |
299 | # Build transform that doesn't apply any data augmentation
300 | interp_mode = T.InterpolationMode.BICUBIC
301 | to_tensor = []
302 | to_tensor += [T.Resize(input_size, interpolation=interp_mode)]
303 | to_tensor += [T.ToTensor()]
304 | normalize = T.Normalize(
305 | mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)
306 | )
307 | to_tensor += [normalize]
308 | self.to_tensor = T.Compose(to_tensor)
309 |
310 | def __len__(self):
311 | return len(self.data_source)
312 |
313 | def __getitem__(self, idx):
314 | item = self.data_source[idx]
315 |
316 | output = {
317 | 'label': item.label,
318 | 'domain': item.domain,
319 | 'impath': item.impath
320 | }
321 |
322 | img0 = read_image(item.impath)
323 |
324 | if self.transform is not None:
325 | if isinstance(self.transform, (list, tuple)):
326 | for i, tfm in enumerate(self.transform):
327 | img = self._transform_image(tfm, img0)
328 | keyname = 'img'
329 | if (i + 1) > 1:
330 | keyname += str(i + 1)
331 | output[keyname] = img
332 | else:
333 | img = self._transform_image(self.transform, img0)
334 | output['img'] = img
335 |
336 | if self.return_img0:
337 | output['img0'] = self.to_tensor(img0)
338 |
339 | return output['img'], output['label']
340 |
341 | def _transform_image(self, tfm, img0):
342 | img_list = []
343 |
344 | for k in range(self.k_tfm):
345 | img_list.append(tfm(img0))
346 |
347 | img = img_list
348 | if len(img) == 1:
349 | img = img[0]
350 |
351 | return img
352 |
353 |
354 | def build_data_loader(
355 | data_source=None,
356 | batch_size=64,
357 | input_size=224,
358 | tfm=None,
359 | is_train=True,
360 | shuffle=False,
361 | dataset_wrapper=None
362 | ):
363 |
364 | if dataset_wrapper is None:
365 | dataset_wrapper = DatasetWrapper
366 |
367 | # Build data loader
368 | data_loader = torch.utils.data.DataLoader(
369 | dataset_wrapper(data_source, input_size=input_size, transform=tfm, is_train=is_train),
370 | batch_size=batch_size,
371 | num_workers=8,
372 | shuffle=shuffle,
373 | drop_last=False,
374 | pin_memory=False
375 | )
376 | assert len(data_loader) > 0
377 |
378 | return data_loader
379 |
--------------------------------------------------------------------------------
/dino/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/CaFo/a805a2aefc6757fdbe10ac9a3165520ceb0e01cb/dino/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/exp.log:
--------------------------------------------------------------------------------
1 | 1 2 4 8 16
2 | ImageNet 63.80 64.34 65.64 66.86 68.79
3 | StanfordCars 61.98 63.36 65.69 70.31 76.73
4 | UCF101 68.60 70.45 72.96 78.06 79.94
5 | Caltech101 91.85 92.37 93.14 93.83 94.60
6 | Flowers102 80.88 84.94 90.95 92.98 95.86
7 | SUN397 64.89 66.81 69.17 70.34 72.60
8 | DTD 53.43 56.32 60.99 66.19 69.62
9 | Eurosat 69.00 72.86 83.90 86.48 88.68
10 | FGVCAircraft 24.96 26.04 32.94 40.38 49.05
11 | OxfordPets 89.21 89.10 90.11 90.52 91.55
12 | Food101 77.99 78.10 78.32 78.84 79.30
--------------------------------------------------------------------------------
/gpt_file/eurosat_prompt.json:
--------------------------------------------------------------------------------
1 | {
2 | "Annual Crop Land": [
3 | "A centered satellite photo of Annual Crop Land would look like a large green field with small patches of brown or bare earth in between.",
4 | "A centered satellite photo of Annual Crop Land would look like a large green field with small patches of brown or bare earth in between.",
5 | "A centered satellite photo of Annual Crop Land would look like a large green field with small patches of brown or bare earth in between.",
6 | "A centered satellite photo of Annual Crop Land would also include any buildings or roads that are near the field.",
7 | "A centered satellite photo of Annual Crop Land would also show any roads or paths that lead to the crop land.",
8 | "A centered satellite photo of Annual Crop Land would look like a large green field with small patches of brown or bare earth in between.",
9 | "A centered satellite photo of Annual Crop Land would look like a large green field with small patches of brown or bare earth in between.",
10 | "A centered satellite photo of Annual Crop Land would look like a large green field with small patches of brown or bare earth in between.",
11 | "A centered satellite photo of Annual Crop Land would look like one large, continuous field of green.",
12 | "A centered satellite photo of Annual Crop Land may also show irrigation systems or other farming infrastructure."
13 | ],
14 | "Forest": [
15 | "A centered satellite photo of Forest Land would look like a large green field with small patches of brown or bare earth in between.",
16 | "A centered satellite photo of Forest Land would look like a large green area with small patches of brown or bare earth in between.",
17 | "A centered satellite photo of Forest Land would look like a dense, green area with few or no bare patches of earth.",
18 | "A centered satellite photo of Forest Land would look like a large green field with small patches of brown or bare earth in between.",
19 | "A centered satellite photo of Forest Land would look like a large green field with small patches of brown or bare earth in between. A centered satellite photo of Grassland would look like a large green field with small patches of brown or bare earth in between.",
20 | "A centered satellite photo of Forest would look like a large green field with small patches of brown or bare earth in between.",
21 | "A centered satellite photo of Forest Land would look like a large playing field with lots of trees.",
22 | "A centered satellite photo of Forest Land would look like a large green field with small patches of brown or bare earth in between.",
23 | "A centered satellite photo of Forest would look like a green or dark green field with patches of brown or bare earth in between.",
24 | "A centered satellite photo of Forest Land would look like a large green area with small patches of brown or bare earth in between."
25 | ],
26 | "Herbaceous Vegetation Land": [
27 | "A centered satellite photo of Herbaceous Vegetation Land would look like a large green or brown field with small patches of green or brown in between.",
28 | "A centered satellite photo of Herbaceous Vegetation Land would look like a large green field with small patches of brown or bare earth in between.",
29 | "A centered satellite photo of Herbaceous Vegetation Land would look like a green field with small patches of brown of bare earth in between. A centered satellite photo of Tree Cover would look like a green field with small patches of brown or bare earth in between, and a few trees scattered throughout.",
30 | "A centered satellite photo of Herbaceous Vegetation Land would look like a field of green with a few brown spots in between.",
31 | "A centered satellite photo of Herbaceous Vegetation Land would look like a green field with a few trees or bushes mixed in.",
32 | "A centered satellite photo of Herbaceous Vegetation Land would look like a large green field with small patches of brown or bare earth in between.",
33 | "A centered satellite photo of Herbaceous Vegetation Land would look like a green field with small patches of brown or bare earth in between.",
34 | "A centered satellite photo of Herbaceous Vegetation Land would look like a green field with very small patches of brown or bare earth in between.",
35 | "A centered satellite photo of Herbaceous Vegetation Land would look like a large green field with small patches of brown or bare earth in between. A centered satellite photo of Perennial Crop Land would look like a large green field with small patches of brown or bare earth in between.",
36 | "A centered satellite photo of Herbaceous Vegetation Land would look like a large, green field with small patches of brown or bare earth in between."
37 | ],
38 | "Highway or Road": [
39 | "A centered satellite photo of Highway or Road Land would look like a long, thin, dark strip with small patches of green or brown on either side.",
40 | "A centered satellite photo of Highway or Road Land would look like a large paved road with small patches of green or brown on either side.",
41 | "A centered satellite photo of Highway or Road would look like a thin, dark line winding through a lighter-colored background.",
42 | "A centered satellite photo of Highway or Road Infrastructure would look like a large number of dark lines running across the landscape.",
43 | "A centered satellite photo of Highway or Road Infrastructure would look like a thin line of asphalt with a small patch of gravel or dirt on each side.",
44 | "A centered satellite photo of Highway or Road Land would look like a long, straight, grey line with small patches of green or brown on either side.",
45 | "A centered satellite photo of Highway or Road Infrastructure would look like a spider web of grey or white lines with small patches of green or brown in between.",
46 | "A centered satellite photo of Highway or Road Land would look like a large number of thin, dark lines criss-crossing each other.",
47 | "A centered satellite photo of Highway or Road Land would look like a large brown or gray road with green fields on either side.",
48 | "A centered satellite photo of Highway or Road Land would look like a spider web of thin, black lines."
49 | ],
50 | "Industrial Buildings": [
51 | "A centered satellite photo of Industrial Buildings would look like a cluster of buildings, usually gray or white, surrounded by a parking lot.",
52 | "A centered satellite photo of Industrial Buildings would look like a group of large structures with small parking lots around them.",
53 | "A centered satellite photo of Industrial Buildings would look like a series of low, rectangular buildings with roofs of different colors.",
54 | "A centered satellite photo of Industrial Buildings would look like large, dark buildings amid a matrix of smaller, lighter buildings.",
55 | "A centered satellite photo of Industrial Buildings would look like a city with large buildings and smokestacks.",
56 | "A centered satellite photo of Industrial Buildings would look like a city with a few buildings that are taller than the others.",
57 | "A centered satellite photo of Industrial Buildings would look like a densely populated area with many buildings and roads.",
58 | "A centered satellite photo of Industrial Buildings would look like large connected buildings surrounded by asphalt parking lots.",
59 | "A centered satellite photo of Industrial Buildings would look like a bunch of large angular buildings with small streets in between them.",
60 | "A centered satellite photo of Industrial Buildings would look like a series of large Modern highrises in an urban area."
61 | ],
62 | "Pasture Land": [
63 | "A centered satellite photo of Pasture Land would look like large green fields with animals grazing on them.",
64 | "A centered satellite photo of Pasture Land would look like a large green field with some areas of brown or bare earth in between.",
65 | "A centered satellite photo of Pasture Land would look like a large green field broken up by areas of trees, bushes, or other foliage.",
66 | "A centered satellite photo of Pasture Land would look like large green fields with small areas of brown or bare earth in between.",
67 | "A centered satellite photo of Pasture Land would look like a large green field with small patches of brown or bare earth in between.",
68 | "A centered satellite photo of Pasture Land would look like a large green or tan field with small patches of brown or bare earth in between.",
69 | "A centered satellite photo of Pasture Land would look like a large green or brown field with small patches of different colors in between.",
70 | "A centered satellite photo of Pasture Land would look like a large field of green with small brown or black spots (cows).",
71 | "A centered satellite photo of Pasture Land would look like large green fields with some areas of brown or bare earth in between.",
72 | "A centered satellite photo of Pasture Land would look like large green fields with small patches of brown or bare earth in between."
73 | ],
74 | "Permanent Crop Land": [
75 | "A centered satellite photo of Permanent Crop Land would look like a large field with different colors depending on what crop is being grown.",
76 | "A centered satellite photo of Permanent Crop Land would look like a large green field with small patches of brown or bare earth in between.",
77 | "A centered satellite photo of Permanent Crop Land would look like a large green field with smaller, more uniform green patches in between.",
78 | "A centered satellite photo of Permanent Crop Land would look like a green field with small patches of brown earth or water in between.",
79 | "A centered satellite photo of Permanent Crop Land would look like a large green field with a few smaller green or brown fields in between.",
80 | "A centered satellite photo of Permanent Crop Land would look like a large green field with small patches of brown or bare earth in between, and there would also be small patches of different colors representing different types of permanent crops.",
81 | "A centered satellite photo of Permanent Crop Land would look like a large green field with small patches of brown or bare earth in between.",
82 | "A centered satellite photo of Permanent Crop Land would look like a mosaic of different colors, depending on the type of crop being grown.",
83 | "A centered satellite photo of Permanent Crop Land would look like a similar green field, however the patches of brown or bare earth would be much smaller, as there is less open land in between crops.",
84 | "A centered satellite photo of Permanent Crop Land would look like a large green or brown field with small patches of bare earth in between."
85 | ],
86 | "Residential Buildings": [
87 | "A centered satellite photo of Residential Buildings would look like a city with tall buildings in the center and smaller buildings on the outskirts.",
88 | "A centered satellite photo of Residential Buildings would look like a city with large buildings and concrete roads. A centered satellite photo of a Commercial Harbor would look like a harbor with many boats and a few warehouses.",
89 | "A centered satellite photo of Residential Buildings would look like many small rectangular buildings that are close together with some green space in between them.",
90 | "A centered satellite photo of Residential Buildings would look like a lot of small buildings close together with some green space in between them.",
91 | "A centered satellite photo of Residential Buildings would look like a city with areas of green trees and parks throughout.",
92 | "A centered satellite photo of Residential Buildings would look like a city with tall buildings in the center and lower buildings or houses on the outskirts.",
93 | "A centered satellite photo of Residential Buildings would look like a bunch of small squares with a variety of colors.",
94 | "A centered satellite photo of Residential Buildings would look like a large number of small, square or rectangular shaped buildings with large open spaces in between.",
95 | "A centered satellite photo of Residential Buildings would look like a large number of small, square or rectangular buildings with small patches of green or bare earth in between.",
96 | "A centered satellite photo of Residential Buildings would look like a small city with many houses and buildings."
97 | ],
98 | "River": [
99 | "A centered satellite photo of River Delta would look like a large mass of water with small islands or patches of land in between.",
100 | "A centered satellite photo of River would look like many small streams or rivers flowing through a larger body of water.",
101 | "A centered satellite photo of River would look like a long, thin blue line with small tributaries branching off of it.",
102 | "A centered satellite photo of River would look like a thin blue line winding through a larger green area.",
103 | "A centered satellite photo of River would look like a long, thin blue or green line winding its way through a landscape.",
104 | "A centered satellite photo of River would look like a large blue or green body of water with smaller tributaries feeding into it.",
105 | "A centered satellite photo of River would look like a large blue body of water with small patches of green or brown land on either side.",
106 | "A centered satellite photo of River Delta would look like a series of branching streams or rivers flowing into a larger body of water.",
107 | "A centered satellite photo of River would look like a long, thin body of water with trees or other landforms surrounding it.",
108 | "A centered satellite photo of River Delta would look like a large body of water with many small waterways flowing into it."
109 | ],
110 | "Sea or Lake": [
111 | "A centered satellite photo of Sea or Lake would look like a large blue circle with small patches of green, white, or brown around the edge.",
112 | "A centered satellite photo of Sea or Lake Ice would look like a large white or blue field with small patches of ocean water in between.",
113 | "A centered satellite photo of Sea or Lake would look like a large dark blue body of water with small white or light-colored areas around the edge.",
114 | "A centered satellite photo of Sea or Lake Ice would look like a large body of white with small patches of blue in between.",
115 | "A centered satellite photo of Sea or Lake ice would look like a large white or light blue field with small patches of dark blue or black in between.",
116 | "A centered satellite photo of Sea or Lake would look like a large blue or green body of water with small islands in it.",
117 | "A centered satellite photo of Sea or Lake would look like a large dark blue body with small areas of whitecaps where the waves are crashing.",
118 | "A centered satellite photo of Sea or Lake ice would look like large white fields with small patches of dark water in between.",
119 | "A centered satellite photo of Sea or Lake ice would look like large white areas with smaller areas of dark water in between.",
120 | "A centered satellite photo of Sea or Lake Ice would look like a large white or light blue area with bits of dark blue in the middle."
121 | ]
122 | }
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import argparse
4 | import yaml
5 | from tqdm import tqdm
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | import torch.nn as nn
10 | import torchvision.transforms as transforms
11 | from torchvision import models as torchvision_models
12 |
13 | from datasets import build_dataset
14 | from datasets.utils import build_data_loader
15 | import clip
16 | from utils import *
17 | import dino.utils as utils
18 | import itertools
19 | import json
20 |
21 | def get_arguments():
22 |
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument('--config', dest='config', help='settings of Tip-Adapter in yaml format')
25 | args = parser.parse_args()
26 |
27 | return args
28 |
29 | def run_ensemble_tip_dalle_adapter_F(cfg,
30 | clip_cache_keys,
31 | clip_cache_values,
32 | clip_val_features,
33 | clip_test_features,
34 | dino_cache_keys,
35 | dino_cache_values,
36 | dino_val_features,
37 | dino_test_features,
38 | val_labels,
39 | test_labels,
40 | clip_weights,
41 | clip_model,
42 | dino_model,
43 | train_loader_F,
44 | dalle_train_loader_F):
45 |
46 | # Enable the cached keys to be learnable
47 | clip_adapter = nn.Linear(clip_cache_keys.shape[0], clip_cache_keys.shape[1], bias=False).to(clip_model.dtype).cuda()
48 | clip_adapter.weight = nn.Parameter(clip_cache_keys.t())
49 | dino_adapter = nn.Linear(dino_cache_keys.shape[0], dino_cache_keys.shape[1], bias=False).to(clip_model.dtype).cuda()
50 | dino_adapter.weight = nn.Parameter(dino_cache_keys.t())
51 |
52 | optimizer = torch.optim.AdamW(
53 | itertools.chain(dino_adapter.parameters(), clip_adapter.parameters()),
54 | lr=cfg['lr'],
55 | eps=1e-4)
56 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, cfg['train_epoch'] * len(train_loader_F))
57 |
58 | beta, alpha = cfg['init_beta'], cfg['init_alpha']
59 | best_acc, best_epoch = 0.0, 0
60 |
61 | for train_idx in range(cfg['train_epoch']):
62 | # Train
63 | clip_adapter.train()
64 | dino_adapter.train()
65 | correct_samples, all_samples = 0, 0
66 | loss_list = []
67 | print('Train Epoch: {:} / {:}'.format(train_idx, cfg['train_epoch']))
68 |
69 | # origin image
70 | for i, (images, target) in enumerate(tqdm(train_loader_F)):
71 | images, target = images.cuda(), target.cuda()
72 | with torch.no_grad():
73 | clip_image_features = clip_model.encode_image(images)
74 | clip_image_features /= clip_image_features.norm(dim=-1, keepdim=True)
75 | dino_image_features = dino_model(images)
76 | dino_image_features /= dino_image_features.norm(dim=-1, keepdim=True)
77 |
78 | clip_affinity = clip_adapter(clip_image_features)
79 | clip_cache_logits = ((-1) * (beta - beta * clip_affinity)).exp() @ clip_cache_values
80 | dino_affinity = dino_adapter(dino_image_features).to(dino_cache_values.dtype)
81 | dino_cache_logits = ((-1) * (beta - beta * dino_affinity)).exp() @ dino_cache_values
82 | clip_logits = 100. * clip_image_features @ clip_weights
83 |
84 | cache_logits = logits_fuse(clip_logits, [clip_cache_logits, dino_cache_logits])
85 | tip_logits = clip_logits + cache_logits * alpha
86 | loss = F.cross_entropy(tip_logits, target)
87 |
88 | acc = cls_acc(tip_logits, target)
89 | correct_samples += acc / 100 * len(tip_logits)
90 | all_samples += len(tip_logits)
91 | loss_list.append(loss.item())
92 |
93 | optimizer.zero_grad()
94 | loss.backward()
95 | optimizer.step()
96 | scheduler.step()
97 |
98 | # dalle image
99 | for i, (images, target) in enumerate(tqdm(dalle_train_loader_F)):
100 | images, target = images.cuda(), target.cuda()
101 | with torch.no_grad():
102 | clip_image_features = clip_model.encode_image(images)
103 | clip_image_features /= clip_image_features.norm(dim=-1, keepdim=True)
104 | dino_image_features = dino_model(images)
105 | dino_image_features /= dino_image_features.norm(dim=-1, keepdim=True)
106 |
107 | clip_affinity = clip_adapter(clip_image_features)
108 | clip_cache_logits = ((-1) * (beta - beta * clip_affinity)).exp() @ clip_cache_values
109 | dino_affinity = dino_adapter(dino_image_features).to(dino_cache_values.dtype)
110 | dino_cache_logits = ((-1) * (beta - beta * dino_affinity)).exp() @ dino_cache_values
111 | clip_logits = 100. * clip_image_features @ clip_weights
112 |
113 | cache_logits = logits_fuse(clip_logits, [clip_cache_logits, dino_cache_logits])
114 | tip_logits = clip_logits + cache_logits * alpha
115 | loss = F.cross_entropy(tip_logits, target)
116 |
117 | acc = cls_acc(tip_logits, target)
118 | correct_samples += acc / 100 * len(tip_logits)
119 | all_samples += len(tip_logits)
120 | loss_list.append(loss.item())
121 |
122 | optimizer.zero_grad()
123 | loss.backward()
124 | optimizer.step()
125 | scheduler.step()
126 |
127 | current_lr = scheduler.get_last_lr()[0]
128 | print('LR: {:.6f}, Acc: {:.4f} ({:}/{:}), Loss: {:.4f}'.format(current_lr, correct_samples / all_samples, correct_samples, all_samples, sum(loss_list)/len(loss_list)))
129 |
130 | # Eval
131 | clip_adapter.eval()
132 | dino_adapter.eval()
133 |
134 | clip_affinity = clip_adapter(clip_test_features)
135 | dino_affinity = dino_adapter(dino_test_features).to(dino_cache_values.dtype)
136 | clip_cache_logits = ((-1) * (beta - beta * clip_affinity)).exp() @ clip_cache_values
137 | dino_cache_logits = ((-1) * (beta - beta * dino_affinity)).exp() @ dino_cache_values
138 | clip_logits = 100. * clip_test_features @ clip_weights
139 | cache_logits = logits_fuse(clip_logits, [clip_cache_logits, dino_cache_logits])
140 | tip_logits = clip_logits + cache_logits * alpha
141 | acc = cls_acc(tip_logits, test_labels)
142 |
143 | print("**** CaFo's test accuracy: {:.2f}. ****\n".format(acc))
144 | if acc > best_acc:
145 | best_acc = acc
146 | best_epoch = train_idx
147 | torch.save(clip_adapter.weight, cfg['cache_dir'] + "/best_F_clip_adapter_" + str(cfg['shots']) + "shots.pt")
148 | torch.save(dino_adapter.weight, cfg['cache_dir'] + "/best_F_dino_adapter_" + str(cfg['shots']) + "shots.pt")
149 |
150 | clip_adapter.weight = torch.load(cfg['cache_dir'] + "/best_F_clip_adapter_" + str(cfg['shots']) + "shots.pt")
151 | dino_adapter.weight = torch.load(cfg['cache_dir'] + "/best_F_dino_adapter_" + str(cfg['shots']) + "shots.pt")
152 | print(f"**** After fine-tuning, CaFo's best test accuracy: {best_acc:.2f}, at epoch: {best_epoch}. ****\n")
153 |
154 | print("\n-------- Searching hyperparameters on the val set. --------")
155 |
156 | # Search Hyperparameters
157 | best_beta, best_alpha = best_beta, best_alpha = search_ensemble_hp(cfg, clip_cache_keys, clip_cache_values, clip_val_features, dino_cache_keys, dino_cache_values, dino_val_features, val_labels, clip_weights)
158 |
159 | print("\n-------- Evaluating on the test set. --------")
160 |
161 | clip_affinity = clip_adapter(clip_test_features)
162 | dino_affinity = dino_adapter(dino_test_features).to(dino_cache_values.dtype)
163 | clip_cache_logits = ((-1) * (best_beta - best_beta * clip_affinity)).exp() @ clip_cache_values
164 | dino_cache_logits = ((-1) * (best_beta - best_beta * dino_affinity)).exp() @ dino_cache_values
165 |
166 | cache_logits = logits_fuse(clip_logits, [clip_cache_logits, dino_cache_logits])
167 | tip_logits = clip_logits + cache_logits * best_alpha
168 | acc = cls_acc(tip_logits, test_labels)
169 | print("**** CaFo's test accuracy: {:.2f}. ****\n".format(max(best_acc, acc)))
170 |
171 | def main():
172 |
173 | # Load config file
174 | args = get_arguments()
175 | assert (os.path.exists(args.config))
176 |
177 | cfg = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
178 |
179 | cache_dir = os.path.join('./caches', cfg['dataset'])
180 | os.makedirs(cache_dir, exist_ok=True)
181 | cfg['cache_dir'] = cache_dir
182 |
183 | print("\nRunning configs.")
184 | print(cfg, "\n")
185 |
186 | # CLIP
187 | clip_model, preprocess = clip.load(cfg['clip_backbone'])
188 | clip_model.eval()
189 |
190 | # DINO
191 | dino_model = torchvision_models.__dict__[cfg['dino_backbone']](num_classes=0)
192 | dino_model.fc = nn.Identity()
193 | dino_model.cuda()
194 | utils.load_pretrained_weights(dino_model, "dino/dino_resnet50_pretrain.pth", "teacher", "vit_small'", 16)
195 | dino_model.eval()
196 |
197 | # Prepare dataset
198 | random.seed(1)
199 | torch.manual_seed(1)
200 |
201 | print("Preparing dataset.")
202 | dataset = build_dataset(cfg['dataset'], cfg['root_path'], cfg['shots'])
203 |
204 | val_loader = build_data_loader(data_source=dataset.val, batch_size=64, is_train=False, tfm=preprocess, shuffle=False)
205 | test_loader = build_data_loader(data_source=dataset.test, batch_size=64, is_train=False, tfm=preprocess, shuffle=False)
206 |
207 | train_tranform = transforms.Compose([
208 | transforms.RandomResizedCrop(size=224, scale=(0.5, 1), interpolation=transforms.InterpolationMode.BICUBIC),
209 | transforms.RandomHorizontalFlip(p=0.5),
210 | transforms.ToTensor(),
211 | transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
212 | ])
213 |
214 | train_loader_cache = build_data_loader(data_source=dataset.train_x, batch_size=256, tfm=train_tranform, is_train=True, shuffle=False)
215 | train_loader_F = build_data_loader(data_source=dataset.train_x, batch_size=256, tfm=train_tranform, is_train=True, shuffle=True)
216 |
217 | dalle_dataset = build_dataset(cfg['dalle_dataset'], cfg['root_path'], cfg['dalle_shots'])
218 | dalle_train_loader_cache = build_data_loader(data_source=dalle_dataset.train_x, batch_size=256, tfm=train_tranform, is_train=True, shuffle=False)
219 | dalle_train_loader_F = build_data_loader(data_source=dalle_dataset.train_x, batch_size=256, tfm=train_tranform, is_train=True, shuffle=True)
220 |
221 | with open(cfg['gpt3_prompt_file']) as f:
222 | gpt3_prompt = json.load(f)
223 |
224 | # Textual features
225 | print("\nGetting textual features as CLIP's classifier.")
226 | #clip_weights = clip_classifier(dataset.classnames, dataset.template, clip_model)
227 | clip_weights = gpt_clip_classifier(dataset.classnames, gpt3_prompt, clip_model, dataset.template)
228 |
229 | # Construct the cache model by few-shot training set
230 | print("\nConstructing cache model by few-shot visual features and labels.")
231 | print("\nConstructing CLIP cache model.")
232 | clip_cache_keys, clip_cache_values = build_clip_cache_model(cfg, clip_model, train_loader_cache)
233 | print("\nConstructing DINO cache model.")
234 | dino_cache_keys, dino_cache_values = build_dino_cache_model(cfg, dino_model, train_loader_cache)
235 |
236 | print("\nConstructing cache model by dalle image.")
237 | print("\nConstructing CLIP cache model.")
238 | clip_dalle_cache_keys, clip_dalle_cache_values = build_clip_dalle_cache_model(cfg, clip_model, dalle_train_loader_cache)
239 | print("\nConstructing DINO cache model.")
240 | dino_dalle_cache_keys, dino_dalle_cache_values = build_dino_dalle_cache_model(cfg, dino_model, dalle_train_loader_cache)
241 |
242 | # Pre-load val features
243 | print("\nLoading visual features and labels from val set.")
244 | print("\nLoading CLIP feature.")
245 | val_clip_features, val_labels = pre_CLIP_load_features(cfg, "val", clip_model, val_loader)
246 | print("\nLoading DINO feature.")
247 | val_dino_features, val_labels = pre_DINO_load_features(cfg, "val", dino_model, val_loader)
248 |
249 | # Pre-load test features
250 | print("\nLoading visual features and labels from test set.")
251 | print("\nLoading CLIP feature.")
252 | test_clip_features, test_labels = pre_CLIP_load_features(cfg, "test", clip_model, test_loader)
253 | print("\nLoading DINO feature.")
254 | test_dino_features, test_labels = pre_DINO_load_features(cfg, "test", dino_model, test_loader)
255 |
256 | # ------------------------------------------ Tip-Adapter-F ------------------------------------------
257 |
258 | run_ensemble_tip_dalle_adapter_F(cfg,
259 | torch.cat((clip_cache_keys, clip_dalle_cache_keys), dim=1),
260 | torch.cat((clip_cache_values, clip_dalle_cache_values), dim=0),
261 | val_clip_features,
262 | test_clip_features,
263 | torch.cat((dino_cache_keys, dino_dalle_cache_keys), dim=1),
264 | torch.cat((dino_cache_values, dino_dalle_cache_values), dim=0),
265 | val_dino_features,
266 | test_dino_features,
267 | val_labels,
268 | test_labels,
269 | clip_weights,
270 | clip_model,
271 | dino_model,
272 | train_loader_F,
273 | dalle_train_loader_F)
274 |
275 | if __name__ == '__main__':
276 | main()
--------------------------------------------------------------------------------
/main_imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import argparse
4 | import yaml
5 | from tqdm import tqdm
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | import torch.nn as nn
10 | import torchvision.transforms as transforms
11 | from torchvision import models as torchvision_models
12 |
13 | from datasets.imagenet import ImageNet
14 | from datasets import build_dataset
15 | from datasets.utils import build_data_loader
16 | import clip
17 | from utils import *
18 | import dino.utils as utils
19 | import itertools
20 | import json
21 |
22 |
23 | def get_arguments():
24 |
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument('--config', dest='config', help='settings of Tip-Adapter in yaml format')
27 | args = parser.parse_args()
28 |
29 | return args
30 |
31 | def run_ensemble_tip_dalle_adapter_F(cfg,
32 | clip_cache_keys,
33 | clip_cache_values,
34 | clip_test_features,
35 | dino_cache_keys,
36 | dino_cache_values,
37 | dino_test_features,
38 | test_labels,
39 | clip_weights,
40 | clip_model,
41 | dino_model,
42 | train_loader_F,
43 | dalle_train_loader_F):
44 |
45 | # Enable the cached keys to be learnable
46 | clip_adapter = nn.Linear(clip_cache_keys.shape[0], clip_cache_keys.shape[1], bias=False).to(clip_model.dtype).cuda()
47 | clip_adapter.weight = nn.Parameter(clip_cache_keys.t())
48 | dino_adapter = nn.Linear(dino_cache_keys.shape[0], dino_cache_keys.shape[1], bias=False).to(clip_model.dtype).cuda()
49 | dino_adapter.weight = nn.Parameter(dino_cache_keys.t())
50 |
51 | optimizer = torch.optim.AdamW(
52 | itertools.chain(dino_adapter.parameters(), clip_adapter.parameters()),
53 | lr=cfg['lr'],
54 | eps=1e-4)
55 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, cfg['train_epoch'] * len(train_loader_F))
56 |
57 | beta, alpha = cfg['init_beta'], cfg['init_alpha']
58 | best_acc, best_epoch = 0.0, 0
59 |
60 | for train_idx in range(cfg['train_epoch']):
61 | # Train
62 | clip_adapter.train()
63 | dino_adapter.train()
64 | correct_samples, all_samples = 0, 0
65 | loss_list = []
66 | print('Train Epoch: {:} / {:}'.format(train_idx, cfg['train_epoch']))
67 |
68 | # origin image
69 | for i, (images, target) in enumerate(tqdm(train_loader_F)):
70 | images, target = images.cuda(), target.cuda()
71 | with torch.no_grad():
72 | clip_image_features = clip_model.encode_image(images)
73 | clip_image_features /= clip_image_features.norm(dim=-1, keepdim=True)
74 | dino_image_features = dino_model(images)
75 | dino_image_features /= dino_image_features.norm(dim=-1, keepdim=True)
76 |
77 | clip_affinity = clip_adapter(clip_image_features)
78 | clip_cache_logits = ((-1) * (beta - beta * clip_affinity)).exp() @ clip_cache_values
79 | dino_affinity = dino_adapter(dino_image_features).to(dino_cache_values.dtype)
80 | dino_cache_logits = ((-1) * (beta - beta * dino_affinity)).exp() @ dino_cache_values
81 | clip_logits = 100. * clip_image_features @ clip_weights
82 |
83 | cache_logits = logits_fuse(clip_logits, [clip_cache_logits, dino_cache_logits])
84 | tip_logits = clip_logits + cache_logits * alpha
85 | loss = F.cross_entropy(tip_logits, target)
86 |
87 | acc = cls_acc(tip_logits, target)
88 | correct_samples += acc / 100 * len(tip_logits)
89 | all_samples += len(tip_logits)
90 | loss_list.append(loss.item())
91 |
92 | optimizer.zero_grad()
93 | loss.backward()
94 | optimizer.step()
95 | scheduler.step()
96 |
97 | # dalle image
98 | for i, (images, target) in enumerate(tqdm(dalle_train_loader_F)):
99 | images, target = images.cuda(), target.cuda()
100 | with torch.no_grad():
101 | clip_image_features = clip_model.encode_image(images)
102 | clip_image_features /= clip_image_features.norm(dim=-1, keepdim=True)
103 | dino_image_features = dino_model(images)
104 | dino_image_features /= dino_image_features.norm(dim=-1, keepdim=True)
105 |
106 | clip_affinity = clip_adapter(clip_image_features)
107 | clip_cache_logits = ((-1) * (beta - beta * clip_affinity)).exp() @ clip_cache_values
108 | dino_affinity = dino_adapter(dino_image_features).to(dino_cache_values.dtype)
109 | dino_cache_logits = ((-1) * (beta - beta * dino_affinity)).exp() @ dino_cache_values
110 | clip_logits = 100. * clip_image_features @ clip_weights
111 |
112 | cache_logits = logits_fuse(clip_logits, [clip_cache_logits, dino_cache_logits])
113 | tip_logits = clip_logits + cache_logits * alpha
114 | loss = F.cross_entropy(tip_logits, target)
115 |
116 | acc = cls_acc(tip_logits, target)
117 | correct_samples += acc / 100 * len(tip_logits)
118 | all_samples += len(tip_logits)
119 | loss_list.append(loss.item())
120 |
121 | optimizer.zero_grad()
122 | loss.backward()
123 | optimizer.step()
124 | scheduler.step()
125 |
126 | current_lr = scheduler.get_last_lr()[0]
127 | print('LR: {:.6f}, Acc: {:.4f} ({:}/{:}), Loss: {:.4f}'.format(current_lr, correct_samples / all_samples, correct_samples, all_samples, sum(loss_list)/len(loss_list)))
128 |
129 | # Eval
130 | clip_adapter.eval()
131 | dino_adapter.eval()
132 |
133 | clip_affinity = clip_adapter(clip_test_features)
134 | dino_affinity = dino_adapter(dino_test_features).to(dino_cache_values.dtype)
135 | clip_cache_logits = ((-1) * (beta - beta * clip_affinity)).exp() @ clip_cache_values
136 | dino_cache_logits = ((-1) * (beta - beta * dino_affinity)).exp() @ dino_cache_values
137 | clip_logits = 100. * clip_test_features @ clip_weights
138 | cache_logits = logits_fuse(clip_logits, [clip_cache_logits, dino_cache_logits])
139 | tip_logits = clip_logits + cache_logits * alpha
140 | acc = cls_acc(tip_logits, test_labels)
141 |
142 | print("**** CaFo's test accuracy: {:.2f}. ****\n".format(acc))
143 | if acc > best_acc:
144 | best_acc = acc
145 | best_epoch = train_idx
146 | torch.save(clip_adapter.weight, cfg['cache_dir'] + "/best_F_clip_adapter_" + str(cfg['shots']) + "shots.pt")
147 | torch.save(dino_adapter.weight, cfg['cache_dir'] + "/best_F_dino_adapter_" + str(cfg['shots']) + "shots.pt")
148 |
149 | clip_adapter.weight = torch.load(cfg['cache_dir'] + "/best_F_clip_adapter_" + str(cfg['shots']) + "shots.pt")
150 | dino_adapter.weight = torch.load(cfg['cache_dir'] + "/best_F_dino_adapter_" + str(cfg['shots']) + "shots.pt")
151 | print(f"**** After fine-tuning, CaFo's best test accuracy: {best_acc:.2f}, at epoch: {best_epoch}. ****\n")
152 |
153 | del clip_logits, tip_logits, cache_logits, clip_cache_logits, dino_cache_logits, clip_affinity, dino_affinity
154 | # Search Hyperparameters
155 | # _ = search_hp(cfg, affinity, clip_cache_values, clip_test_features, test_labels, clip_weights, clip_adapter=adapter)
156 | best_beta, best_alpha = search_ensemble_hp(cfg, clip_cache_keys, clip_cache_values, clip_test_features, dino_cache_keys, dino_cache_values, dino_test_features, test_labels, clip_weights, clip_adapter=clip_adapter, dino_adapter=dino_adapter)
157 | clip_affinity = clip_adapter(clip_test_features)
158 | dino_affinity = dino_adapter(dino_test_features).to(dino_cache_values.dtype)
159 | clip_cache_logits = ((-1) * (best_beta - best_beta * clip_affinity)).exp() @ clip_cache_values
160 | dino_cache_logits = ((-1) * (best_beta - best_beta * dino_affinity)).exp() @ dino_cache_values
161 | clip_logits = 100. * clip_test_features @ clip_weights
162 | cache_logits = logits_fuse(clip_logits, [clip_cache_logits, dino_cache_logits])
163 | tip_logits = clip_logits + cache_logits * best_alpha
164 | print("save logits!!!!!!!!!!!!!")
165 | torch.save(tip_logits, cfg['cache_dir'] + "/best_tip_dino_dalle_logits_" + str(cfg['shots']) + "shots.pt")
166 |
167 | def main():
168 |
169 | # Load config file
170 | args = get_arguments()
171 | assert (os.path.exists(args.config))
172 |
173 | cfg = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
174 |
175 | cache_dir = os.path.join('./caches', cfg['dataset'])
176 | os.makedirs(cache_dir, exist_ok=True)
177 | cfg['cache_dir'] = cache_dir
178 |
179 | print("\nRunning configs.")
180 | print(cfg, "\n")
181 |
182 | # CLIP
183 | clip_model, preprocess = clip.load(cfg['clip_backbone'])
184 | clip_model.eval()
185 |
186 | # DINO
187 | dino_model = torchvision_models.__dict__[cfg['dino_backbone']](num_classes=0)
188 | dino_model.fc = nn.Identity()
189 | dino_model.cuda()
190 | utils.load_pretrained_weights(dino_model, "dino/dino_resnet50_pretrain.pth", "teacher", "vit_small'", 16)
191 | dino_model.eval()
192 |
193 | # ImageNet dataset
194 | random.seed(2)
195 | torch.manual_seed(1)
196 |
197 | print("Preparing ImageNet dataset.")
198 | imagenet = ImageNet(cfg['root_path'], cfg['shots'], preprocess)
199 |
200 | test_loader = torch.utils.data.DataLoader(imagenet.test, batch_size=64, num_workers=8, shuffle=False)
201 |
202 | train_loader_cache = torch.utils.data.DataLoader(imagenet.train, batch_size=256, num_workers=8, shuffle=False)
203 | train_loader_F = torch.utils.data.DataLoader(imagenet.train, batch_size=256, num_workers=8, shuffle=True)
204 |
205 | dalle_dataset = build_dataset(cfg['dalle_dataset'], cfg['root_path'], cfg['dalle_shots'])
206 | train_tranform = transforms.Compose([
207 | transforms.RandomResizedCrop(size=224, scale=(0.5, 1), interpolation=transforms.InterpolationMode.BICUBIC),
208 | transforms.RandomHorizontalFlip(p=0.5),
209 | transforms.ToTensor(),
210 | transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
211 | ])
212 | dalle_train_loader_cache = build_data_loader(data_source=dalle_dataset.train_x, batch_size=256, tfm=train_tranform, is_train=True, shuffle=False)
213 | dalle_train_loader_F = build_data_loader(data_source=dalle_dataset.train_x, batch_size=256, tfm=train_tranform, is_train=True, shuffle=True)
214 |
215 | with open(cfg['gpt3_prompt_file']) as f:
216 | gpt3_prompt = json.load(f)
217 |
218 | # Textual features
219 | print("Getting textual features as CLIP's classifier.")
220 | clip_weights = gpt_clip_classifier(imagenet.classnames, gpt3_prompt, clip_model, imagenet.template)
221 |
222 |
223 | # Construct the cache model by few-shot training set
224 | print("\nConstructing cache model by few-shot visual features and labels.")
225 | print("\nConstructing CLIP cache model.")
226 | clip_cache_keys, clip_cache_values = build_clip_cache_model(cfg, clip_model, train_loader_cache)
227 | print("\nConstructing DINO cache model.")
228 | dino_cache_keys, dino_cache_values = build_dino_cache_model(cfg, dino_model, train_loader_cache)
229 |
230 | print("\nConstructing cache model by dalle image.")
231 | print("\nConstructing CLIP cache model.")
232 | clip_dalle_cache_keys, clip_dalle_cache_values = build_clip_dalle_cache_model(cfg, clip_model, dalle_train_loader_cache)
233 | print("\nConstructing DINO cache model.")
234 | dino_dalle_cache_keys, dino_dalle_cache_values = build_dino_dalle_cache_model(cfg, dino_model, dalle_train_loader_cache)
235 |
236 | # Pre-load test features
237 | print("\nLoading visual features and labels from test set.")
238 | print("\nLoading CLIP feature.")
239 | test_clip_features, test_labels = pre_CLIP_load_features(cfg, "test", clip_model, test_loader)
240 | print("\nLoading DINO feature.")
241 | test_dino_features, test_labels = pre_DINO_load_features(cfg, "test", dino_model, test_loader)
242 |
243 | # ------------------------------------------ Tip-Adapter-F ------------------------------------------
244 |
245 | run_ensemble_tip_dalle_adapter_F(cfg,
246 | torch.cat((clip_cache_keys, clip_dalle_cache_keys), dim=1),
247 | torch.cat((clip_cache_values, clip_dalle_cache_values), dim=0),
248 | test_clip_features,
249 | torch.cat((dino_cache_keys, dino_dalle_cache_keys), dim=1),
250 | torch.cat((dino_cache_values, dino_dalle_cache_values), dim=0),
251 | test_dino_features,
252 | test_labels,
253 | clip_weights,
254 | clip_model,
255 | dino_model,
256 | train_loader_F,
257 | dalle_train_loader_F)
258 |
259 | if __name__ == '__main__':
260 | main()
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | flake8==3.7.9
2 | yapf==0.29.0
3 | isort==4.3.21
4 | yacs
5 | gdown
6 | tb-nightly
7 | future
8 | scipy
9 | scikit-learn
10 | tqdm
11 | ftfy
12 | regex
13 | wilds==1.2.2
14 | tabulate
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | import torch.nn as nn
6 |
7 | import clip
8 |
9 |
10 | def cls_acc(output, target, topk=1):
11 | pred = output.topk(topk, 1, True, True)[1].t()
12 | correct = pred.eq(target.view(1, -1).expand_as(pred))
13 | acc = float(correct[: topk].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
14 | acc = 100 * acc / target.shape[0]
15 | return acc
16 |
17 | def gpt_clip_classifier(classnames, gpt_prompts, clip_model, template):
18 | with torch.no_grad():
19 | clip_weights = []
20 | for classname in classnames:
21 | # Tokenize the prompts
22 | classname = classname.replace('_', ' ')
23 | texts = []
24 | for t in gpt_prompts[classname]:
25 | texts.append(t)
26 | texts = clip.tokenize(texts).cuda()
27 | # prompt ensemble for ImageNet
28 | class_embeddings = clip_model.encode_text(texts)
29 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
30 | class_embedding = class_embeddings.mean(dim=0)
31 | class_embedding /= class_embedding.norm()
32 | clip_weights.append(class_embedding)
33 |
34 | clip_weights = torch.stack(clip_weights, dim=1).cuda()
35 | return clip_weights
36 |
37 |
38 | def build_clip_cache_model(cfg, clip_model, train_loader_cache):
39 |
40 | if cfg['load_cache'] == False:
41 | cache_keys = []
42 | cache_values = []
43 |
44 | with torch.no_grad():
45 | # Data augmentation for the cache model
46 | for augment_idx in range(cfg['augment_epoch']):
47 | train_features = []
48 |
49 | print('Augment Epoch: {:} / {:}'.format(augment_idx, cfg['augment_epoch']))
50 | for i, (images, target) in enumerate(tqdm(train_loader_cache)):
51 | images = images.cuda()
52 | image_features = clip_model.encode_image(images)
53 | train_features.append(image_features)
54 | if augment_idx == 0:
55 | target = target.cuda()
56 | cache_values.append(target)
57 | cache_keys.append(torch.cat(train_features, dim=0).unsqueeze(0))
58 |
59 | cache_keys = torch.cat(cache_keys, dim=0).mean(dim=0)
60 | cache_keys /= cache_keys.norm(dim=-1, keepdim=True)
61 | cache_keys = cache_keys.permute(1, 0)
62 | cache_values = F.one_hot(torch.cat(cache_values, dim=0)).half()
63 |
64 | torch.save(cache_keys, cfg['cache_dir'] + '/clip_keys_' + str(cfg['shots']) + "shots.pt")
65 | torch.save(cache_values, cfg['cache_dir'] + '/clip_values_' + str(cfg['shots']) + "shots.pt")
66 |
67 | else:
68 | cache_keys = torch.load(cfg['cache_dir'] + '/clip_keys_' + str(cfg['shots']) + "shots.pt")
69 | cache_values = torch.load(cfg['cache_dir'] + '/clip_values_' + str(cfg['shots']) + "shots.pt")
70 |
71 | return cache_keys, cache_values
72 |
73 | def build_dino_cache_model(cfg, dino_model, train_loader_cache):
74 |
75 | if cfg['load_cache'] == False:
76 | cache_keys = []
77 | cache_values = []
78 |
79 | with torch.no_grad():
80 | # Data augmentation for the cache model
81 | for augment_idx in range(cfg['augment_epoch']):
82 | train_features = []
83 |
84 | print('Augment Epoch: {:} / {:}'.format(augment_idx, cfg['augment_epoch']))
85 | for i, (images, target) in enumerate(tqdm(train_loader_cache)):
86 | images = images.cuda()
87 | image_features = dino_model(images)
88 | train_features.append(image_features)
89 | if augment_idx == 0:
90 | target = target.cuda()
91 | cache_values.append(target)
92 | cache_keys.append(torch.cat(train_features, dim=0).unsqueeze(0))
93 |
94 | cache_keys = torch.cat(cache_keys, dim=0).mean(dim=0)
95 | cache_keys /= cache_keys.norm(dim=-1, keepdim=True)
96 | cache_keys = cache_keys.permute(1, 0)
97 | cache_values = F.one_hot(torch.cat(cache_values, dim=0)).half()
98 |
99 | torch.save(cache_keys, cfg['cache_dir'] + '/dino_keys_' + str(cfg['shots']) + "shots.pt")
100 | torch.save(cache_values, cfg['cache_dir'] + '/dino_values_' + str(cfg['shots']) + "shots.pt")
101 |
102 | else:
103 | cache_keys = torch.load(cfg['cache_dir'] + '/dino_keys_' + str(cfg['shots']) + "shots.pt")
104 | cache_values = torch.load(cfg['cache_dir'] + '/dino_values_' + str(cfg['shots']) + "shots.pt")
105 |
106 | return cache_keys, cache_values
107 |
108 | def build_clip_dalle_cache_model(cfg, clip_model, train_loader_cache):
109 |
110 | if cfg['load_cache'] == False:
111 | cache_keys = []
112 | cache_values = []
113 |
114 | with torch.no_grad():
115 | # Data augmentation for the cache model
116 | for augment_idx in range(cfg['augment_epoch']):
117 | train_features = []
118 |
119 | print('Augment Epoch: {:} / {:}'.format(augment_idx, cfg['augment_epoch']))
120 | for i, (images, target) in enumerate(tqdm(train_loader_cache)):
121 | images = images.cuda()
122 | image_features = clip_model.encode_image(images)
123 | train_features.append(image_features)
124 | if augment_idx == 0:
125 | target = target.cuda()
126 | cache_values.append(target)
127 | cache_keys.append(torch.cat(train_features, dim=0).unsqueeze(0))
128 |
129 | cache_keys = torch.cat(cache_keys, dim=0).mean(dim=0)
130 | cache_keys /= cache_keys.norm(dim=-1, keepdim=True)
131 | cache_keys = cache_keys.permute(1, 0)
132 | cache_values = F.one_hot(torch.cat(cache_values, dim=0)).half()
133 |
134 | torch.save(cache_keys, cfg['cache_dir'] + '/clip_dalle_keys_' + str(cfg['dalle_shots']) + "shots.pt")
135 | torch.save(cache_values, cfg['cache_dir'] + '/clip_dalle_values_' + str(cfg['dalle_shots']) + "shots.pt")
136 |
137 | else:
138 | cache_keys = torch.load(cfg['cache_dir'] + '/clip_dalle_keys_' + str(cfg['dalle_shots']) + "shots.pt")
139 | cache_values = torch.load(cfg['cache_dir'] + '/clip_dalle_values_' + str(cfg['dalle_shots']) + "shots.pt")
140 |
141 | return cache_keys, cache_values
142 |
143 | def build_dino_dalle_cache_model(cfg, dino_model, train_loader_cache):
144 |
145 | if cfg['load_cache'] == False:
146 | cache_keys = []
147 | cache_values = []
148 |
149 | with torch.no_grad():
150 | # Data augmentation for the cache model
151 | for augment_idx in range(cfg['augment_epoch']):
152 | train_features = []
153 |
154 | print('Augment Epoch: {:} / {:}'.format(augment_idx, cfg['augment_epoch']))
155 | for i, (images, target) in enumerate(tqdm(train_loader_cache)):
156 | images = images.cuda()
157 | image_features = dino_model(images)
158 | train_features.append(image_features)
159 | if augment_idx == 0:
160 | target = target.cuda()
161 | cache_values.append(target)
162 | cache_keys.append(torch.cat(train_features, dim=0).unsqueeze(0))
163 |
164 | cache_keys = torch.cat(cache_keys, dim=0).mean(dim=0)
165 | cache_keys /= cache_keys.norm(dim=-1, keepdim=True)
166 | cache_keys = cache_keys.permute(1, 0)
167 | cache_values = F.one_hot(torch.cat(cache_values, dim=0)).half()
168 |
169 | torch.save(cache_keys, cfg['cache_dir'] + '/dino_dalle_keys_' + str(cfg['dalle_shots']) + "shots.pt")
170 | torch.save(cache_values, cfg['cache_dir'] + '/dino_dalle_values_' + str(cfg['dalle_shots']) + "shots.pt")
171 |
172 | else:
173 | cache_keys = torch.load(cfg['cache_dir'] + '/dino_dalle_keys_' + str(cfg['dalle_shots']) + "shots.pt")
174 | cache_values = torch.load(cfg['cache_dir'] + '/dino_dalle_values_' + str(cfg['dalle_shots']) + "shots.pt")
175 |
176 | return cache_keys, cache_values
177 |
178 |
179 | def pre_CLIP_load_features(cfg, split, clip_model, loader):
180 |
181 | if cfg['load_pre_feat'] == False:
182 | features, labels = [], []
183 |
184 | with torch.no_grad():
185 | for i, (images, target) in enumerate(tqdm(loader)):
186 | images, target = images.cuda(), target.cuda()
187 | image_features = clip_model.encode_image(images)
188 | image_features /= image_features.norm(dim=-1, keepdim=True)
189 | features.append(image_features)
190 | labels.append(target)
191 |
192 | features, labels = torch.cat(features), torch.cat(labels)
193 |
194 | torch.save(features, cfg['cache_dir'] + "/" + split + "_clip_f.pt")
195 | torch.save(labels, cfg['cache_dir'] + "/" + split + "_clip_l.pt")
196 |
197 | else:
198 | features = torch.load(cfg['cache_dir'] + "/" + split + "_clip_f.pt")
199 | labels = torch.load(cfg['cache_dir'] + "/" + split + "_clip_l.pt")
200 |
201 | return features, labels
202 |
203 |
204 | def pre_DINO_load_features(cfg, split, dino_model, loader):
205 |
206 | if cfg['load_pre_feat'] == False:
207 | features, labels = [], []
208 |
209 | with torch.no_grad():
210 | for i, (images, target) in enumerate(tqdm(loader)):
211 | images, target = images.cuda(), target.cuda()
212 | image_features = dino_model(images)
213 | image_features /= image_features.norm(dim=-1, keepdim=True)
214 | features.append(image_features)
215 | labels.append(target)
216 |
217 | features, labels = torch.cat(features), torch.cat(labels)
218 |
219 | torch.save(features, cfg['cache_dir'] + "/" + split + "_dino_f.pt")
220 | torch.save(labels, cfg['cache_dir'] + "/" + split + "_dino_l.pt")
221 |
222 | else:
223 | features = torch.load(cfg['cache_dir'] + "/" + split + "_dino_f.pt")
224 | labels = torch.load(cfg['cache_dir'] + "/" + split + "_dino_l.pt")
225 |
226 | return features, labels
227 |
228 |
229 | def search_hp(cfg, cache_keys, cache_values, features, labels, clip_weights, adapter=None):
230 |
231 | if cfg['search_hp'] == True:
232 |
233 | beta_list = [i * (cfg['search_scale'][0] - 0.1) / cfg['search_step'][0] + 0.1 for i in range(cfg['search_step'][0])]
234 | alpha_list = [i * (cfg['search_scale'][1] - 0.1) / cfg['search_step'][1] + 0.1 for i in range(cfg['search_step'][1])]
235 |
236 | best_acc = 0
237 | best_beta, best_alpha = 0, 0
238 |
239 | for beta in beta_list:
240 | for alpha in alpha_list:
241 | if adapter:
242 | affinity = adapter(features)
243 | else:
244 | affinity = features @ cache_keys
245 |
246 | cache_logits = ((-1) * (beta - beta * affinity)).exp() @ cache_values
247 | clip_logits = 100. * features @ clip_weights
248 | tip_logits = clip_logits + cache_logits * alpha
249 | acc = cls_acc(tip_logits, labels)
250 |
251 | if acc > best_acc:
252 | print("New best setting, beta: {:.2f}, alpha: {:.2f}; accuracy: {:.2f}".format(beta, alpha, acc))
253 | best_acc = acc
254 | best_beta = beta
255 | best_alpha = alpha
256 |
257 | print("\nAfter searching, the best accuarcy: {:.2f}.\n".format(best_acc))
258 |
259 | return best_beta, best_alpha
260 |
261 | def search_no_clip_hp(cfg, cache_keys, cache_values, features, labels, adapter=None):
262 |
263 | if cfg['search_hp'] == True:
264 |
265 | beta_list = [i * (cfg['search_scale'][0] - 0.1) / cfg['search_step'][0] + 0.1 for i in range(cfg['search_step'][0])]
266 | alpha_list = [i * (cfg['search_scale'][1] - 0.1) / cfg['search_step'][1] + 0.1 for i in range(cfg['search_step'][1])]
267 |
268 | best_acc = 0
269 | best_beta, best_alpha = 0, 0
270 |
271 | for beta in beta_list:
272 | for alpha in alpha_list:
273 | if adapter:
274 | affinity = adapter(features).to(torch.float16)
275 | else:
276 | affinity = features @ cache_keys
277 |
278 | cache_logits = ((-1) * (beta - beta * affinity)).exp() @ cache_values
279 | # clip_logits = 100. * features @ clip_weights
280 | # tip_logits = clip_logits + cache_logits * alpha
281 | tip_logits = cache_logits
282 | acc = cls_acc(tip_logits, labels)
283 |
284 | if acc > best_acc:
285 | print("New best setting, beta: {:.2f}, alpha: {:.2f}; accuracy: {:.2f}".format(beta, alpha, acc))
286 | best_acc = acc
287 | best_beta = beta
288 | best_alpha = alpha
289 |
290 | print("\nAfter searching, the best accuarcy: {:.2f}.\n".format(best_acc))
291 |
292 | return best_beta, best_alpha
293 |
294 |
295 | def search_ensemble_hp(cfg,
296 | clip_cache_keys,
297 | clip_cache_values,
298 | clip_features,
299 | dino_cache_keys,
300 | dino_cache_values,
301 | dino_features,
302 | labels,
303 | clip_weights,
304 | clip_adapter=None,
305 | dino_adapter=None):
306 |
307 | if cfg['search_hp'] == True:
308 |
309 | beta_list = [i * (cfg['search_scale'][0] - 0.1) / cfg['search_step'][0] + 0.1 for i in range(cfg['search_step'][0])]
310 | alpha_list = [i * (cfg['search_scale'][1] - 0.1) / cfg['search_step'][1] + 0.1 for i in range(cfg['search_step'][1])]
311 |
312 | best_acc = 0
313 | best_beta, best_alpha = 0, 0
314 |
315 | for beta in beta_list:
316 | for alpha in alpha_list:
317 | if clip_adapter:
318 | clip_affinity = clip_adapter(clip_features)
319 | dino_affinity = dino_adapter(dino_features).to(dino_cache_values)
320 | else:
321 | clip_affinity = clip_features @ clip_cache_keys
322 | dino_affinity = (dino_features @ dino_cache_keys).to(dino_cache_values)
323 |
324 | clip_cache_logits = ((-1) * (beta - beta * clip_affinity)).exp() @ clip_cache_values
325 | dino_cache_logits = ((-1) * (beta - beta * dino_affinity)).exp() @ dino_cache_values
326 | clip_logits = 100. * clip_features @ clip_weights
327 | cache_logits = logits_fuse(clip_logits, [clip_cache_logits, dino_cache_logits])
328 | tip_logits = clip_logits + cache_logits * alpha
329 | acc = cls_acc(tip_logits, labels)
330 |
331 | if acc > best_acc:
332 | print("New best setting, beta: {:.2f}, alpha: {:.2f}; accuracy: {:.2f}".format(beta, alpha, acc))
333 | best_acc = acc
334 | best_beta = beta
335 | best_alpha = alpha
336 |
337 | print("\nAfter searching, the best accuarcy: {:.2f}.\n".format(best_acc))
338 | with open("best.txt","w") as f:
339 | f.write("After searching, the best accuarcy: {:.2f}.\n".format(best_acc))
340 | return best_beta, best_alpha
341 |
342 |
343 | # clip zero_shot as baseline
344 | def logits_fuse(zero_logtis, logits, normalize='mean'):
345 | # normalize logits
346 | softmax_fun = nn.Softmax(dim=1)
347 | if normalize == 'softmax':
348 | zero_logtis = softmax_fun(zero_logtis)
349 | elif normalize =='linear':
350 | zero_logtis /= torch.norm(zero_logtis, p=2, dim=1, keepdim=True)
351 | elif normalize == 'mean':
352 | logits_std = torch.std(zero_logtis, dim=1, keepdim=True)
353 | logits_mean = torch.mean(zero_logtis, dim=1, keepdim=True)
354 | zero_logtis = (zero_logtis - logits_mean) / logits_std
355 | else:
356 | raise("error normalize!")
357 | similarity_matrix = []
358 | normalize_logits = []
359 | for logit in logits:
360 | if normalize == 'softmax':
361 | current_normalize_logits = softmax_fun(logit)
362 | elif normalize =='linear':
363 | current_normalize_logits = logit / torch.norm(logit, p=2, dim=1, keepdim=True)
364 | elif normalize == 'mean':
365 | logits_std = torch.std(logit, dim=1, keepdim=True)
366 | logits_mean = torch.mean(logit, dim=1, keepdim=True)
367 | current_normalize_logits = (logit - logits_mean) / logits_std
368 | else:
369 | raise("error normalize!")
370 | current_similarity = current_normalize_logits * zero_logtis
371 | current_similarity = torch.sum(current_similarity, dim=1, keepdim=True)
372 | similarity_matrix.append(current_similarity)
373 | normalize_logits.append(current_normalize_logits)
374 | similarity_matrix = torch.stack(similarity_matrix, dim=-2)
375 | similarity_matrix = softmax_fun(similarity_matrix)
376 | normalize_logits = torch.stack(normalize_logits, dim=-2)
377 | result_logits = torch.sum(normalize_logits * similarity_matrix, dim=1)
378 |
379 | return result_logits
380 | def logits_fuse_s(zero_logtis, logits, normalize='mean'):
381 | # normalize logits
382 | softmax_fun = nn.Softmax(dim=1)
383 | if normalize == 'softmax':
384 | zero_logtis = softmax_fun(zero_logtis)
385 | elif normalize =='linear':
386 | zero_logtis /= torch.norm(zero_logtis, p=2, dim=1, keepdim=True)
387 | elif normalize == 'mean':
388 | logits_std = torch.std(zero_logtis, dim=1, keepdim=True)
389 | logits_mean = torch.mean(zero_logtis, dim=1, keepdim=True)
390 | zero_logtis = (zero_logtis - logits_mean) / logits_std
391 | else:
392 | raise("error normalize!")
393 | similarity_matrix = []
394 | normalize_logits = []
395 | for logit in logits:
396 | if normalize == 'softmax':
397 | current_normalize_logits = softmax_fun(logit)
398 | elif normalize =='linear':
399 | current_normalize_logits = logit / torch.norm(logit, p=2, dim=1, keepdim=True)
400 | elif normalize == 'mean':
401 | logits_std = torch.std(logit, dim=1, keepdim=True)
402 | logits_mean = torch.mean(logit, dim=1, keepdim=True)
403 | current_normalize_logits = (logit - logits_mean) / logits_std
404 | else:
405 | raise("error normalize!")
406 | current_similarity = current_normalize_logits * zero_logtis
407 | current_similarity = torch.sum(current_similarity, dim=1, keepdim=True)
408 | similarity_matrix.append(current_similarity)
409 | normalize_logits.append(current_normalize_logits)
410 | similarity_matrix = torch.stack(similarity_matrix, dim=-2)
411 | similarity_matrix = softmax_fun(similarity_matrix)
412 | count = 0
413 | for i in similarity_matrix:
414 | if i[0]>0.4 and i[0]<0.6:
415 | count += 1
416 | normalize_logits = torch.stack(normalize_logits, dim=-2)
417 | result_logits = torch.sum(normalize_logits * similarity_matrix, dim=1)
418 |
419 | return result_logits, count
420 |
--------------------------------------------------------------------------------