├── .gitignore ├── DATASET.md ├── README.md ├── clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── clip.py ├── model.py └── simple_tokenizer.py ├── coop-configs ├── caltech101.yaml ├── dtd.yaml ├── eurosat.yaml ├── fgvc.yaml ├── food101.yaml ├── imagenet.yaml ├── oxford_flowers.yaml ├── oxford_pets.yaml ├── stanford_cars.yaml ├── sun397.yaml └── ucf101.yaml ├── datasets ├── __init__.py ├── caltech101.py ├── dtd.py ├── eurosat.py ├── fgvc.py ├── food101.py ├── imagenet.py ├── oxford_flowers.py ├── oxford_pets.py ├── stanford_cars.py ├── sun397.py ├── ucf101.py └── utils.py ├── main_coop_vae.py ├── main_imagenet_coop_vae.py ├── requirements.txt ├── scripts └── coop_vae.sh └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | _Store 2 | .idea 3 | __pycache__/ 4 | venv* 5 | *.pyc 6 | *.log 7 | .ipynb_checkpoints 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /DATASET.md: -------------------------------------------------------------------------------- 1 | # How to install datasets 2 | 3 | We suggest putting all datasets under the same folder (say `$DATA`) to ease management and following the instructions below to organize datasets to avoid modifying the source code. The file structure looks like 4 | 5 | ``` 6 | $DATA/ 7 | |–– imagenet/ 8 | |–– caltech-101/ 9 | |–– oxford_pets/ 10 | |–– stanford_cars/ 11 | ``` 12 | 13 | If you have some datasets already installed somewhere else, you can create symbolic links in `$DATA/dataset_name` that point to the original data to avoid duplicate download. 14 | 15 | Datasets list: 16 | - [ImageNet](#imagenet) 17 | - [Caltech101](#caltech101) 18 | - [OxfordPets](#oxfordpets) 19 | - [StanfordCars](#stanfordcars) 20 | - [Flowers102](#flowers102) 21 | - [Food101](#food101) 22 | - [FGVCAircraft](#fgvcaircraft) 23 | - [SUN397](#sun397) 24 | - [DTD](#dtd) 25 | - [EuroSAT](#eurosat) 26 | - [UCF101](#ucf101) 27 | 28 | The instructions to prepare each dataset are detailed below. To ensure reproducibility and fair comparison for future work, we utilize CoOp-style train/val/test splits for all datasets except ImageNet where the validation set is used as test set. 29 | 30 | ### ImageNet 31 | - Create a folder named `imagenet/` under `$DATA`. 32 | - Create `images/` under `imagenet/`. 33 | - Download the dataset from the [official website](https://image-net.org/index.php) and extract the training and validation sets to `$DATA/imagenet/images`. The directory structure should look like 34 | ``` 35 | imagenet/ 36 | |–– images/ 37 | | |–– train/ # contains 1,000 folders like n01440764, n01443537, etc. 38 | | |–– val/ 39 | ``` 40 | - If you had downloaded the ImageNet dataset before, you can create symbolic links to map the training and validation sets to `$DATA/imagenet/images`. 41 | - Download the `classnames.txt` to `$DATA/imagenet/` from this [link](https://drive.google.com/file/d/1-61f_ol79pViBFDG_IDlUQSwoLcn2XXF/view?usp=sharing). The class names are copied from [CLIP](https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb). 42 | 43 | ### Caltech101 44 | - Create a folder named `caltech-101/` under `$DATA`. 45 | - Download `101_ObjectCategories.tar.gz` from http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz and extract the file under `$DATA/caltech-101`. 46 | - Download `split_zhou_Caltech101.json` from this [link](https://drive.google.com/file/d/1hyarUivQE36mY6jSomru6Fjd-JzwcCzN/view?usp=sharing) and put it under `$DATA/caltech-101`. 47 | 48 | The directory structure should look like 49 | ``` 50 | caltech-101/ 51 | |–– 101_ObjectCategories/ 52 | |–– split_zhou_Caltech101.json 53 | ``` 54 | 55 | ### OxfordPets 56 | - Create a folder named `oxford_pets/` under `$DATA`. 57 | - Download the images from https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz. 58 | - Download the annotations from https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz. 59 | - Download `split_zhou_OxfordPets.json` from this [link](https://drive.google.com/file/d/1501r8Ber4nNKvmlFVQZ8SeUHTcdTTEqs/view?usp=sharing). 60 | 61 | The directory structure should look like 62 | ``` 63 | oxford_pets/ 64 | |–– images/ 65 | |–– annotations/ 66 | |–– split_zhou_OxfordPets.json 67 | ``` 68 | 69 | ### StanfordCars 70 | - Create a folder named `stanford_cars/` under `$DATA`. 71 | - Download the train images http://ai.stanford.edu/~jkrause/car196/cars_train.tgz. 72 | - Download the test images http://ai.stanford.edu/~jkrause/car196/cars_test.tgz. 73 | - Download the train labels https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz. 74 | - Download the test labels http://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat. 75 | - Download `split_zhou_StanfordCars.json` from this [link](https://drive.google.com/file/d/1ObCFbaAgVu0I-k_Au-gIUcefirdAuizT/view?usp=sharing). 76 | 77 | The directory structure should look like 78 | ``` 79 | stanford_cars/ 80 | |–– cars_test\ 81 | |–– cars_test_annos_withlabels.mat 82 | |–– cars_train\ 83 | |–– devkit\ 84 | |–– split_zhou_StanfordCars.json 85 | ``` 86 | 87 | ### Flowers102 88 | - Create a folder named `oxford_flowers/` under `$DATA`. 89 | - Download the images and labels from https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz and https://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat respectively. 90 | - Download `cat_to_name.json` from [here](https://drive.google.com/file/d/1AkcxCXeK_RCGCEC_GvmWxjcjaNhu-at0/view?usp=sharing). 91 | - Download `split_zhou_OxfordFlowers.json` from [here](https://drive.google.com/file/d/1Pp0sRXzZFZq15zVOzKjKBu4A9i01nozT/view?usp=sharing). 92 | 93 | The directory structure should look like 94 | ``` 95 | oxford_flowers/ 96 | |–– cat_to_name.json 97 | |–– imagelabels.mat 98 | |–– jpg/ 99 | |–– split_zhou_OxfordFlowers.json 100 | ``` 101 | 102 | ### Food101 103 | - Download the dataset from https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/ and extract the file `food-101.tar.gz` under `$DATA`, resulting in a folder named `$DATA/food-101/`. 104 | - Download `split_zhou_Food101.json` from [here](https://drive.google.com/file/d/1QK0tGi096I0Ba6kggatX1ee6dJFIcEJl/view?usp=sharing). 105 | 106 | The directory structure should look like 107 | ``` 108 | food-101/ 109 | |–– images/ 110 | |–– license_agreement.txt 111 | |–– meta/ 112 | |–– README.txt 113 | |–– split_zhou_Food101.json 114 | ``` 115 | 116 | ### FGVCAircraft 117 | - Download the data from https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz. 118 | - Extract `fgvc-aircraft-2013b.tar.gz` and keep only `data/`. 119 | - Move `data/` to `$DATA` and rename the folder to `fgvc_aircraft/`. 120 | 121 | The directory structure should look like 122 | ``` 123 | fgvc_aircraft/ 124 | |–– images/ 125 | |–– ... # a bunch of .txt files 126 | ``` 127 | 128 | ### SUN397 129 | - Create a folder named `sun397/` under `$DATA`. 130 | - Download the images http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz. 131 | - Download the partitions https://vision.princeton.edu/projects/2010/SUN/download/Partitions.zip. 132 | - Extract these files under `$DATA/sun397/`. 133 | - Download `split_zhou_SUN397.json` from this [link](https://drive.google.com/file/d/1y2RD81BYuiyvebdN-JymPfyWYcd8_MUq/view?usp=sharing). 134 | 135 | The directory structure should look like 136 | ``` 137 | sun397/ 138 | |–– SUN397/ 139 | |–– split_zhou_SUN397.json 140 | |–– ... # a bunch of .txt files 141 | ``` 142 | 143 | ### DTD 144 | - Download the dataset from https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz and extract it to `$DATA`. This should lead to `$DATA/dtd/`. 145 | - Download `split_zhou_DescribableTextures.json` from this [link](https://drive.google.com/file/d/1u3_QfB467jqHgNXC00UIzbLZRQCg2S7x/view?usp=sharing). 146 | 147 | The directory structure should look like 148 | ``` 149 | dtd/ 150 | |–– images/ 151 | |–– imdb/ 152 | |–– labels/ 153 | |–– split_zhou_DescribableTextures.json 154 | ``` 155 | 156 | ### EuroSAT 157 | - Create a folder named `eurosat/` under `$DATA`. 158 | - Download the dataset from http://madm.dfki.de/files/sentinel/EuroSAT.zip and extract it to `$DATA/eurosat/`. 159 | - Download `split_zhou_EuroSAT.json` from [here](https://drive.google.com/file/d/1Ip7yaCWFi0eaOFUGga0lUdVi_DDQth1o/view?usp=sharing). 160 | 161 | The directory structure should look like 162 | ``` 163 | eurosat/ 164 | |–– 2750/ 165 | |–– split_zhou_EuroSAT.json 166 | ``` 167 | 168 | ### UCF101 169 | - Create a folder named `ucf101/` under `$DATA`. 170 | - Download the zip file `UCF-101-midframes.zip` from [here](https://drive.google.com/file/d/10Jqome3vtUA2keJkNanAiFpgbyC9Hc2O/view?usp=sharing) and extract it to `$DATA/ucf101/`. This zip file contains the extracted middle video frames. 171 | - Download `split_zhou_UCF101.json` from this [link](https://drive.google.com/file/d/1I0S0q91hJfsV9Gf4xDIjgDq4AqBNJb1y/view?usp=sharing). 172 | 173 | The directory structure should look like 174 | ``` 175 | ucf101/ 176 | |–– UCF-101-midframes/ 177 | |–– split_zhou_UCF101.json 178 | ``` 179 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Improving Zero-Shot Generalization for CLIP with Synthesized Prompts 2 | 3 | Official implementation of [Improving Zero-Shot Generalization for CLIP with Synthesized Prompts](https://arxiv.org/abs/2307.07397). 4 | 5 | This paper has been accepted by **ICCV 2023**. 6 | 7 | ## Requirements 8 | ### Installation 9 | Create a conda environment and install dependencies: 10 | ``` 11 | conda create -n ship python=3.9 12 | conda activate ship 13 | 14 | pip install -r requirements.txt 15 | 16 | # Install the according versions of torch and torchvision 17 | conda install pytorch torchvision cudatoolkit 18 | ``` 19 | 20 | ### Dataset 21 | Follow [DATASET.md](DATASET.md) to install ImageNet and other 10 datasets referring to CoOp. 22 | 23 | ## Get Started 24 | ### Configs 25 | The running configurations can be modified in `coop-configs/dataset.yaml`, including shot numbers, visual encoders, and hyperparamters. 26 | 27 | ### Running 28 | For ImageNet dataset: 29 | ```bash 30 | CUDA_VISIBLE_DEVICES=0 python main_imagenet_coop_vae.py --config configs/imagenet.yaml 31 | ``` 32 | For other 10 datasets: 33 | ```bash 34 | CUDA_VISIBLE_DEVICES=0 python main_coop_vae.py --config configs/dataset.yaml 35 | ``` 36 | 37 | ## Acknowledgement 38 | 39 | This repo benefits from [CLIP](https://github.com/openai/CLIP), [CoOp](https://github.com/KaiyangZhou/Dassl.pytorch) and [Tip-Adapter](https://github.com/gaopengcuhk/Tip-Adapter). Thanks for their wonderful works. 40 | 41 | ## Citation 42 | 43 | ``` 44 | @inproceedings{wang2023improving, 45 | title={Improving Zero-Shot Generalization for CLIP with Synthesized Prompts}, 46 | author={Zhengbo Wang and Jian Liang and Ran He and Nan Xu and Zilei Wang and Tieniu Tan}, 47 | author={Wang, Zhengbo and Liang, Jian and He, Ran and Xu, Nan and Wang, Zilei and Tan, Tieniu}, 48 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 49 | month={October}, 50 | year={2023}, 51 | pages={3032-3042} 52 | } 53 | ``` 54 | 55 | ## Contact 56 | 57 | If you have any questions, feel free to contact zhengbowang@mail.ustc.edu.cn. 58 | -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrflogs/SHIP/53b3fcba9e4337f373c19e1cfdb98477d683ea18/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | try: 16 | from torchvision.transforms import InterpolationMode 17 | BICUBIC = InterpolationMode.BICUBIC 18 | except ImportError: 19 | BICUBIC = Image.BICUBIC 20 | 21 | 22 | if torch.__version__.split(".") < ["1", "7", "1"]: 23 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 24 | 25 | 26 | __all__ = ["available_models", "load", "tokenize"] 27 | _tokenizer = _Tokenizer() 28 | 29 | _MODELS = { 30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 34 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 35 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 36 | } 37 | 38 | 39 | def _download(url: str, root: str): 40 | os.makedirs(root, exist_ok=True) 41 | filename = os.path.basename(url) 42 | 43 | expected_sha256 = url.split("/")[-2] 44 | download_target = os.path.join(root, filename) 45 | 46 | if os.path.exists(download_target) and not os.path.isfile(download_target): 47 | raise RuntimeError(f"{download_target} exists and is not a regular file") 48 | 49 | if os.path.isfile(download_target): 50 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 51 | return download_target 52 | else: 53 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 54 | 55 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 56 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 57 | while True: 58 | buffer = source.read(8192) 59 | if not buffer: 60 | break 61 | 62 | output.write(buffer) 63 | loop.update(len(buffer)) 64 | 65 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 66 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 67 | 68 | return download_target 69 | 70 | 71 | def _convert_image_to_rgb(image): 72 | return image.convert("RGB") 73 | 74 | 75 | def _transform(n_px): 76 | return Compose([ 77 | Resize(n_px, interpolation=BICUBIC), 78 | CenterCrop(n_px), 79 | _convert_image_to_rgb, 80 | ToTensor(), 81 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 82 | ]) 83 | 84 | 85 | def available_models() -> List[str]: 86 | """Returns the names of available CLIP models""" 87 | return list(_MODELS.keys()) 88 | 89 | 90 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 91 | """Load a CLIP model 92 | 93 | Parameters 94 | ---------- 95 | name : str 96 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 97 | 98 | device : Union[str, torch.device] 99 | The device to put the loaded model 100 | 101 | jit : bool 102 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 103 | 104 | download_root: str 105 | path to download the model files; by default, it uses "~/.cache/clip" 106 | 107 | Returns 108 | ------- 109 | model : torch.nn.Module 110 | The CLIP model 111 | 112 | preprocess : Callable[[PIL.Image], torch.Tensor] 113 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 114 | """ 115 | if name in _MODELS: 116 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 117 | elif os.path.isfile(name): 118 | model_path = name 119 | else: 120 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 121 | 122 | try: 123 | # loading JIT archive 124 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 125 | state_dict = None 126 | except RuntimeError: 127 | # loading saved state dict 128 | if jit: 129 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 130 | jit = False 131 | state_dict = torch.load(model_path, map_location="cpu") 132 | 133 | if not jit: 134 | model = build_model(state_dict or model.state_dict()).to(device) 135 | if str(device) == "cpu": 136 | model.float() 137 | return model, _transform(model.visual.input_resolution) 138 | 139 | # patch the device names 140 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 141 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 142 | 143 | def patch_device(module): 144 | try: 145 | graphs = [module.graph] if hasattr(module, "graph") else [] 146 | except RuntimeError: 147 | graphs = [] 148 | 149 | if hasattr(module, "forward1"): 150 | graphs.append(module.forward1.graph) 151 | 152 | for graph in graphs: 153 | for node in graph.findAllNodes("prim::Constant"): 154 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 155 | node.copyAttributes(device_node) 156 | 157 | model.apply(patch_device) 158 | patch_device(model.encode_image) 159 | patch_device(model.encode_text) 160 | 161 | # patch dtype to float32 on CPU 162 | if str(device) == "cpu": 163 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 164 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 165 | float_node = float_input.node() 166 | 167 | def patch_float(module): 168 | try: 169 | graphs = [module.graph] if hasattr(module, "graph") else [] 170 | except RuntimeError: 171 | graphs = [] 172 | 173 | if hasattr(module, "forward1"): 174 | graphs.append(module.forward1.graph) 175 | 176 | for graph in graphs: 177 | for node in graph.findAllNodes("aten::to"): 178 | inputs = list(node.inputs()) 179 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 180 | if inputs[i].node()["value"] == 5: 181 | inputs[i].node().copyAttributes(float_node) 182 | 183 | model.apply(patch_float) 184 | patch_float(model.encode_image) 185 | patch_float(model.encode_text) 186 | 187 | model.float() 188 | 189 | return model, _transform(model.input_resolution.item()) 190 | 191 | 192 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 193 | """ 194 | Returns the tokenized representation of given input string(s) 195 | 196 | Parameters 197 | ---------- 198 | texts : Union[str, List[str]] 199 | An input string or a list of input strings to tokenize 200 | 201 | context_length : int 202 | The context length to use; all CLIP models use 77 as the context length 203 | 204 | truncate: bool 205 | Whether to truncate the text in case its encoding is longer than the context length 206 | 207 | Returns 208 | ------- 209 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 210 | """ 211 | if isinstance(texts, str): 212 | texts = [texts] 213 | 214 | sot_token = _tokenizer.encoder["<|startoftext|>"] 215 | eot_token = _tokenizer.encoder["<|endoftext|>"] 216 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 217 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 218 | 219 | for i, tokens in enumerate(all_tokens): 220 | if len(tokens) > context_length: 221 | if truncate: 222 | tokens = tokens[:context_length] 223 | tokens[-1] = eot_token 224 | else: 225 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 226 | result[i, :len(tokens)] = torch.tensor(tokens) 227 | 228 | return result 229 | -------------------------------------------------------------------------------- /clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | 20 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 24 | 25 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 26 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 27 | 28 | self.relu = nn.ReLU(inplace=True) 29 | self.downsample = None 30 | self.stride = stride 31 | 32 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 33 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 34 | self.downsample = nn.Sequential(OrderedDict([ 35 | ("-1", nn.AvgPool2d(stride)), 36 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 37 | ("1", nn.BatchNorm2d(planes * self.expansion)) 38 | ])) 39 | 40 | def forward(self, x: torch.Tensor): 41 | identity = x 42 | 43 | out = self.relu(self.bn1(self.conv1(x))) 44 | out = self.relu(self.bn2(self.conv2(out))) 45 | out = self.avgpool(out) 46 | out = self.bn3(self.conv3(out)) 47 | 48 | if self.downsample is not None: 49 | identity = self.downsample(x) 50 | 51 | out += identity 52 | out = self.relu(out) 53 | return out 54 | 55 | 56 | class AttentionPool2d(nn.Module): 57 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 58 | super().__init__() 59 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 60 | self.k_proj = nn.Linear(embed_dim, embed_dim) 61 | self.q_proj = nn.Linear(embed_dim, embed_dim) 62 | self.v_proj = nn.Linear(embed_dim, embed_dim) 63 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 64 | self.num_heads = num_heads 65 | 66 | def forward(self, x): 67 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 68 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 69 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 70 | x, _ = F.multi_head_attention_forward( 71 | query=x, key=x, value=x, 72 | embed_dim_to_check=x.shape[-1], 73 | num_heads=self.num_heads, 74 | q_proj_weight=self.q_proj.weight, 75 | k_proj_weight=self.k_proj.weight, 76 | v_proj_weight=self.v_proj.weight, 77 | in_proj_weight=None, 78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 79 | bias_k=None, 80 | bias_v=None, 81 | add_zero_attn=False, 82 | dropout_p=0, 83 | out_proj_weight=self.c_proj.weight, 84 | out_proj_bias=self.c_proj.bias, 85 | use_separate_proj_weight=True, 86 | training=self.training, 87 | need_weights=False 88 | ) 89 | 90 | return x[0] 91 | 92 | 93 | class ModifiedResNet(nn.Module): 94 | """ 95 | A ResNet class that is similar to torchvision's but contains the following changes: 96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 98 | - The final pooling layer is a QKV attention instead of an average pool 99 | """ 100 | 101 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 102 | super().__init__() 103 | self.output_dim = output_dim 104 | self.input_resolution = input_resolution 105 | 106 | # the 3-layer stem 107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 108 | self.bn1 = nn.BatchNorm2d(width // 2) 109 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 110 | self.bn2 = nn.BatchNorm2d(width // 2) 111 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 112 | self.bn3 = nn.BatchNorm2d(width) 113 | self.avgpool = nn.AvgPool2d(2) 114 | self.relu = nn.ReLU(inplace=True) 115 | 116 | # residual layers 117 | self._inplanes = width # this is a *mutable* variable used during construction 118 | self.layer1 = self._make_layer(width, layers[0]) 119 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 120 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 121 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 122 | 123 | embed_dim = width * 32 # the ResNet feature dimension 124 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 125 | 126 | def _make_layer(self, planes, blocks, stride=1): 127 | layers = [Bottleneck(self._inplanes, planes, stride)] 128 | 129 | self._inplanes = planes * Bottleneck.expansion 130 | for _ in range(1, blocks): 131 | layers.append(Bottleneck(self._inplanes, planes)) 132 | 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | def stem(x): 137 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 138 | x = self.relu(bn(conv(x))) 139 | x = self.avgpool(x) 140 | return x 141 | 142 | x = x.type(self.conv1.weight.dtype) 143 | x = stem(x) 144 | x = self.layer1(x) 145 | x = self.layer2(x) 146 | x = self.layer3(x) 147 | x = self.layer4(x) 148 | x = self.attnpool(x) 149 | 150 | return x 151 | 152 | 153 | class LayerNorm(nn.LayerNorm): 154 | """Subclass torch's LayerNorm to handle fp16.""" 155 | 156 | def forward(self, x: torch.Tensor): 157 | orig_type = x.dtype 158 | ret = super().forward(x.type(torch.float32)) 159 | return ret.type(orig_type) 160 | 161 | 162 | class QuickGELU(nn.Module): 163 | def forward(self, x: torch.Tensor): 164 | return x * torch.sigmoid(1.702 * x) 165 | 166 | 167 | class ResidualAttentionBlock(nn.Module): 168 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 169 | super().__init__() 170 | 171 | self.attn = nn.MultiheadAttention(d_model, n_head) 172 | self.ln_1 = LayerNorm(d_model) 173 | self.mlp = nn.Sequential(OrderedDict([ 174 | ("c_fc", nn.Linear(d_model, d_model * 4)), 175 | ("gelu", QuickGELU()), 176 | ("c_proj", nn.Linear(d_model * 4, d_model)) 177 | ])) 178 | self.ln_2 = LayerNorm(d_model) 179 | self.attn_mask = attn_mask 180 | 181 | def attention(self, x: torch.Tensor): 182 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 183 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 184 | 185 | def forward(self, x: torch.Tensor): 186 | x = x + self.attention(self.ln_1(x)) 187 | x = x + self.mlp(self.ln_2(x)) 188 | return x 189 | 190 | 191 | class Transformer(nn.Module): 192 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 193 | super().__init__() 194 | self.width = width 195 | self.layers = layers 196 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 197 | 198 | def forward(self, x: torch.Tensor): 199 | return self.resblocks(x) 200 | 201 | 202 | class VisionTransformer(nn.Module): 203 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 204 | super().__init__() 205 | self.input_resolution = input_resolution 206 | self.output_dim = output_dim 207 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 208 | 209 | scale = width ** -0.5 210 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 211 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 212 | self.ln_pre = LayerNorm(width) 213 | 214 | self.transformer = Transformer(width, layers, heads) 215 | 216 | self.ln_post = LayerNorm(width) 217 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 218 | 219 | def forward(self, x: torch.Tensor): 220 | x = self.conv1(x) # shape = [*, width, grid, grid] 221 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 222 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 223 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 224 | x = x + self.positional_embedding.to(x.dtype) 225 | x = self.ln_pre(x) 226 | 227 | x = x.permute(1, 0, 2) # NLD -> LND 228 | x = self.transformer(x) 229 | x = x.permute(1, 0, 2) # LND -> NLD 230 | 231 | x = self.ln_post(x[:, 0, :]) 232 | 233 | if self.proj is not None: 234 | x = x @ self.proj 235 | 236 | return x 237 | 238 | 239 | class CLIP(nn.Module): 240 | def __init__(self, 241 | embed_dim: int, 242 | # vision 243 | image_resolution: int, 244 | vision_layers: Union[Tuple[int, int, int, int], int], 245 | vision_width: int, 246 | vision_patch_size: int, 247 | # text 248 | context_length: int, 249 | vocab_size: int, 250 | transformer_width: int, 251 | transformer_heads: int, 252 | transformer_layers: int 253 | ): 254 | super().__init__() 255 | 256 | self.context_length = context_length 257 | 258 | if isinstance(vision_layers, (tuple, list)): 259 | vision_heads = vision_width * 32 // 64 260 | self.visual = ModifiedResNet( 261 | layers=vision_layers, 262 | output_dim=embed_dim, 263 | heads=vision_heads, 264 | input_resolution=image_resolution, 265 | width=vision_width 266 | ) 267 | else: 268 | vision_heads = vision_width // 64 269 | self.visual = VisionTransformer( 270 | input_resolution=image_resolution, 271 | patch_size=vision_patch_size, 272 | width=vision_width, 273 | layers=vision_layers, 274 | heads=vision_heads, 275 | output_dim=embed_dim 276 | ) 277 | 278 | self.transformer = Transformer( 279 | width=transformer_width, 280 | layers=transformer_layers, 281 | heads=transformer_heads, 282 | attn_mask=self.build_attention_mask() 283 | ) 284 | 285 | self.vocab_size = vocab_size 286 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 287 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 288 | self.ln_final = LayerNorm(transformer_width) 289 | 290 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 291 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 292 | 293 | self.initialize_parameters() 294 | 295 | def initialize_parameters(self): 296 | nn.init.normal_(self.token_embedding.weight, std=0.02) 297 | nn.init.normal_(self.positional_embedding, std=0.01) 298 | 299 | if isinstance(self.visual, ModifiedResNet): 300 | if self.visual.attnpool is not None: 301 | std = self.visual.attnpool.c_proj.in_features ** -0.5 302 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 303 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 304 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 305 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 306 | 307 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 308 | for name, param in resnet_block.named_parameters(): 309 | if name.endswith("bn3.weight"): 310 | nn.init.zeros_(param) 311 | 312 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 313 | attn_std = self.transformer.width ** -0.5 314 | fc_std = (2 * self.transformer.width) ** -0.5 315 | for block in self.transformer.resblocks: 316 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 317 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 318 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 319 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 320 | 321 | if self.text_projection is not None: 322 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 323 | 324 | def build_attention_mask(self): 325 | # lazily create causal attention mask, with full attention between the vision tokens 326 | # pytorch uses additive attention mask; fill with -inf 327 | mask = torch.empty(self.context_length, self.context_length) 328 | mask.fill_(float("-inf")) 329 | mask.triu_(1) # zero out the lower diagonal 330 | return mask 331 | 332 | @property 333 | def dtype(self): 334 | return self.visual.conv1.weight.dtype 335 | 336 | def encode_image(self, image): 337 | return self.visual(image.type(self.dtype)) 338 | 339 | def encode_text(self, text): 340 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 341 | 342 | x = x + self.positional_embedding.type(self.dtype) 343 | x = x.permute(1, 0, 2) # NLD -> LND 344 | x = self.transformer(x) 345 | x = x.permute(1, 0, 2) # LND -> NLD 346 | x = self.ln_final(x).type(self.dtype) 347 | 348 | # x.shape = [batch_size, n_ctx, transformer.width] 349 | # take features from the eot embedding (eot_token is the highest number in each sequence) 350 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 351 | 352 | return x 353 | 354 | def forward(self, image, text): 355 | image_features = self.encode_image(image) 356 | text_features = self.encode_text(text) 357 | 358 | # normalized features 359 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 360 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 361 | 362 | # cosine similarity as logits 363 | logit_scale = self.logit_scale.exp() 364 | logits_per_image = logit_scale * image_features @ text_features.t() 365 | logits_per_text = logits_per_image.t() 366 | 367 | # shape = [global_batch_size, global_batch_size] 368 | return logits_per_image, logits_per_text 369 | 370 | 371 | def convert_weights(model: nn.Module): 372 | """Convert applicable model parameters to fp16""" 373 | 374 | def _convert_weights_to_fp16(l): 375 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 376 | l.weight.data = l.weight.data.half() 377 | if l.bias is not None: 378 | l.bias.data = l.bias.data.half() 379 | 380 | if isinstance(l, nn.MultiheadAttention): 381 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 382 | tensor = getattr(l, attr) 383 | if tensor is not None: 384 | tensor.data = tensor.data.half() 385 | 386 | for name in ["text_projection", "proj"]: 387 | if hasattr(l, name): 388 | attr = getattr(l, name) 389 | if attr is not None: 390 | attr.data = attr.data.half() 391 | 392 | model.apply(_convert_weights_to_fp16) 393 | 394 | 395 | def build_model(state_dict: dict): 396 | vit = "visual.proj" in state_dict 397 | 398 | if vit: 399 | vision_width = state_dict["visual.conv1.weight"].shape[0] 400 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 401 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 402 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 403 | image_resolution = vision_patch_size * grid_size 404 | else: 405 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 406 | vision_layers = tuple(counts) 407 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 408 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 409 | vision_patch_size = None 410 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 411 | image_resolution = output_width * 32 412 | 413 | embed_dim = state_dict["text_projection"].shape[1] 414 | context_length = state_dict["positional_embedding"].shape[0] 415 | vocab_size = state_dict["token_embedding.weight"].shape[0] 416 | transformer_width = state_dict["ln_final.weight"].shape[0] 417 | transformer_heads = transformer_width // 64 418 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 419 | 420 | model = CLIP( 421 | embed_dim, 422 | image_resolution, vision_layers, vision_width, vision_patch_size, 423 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 424 | ) 425 | 426 | for key in ["input_resolution", "context_length", "vocab_size"]: 427 | if key in state_dict: 428 | del state_dict[key] 429 | 430 | convert_weights(model) 431 | model.load_state_dict(state_dict) 432 | return model.eval() 433 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /coop-configs/caltech101.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: 'datasets' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [12, 5] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 3 22 | 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'caltech101' 26 | shots: 16 27 | # backbone: 'RN50' 28 | backbone: 'ViT-B/16' 29 | 30 | lr: 0.001 31 | augment_epoch: 10 32 | train_epoch: 200 -------------------------------------------------------------------------------- /coop-configs/dtd.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: 'datasets' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [13, 13] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 2 22 | 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'dtd' 26 | shots: 16 27 | # backbone: 'RN50' 28 | backbone: 'ViT-B/16' 29 | 30 | lr: 0.001 31 | augment_epoch: 10 32 | train_epoch: 200 -------------------------------------------------------------------------------- /coop-configs/eurosat.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: 'datasets' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [12, 10] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 2 22 | 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'eurosat' 26 | shots: 16 27 | # backbone: 'RN50' 28 | backbone: 'ViT-B/16' 29 | 30 | lr: 0.001 31 | augment_epoch: 10 32 | train_epoch: 200 -------------------------------------------------------------------------------- /coop-configs/fgvc.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: 'datasets' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [30, 30] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 5 22 | 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'fgvc' 26 | shots: 16 27 | # backbone: 'RN50' 28 | backbone: 'ViT-B/16' 29 | 30 | lr: 0.001 31 | augment_epoch: 10 32 | train_epoch: 200 -------------------------------------------------------------------------------- /coop-configs/food101.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: './datasets/' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [10, 10] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 1.17 22 | 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'food101' 26 | shots: 16 27 | # backbone: 'RN50' 28 | backbone: 'ViT-B/16' 29 | 30 | lr: 0.001 31 | augment_epoch: 10 32 | train_epoch: 200 -------------------------------------------------------------------------------- /coop-configs/imagenet.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: 'datasets' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [7, 3] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 1.17 22 | 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'imagenet' 26 | shots: 16 27 | # backbone: 'RN50' 28 | backbone: 'ViT-B/16' 29 | 30 | lr: 0.001 31 | augment_epoch: 10 32 | train_epoch: 20 -------------------------------------------------------------------------------- /coop-configs/oxford_flowers.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: 'datasets' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [50, 50] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 10 22 | 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'oxford_flowers' 26 | shots: 16 27 | # backbone: 'RN50' 28 | backbone: 'ViT-B/16' 29 | 30 | lr: 0.001 31 | augment_epoch: 10 32 | train_epoch: 200 -------------------------------------------------------------------------------- /coop-configs/oxford_pets.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: 'datasets' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [7, 3] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 1.17 22 | 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'oxford_pets' 26 | shots: 16 27 | # backbone: 'RN50' 28 | backbone: "ViT-B/16" 29 | 30 | lr: 0.001 31 | augment_epoch: 10 32 | train_epoch: 200 -------------------------------------------------------------------------------- /coop-configs/stanford_cars.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: 'datasets' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [20, 10] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 3 22 | 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'stanford_cars' 26 | shots: 16 27 | # backbone: 'RN50' 28 | backbone: 'ViT-B/16' 29 | 30 | lr: 0.001 31 | augment_epoch: 10 32 | train_epoch: 200 -------------------------------------------------------------------------------- /coop-configs/sun397.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: 'datasets' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [12, 10] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 1.17 22 | 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'sun397' 26 | shots: 16 27 | # backbone: 'RN50' 28 | backbone: 'ViT-B/16' 29 | 30 | lr: 0.001 31 | augment_epoch: 10 32 | train_epoch: 200 -------------------------------------------------------------------------------- /coop-configs/ucf101.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: 'datasets' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [7, 3] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 3 22 | 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'ucf101' 26 | shots: 16 27 | # backbone: 'RN50' 28 | backbone: 'ViT-B/16' 29 | 30 | lr: 0.001 31 | augment_epoch: 10 32 | train_epoch: 200 -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .oxford_pets import OxfordPets 2 | from .eurosat import EuroSAT 3 | from .ucf101 import UCF101 4 | from .sun397 import SUN397 5 | from .caltech101 import Caltech101 6 | from .dtd import DescribableTextures 7 | from .fgvc import FGVCAircraft 8 | from .food101 import Food101 9 | from .oxford_flowers import OxfordFlowers 10 | from .stanford_cars import StanfordCars 11 | 12 | 13 | dataset_list = { 14 | "oxford_pets": OxfordPets, 15 | "eurosat": EuroSAT, 16 | "ucf101": UCF101, 17 | "sun397": SUN397, 18 | "caltech101": Caltech101, 19 | "dtd": DescribableTextures, 20 | "fgvc": FGVCAircraft, 21 | "food101": Food101, 22 | "oxford_flowers": OxfordFlowers, 23 | "stanford_cars": StanfordCars, 24 | } 25 | 26 | 27 | def build_dataset(cfg, dataset, root_path, shots): 28 | return dataset_list[dataset](cfg, root_path, shots) -------------------------------------------------------------------------------- /datasets/caltech101.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase 4 | from .oxford_pets import OxfordPets 5 | 6 | 7 | template = ['a photo of a {}.'] 8 | 9 | 10 | class Caltech101(DatasetBase): 11 | 12 | dataset_dir = 'caltech-101' 13 | 14 | def __init__(self, cfg, root, num_shots): 15 | self.dataset_dir = os.path.join(root, self.dataset_dir) 16 | self.image_dir = os.path.join(self.dataset_dir, '101_ObjectCategories') 17 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_Caltech101.json') 18 | 19 | self.template = template 20 | 21 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 22 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 23 | 24 | # 25 | subsample = cfg['subsample_classes'] 26 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 27 | 28 | super().__init__(train_x=train, val=val, test=test) -------------------------------------------------------------------------------- /datasets/dtd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | from .utils import Datum, DatasetBase, listdir_nohidden 5 | from .oxford_pets import OxfordPets 6 | 7 | 8 | template = ['{} texture.'] 9 | 10 | 11 | class DescribableTextures(DatasetBase): 12 | 13 | dataset_dir = 'dtd' 14 | 15 | def __init__(self, cfg, root, num_shots): 16 | self.dataset_dir = os.path.join(root, self.dataset_dir) 17 | self.image_dir = os.path.join(self.dataset_dir, 'images') 18 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_DescribableTextures.json') 19 | 20 | self.template = template 21 | 22 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 23 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 24 | 25 | # 26 | subsample = cfg['subsample_classes'] 27 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 28 | 29 | super().__init__(train_x=train, val=val, test=test) 30 | 31 | @staticmethod 32 | def read_and_split_data( 33 | image_dir, 34 | p_trn=0.5, 35 | p_val=0.2, 36 | ignored=[], 37 | new_cnames=None 38 | ): 39 | # The data are supposed to be organized into the following structure 40 | # ============= 41 | # images/ 42 | # dog/ 43 | # cat/ 44 | # horse/ 45 | # ============= 46 | categories = listdir_nohidden(image_dir) 47 | categories = [c for c in categories if c not in ignored] 48 | categories.sort() 49 | 50 | p_tst = 1 - p_trn - p_val 51 | print(f'Splitting into {p_trn:.0%} train, {p_val:.0%} val, and {p_tst:.0%} test') 52 | 53 | def _collate(ims, y, c): 54 | items = [] 55 | for im in ims: 56 | item = Datum( 57 | impath=im, 58 | label=y, # is already 0-based 59 | classname=c 60 | ) 61 | items.append(item) 62 | return items 63 | 64 | train, val, test = [], [], [] 65 | for label, category in enumerate(categories): 66 | category_dir = os.path.join(image_dir, category) 67 | images = listdir_nohidden(category_dir) 68 | images = [os.path.join(category_dir, im) for im in images] 69 | random.shuffle(images) 70 | n_total = len(images) 71 | n_train = round(n_total * p_trn) 72 | n_val = round(n_total * p_val) 73 | n_test = n_total - n_train - n_val 74 | assert n_train > 0 and n_val > 0 and n_test > 0 75 | 76 | if new_cnames is not None and category in new_cnames: 77 | category = new_cnames[category] 78 | 79 | train.extend(_collate(images[:n_train], label, category)) 80 | val.extend(_collate(images[n_train:n_train+n_val], label, category)) 81 | test.extend(_collate(images[n_train+n_val:], label, category)) 82 | 83 | return train, val, test 84 | -------------------------------------------------------------------------------- /datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 4 | from .oxford_pets import OxfordPets 5 | 6 | 7 | template = ['a centered satellite photo of {}.'] 8 | 9 | 10 | NEW_CNAMES = { 11 | 'AnnualCrop': 'Annual Crop Land', 12 | 'Forest': 'Forest', 13 | 'HerbaceousVegetation': 'Herbaceous Vegetation Land', 14 | 'Highway': 'Highway or Road', 15 | 'Industrial': 'Industrial Buildings', 16 | 'Pasture': 'Pasture Land', 17 | 'PermanentCrop': 'Permanent Crop Land', 18 | 'Residential': 'Residential Buildings', 19 | 'River': 'River', 20 | 'SeaLake': 'Sea or Lake' 21 | } 22 | 23 | 24 | class EuroSAT(DatasetBase): 25 | 26 | dataset_dir = 'eurosat' 27 | 28 | def __init__(self, cfg, root, num_shots): 29 | self.dataset_dir = os.path.join(root, self.dataset_dir) 30 | self.image_dir = os.path.join(self.dataset_dir, '2750') 31 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_EuroSAT.json') 32 | 33 | self.template = template 34 | 35 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 36 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 37 | 38 | # 39 | subsample = cfg['subsample_classes'] 40 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 41 | 42 | super().__init__(train_x=train, val=val, test=test) 43 | 44 | def update_classname(self, dataset_old): 45 | dataset_new = [] 46 | for item_old in dataset_old: 47 | cname_old = item_old.classname 48 | cname_new = NEW_CLASSNAMES[cname_old] 49 | item_new = Datum( 50 | impath=item_old.impath, 51 | label=item_old.label, 52 | classname=cname_new 53 | ) 54 | dataset_new.append(item_new) 55 | return dataset_new 56 | -------------------------------------------------------------------------------- /datasets/fgvc.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 4 | from .oxford_pets import OxfordPets 5 | 6 | template = ['a photo of a {}, a type of aircraft.'] 7 | 8 | 9 | class FGVCAircraft(DatasetBase): 10 | 11 | dataset_dir = 'fgvc_aircraft' 12 | 13 | def __init__(self, cfg, root, num_shots): 14 | 15 | self.dataset_dir = os.path.join(root, self.dataset_dir) 16 | self.image_dir = os.path.join(self.dataset_dir, 'images') 17 | 18 | self.template = template 19 | 20 | classnames = [] 21 | with open(os.path.join(self.dataset_dir, 'variants.txt'), 'r') as f: 22 | lines = f.readlines() 23 | for line in lines: 24 | classnames.append(line.strip()) 25 | cname2lab = {c: i for i, c in enumerate(classnames)} 26 | 27 | train = self.read_data(cname2lab, 'images_variant_train.txt') 28 | val = self.read_data(cname2lab, 'images_variant_val.txt') 29 | test = self.read_data(cname2lab, 'images_variant_test.txt') 30 | 31 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 32 | 33 | # 34 | subsample = cfg['subsample_classes'] 35 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 36 | 37 | super().__init__(train_x=train, val=val, test=test) 38 | 39 | def read_data(self, cname2lab, split_file): 40 | filepath = os.path.join(self.dataset_dir, split_file) 41 | items = [] 42 | 43 | with open(filepath, 'r') as f: 44 | lines = f.readlines() 45 | for line in lines: 46 | line = line.strip().split(' ') 47 | imname = line[0] + '.jpg' 48 | classname = ' '.join(line[1:]) 49 | impath = os.path.join(self.image_dir, imname) 50 | label = cname2lab[classname] 51 | item = Datum( 52 | impath=impath, 53 | label=label, 54 | classname=classname 55 | ) 56 | items.append(item) 57 | 58 | return items -------------------------------------------------------------------------------- /datasets/food101.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 4 | from .oxford_pets import OxfordPets 5 | 6 | 7 | template = ['a photo of {}, a type of food.'] 8 | 9 | 10 | class Food101(DatasetBase): 11 | 12 | dataset_dir = 'food-101' 13 | 14 | def __init__(self, cfg, root, num_shots): 15 | self.dataset_dir = os.path.join(root, self.dataset_dir) 16 | self.image_dir = os.path.join(self.dataset_dir, 'images') 17 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_Food101.json') 18 | 19 | self.template = template 20 | 21 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 22 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 23 | 24 | # 25 | subsample = cfg['subsample_classes'] 26 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 27 | 28 | super().__init__(train_x=train, val=val, test=test) -------------------------------------------------------------------------------- /datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import random 4 | from collections import defaultdict 5 | from collections import OrderedDict 6 | 7 | import torch 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | 11 | from .oxford_pets import OxfordPets 12 | from .utils import Datum, DatasetBase 13 | 14 | imagenet_classes = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", 15 | "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", 16 | "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", 17 | "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", 18 | "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", 19 | "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", 20 | "box turtle", "banded gecko", "green iguana", "Carolina anole", 21 | "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", 22 | "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", 23 | "American alligator", "triceratops", "worm snake", "ring-necked snake", 24 | "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", 25 | "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", 26 | "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", 27 | "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", 28 | "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", 29 | "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", 30 | "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", 31 | "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", 32 | "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", 33 | "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", 34 | "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", 35 | "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", 36 | "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", 37 | "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", 38 | "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", 39 | "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", 40 | "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", 41 | "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", 42 | "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", 43 | "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", 44 | "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", 45 | "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", 46 | "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", 47 | "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", 48 | "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", 49 | "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", 50 | "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", 51 | "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", 52 | "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", 53 | "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", 54 | "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", 55 | "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", 56 | "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", 57 | "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", 58 | "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", 59 | "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", 60 | "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", 61 | "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", 62 | "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", 63 | "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", 64 | "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", 65 | "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", 66 | "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", 67 | "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", 68 | "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", 69 | "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", 70 | "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", 71 | "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", 72 | "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", 73 | "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", 74 | "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", 75 | "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", 76 | "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", 77 | "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", 78 | "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", 79 | "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", 80 | "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", 81 | "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", 82 | "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", 83 | "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", 84 | "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", 85 | "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", 86 | "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", 87 | "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", 88 | "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", 89 | "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", 90 | "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", 91 | "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", 92 | "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", 93 | "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", 94 | "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", 95 | "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", 96 | "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", 97 | "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", 98 | "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", 99 | "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", 100 | "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", 101 | "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", 102 | "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", 103 | "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", 104 | "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", 105 | "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", 106 | "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", 107 | "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", 108 | "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", 109 | "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", 110 | "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", 111 | "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", 112 | "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", 113 | "freight car", "French horn", "frying pan", "fur coat", "garbage truck", 114 | "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", 115 | "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", 116 | "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", 117 | "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", 118 | "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", 119 | "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", 120 | "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", 121 | "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", 122 | "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", 123 | "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", 124 | "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", 125 | "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", 126 | "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", 127 | "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", 128 | "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", 129 | "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", 130 | "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", 131 | "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", 132 | "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", 133 | "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", 134 | "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", 135 | "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", 136 | "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", 137 | "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", 138 | "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", 139 | "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", 140 | "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", 141 | "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", 142 | "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", 143 | "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", 144 | "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", 145 | "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", 146 | "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", 147 | "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", 148 | "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", 149 | "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", 150 | "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", 151 | "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", 152 | "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", 153 | "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", 154 | "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", 155 | "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", 156 | "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", 157 | "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", 158 | "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", 159 | "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", 160 | "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", 161 | "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", 162 | "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", 163 | "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", 164 | "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", 165 | "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", 166 | "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", 167 | "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", 168 | "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", 169 | "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", 170 | "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", 171 | "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", 172 | "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", 173 | "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", 174 | "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", 175 | "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", 176 | "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", 177 | "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", 178 | "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"] 179 | 180 | imagenet_templates = ["itap of a {}.", 181 | "a bad photo of the {}.", 182 | "a origami {}.", 183 | "a photo of the large {}.", 184 | "a {} in a video game.", 185 | "art of the {}.", 186 | "a photo of the small {}."] 187 | 188 | def listdir_nohidden(path, sort=False): 189 | """List non-hidden items in a directory. 190 | 191 | Args: 192 | path (str): directory path. 193 | sort (bool): sort the items. 194 | """ 195 | items = [f for f in os.listdir(path) if not f.startswith('.')] 196 | if sort: 197 | items.sort() 198 | return items 199 | 200 | class ImageNet(): 201 | 202 | dataset_dir = 'imagenet' 203 | 204 | def __init__(self, cfg, root, num_shots, preprocess): 205 | 206 | self.dataset_dir = os.path.join(root, self.dataset_dir) 207 | self.image_dir = os.path.join(self.dataset_dir, 'images') 208 | 209 | train_preprocess = transforms.Compose([ 210 | transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0), interpolation=transforms.InterpolationMode.BICUBIC), 211 | transforms.RandomHorizontalFlip(p=0.5), 212 | transforms.ToTensor(), 213 | transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) 214 | ]) 215 | test_preprocess = preprocess 216 | self.train = torchvision.datasets.ImageFolder(os.path.join(self.image_dir, 'train'), transform=train_preprocess) 217 | self.test = torchvision.datasets.ImageFolder(os.path.join(self.image_dir, 'val'), transform=test_preprocess) 218 | 219 | self.template = imagenet_templates 220 | 221 | subsample = cfg['subsample_classes'] 222 | 223 | n = len(imagenet_classes) 224 | # Divide classes into two halves 225 | m = math.ceil(n / 2) 226 | if subsample == 'all': 227 | self.classnames = imagenet_classes 228 | # train 229 | split_by_label_dict = defaultdict(list) 230 | for i in range(len(self.train.imgs)): 231 | split_by_label_dict[self.train.targets[i]].append(self.train.imgs[i]) 232 | imgs = [] 233 | for label, items in split_by_label_dict.items(): 234 | imgs = imgs + random.sample(items, num_shots) 235 | self.train.samples = imgs 236 | elif subsample == 'base': 237 | self.classnames = imagenet_classes[:m] 238 | # train 239 | split_by_label_dict = defaultdict(list) 240 | for i in range(len(self.train.imgs)): 241 | split_by_label_dict[self.train.targets[i]].append(self.train.imgs[i]) 242 | imgs = [] 243 | for label, items in split_by_label_dict.items(): 244 | imgs = imgs + random.sample(items, num_shots) 245 | if label >= m - 1: 246 | break 247 | self.train.samples = imgs 248 | # test 249 | split_by_label_dict = defaultdict(list) 250 | for i in range(len(self.test.imgs)): 251 | split_by_label_dict[self.test.targets[i]].append(self.test.imgs[i]) 252 | imgs = [] 253 | targets = [] 254 | for label, items in split_by_label_dict.items(): 255 | imgs = imgs + items 256 | targets = targets + [label for i in range(len(items))] 257 | if label >= m - 1: 258 | break 259 | self.test.samples = imgs 260 | 261 | elif subsample == 'new': 262 | self.classnames = imagenet_classes[m:] 263 | # train 264 | split_by_label_dict = defaultdict(list) 265 | for i in range(len(self.train.imgs)): 266 | split_by_label_dict[self.train.targets[i]].append(self.train.imgs[i]) 267 | imgs = [] 268 | targets = [] 269 | for label, items in split_by_label_dict.items(): 270 | if label >= m: 271 | imgs = imgs + random.sample(items, num_shots) 272 | self.train.samples = imgs 273 | 274 | # test 275 | split_by_label_dict = defaultdict(list) 276 | for i in range(len(self.test.imgs)): 277 | split_by_label_dict[self.test.targets[i]].append(self.test.imgs[i]) 278 | imgs = [] 279 | for label, items in split_by_label_dict.items(): 280 | if label >= m: 281 | imgs = imgs + items 282 | self.test.samples = imgs 283 | -------------------------------------------------------------------------------- /datasets/oxford_flowers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from scipy.io import loadmat 4 | from collections import defaultdict 5 | 6 | from .oxford_pets import OxfordPets 7 | from .utils import Datum, DatasetBase, read_json 8 | 9 | 10 | template = ['a photo of a {}, a type of flower.'] 11 | 12 | 13 | class OxfordFlowers(DatasetBase): 14 | 15 | dataset_dir = 'oxford_flowers' 16 | 17 | def __init__(self, cfg, root, num_shots): 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, 'jpg') 20 | self.label_file = os.path.join(self.dataset_dir, 'imagelabels.mat') 21 | self.lab2cname_file = os.path.join(self.dataset_dir, 'cat_to_name.json') 22 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_OxfordFlowers.json') 23 | 24 | self.template = template 25 | 26 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 27 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 28 | 29 | # 30 | subsample = cfg['subsample_classes'] 31 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 32 | 33 | super().__init__(train_x=train, val=val, test=test) 34 | 35 | def read_data(self): 36 | tracker = defaultdict(list) 37 | label_file = loadmat(self.label_file)['labels'][0] 38 | for i, label in enumerate(label_file): 39 | imname = f'image_{str(i + 1).zfill(5)}.jpg' 40 | impath = os.path.join(self.image_dir, imname) 41 | label = int(label) 42 | tracker[label].append(impath) 43 | 44 | print('Splitting data into 50% train, 20% val, and 30% test') 45 | 46 | def _collate(ims, y, c): 47 | items = [] 48 | for im in ims: 49 | item = Datum( 50 | impath=im, 51 | label=y-1, # convert to 0-based label 52 | classname=c 53 | ) 54 | items.append(item) 55 | return items 56 | 57 | lab2cname = read_json(self.lab2cname_file) 58 | train, val, test = [], [], [] 59 | for label, impaths in tracker.items(): 60 | random.shuffle(impaths) 61 | n_total = len(impaths) 62 | n_train = round(n_total * 0.5) 63 | n_val = round(n_total * 0.2) 64 | n_test = n_total - n_train - n_val 65 | assert n_train > 0 and n_val > 0 and n_test > 0 66 | cname = lab2cname[str(label)] 67 | train.extend(_collate(impaths[:n_train], label, cname)) 68 | val.extend(_collate(impaths[n_train:n_train+n_val], label, cname)) 69 | test.extend(_collate(impaths[n_train+n_val:], label, cname)) 70 | 71 | return train, val, test -------------------------------------------------------------------------------- /datasets/oxford_pets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import random 4 | from collections import defaultdict 5 | 6 | import torchvision.transforms as transforms 7 | 8 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 9 | 10 | 11 | template = ['a photo of a {}, a type of pet.'] 12 | 13 | 14 | class OxfordPets(DatasetBase): 15 | 16 | dataset_dir = 'oxford_pets' 17 | 18 | def __init__(self, cfg, root, num_shots): 19 | self.dataset_dir = os.path.join(root, self.dataset_dir) 20 | self.image_dir = os.path.join(self.dataset_dir, 'images') 21 | self.anno_dir = os.path.join(self.dataset_dir, 'annotations') 22 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_OxfordPets.json') 23 | 24 | self.template = template 25 | 26 | train, val, test = self.read_split(self.split_path, self.image_dir) 27 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 28 | 29 | # 30 | subsample = cfg['subsample_classes'] 31 | train, val, test = self.subsample_classes(train, val, test, subsample=subsample) 32 | 33 | super().__init__(train_x=train, val=val, test=test) 34 | 35 | def read_data(self, split_file): 36 | filepath = os.path.join(self.anno_dir, split_file) 37 | items = [] 38 | 39 | with open(filepath, 'r') as f: 40 | lines = f.readlines() 41 | for line in lines: 42 | line = line.strip() 43 | imname, label, species, _ = line.split(' ') 44 | breed = imname.split('_')[:-1] 45 | breed = '_'.join(breed) 46 | breed = breed.lower() 47 | imname += '.jpg' 48 | impath = os.path.join(self.image_dir, imname) 49 | label = int(label) - 1 # convert to 0-based index 50 | item = Datum( 51 | impath=impath, 52 | label=label, 53 | classname=breed 54 | ) 55 | items.append(item) 56 | 57 | return items 58 | 59 | @staticmethod 60 | def split_trainval(trainval, p_val=0.2): 61 | p_trn = 1 - p_val 62 | print(f'Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val') 63 | tracker = defaultdict(list) 64 | for idx, item in enumerate(trainval): 65 | label = item.label 66 | tracker[label].append(idx) 67 | 68 | train, val = [], [] 69 | for label, idxs in tracker.items(): 70 | n_val = round(len(idxs) * p_val) 71 | assert n_val > 0 72 | random.shuffle(idxs) 73 | for n, idx in enumerate(idxs): 74 | item = trainval[idx] 75 | if n < n_val: 76 | val.append(item) 77 | else: 78 | train.append(item) 79 | 80 | return train, val 81 | 82 | @staticmethod 83 | def save_split(train, val, test, filepath, path_prefix): 84 | def _extract(items): 85 | out = [] 86 | for item in items: 87 | impath = item.impath 88 | label = item.label 89 | classname = item.classname 90 | impath = impath.replace(path_prefix, '') 91 | if impath.startswith('/'): 92 | impath = impath[1:] 93 | out.append((impath, label, classname)) 94 | return out 95 | 96 | train = _extract(train) 97 | val = _extract(val) 98 | test = _extract(test) 99 | 100 | split = { 101 | 'train': train, 102 | 'val': val, 103 | 'test': test 104 | } 105 | 106 | write_json(split, filepath) 107 | print(f'Saved split to {filepath}') 108 | 109 | @staticmethod 110 | def read_split(filepath, path_prefix): 111 | def _convert(items): 112 | out = [] 113 | for impath, label, classname in items: 114 | impath = os.path.join(path_prefix, impath) 115 | item = Datum( 116 | impath=impath, 117 | label=int(label), 118 | classname=classname 119 | ) 120 | out.append(item) 121 | return out 122 | 123 | print(f'Reading split from {filepath}') 124 | split = read_json(filepath) 125 | train = _convert(split['train']) 126 | val = _convert(split['val']) 127 | test = _convert(split['test']) 128 | 129 | return train, val, test 130 | 131 | @staticmethod 132 | def subsample_classes(*args, subsample="all"): 133 | """Divide classes into two groups. The first group 134 | represents base classes while the second group represents 135 | new classes. 136 | 137 | Args: 138 | args: a list of datasets, e.g. train, val and test. 139 | subsample (str): what classes to subsample. 140 | """ 141 | assert subsample in ["all", "base", "new"] 142 | 143 | if subsample == "all": 144 | return args 145 | 146 | dataset = args[0] 147 | labels = set() 148 | for item in dataset: 149 | labels.add(item.label) 150 | labels = list(labels) 151 | labels.sort() 152 | n = len(labels) 153 | # Divide classes into two halves 154 | m = math.ceil(n / 2) 155 | 156 | print(f"SUBSAMPLE {subsample.upper()} CLASSES!") 157 | if subsample == "base": 158 | selected = labels[:m] # take the first half 159 | elif subsample == "new": 160 | selected = labels[m:] # take the second half 161 | relabeler = {y: y_new for y_new, y in enumerate(selected)} 162 | 163 | output = [] 164 | for dataset in args: 165 | dataset_new = [] 166 | for item in dataset: 167 | if item.label not in selected: 168 | continue 169 | item_new = Datum( 170 | impath=item.impath, 171 | label=relabeler[item.label], 172 | classname=item.classname 173 | ) 174 | dataset_new.append(item_new) 175 | output.append(dataset_new) 176 | 177 | return output -------------------------------------------------------------------------------- /datasets/stanford_cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | from scipy.io import loadmat 3 | 4 | from .oxford_pets import OxfordPets 5 | from .utils import Datum, DatasetBase 6 | 7 | 8 | template = ['a photo of a {}.'] 9 | 10 | 11 | class StanfordCars(DatasetBase): 12 | 13 | dataset_dir = 'stanford_cars' 14 | 15 | def __init__(self, cfg, root, num_shots): 16 | self.dataset_dir = os.path.join(root, self.dataset_dir) 17 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_StanfordCars.json') 18 | 19 | self.template = template 20 | 21 | train, val, test = OxfordPets.read_split(self.split_path, self.dataset_dir) 22 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 23 | 24 | # 25 | subsample = cfg['subsample_classes'] 26 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 27 | 28 | super().__init__(train_x=train, val=val, test=test) 29 | 30 | def read_data(self, image_dir, anno_file, meta_file): 31 | anno_file = loadmat(anno_file)['annotations'][0] 32 | meta_file = loadmat(meta_file)['class_names'][0] 33 | items = [] 34 | 35 | for i in range(len(anno_file)): 36 | imname = anno_file[i]['fname'][0] 37 | impath = os.path.join(self.dataset_dir, image_dir, imname) 38 | label = anno_file[i]['class'][0, 0] 39 | label = int(label) - 1 # convert to 0-based index 40 | classname = meta_file[label][0] 41 | names = classname.split(' ') 42 | year = names.pop(-1) 43 | names.insert(0, year) 44 | classname = ' '.join(names) 45 | item = Datum( 46 | impath=impath, 47 | label=label, 48 | classname=classname 49 | ) 50 | items.append(item) 51 | 52 | return items -------------------------------------------------------------------------------- /datasets/sun397.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 4 | 5 | from .oxford_pets import OxfordPets 6 | 7 | 8 | template = ['a photo of a {}.'] 9 | 10 | 11 | class SUN397(DatasetBase): 12 | 13 | dataset_dir = 'sun397' 14 | 15 | def __init__(self, cfg, root, num_shots): 16 | self.dataset_dir = os.path.join(root, self.dataset_dir) 17 | self.image_dir = os.path.join(self.dataset_dir, 'SUN397') 18 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_SUN397.json') 19 | 20 | self.template = template 21 | 22 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 23 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 24 | 25 | # 26 | subsample = cfg['subsample_classes'] 27 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 28 | 29 | super().__init__(train_x=train, val=val, test=test) 30 | 31 | def read_data(self, cname2lab, text_file): 32 | text_file = os.path.join(self.dataset_dir, text_file) 33 | items = [] 34 | 35 | with open(text_file, 'r') as f: 36 | lines = f.readlines() 37 | for line in lines: 38 | imname = line.strip()[1:] # remove / 39 | classname = os.path.dirname(imname) 40 | label = cname2lab[classname] 41 | impath = os.path.join(self.image_dir, imname) 42 | 43 | names = classname.split('/')[1:] # remove 1st letter 44 | names = names[::-1] # put words like indoor/outdoor at first 45 | classname = ' '.join(names) 46 | 47 | item = Datum( 48 | impath=impath, 49 | label=label, 50 | classname=classname 51 | ) 52 | items.append(item) 53 | 54 | return items 55 | -------------------------------------------------------------------------------- /datasets/ucf101.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 4 | 5 | from .oxford_pets import OxfordPets 6 | 7 | 8 | template = ['a photo of a person doing {}.'] 9 | 10 | 11 | class UCF101(DatasetBase): 12 | 13 | dataset_dir = 'ucf101' 14 | 15 | def __init__(self, cfg, root, num_shots): 16 | self.dataset_dir = os.path.join(root, self.dataset_dir) 17 | self.image_dir = os.path.join(self.dataset_dir, 'UCF-101-midframes') 18 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_UCF101.json') 19 | 20 | self.template = template 21 | 22 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 23 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 24 | 25 | subsample = cfg['subsample_classes'] 26 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 27 | 28 | super().__init__(train_x=train, val=val, test=test) 29 | 30 | def read_data(self, cname2lab, text_file): 31 | text_file = os.path.join(self.dataset_dir, text_file) 32 | items = [] 33 | 34 | with open(text_file, 'r') as f: 35 | lines = f.readlines() 36 | for line in lines: 37 | line = line.strip().split(' ')[0] # trainlist: filename, label 38 | action, filename = line.split('/') 39 | label = cname2lab[action] 40 | 41 | elements = re.findall('[A-Z][^A-Z]*', action) 42 | renamed_action = '_'.join(elements) 43 | 44 | filename = filename.replace('.avi', '.jpg') 45 | impath = os.path.join(self.image_dir, renamed_action, filename) 46 | 47 | item = Datum( 48 | impath=impath, 49 | label=label, 50 | classname=renamed_action 51 | ) 52 | items.append(item) 53 | 54 | return items 55 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import os.path as osp 4 | import tarfile 5 | import zipfile 6 | from collections import defaultdict 7 | import gdown 8 | import json 9 | import torch 10 | from torch.utils.data import Dataset as TorchDataset 11 | import torchvision.transforms as T 12 | from PIL import Image 13 | import clip 14 | 15 | def read_json(fpath): 16 | """Read json file from a path.""" 17 | with open(fpath, 'r') as f: 18 | obj = json.load(f) 19 | return obj 20 | 21 | 22 | def write_json(obj, fpath): 23 | """Writes to a json file.""" 24 | if not osp.exists(osp.dirname(fpath)): 25 | os.makedirs(osp.dirname(fpath)) 26 | with open(fpath, 'w') as f: 27 | json.dump(obj, f, indent=4, separators=(',', ': ')) 28 | 29 | 30 | def read_image(path): 31 | """Read image from path using ``PIL.Image``. 32 | 33 | Args: 34 | path (str): path to an image. 35 | 36 | Returns: 37 | PIL image 38 | """ 39 | if not osp.exists(path): 40 | raise IOError('No file exists at {}'.format(path)) 41 | 42 | while True: 43 | try: 44 | img = Image.open(path).convert('RGB') 45 | return img 46 | except IOError: 47 | print( 48 | 'Cannot read image from {}, ' 49 | 'probably due to heavy IO. Will re-try'.format(path) 50 | ) 51 | 52 | 53 | def listdir_nohidden(path, sort=False): 54 | """List non-hidden items in a directory. 55 | 56 | Args: 57 | path (str): directory path. 58 | sort (bool): sort the items. 59 | """ 60 | items = [f for f in os.listdir(path) if not f.startswith('.') and 'sh' not in f] 61 | if sort: 62 | items.sort() 63 | return items 64 | 65 | 66 | class Datum: 67 | """Data instance which defines the basic attributes. 68 | 69 | Args: 70 | impath (str): image path. 71 | label (int): class label. 72 | domain (int): domain label. 73 | classname (str): class name. 74 | """ 75 | 76 | def __init__(self, impath='', label=0, domain=-1, classname=''): 77 | assert isinstance(impath, str) 78 | assert isinstance(label, int) 79 | assert isinstance(domain, int) 80 | assert isinstance(classname, str) 81 | 82 | self._impath = impath 83 | self._label = label 84 | self._domain = domain 85 | self._classname = classname 86 | 87 | @property 88 | def impath(self): 89 | return self._impath 90 | 91 | @property 92 | def label(self): 93 | return self._label 94 | 95 | @property 96 | def domain(self): 97 | return self._domain 98 | 99 | @property 100 | def classname(self): 101 | return self._classname 102 | 103 | 104 | class DatasetBase: 105 | """A unified dataset class for 106 | 1) domain adaptation 107 | 2) domain generalization 108 | 3) semi-supervised learning 109 | """ 110 | dataset_dir = '' # the directory where the dataset is stored 111 | domains = [] # string names of all domains 112 | 113 | def __init__(self, train_x=None, train_u=None, val=None, test=None): 114 | self._train_x = train_x # labeled training data 115 | self._train_u = train_u # unlabeled training data (optional) 116 | self._val = val # validation data (optional) 117 | self._test = test # test data 118 | 119 | self._num_classes = self.get_num_classes(train_x) 120 | self._lab2cname, self._classnames = self.get_lab2cname(train_x) 121 | 122 | @property 123 | def train_x(self): 124 | return self._train_x 125 | 126 | @property 127 | def train_u(self): 128 | return self._train_u 129 | 130 | @property 131 | def val(self): 132 | return self._val 133 | 134 | @property 135 | def test(self): 136 | return self._test 137 | 138 | @property 139 | def lab2cname(self): 140 | return self._lab2cname 141 | 142 | @property 143 | def classnames(self): 144 | return self._classnames 145 | 146 | @property 147 | def num_classes(self): 148 | return self._num_classes 149 | 150 | def get_num_classes(self, data_source): 151 | """Count number of classes. 152 | 153 | Args: 154 | data_source (list): a list of Datum objects. 155 | """ 156 | label_set = set() 157 | for item in data_source: 158 | label_set.add(item.label) 159 | return max(label_set) + 1 160 | 161 | def get_lab2cname(self, data_source): 162 | """Get a label-to-classname mapping (dict). 163 | 164 | Args: 165 | data_source (list): a list of Datum objects. 166 | """ 167 | container = set() 168 | for item in data_source: 169 | container.add((item.label, item.classname)) 170 | mapping = {label: classname for label, classname in container} 171 | labels = list(mapping.keys()) 172 | labels.sort() 173 | classnames = [mapping[label] for label in labels] 174 | return mapping, classnames 175 | 176 | def check_input_domains(self, source_domains, target_domains): 177 | self.is_input_domain_valid(source_domains) 178 | self.is_input_domain_valid(target_domains) 179 | 180 | def is_input_domain_valid(self, input_domains): 181 | for domain in input_domains: 182 | if domain not in self.domains: 183 | raise ValueError( 184 | 'Input domain must belong to {}, ' 185 | 'but got [{}]'.format(self.domains, domain) 186 | ) 187 | 188 | def download_data(self, url, dst, from_gdrive=True): 189 | if not osp.exists(osp.dirname(dst)): 190 | os.makedirs(osp.dirname(dst)) 191 | 192 | if from_gdrive: 193 | gdown.download(url, dst, quiet=False) 194 | else: 195 | raise NotImplementedError 196 | 197 | print('Extracting file ...') 198 | 199 | try: 200 | tar = tarfile.open(dst) 201 | tar.extractall(path=osp.dirname(dst)) 202 | tar.close() 203 | except: 204 | zip_ref = zipfile.ZipFile(dst, 'r') 205 | zip_ref.extractall(osp.dirname(dst)) 206 | zip_ref.close() 207 | 208 | print('File extracted to {}'.format(osp.dirname(dst))) 209 | 210 | def generate_fewshot_dataset( 211 | self, *data_sources, num_shots=-1, repeat=True 212 | ): 213 | """Generate a few-shot dataset (typically for the training set). 214 | 215 | This function is useful when one wants to evaluate a model 216 | in a few-shot learning setting where each class only contains 217 | a few number of images. 218 | 219 | Args: 220 | data_sources: each individual is a list containing Datum objects. 221 | num_shots (int): number of instances per class to sample. 222 | repeat (bool): repeat images if needed. 223 | """ 224 | if num_shots < 1: 225 | if len(data_sources) == 1: 226 | return data_sources[0] 227 | return data_sources 228 | 229 | print(f'Creating a {num_shots}-shot dataset') 230 | 231 | output = [] 232 | 233 | for data_source in data_sources: 234 | tracker = self.split_dataset_by_label(data_source) 235 | dataset = [] 236 | 237 | for label, items in tracker.items(): 238 | if len(items) >= num_shots: 239 | sampled_items = random.sample(items, num_shots) 240 | else: 241 | if repeat: 242 | sampled_items = random.choices(items, k=num_shots) 243 | else: 244 | sampled_items = items 245 | dataset.extend(sampled_items) 246 | 247 | output.append(dataset) 248 | 249 | if len(output) == 1: 250 | return output[0] 251 | 252 | return output 253 | 254 | def split_dataset_by_label(self, data_source): 255 | """Split a dataset, i.e. a list of Datum objects, 256 | into class-specific groups stored in a dictionary. 257 | 258 | Args: 259 | data_source (list): a list of Datum objects. 260 | """ 261 | output = defaultdict(list) 262 | 263 | for item in data_source: 264 | output[item.label].append(item) 265 | 266 | return output 267 | 268 | def split_dataset_by_domain(self, data_source): 269 | """Split a dataset, i.e. a list of Datum objects, 270 | into domain-specific groups stored in a dictionary. 271 | 272 | Args: 273 | data_source (list): a list of Datum objects. 274 | """ 275 | output = defaultdict(list) 276 | 277 | for item in data_source: 278 | output[item.domain].append(item) 279 | 280 | return output 281 | 282 | 283 | class DatasetWrapper(TorchDataset): 284 | def __init__(self, data_source, input_size, transform=None, is_train=False, 285 | return_img0=False, k_tfm=1): 286 | self.data_source = data_source 287 | self.transform = transform # accept list (tuple) as input 288 | self.is_train = is_train 289 | # Augmenting an image K>1 times is only allowed during training 290 | self.k_tfm = k_tfm if is_train else 1 291 | self.return_img0 = return_img0 292 | 293 | if self.k_tfm > 1 and transform is None: 294 | raise ValueError( 295 | 'Cannot augment the image {} times ' 296 | 'because transform is None'.format(self.k_tfm) 297 | ) 298 | 299 | # Build transform that doesn't apply any data augmentation 300 | interp_mode = T.InterpolationMode.BICUBIC 301 | to_tensor = [] 302 | to_tensor += [T.Resize(input_size, interpolation=interp_mode)] 303 | to_tensor += [T.ToTensor()] 304 | normalize = T.Normalize( 305 | mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) 306 | ) 307 | to_tensor += [normalize] 308 | self.to_tensor = T.Compose(to_tensor) 309 | 310 | def __len__(self): 311 | return len(self.data_source) 312 | 313 | def __getitem__(self, idx): 314 | item = self.data_source[idx] 315 | 316 | output = { 317 | 'label': item.label, 318 | 'domain': item.domain, 319 | 'impath': item.impath 320 | } 321 | 322 | img0 = read_image(item.impath) 323 | 324 | if self.transform is not None: 325 | if isinstance(self.transform, (list, tuple)): 326 | for i, tfm in enumerate(self.transform): 327 | img = self._transform_image(tfm, img0) 328 | keyname = 'img' 329 | if (i + 1) > 1: 330 | keyname += str(i + 1) 331 | output[keyname] = img 332 | else: 333 | img = self._transform_image(self.transform, img0) 334 | output['img'] = img 335 | 336 | if self.return_img0: 337 | output['img0'] = self.to_tensor(img0) 338 | 339 | return output['img'], output['label'] 340 | 341 | def _transform_image(self, tfm, img0): 342 | img_list = [] 343 | 344 | for k in range(self.k_tfm): 345 | img_list.append(tfm(img0)) 346 | 347 | img = img_list 348 | if len(img) == 1: 349 | img = img[0] 350 | 351 | return img 352 | 353 | 354 | def build_data_loader( 355 | data_source=None, 356 | batch_size=64, 357 | input_size=224, 358 | tfm=None, 359 | is_train=True, 360 | shuffle=False, 361 | dataset_wrapper=None 362 | ): 363 | 364 | if dataset_wrapper is None: 365 | dataset_wrapper = DatasetWrapper 366 | 367 | # Build data loader 368 | data_loader = torch.utils.data.DataLoader( 369 | dataset_wrapper(data_source, input_size=input_size, transform=tfm, is_train=is_train), 370 | batch_size=batch_size, 371 | num_workers=8, 372 | shuffle=shuffle, 373 | drop_last=False, 374 | pin_memory=(torch.cuda.is_available()) 375 | ) 376 | assert len(data_loader) > 0 377 | 378 | return data_loader 379 | -------------------------------------------------------------------------------- /main_coop_vae.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import argparse 4 | import yaml 5 | from tqdm import tqdm 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | import torchvision.transforms as transforms 11 | 12 | from datasets import build_dataset 13 | from datasets.utils import build_data_loader 14 | import clip 15 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 16 | from utils import * 17 | from torch.autograd import Variable 18 | 19 | _tokenizer = _Tokenizer() 20 | train_tranform = transforms.Compose([ 21 | transforms.RandomResizedCrop(size=224, scale=(0.5, 1), interpolation=transforms.InterpolationMode.BICUBIC), 22 | transforms.RandomHorizontalFlip(p=0.5), 23 | transforms.ToTensor(), 24 | transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) 25 | ]) 26 | 27 | 28 | def get_arguments(): 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--config', dest='config', help='settings of Tip-Adapter in yaml format') 32 | args = parser.parse_args() 33 | 34 | return args 35 | 36 | def weights_init(m): 37 | classname = m.__class__.__name__ 38 | if classname.find('Linear') != -1: 39 | m.weight.data.normal_(0.0, 0.02) 40 | m.bias.data.fill_(0) 41 | elif classname.find('BatchNorm') != -1: 42 | m.weight.data.normal_(1.0, 0.02) 43 | m.bias.data.fill_(0) 44 | 45 | class CoOp_PromptLearner(nn.Module): 46 | def __init__(self, classnames, clip_model): 47 | super().__init__() 48 | n_cls = len(classnames) 49 | n_ctx = 4 50 | ctx_init = 'a photo of a' # caltech101 51 | 52 | # ctx_init = None 53 | self.dtype = clip_model.dtype 54 | ctx_dim = clip_model.ln_final.weight.shape[0] 55 | 56 | self.n_cls = n_cls 57 | self.n_ctx = n_ctx 58 | 59 | if ctx_init: 60 | # use given words to initialize context vectors 61 | ctx_init = ctx_init.replace("_", " ") 62 | n_ctx = len(ctx_init.split(" ")) 63 | prompt = clip.tokenize(ctx_init).cuda() 64 | with torch.no_grad(): 65 | embedding = clip_model.token_embedding(prompt).type(self.dtype) 66 | ctx_vectors = embedding[0, 1 : 1 + n_ctx, :].cuda() 67 | prompt_prefix = ctx_init 68 | self.n_ctx = n_ctx 69 | else: 70 | # random initialization 71 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=self.dtype).cuda() 72 | nn.init.normal_(ctx_vectors, std=0.02) 73 | prompt_prefix = " ".join(["X"] * n_ctx) 74 | 75 | print(f'Initial context: "{prompt_prefix}"') 76 | print(f"Number of context words (tokens): {n_ctx}") 77 | 78 | self.ctx = nn.Parameter(ctx_vectors) 79 | self.prompt_prefix = prompt_prefix 80 | self.get_prefix_suffix_token(classnames, clip_model) 81 | 82 | 83 | def get_prefix_suffix_token(self, classnames, clip_model): 84 | prompt_prefix = self.prompt_prefix 85 | classnames = [name.replace("_", " ") for name in classnames] 86 | name_lens = [len(_tokenizer.encode(name)) for name in classnames] 87 | prompts = [prompt_prefix + " " + name + "." for name in classnames] 88 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).cuda() # (n_cls, n_tkn) 89 | with torch.no_grad(): 90 | embedding = clip_model.token_embedding(tokenized_prompts).type(self.dtype) 91 | 92 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS 93 | self.register_buffer("token_suffix", embedding[:, 1 + self.n_ctx :, :]) # CLS, EOS 94 | 95 | self.tokenized_prompts = tokenized_prompts # torch.Tensor 96 | self.name_lens = name_lens 97 | 98 | def forward(self): 99 | ctx = self.ctx 100 | if ctx.dim() == 2: 101 | ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) 102 | 103 | prefix = self.token_prefix 104 | suffix = self.token_suffix 105 | 106 | prompts = torch.cat( 107 | [ 108 | prefix, # (n_cls, 1, dim) 109 | ctx, # (n_cls, n_ctx, dim) 110 | suffix, # (n_cls, *, dim) 111 | ], 112 | dim=1, 113 | ) 114 | 115 | return prompts 116 | 117 | 118 | def run_coop(cfg, text_encoder, prompt_learner, clip_weights, clip_model, netG=None): 119 | coop_prompt_learner = CoOp_PromptLearner(all_classnames, clip_model) 120 | optimizer = torch.optim.SGD(coop_prompt_learner.parameters(), lr=2e-3) 121 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, cfg['train_epoch'] * len(train_loader_F)) 122 | 123 | best_base_acc, best_new_acc, best_H = 0, 0, 0 124 | best_epoch = 0 125 | 126 | for train_idx in range(cfg['train_epoch']): 127 | # Train 128 | loss_list = [] 129 | print('Train Epoch: {:} / {:}'.format(train_idx, cfg['train_epoch'])) 130 | 131 | for i, (images, target) in enumerate(tqdm(train_loader_F)): 132 | images, target = images.cuda(), target.cuda() 133 | 134 | with torch.no_grad(): 135 | image_features = clip_model.encode_image(images) 136 | image_features /= image_features.norm(dim=-1, keepdim=True) 137 | 138 | if netG is not None: 139 | with torch.no_grad(): 140 | gen_target = torch.randint(len(base_classnames), len(all_classnames), (target.shape[0], )).cuda() 141 | z = torch.randn([gen_target.shape[0], image_features.shape[1]]).cuda() 142 | text_features = clip_weights.T[gen_target].float() 143 | bias = netG(z) 144 | prompt_learner.get_prefix_suffix_token(all_classnames, clip_model) # update prefix and suffix for new dataset. 145 | prompts = prompt_learner(bias, gen_target) 146 | 147 | tokenized_prompts = prompt_learner.tokenized_prompts 148 | text_features = text_encoder(prompts, tokenized_prompts[gen_target]) 149 | gen_feature = text_features / text_features.norm(dim=-1, keepdim=True) 150 | gen_target = gen_target 151 | image_features = torch.cat([image_features, gen_feature], dim=0).half() 152 | target = torch.cat([target, gen_target], dim=0).half() 153 | 154 | prompts = coop_prompt_learner() 155 | tokenized_prompts = coop_prompt_learner.tokenized_prompts 156 | 157 | text_features = text_encoder(prompts, tokenized_prompts) 158 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 159 | logits = 100. * image_features.float() @ text_features.T.float() 160 | loss = F.cross_entropy(logits, target.long()) 161 | optimizer.zero_grad() 162 | loss.backward() 163 | optimizer.step() 164 | scheduler.step() 165 | 166 | # Evaluation 167 | with torch.no_grad(): 168 | prompts = coop_prompt_learner() 169 | tokenized_prompts = coop_prompt_learner.tokenized_prompts 170 | text_features = text_encoder(prompts, tokenized_prompts) 171 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 172 | # new 173 | clip_logits = 100. * test_features_new.float() @ text_features.T.float()[:, len(base_classnames):] 174 | new_acc = cls_acc(clip_logits, test_labels_new) 175 | 176 | # base 177 | clip_logits = 100. * test_features.float() @ text_features.T.float()[:, :len(base_classnames)] 178 | base_acc = cls_acc(clip_logits, test_labels) 179 | 180 | H = 2 * base_acc * new_acc / (base_acc + new_acc) 181 | if H > best_H: 182 | best_base_acc = base_acc 183 | best_new_acc = new_acc 184 | best_H = H 185 | best_epoch = train_idx 186 | 187 | print("base acc:\t%.2f new acc:\t%.2f H:\t%.2f " % (base_acc, new_acc, H)) 188 | 189 | print(f"**** After fine-tuning, CoOp's best base test accuracy: {best_base_acc:.2f}, at epoch: {best_epoch}. ****\n") 190 | print(f"**** After fine-tuning, CoOp's best new test accuracy: {best_new_acc:.2f}, at epoch: {best_epoch}. ****\n") 191 | print(f"**** After fine-tuning, CoOp's best H test accuracy: {best_H:.2f} ****\n") 192 | 193 | return best_base_acc, best_new_acc, best_H 194 | 195 | class TextEncoder(nn.Module): 196 | def __init__(self, clip_model): 197 | super().__init__() 198 | self.transformer = clip_model.transformer 199 | self.positional_embedding = clip_model.positional_embedding 200 | self.ln_final = clip_model.ln_final 201 | self.text_projection = clip_model.text_projection 202 | self.dtype = clip_model.dtype 203 | 204 | def forward(self, prompts, tokenized_prompts): 205 | x = prompts + self.positional_embedding.type(self.dtype) 206 | x = x.permute(1, 0, 2) # NLD -> LND 207 | x = self.transformer(x) 208 | x = x.permute(1, 0, 2) # LND -> NLD 209 | x = self.ln_final(x).float() 210 | 211 | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection 212 | 213 | return x 214 | 215 | 216 | class PromptLearner(nn.Module): 217 | def __init__(self, classnames, clip_model): 218 | super().__init__() 219 | n_cls = len(classnames) 220 | n_ctx = 4 221 | ctx_init = None 222 | # ctx_init = 'a photo of a' 223 | self.dtype = clip_model.dtype 224 | ctx_dim = clip_model.ln_final.weight.shape[0] 225 | 226 | self.n_cls = n_cls 227 | self.n_ctx = n_ctx 228 | 229 | if ctx_init: 230 | # use given words to initialize context vectors 231 | ctx_init = ctx_init.replace("_", " ") 232 | n_ctx = len(ctx_init.split(" ")) 233 | prompt = clip.tokenize(ctx_init).cuda() 234 | with torch.no_grad(): 235 | embedding = clip_model.token_embedding(prompt).type(self.dtype) 236 | ctx_vectors = embedding[0, 1 : 1 + n_ctx, :].cuda() 237 | prompt_prefix = ctx_init 238 | self.n_ctx = n_ctx 239 | else: 240 | # random initialization 241 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=self.dtype).cuda() 242 | nn.init.normal_(ctx_vectors, std=0.02) 243 | prompt_prefix = " ".join(["X"] * n_ctx) 244 | 245 | print(f'Initial context: "{prompt_prefix}"') 246 | print(f"Number of context words (tokens): {n_ctx}") 247 | 248 | self.ctx = nn.Parameter(ctx_vectors) 249 | # self.ctx = ctx_vectors # No prompt learning. 250 | self.prompt_prefix = prompt_prefix 251 | self.get_prefix_suffix_token(classnames, clip_model) 252 | 253 | 254 | def get_prefix_suffix_token(self, classnames, clip_model): 255 | prompt_prefix = self.prompt_prefix 256 | classnames = [name.replace("_", " ") for name in classnames] 257 | name_lens = [len(_tokenizer.encode(name)) for name in classnames] 258 | prompts = [prompt_prefix + " " + name + "." for name in classnames] 259 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).cuda() # (n_cls, n_tkn) 260 | with torch.no_grad(): 261 | embedding = clip_model.token_embedding(tokenized_prompts).type(self.dtype) 262 | 263 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS 264 | self.register_buffer("token_suffix", embedding[:, 1 + self.n_ctx :, :]) # CLS, EOS 265 | 266 | self.tokenized_prompts = tokenized_prompts # torch.Tensor 267 | self.name_lens = name_lens 268 | 269 | def forward(self, bias, target): 270 | prefix = self.token_prefix[target] 271 | suffix = self.token_suffix[target] 272 | ctx = self.ctx # (n_ctx, ctx_dim) 273 | bias = bias.unsqueeze(1) # (batch, 1, ctx_dim) 274 | ctx = ctx.unsqueeze(0) 275 | ctx_shifted = ctx + bias # (batch, n_ctx, ctx_dim) 276 | prompts = torch.cat([prefix, ctx_shifted, suffix], dim=1) 277 | return prompts 278 | 279 | def vae_loss(recon_x, x, mean, log_var, target, clip_weights): 280 | REC = (recon_x - x).pow(2).sum(1).mean() 281 | KLD = -0.5 * (1 + log_var - mean.pow(2) - log_var.exp()).sum(dim=1).mean() 282 | return (REC + 1 * KLD) 283 | 284 | class Encoder(nn.Module): 285 | 286 | def __init__(self): 287 | super(Encoder, self).__init__() 288 | self.net = nn.Sequential( 289 | nn.Linear(512 * 1, 2048), 290 | nn.ReLU(), 291 | ) 292 | self.mean = nn.Linear(2048, 512) 293 | self.log_var = nn.Linear(2048, 512) 294 | 295 | self.apply(weights_init) 296 | 297 | def forward(self, x, a): 298 | # x = torch.cat([x, a], dim=1) 299 | x = self.net(x) 300 | mean = self.mean(x) 301 | log_var = self.log_var(x) 302 | return mean, log_var 303 | 304 | class Generator(nn.Module): 305 | 306 | def __init__(self): 307 | super(Generator, self).__init__() 308 | n_ctx = 4 309 | self.net = nn.Sequential( 310 | nn.Linear(512 * 1, 4096), 311 | nn.ReLU(), 312 | nn.Linear(4096, 512 * 1), 313 | ) 314 | self.apply(weights_init) 315 | 316 | def forward(self, x): 317 | out = self.net(x) 318 | return out 319 | 320 | def run_vae_generator(cfg, dataset, cache_keys, cache_values, clip_weights, clip_model): 321 | # need to evaluate new dataset. 322 | # CLIP 323 | clip_model, preprocess = clip.load(cfg['backbone']) 324 | clip_model.eval() 325 | for p in clip_model.parameters(): 326 | p.requires_grad = False 327 | 328 | text_encoder = TextEncoder(clip_model).float().cuda() 329 | prompt_learner = PromptLearner(dataset.classnames, clip_model).float().cuda() 330 | 331 | # test on new dataset. 332 | # global base&new val/test data. 333 | global train_loader_F_new 334 | global val_features_new, val_labels_new 335 | global test_features_new, test_labels_new 336 | 337 | print("\nLoading visual features and labels from new test set.") 338 | cfg['subsample_classes'] = "new" 339 | dataset_new = build_dataset(cfg, cfg['dataset'], cfg['root_path'], cfg['shots']) 340 | val_loader_new = build_data_loader(data_source=dataset_new.val, batch_size=64, is_train=False, tfm=preprocess, shuffle=False) 341 | test_loader_new = build_data_loader(data_source=dataset_new.test, batch_size=64, is_train=False, tfm=preprocess, shuffle=False) 342 | 343 | val_features_new, val_labels_new = pre_load_features(cfg, "val", clip_model, val_loader_new) 344 | test_features_new, test_labels_new = pre_load_features(cfg, "test", clip_model, test_loader_new) 345 | 346 | train_loader_F_new = build_data_loader(data_source=dataset_new.train_x, batch_size=256, tfm=train_tranform, is_train=True, shuffle=True) 347 | print("\nGetting textual features as CLIP's classifier.") 348 | clip_weights_new = clip_classifier(dataset_new.classnames, dataset_new.template, clip_model.float()) 349 | 350 | global base_classnames 351 | global new_classnames 352 | global all_classnames 353 | base_classnames = dataset.classnames 354 | new_classnames = dataset_new.classnames 355 | all_classnames = base_classnames + new_classnames 356 | 357 | # train VAE. 358 | netE = Encoder().cuda() 359 | netG = Generator().cuda() 360 | optimizerE = torch.optim.AdamW(netE.parameters(), lr=1e-3) 361 | optimizerG = torch.optim.AdamW(netG.parameters(), lr=1e-3) 362 | optimizerP = torch.optim.AdamW(prompt_learner.parameters(), lr=1e-3) 363 | 364 | best_base, best_new, best_H = 0.0, 0.0, 0.0 365 | 366 | for train_idx in range(1, 50 + 1): 367 | # Train 368 | netE.train() 369 | netG.train() 370 | 371 | loss_list = [] 372 | print('Train Epoch: {:} / {:}'.format(train_idx, cfg['train_epoch'])) 373 | 374 | for i, (images, target) in enumerate(tqdm(train_loader_F)): 375 | images, target = images.cuda(), target.cuda() 376 | with torch.no_grad(): 377 | image_features = clip_model.encode_image(images).float() 378 | image_features /= image_features.norm(dim=-1, keepdim=True) 379 | 380 | text_features = clip_weights.T[target].float() 381 | netE.zero_grad() 382 | netG.zero_grad() 383 | mean, log_var = netE(image_features, text_features) 384 | std = torch.exp(0.5 * log_var) 385 | z = torch.randn(mean.shape).cuda() 386 | z = std * z + mean 387 | bias = netG(z) 388 | 389 | prompt_learner.get_prefix_suffix_token(base_classnames, clip_model) 390 | prompts = prompt_learner(bias, target) 391 | 392 | tokenized_prompts = prompt_learner.tokenized_prompts 393 | text_features = text_encoder(prompts, tokenized_prompts[target]) 394 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 395 | recon_features = text_features 396 | loss = vae_loss(recon_features, image_features, mean, log_var, target, clip_weights) 397 | 398 | optimizerP.zero_grad() 399 | loss.backward() 400 | loss_list.append(loss.item()) 401 | optimizerE.step() 402 | optimizerG.step() 403 | optimizerP.step() 404 | 405 | print('Loss: {:.4f}'.format(sum(loss_list)/len(loss_list))) 406 | if train_idx % 10 == 0: 407 | 408 | # Evaluation. 409 | netE.eval() 410 | netG.eval() 411 | clip_weights_mix = torch.cat([clip_weights, clip_weights_new], dim=1) 412 | # run CoOp 413 | base, new, H = run_coop(cfg, text_encoder, prompt_learner, clip_weights_mix, clip_model, netG=netG) 414 | if H > best_H: 415 | best_base = base 416 | best_new = new 417 | best_H = H 418 | 419 | print("Evaluate on dataset:", cfg['dataset']) 420 | print("best base acc: %.2f" % best_base) 421 | print("best new acc: %.2f" % best_new) 422 | print("best H: %.2f" % best_H) 423 | 424 | def main(): 425 | # Load config file 426 | args = get_arguments() 427 | assert (os.path.exists(args.config)) 428 | 429 | cfg = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 430 | 431 | # Load cfg for conditional prompt. 432 | cfg['subsample_classes'] = "base" # all, base or new 433 | 434 | cache_dir = os.path.join('./caches', cfg['dataset']) 435 | os.makedirs(cache_dir, exist_ok=True) 436 | cfg['cache_dir'] = cache_dir 437 | 438 | print("\nRunning configs.") 439 | print(cfg, "\n") 440 | 441 | # CLIP 442 | clip_model, preprocess = clip.load(cfg['backbone']) 443 | clip_model.eval() 444 | for p in clip_model.parameters(): 445 | p.requires_grad = False 446 | 447 | # Prepare dataset 448 | random.seed(1) 449 | torch.manual_seed(1) 450 | 451 | global train_loader_F 452 | global val_features, val_labels 453 | global test_features, test_labels 454 | 455 | print("Preparing dataset.") 456 | dataset = build_dataset(cfg, cfg['dataset'], cfg['root_path'], cfg['shots']) 457 | # validate set 458 | val_loader = build_data_loader(data_source=dataset.val, batch_size=64, is_train=False, tfm=preprocess, shuffle=False) 459 | test_loader = build_data_loader(data_source=dataset.test, batch_size=64, is_train=False, tfm=preprocess, shuffle=False) 460 | # train set 461 | train_loader_cache = build_data_loader(data_source=dataset.train_x, batch_size=256, tfm=train_tranform, is_train=True, shuffle=False) 462 | train_loader_F = build_data_loader(data_source=dataset.train_x, batch_size=256, tfm=train_tranform, is_train=True, shuffle=True) 463 | 464 | # Textual features 465 | print("\nGetting textual features as CLIP's classifier.") 466 | clip_weights = clip_classifier(dataset.classnames, dataset.template, clip_model) 467 | 468 | # Construct the cache model by few-shot training set 469 | print("\nConstructing cache model by few-shot visual features and labels.") 470 | cache_keys, cache_values = build_cache_model(cfg, clip_model, train_loader_cache) 471 | 472 | # Pre-load val features 473 | print("\nLoading visual features and labels from val set.") 474 | val_features, val_labels = pre_load_features(cfg, "val", clip_model, val_loader) 475 | 476 | # Pre-load test features 477 | print("\nLoading visual features and labels from test set.") 478 | test_features, test_labels = pre_load_features(cfg, "test", clip_model, test_loader) 479 | 480 | # ------------------------------------------ build and train generative model --------------------------------------- 481 | run_vae_generator(cfg, dataset, cache_keys, cache_values, clip_weights, clip_model) 482 | 483 | 484 | if __name__ == '__main__': 485 | main() -------------------------------------------------------------------------------- /main_imagenet_coop_vae.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import argparse 4 | import yaml 5 | from tqdm import tqdm 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | 11 | from datasets.imagenet import ImageNet 12 | import clip 13 | from utils import * 14 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | _tokenizer = _Tokenizer() 17 | def get_arguments(): 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--config', dest='config', help='settings of Tip-Adapter in yaml format') 21 | args = parser.parse_args() 22 | 23 | return args 24 | 25 | def weights_init(m): 26 | classname = m.__class__.__name__ 27 | if classname.find('Linear') != -1: 28 | m.weight.data.normal_(0.0, 0.02) 29 | m.bias.data.fill_(0) 30 | elif classname.find('BatchNorm') != -1: 31 | m.weight.data.normal_(1.0, 0.02) 32 | m.bias.data.fill_(0) 33 | 34 | class CoOp_PromptLearner(nn.Module): 35 | def __init__(self, classnames, clip_model): 36 | super().__init__() 37 | n_cls = len(classnames) 38 | n_ctx = 4 39 | ctx_init = 'a photo of a' # caltech101 40 | # ctx_init = None 41 | self.dtype = clip_model.dtype 42 | ctx_dim = clip_model.ln_final.weight.shape[0] 43 | 44 | self.n_cls = n_cls 45 | self.n_ctx = n_ctx 46 | 47 | if ctx_init: 48 | # use given words to initialize context vectors 49 | ctx_init = ctx_init.replace("_", " ") 50 | n_ctx = len(ctx_init.split(" ")) 51 | prompt = clip.tokenize(ctx_init).cuda() 52 | with torch.no_grad(): 53 | embedding = clip_model.token_embedding(prompt).type(self.dtype) 54 | ctx_vectors = embedding[0, 1 : 1 + n_ctx, :].cuda() 55 | prompt_prefix = ctx_init 56 | self.n_ctx = n_ctx 57 | else: 58 | # random initialization 59 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=self.dtype).cuda() 60 | nn.init.normal_(ctx_vectors, std=0.02) 61 | prompt_prefix = " ".join(["X"] * n_ctx) 62 | 63 | print(f'Initial context: "{prompt_prefix}"') 64 | print(f"Number of context words (tokens): {n_ctx}") 65 | 66 | self.ctx = nn.Parameter(ctx_vectors) 67 | self.prompt_prefix = prompt_prefix 68 | self.get_prefix_suffix_token(classnames, clip_model) 69 | 70 | 71 | def get_prefix_suffix_token(self, classnames, clip_model): 72 | prompt_prefix = self.prompt_prefix 73 | classnames = [name.replace("_", " ") for name in classnames] 74 | name_lens = [len(_tokenizer.encode(name)) for name in classnames] 75 | prompts = [prompt_prefix + " " + name + "." for name in classnames] 76 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).cuda() # (n_cls, n_tkn) 77 | with torch.no_grad(): 78 | embedding = clip_model.token_embedding(tokenized_prompts).type(self.dtype) 79 | 80 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS 81 | self.register_buffer("token_suffix", embedding[:, 1 + self.n_ctx :, :]) # CLS, EOS 82 | 83 | self.tokenized_prompts = tokenized_prompts # torch.Tensor 84 | self.name_lens = name_lens 85 | 86 | def forward(self): 87 | ctx = self.ctx 88 | if ctx.dim() == 2: 89 | ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) 90 | 91 | prefix = self.token_prefix 92 | suffix = self.token_suffix 93 | 94 | prompts = torch.cat( 95 | [ 96 | prefix, # (n_cls, 1, dim) 97 | ctx, # (n_cls, n_ctx, dim) 98 | suffix, # (n_cls, *, dim) 99 | ], 100 | dim=1, 101 | ) 102 | 103 | return prompts 104 | 105 | 106 | def run_coop(cfg, text_encoder, prompt_learner, clip_weights, clip_model, netG=None): 107 | coop_prompt_learner = CoOp_PromptLearner(all_classnames, clip_model) 108 | # optimizer = torch.optim.SGD(coop_prompt_learner.parameters(), lr=2e-3) 109 | optimizer = torch.optim.AdamW(coop_prompt_learner.parameters(), lr=1e-3, eps=1e-4) 110 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, cfg['train_epoch'] * len(train_loader_F)) 111 | 112 | best_base_acc, best_new_acc, best_H = 0, 0, 0 113 | best_epoch = 0 114 | 115 | for train_idx in range(cfg['train_epoch']): 116 | # Train 117 | loss_list = [] 118 | print('Train Epoch: {:} / {:}'.format(train_idx, cfg['train_epoch'])) 119 | 120 | for i, (images, target) in enumerate(tqdm(train_loader_F)): 121 | images, target = images.cuda(), target.cuda() 122 | 123 | with torch.no_grad(): 124 | image_features = clip_model.encode_image(images) 125 | image_features /= image_features.norm(dim=-1, keepdim=True) 126 | 127 | if netG is not None: 128 | with torch.no_grad(): 129 | gen_target = torch.randint(len(base_classnames), len(all_classnames), (target.shape[0], )).cuda() 130 | z = torch.randn([gen_target.shape[0], image_features.shape[1]]).cuda() 131 | text_features = clip_weights.T[gen_target].float() 132 | bias = netG(z) 133 | prompt_learner.get_prefix_suffix_token(all_classnames, clip_model) # update prefix and suffix for new dataset. 134 | prompts = prompt_learner(bias, gen_target) 135 | tokenized_prompts = prompt_learner.tokenized_prompts 136 | text_features = text_encoder(prompts, tokenized_prompts[gen_target]) 137 | gen_feature = text_features / text_features.norm(dim=-1, keepdim=True) 138 | gen_target = gen_target 139 | image_features = torch.cat([image_features, gen_feature], dim=0).half() 140 | target = torch.cat([target, gen_target], dim=0).half() 141 | 142 | prompts = coop_prompt_learner() 143 | tokenized_prompts = coop_prompt_learner.tokenized_prompts 144 | 145 | text_features = text_encoder(prompts, tokenized_prompts) 146 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 147 | logits = 100. * image_features.float() @ text_features.T.float() 148 | loss = F.cross_entropy(logits, target.long()) 149 | optimizer.zero_grad() 150 | loss.backward() 151 | optimizer.step() 152 | scheduler.step() 153 | 154 | # Evaluation 155 | with torch.no_grad(): 156 | prompts = coop_prompt_learner() 157 | tokenized_prompts = coop_prompt_learner.tokenized_prompts 158 | text_features = text_encoder(prompts, tokenized_prompts) 159 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 160 | 161 | # new 162 | clip_logits = 100. * test_features_new.float() @ text_features.T.float()[:, len(base_classnames):] 163 | new_acc = cls_acc(clip_logits, test_labels_new) 164 | 165 | # base 166 | clip_logits = 100. * test_features.float() @ text_features.T.float()[:, :len(base_classnames)] 167 | base_acc = cls_acc(clip_logits, test_labels) 168 | 169 | H = 2 * base_acc * new_acc / (base_acc + new_acc) 170 | if H > best_H: 171 | best_base_acc = base_acc 172 | best_new_acc = new_acc 173 | best_H = H 174 | best_epoch = train_idx 175 | 176 | print("base acc:\t%.2f new acc:\t%.2f H:\t%.2f " % (base_acc, new_acc, H)) 177 | 178 | print(f"**** After fine-tuning, CoOp's best base test accuracy: {best_base_acc:.2f}, at epoch: {best_epoch}. ****\n") 179 | print(f"**** After fine-tuning, CoOp's best new test accuracy: {best_new_acc:.2f}, at epoch: {best_epoch}. ****\n") 180 | print(f"**** After fine-tuning, CoOp's best H test accuracy: {best_H:.2f} ****\n") 181 | 182 | return best_base_acc, best_new_acc, best_H 183 | 184 | class TextEncoder(nn.Module): 185 | def __init__(self, clip_model): 186 | super().__init__() 187 | self.transformer = clip_model.transformer 188 | self.positional_embedding = clip_model.positional_embedding 189 | self.ln_final = clip_model.ln_final 190 | self.text_projection = clip_model.text_projection 191 | self.dtype = clip_model.dtype 192 | 193 | def forward(self, prompts, tokenized_prompts): 194 | x = prompts.half() + self.positional_embedding.type(self.dtype) 195 | x = x.permute(1, 0, 2) # NLD -> LND 196 | x = self.transformer(x.float()).half() # LayerNorm need to compute at fp32 for fp16 input 197 | x = x.permute(1, 0, 2) # LND -> NLD 198 | x = self.ln_final(x.float()).type(self.dtype) 199 | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection.type(self.dtype) 200 | return x 201 | 202 | class PromptLearner(nn.Module): 203 | def __init__(self, classnames, clip_model): 204 | super().__init__() 205 | n_cls = len(classnames) 206 | n_ctx = 4 207 | # ctx_init = 'a photo of a' 208 | ctx_init = None 209 | self.dtype = clip_model.dtype 210 | ctx_dim = clip_model.ln_final.weight.shape[0] 211 | 212 | self.n_cls = n_cls 213 | self.n_ctx = n_ctx 214 | 215 | if ctx_init: 216 | # use given words to initialize context vectors 217 | ctx_init = ctx_init.replace("_", " ") 218 | n_ctx = len(ctx_init.split(" ")) 219 | prompt = clip.tokenize(ctx_init).cuda() 220 | with torch.no_grad(): 221 | embedding = clip_model.token_embedding(prompt).type(self.dtype) 222 | ctx_vectors = embedding[0, 1 : 1 + n_ctx, :].cuda() 223 | prompt_prefix = ctx_init 224 | self.n_ctx = n_ctx 225 | else: 226 | # random initialization 227 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=self.dtype).cuda() 228 | nn.init.normal_(ctx_vectors, std=0.02) 229 | prompt_prefix = " ".join(["X"] * n_ctx) 230 | 231 | print(f'Initial context: "{prompt_prefix}"') 232 | print(f"Number of context words (tokens): {n_ctx}") 233 | 234 | self.ctx = nn.Parameter(ctx_vectors) 235 | self.prompt_prefix = prompt_prefix 236 | self.get_prefix_suffix_token(classnames, clip_model) 237 | 238 | def get_prefix_suffix_token(self, classnames, clip_model): 239 | prompt_prefix = self.prompt_prefix 240 | classnames = [name.replace("_", " ") for name in classnames] 241 | name_lens = [len(_tokenizer.encode(name)) for name in classnames] 242 | prompts = [prompt_prefix + " " + name + "." for name in classnames] 243 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).cuda() # (n_cls, n_tkn) 244 | with torch.no_grad(): 245 | embedding = clip_model.token_embedding(tokenized_prompts).type(self.dtype) 246 | 247 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS 248 | self.register_buffer("token_suffix", embedding[:, 1 + self.n_ctx :, :]) # CLS, EOS 249 | 250 | self.tokenized_prompts = tokenized_prompts # torch.Tensor 251 | self.name_lens = name_lens 252 | 253 | def forward(self, bias, target): 254 | prefix = self.token_prefix[target] 255 | suffix = self.token_suffix[target] 256 | ctx = self.ctx # (n_ctx, ctx_dim) 257 | bias = bias.unsqueeze(1) # (batch, 1, ctx_dim) 258 | ctx = ctx.unsqueeze(0) # (1, n_ctx, ctx_dim) 259 | ctx_shifted = ctx + bias # (batch, n_ctx, ctx_dim) 260 | prompts = torch.cat([prefix, ctx_shifted, suffix], dim=1) 261 | return prompts 262 | 263 | def vae_loss(recon_x, x, mean, log_var, target, clip_weights): 264 | REC = (recon_x - x).pow(2).sum(1).mean() 265 | KLD = -0.5 * (1 + log_var - mean.pow(2) - log_var.exp()).sum(dim=1).mean() 266 | return (REC + 1 * KLD) 267 | 268 | class Encoder(nn.Module): 269 | 270 | def __init__(self): 271 | super(Encoder, self).__init__() 272 | self.net = nn.Sequential( 273 | nn.Linear(512 * 1, 4096), 274 | nn.ReLU(), 275 | ) 276 | self.mean = nn.Linear(4096, 512) 277 | self.log_var = nn.Linear(4096, 512) 278 | self.apply(weights_init) 279 | 280 | def forward(self, x, a): 281 | # x = torch.cat([x, a], dim=1) 282 | x = self.net(x) 283 | mean = self.mean(x) 284 | log_var = self.log_var(x) 285 | return mean, log_var 286 | 287 | class Generator(nn.Module): 288 | 289 | def __init__(self): 290 | super(Generator, self).__init__() 291 | self.net = nn.Sequential( 292 | nn.Linear(512 * 1, 4096), 293 | nn.LeakyReLU(0.2), 294 | nn.Linear(4096, 512) 295 | ) 296 | self.apply(weights_init) 297 | 298 | def forward(self, x): 299 | out = self.net(x) 300 | return out 301 | 302 | def run_vae_generator(cfg, clip_weights, clip_model): 303 | # CLIP 304 | for p in clip_model.parameters(): 305 | p.requires_grad = False 306 | 307 | text_encoder = TextEncoder(clip_model).float().cuda() 308 | prompt_learner = PromptLearner(all_classnames, clip_model).float().cuda() 309 | 310 | # train VAE. 311 | netE = Encoder().cuda() 312 | netG = Generator().cuda() 313 | optimizerE = torch.optim.AdamW(netE.parameters(), lr=1e-3) 314 | optimizerG = torch.optim.AdamW(netG.parameters(), lr=1e-3) 315 | optimizerP = torch.optim.AdamW(prompt_learner.parameters(), lr=1e-3) 316 | 317 | best_base, best_new, best_H = 0.0, 0.0, 0.0 318 | 319 | for train_idx in range(1, 10 + 1): 320 | # Train 321 | netE.train() 322 | netG.train() 323 | 324 | loss_list = [] 325 | print('Train Epoch: {:} / {:}'.format(train_idx, cfg['train_epoch'])) 326 | 327 | for i, (images, target) in enumerate(tqdm(train_loader_F)): 328 | images, target = images.cuda(), target.cuda() 329 | with torch.no_grad(): 330 | image_features = clip_model.encode_image(images).float() 331 | image_features /= image_features.norm(dim=-1, keepdim=True) 332 | 333 | text_features = clip_weights.T[target].float() 334 | netE.zero_grad() 335 | netG.zero_grad() 336 | mean, log_var = netE(image_features, text_features) 337 | std = torch.exp(0.5 * log_var) 338 | z = torch.randn(mean.shape).cuda() 339 | z = std * z + mean 340 | bias = netG(z) 341 | prompt_learner.get_prefix_suffix_token(base_classnames, clip_model) 342 | prompts = prompt_learner(bias, target) 343 | tokenized_prompts = prompt_learner.tokenized_prompts 344 | text_features = text_encoder(prompts, tokenized_prompts[target]) 345 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 346 | recon_features = text_features 347 | loss = vae_loss(recon_features, image_features, mean, log_var, target, clip_weights) 348 | optimizerP.zero_grad() 349 | loss.backward() 350 | loss_list.append(loss.item()) 351 | optimizerE.step() 352 | optimizerG.step() 353 | optimizerP.step() 354 | 355 | print('Loss: {:.4f}'.format(sum(loss_list)/len(loss_list))) 356 | if train_idx % 10 == 0: 357 | # Evaluation. 358 | netE.eval() 359 | netG.eval() 360 | clip_weights_mix = torch.cat([clip_weights, clip_weights_new], dim=1) 361 | base, new, H = run_coop(cfg, text_encoder, prompt_learner, clip_weights_mix, clip_model, netG=netG) 362 | if H > best_H: 363 | best_base = base 364 | best_new = new 365 | best_H = H 366 | best_epoch = train_idx 367 | print("base acc:\t%.2f new acc:\t%.2f H:\t%.2f " % (base, new, H)) 368 | 369 | print("Evaluate on dataset:", cfg['dataset']) 370 | print("best base acc: %.2f" % best_base) 371 | print("best new acc: %.2f" % best_new) 372 | print("best H: %.2f" % best_H) 373 | 374 | def main(): 375 | 376 | # Load config file 377 | args = get_arguments() 378 | assert (os.path.exists(args.config)) 379 | 380 | cfg = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 381 | 382 | cache_dir = os.path.join('./caches', cfg['dataset']) 383 | os.makedirs(cache_dir, exist_ok=True) 384 | cfg['cache_dir'] = cache_dir 385 | 386 | print("\nRunning configs.") 387 | print(cfg, "\n") 388 | 389 | # CLIP 390 | clip_model, preprocess = clip.load(cfg['backbone']) 391 | clip_model.eval() 392 | 393 | # ImageNet dataset 394 | random.seed(1) 395 | torch.manual_seed(1) 396 | 397 | global test_features, test_labels, train_loader_F 398 | global test_features_new, test_labels_new, train_loader_F_new 399 | global base_classnames, new_classnames, all_classnames 400 | global clip_weights_new 401 | 402 | print("Preparing ImageNet dataset.") 403 | # base classses 404 | cfg['subsample_classes'] = 'base' # all/base/new 405 | imagenet = ImageNet(cfg, cfg['root_path'], cfg['shots'], preprocess) 406 | test_loader = torch.utils.data.DataLoader(imagenet.test, batch_size=64, num_workers=8, shuffle=False) 407 | train_loader_F = torch.utils.data.DataLoader(imagenet.train, batch_size=256, num_workers=8, shuffle=True) 408 | # Textual features 409 | print("Getting textual features as CLIP's classifier.") 410 | clip_weights = clip_classifier(imagenet.classnames, imagenet.template, clip_model) 411 | base_classnames = imagenet.classnames 412 | 413 | # Pre-load test features 414 | print("\nLoading visual features and labels from test set.") 415 | test_features, test_labels = pre_load_features(cfg, "test", clip_model, test_loader) 416 | 417 | # new classes 418 | cfg['subsample_classes'] = 'new' # all/base/new 419 | imagenet = ImageNet(cfg, cfg['root_path'], cfg['shots'], preprocess) 420 | test_loader = torch.utils.data.DataLoader(imagenet.test, batch_size=64, num_workers=8, shuffle=False) 421 | test_features_new, test_labels_new = pre_load_features(cfg, "test", clip_model, test_loader) 422 | test_labels_new = test_labels_new - 500 423 | clip_weights_new = clip_classifier(imagenet.classnames, imagenet.template, clip_model) 424 | new_classnames = imagenet.classnames 425 | 426 | all_classnames = base_classnames + new_classnames 427 | 428 | # 429 | run_vae_generator(cfg, clip_weights, clip_model) 430 | 431 | 432 | if __name__ == '__main__': 433 | main() 434 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | flake8==3.7.9 2 | yapf==0.29.0 3 | isort==4.3.21 4 | yacs 5 | gdown 6 | tb-nightly 7 | future 8 | scipy 9 | scikit-learn 10 | tqdm 11 | ftfy 12 | regex 13 | wilds==1.2.2 14 | tabulate -------------------------------------------------------------------------------- /scripts/coop_vae.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main_coop_vae.py --config coop-configs/caltech101.yaml 2 | CUDA_VISIBLE_DEVICES=0 python main_coop_vae.py --config coop-configs/oxford_pets.yaml 3 | CUDA_VISIBLE_DEVICES=0 python main_coop_vae.py --config coop-configs/stanford_cars.yaml 4 | CUDA_VISIBLE_DEVICES=0 python main_coop_vae.py --config coop-configs/oxford_flowers.yaml 5 | CUDA_VISIBLE_DEVICES=0 python main_coop_vae.py --config coop-configs/food101.yaml 6 | CUDA_VISIBLE_DEVICES=0 python main_coop_vae.py --config coop-configs/fgvc.yaml 7 | CUDA_VISIBLE_DEVICES=0 python main_coop_vae.py --config coop-configs/sun397.yaml 8 | CUDA_VISIBLE_DEVICES=0 python main_coop_vae.py --config coop-configs/dtd.yaml 9 | CUDA_VISIBLE_DEVICES=0 python main_coop_vae.py --config coop-configs/eurosat.yaml 10 | CUDA_VISIBLE_DEVICES=0 python main_coop_vae.py --config coop-configs/ucf101.yaml 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | 7 | import clip 8 | 9 | 10 | def cls_acc(output, target, topk=1): 11 | pred = output.topk(topk, 1, True, True)[1].t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | acc = float(correct[: topk].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) 14 | acc = 100 * acc / target.shape[0] 15 | return acc 16 | 17 | 18 | def clip_classifier(classnames, template, clip_model, norm=True): 19 | with torch.no_grad(): 20 | clip_weights = [] 21 | 22 | for classname in classnames: 23 | # Tokenize the prompts 24 | classname = classname.replace('_', ' ') 25 | texts = [t.format(classname) for t in template] 26 | texts = clip.tokenize(texts, context_length=77).cuda() 27 | # prompt ensemble for ImageNet 28 | class_embeddings = clip_model.encode_text(texts) 29 | if norm: 30 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) 31 | class_embedding = class_embeddings.mean(dim=0) 32 | if norm: 33 | class_embedding /= class_embedding.norm() 34 | clip_weights.append(class_embedding) 35 | 36 | clip_weights = torch.stack(clip_weights, dim=1).cuda() 37 | return clip_weights 38 | 39 | 40 | def build_cache_model(cfg, clip_model, train_loader_cache): 41 | 42 | if cfg['load_cache'] == False: 43 | cache_keys = [] 44 | cache_values = [] 45 | 46 | with torch.no_grad(): 47 | # Data augmentation for the cache model 48 | for augment_idx in range(cfg['augment_epoch']): 49 | train_features = [] 50 | 51 | print('Augment Epoch: {:} / {:}'.format(augment_idx, cfg['augment_epoch'])) 52 | for i, (images, target) in enumerate(tqdm(train_loader_cache)): 53 | images = images.cuda() 54 | image_features = clip_model.encode_image(images) 55 | train_features.append(image_features) 56 | if augment_idx == 0: 57 | target = target.cuda() 58 | cache_values.append(target) 59 | cache_keys.append(torch.cat(train_features, dim=0).unsqueeze(0)) 60 | 61 | # cache_keys = torch.cat(cache_keys, dim=0).mean(dim=0) 62 | cache_keys = torch.cat(cache_keys, dim=0) 63 | cache_keys /= cache_keys.norm(dim=-1, keepdim=True) 64 | cache_keys = cache_keys.mean(dim=0) 65 | cache_keys /= cache_keys.norm(dim=-1, keepdim=True) 66 | cache_keys = cache_keys.permute(1, 0) 67 | cache_values = F.one_hot(torch.cat(cache_values, dim=0)).half() 68 | 69 | torch.save(cache_keys, cfg['cache_dir'] + '/keys_' + str(cfg['shots']) + "shots.pt") 70 | torch.save(cache_values, cfg['cache_dir'] + '/values_' + str(cfg['shots']) + "shots.pt") 71 | 72 | else: 73 | cache_keys = torch.load(cfg['cache_dir'] + '/keys_' + str(cfg['shots']) + "shots.pt") 74 | cache_values = torch.load(cfg['cache_dir'] + '/values_' + str(cfg['shots']) + "shots.pt") 75 | 76 | return cache_keys, cache_values 77 | 78 | 79 | def pre_load_features(cfg, split, clip_model, loader, norm=True): 80 | 81 | if cfg['load_pre_feat'] == False: 82 | features, labels = [], [] 83 | 84 | with torch.no_grad(): 85 | for i, (images, target) in enumerate(tqdm(loader)): 86 | 87 | images, target = images.cuda(), target.cuda() 88 | image_features = clip_model.encode_image(images) 89 | if norm: 90 | image_features /= image_features.norm(dim=-1, keepdim=True) 91 | features.append(image_features) 92 | labels.append(target) 93 | 94 | features, labels = torch.cat(features), torch.cat(labels) 95 | 96 | # torch.save(features, cfg['cache_dir'] + "/" + split + "_f.pt") 97 | # torch.save(labels, cfg['cache_dir'] + "/" + split + "_l.pt") 98 | 99 | else: 100 | features = torch.load(cfg['cache_dir'] + "/" + split + "_f.pt") 101 | labels = torch.load(cfg['cache_dir'] + "/" + split + "_l.pt") 102 | 103 | return features, labels 104 | 105 | 106 | def search_hp(cfg, cache_keys, cache_values, features, labels, clip_weights, adapter=None): 107 | 108 | if cfg['search_hp'] == True: 109 | 110 | beta_list = [i * (cfg['search_scale'][0] - 0.1) / cfg['search_step'][0] + 0.1 for i in range(cfg['search_step'][0])] 111 | alpha_list = [i * (cfg['search_scale'][1] - 0.1) / cfg['search_step'][1] + 0.1 for i in range(cfg['search_step'][1])] 112 | 113 | best_acc = 0 114 | best_beta, best_alpha = 0, 0 115 | 116 | for beta in beta_list: 117 | for alpha in alpha_list: 118 | if adapter: 119 | affinity = adapter(features) 120 | else: 121 | affinity = features @ cache_keys 122 | 123 | cache_logits = ((-1) * (beta - beta * affinity)).exp() @ cache_values 124 | clip_logits = 100. * features @ clip_weights 125 | tip_logits = clip_logits + cache_logits * alpha 126 | acc = cls_acc(tip_logits, labels) 127 | 128 | if acc > best_acc: 129 | print("New best setting, beta: {:.2f}, alpha: {:.2f}; accuracy: {:.2f}".format(beta, alpha, acc)) 130 | best_acc = acc 131 | best_beta = beta 132 | best_alpha = alpha 133 | 134 | print("\nAfter searching, the best accuarcy: {:.2f}.\n".format(best_acc)) 135 | 136 | return best_beta, best_alpha 137 | --------------------------------------------------------------------------------