├── .gitignore ├── DATASET.md ├── README.md ├── Results.xlsx ├── cache_model.png ├── calch_load.ipynb ├── clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── clip.py ├── model.py └── simple_tokenizer.py ├── configs ├── caltech101.yaml ├── dtd.yaml ├── eurosat.yaml ├── fgvc.yaml ├── food101.yaml ├── imagenet.yaml ├── oxford_flowers.yaml ├── oxford_pets.yaml ├── stanford_cars.yaml ├── sun397.yaml └── ucf101.yaml ├── datasets ├── __init__.py ├── caltech101.py ├── dtd.py ├── eurosat.py ├── fgvc.py ├── food101.py ├── imagenet.py ├── oxford_flowers.py ├── oxford_pets.py ├── stanford_cars.py ├── sun397.py ├── ucf101.py └── utils.py ├── draw_curves.py ├── draw_curves_full.py ├── exp.log ├── main.py ├── main_curves ├── Caltech101.pdf ├── DTD.pdf ├── EuroSAT.pdf ├── FGVCAircraft.pdf ├── Flowers102.pdf ├── Food101.pdf ├── ImageNet.pdf ├── OxfordPets.pdf ├── SUN397.pdf ├── StanfordCars.pdf ├── UCF101.pdf └── average.pdf ├── main_curves_full ├── Caltech101.pdf ├── DTD.pdf ├── EuroSAT.pdf ├── FGVCAircraft.pdf ├── Flowers102.pdf ├── Food101.pdf ├── ImageNet.pdf ├── OxfordPets.pdf ├── SUN397.pdf ├── StanfordCars.pdf ├── UCF101.pdf └── average.pdf ├── main_imagenet.py ├── requirements.txt ├── scripts └── data.sh └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Custom 132 | output/ 133 | debug.sh 134 | 135 | #by jason 136 | data 137 | *.pt 138 | Dassl.pytorch -------------------------------------------------------------------------------- /DATASET.md: -------------------------------------------------------------------------------- 1 | # How to install datasets 2 | 3 | We suggest putting all datasets under the same folder (say `$DATA`) to ease management and following the instructions below to organize datasets to avoid modifying the source code. The file structure looks like 4 | 5 | ``` 6 | $DATA/ 7 | |–– imagenet/ 8 | |–– caltech-101/ 9 | |–– oxford_pets/ 10 | |–– stanford_cars/ 11 | ``` 12 | 13 | If you have some datasets already installed somewhere else, you can create symbolic links in `$DATA/dataset_name` that point to the original data to avoid duplicate download. 14 | 15 | Datasets list: 16 | - [ImageNet](#imagenet) 17 | - [Caltech101](#caltech101) 18 | - [OxfordPets](#oxfordpets) 19 | - [StanfordCars](#stanfordcars) 20 | - [Flowers102](#flowers102) 21 | - [Food101](#food101) 22 | - [FGVCAircraft](#fgvcaircraft) 23 | - [SUN397](#sun397) 24 | - [DTD](#dtd) 25 | - [EuroSAT](#eurosat) 26 | - [UCF101](#ucf101) 27 | 28 | The instructions to prepare each dataset are detailed below. To ensure reproducibility and fair comparison for future work, we utilize CoOp-style train/val/test splits for all datasets except ImageNet where the validation set is used as test set. 29 | 30 | ### ImageNet 31 | - Create a folder named `imagenet/` under `$DATA`. 32 | - Create `images/` under `imagenet/`. 33 | - Download the dataset from the [official website](https://image-net.org/index.php) and extract the training and validation sets to `$DATA/imagenet/images`. The directory structure should look like 34 | ``` 35 | imagenet/ 36 | |–– images/ 37 | | |–– train/ # contains 1,000 folders like n01440764, n01443537, etc. 38 | | |–– val/ 39 | ``` 40 | - If you had downloaded the ImageNet dataset before, you can create symbolic links to map the training and validation sets to `$DATA/imagenet/images`. 41 | - Download the `classnames.txt` to `$DATA/imagenet/` from this [link](https://drive.google.com/file/d/1-61f_ol79pViBFDG_IDlUQSwoLcn2XXF/view?usp=sharing). The class names are copied from [CLIP](https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb). 42 | 43 | ### Caltech101 44 | - Create a folder named `caltech-101/` under `$DATA`. 45 | - Download `101_ObjectCategories.tar.gz` from http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz and extract the file under `$DATA/caltech-101`. 46 | - Download `split_zhou_Caltech101.json` from this [link](https://drive.google.com/file/d/1hyarUivQE36mY6jSomru6Fjd-JzwcCzN/view?usp=sharing) and put it under `$DATA/caltech-101`. 47 | 48 | The directory structure should look like 49 | ``` 50 | caltech-101/ 51 | |–– 101_ObjectCategories/ 52 | |–– split_zhou_Caltech101.json 53 | ``` 54 | 55 | ### OxfordPets 56 | - Create a folder named `oxford_pets/` under `$DATA`. 57 | - Download the images from https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz. 58 | - Download the annotations from https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz. 59 | - Download `split_zhou_OxfordPets.json` from this [link](https://drive.google.com/file/d/1501r8Ber4nNKvmlFVQZ8SeUHTcdTTEqs/view?usp=sharing). 60 | 61 | The directory structure should look like 62 | ``` 63 | oxford_pets/ 64 | |–– images/ 65 | |–– annotations/ 66 | |–– split_zhou_OxfordPets.json 67 | ``` 68 | 69 | ### StanfordCars 70 | - Create a folder named `stanford_cars/` under `$DATA`. 71 | - Download the train images http://ai.stanford.edu/~jkrause/car196/cars_train.tgz. 72 | - Download the test images http://ai.stanford.edu/~jkrause/car196/cars_test.tgz. 73 | - Download the train labels https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz. 74 | - Download the test labels http://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat. 75 | - Download `split_zhou_StanfordCars.json` from this [link](https://drive.google.com/file/d/1ObCFbaAgVu0I-k_Au-gIUcefirdAuizT/view?usp=sharing). 76 | 77 | The directory structure should look like 78 | ``` 79 | stanford_cars/ 80 | |–– cars_test\ 81 | |–– cars_test_annos_withlabels.mat 82 | |–– cars_train\ 83 | |–– devkit\ 84 | |–– split_zhou_StanfordCars.json 85 | ``` 86 | 87 | ### Flowers102 88 | - Create a folder named `oxford_flowers/` under `$DATA`. 89 | - Download the images and labels from https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz and https://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat respectively. 90 | - Download `cat_to_name.json` from [here](https://drive.google.com/file/d/1AkcxCXeK_RCGCEC_GvmWxjcjaNhu-at0/view?usp=sharing). 91 | - Download `split_zhou_OxfordFlowers.json` from [here](https://drive.google.com/file/d/1Pp0sRXzZFZq15zVOzKjKBu4A9i01nozT/view?usp=sharing). 92 | 93 | The directory structure should look like 94 | ``` 95 | oxford_flowers/ 96 | |–– cat_to_name.json 97 | |–– imagelabels.mat 98 | |–– jpg/ 99 | |–– split_zhou_OxfordFlowers.json 100 | ``` 101 | 102 | ### Food101 103 | - Download the dataset from https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/ and extract the file `food-101.tar.gz` under `$DATA`, resulting in a folder named `$DATA/food-101/`. 104 | - Download `split_zhou_Food101.json` from [here](https://drive.google.com/file/d/1QK0tGi096I0Ba6kggatX1ee6dJFIcEJl/view?usp=sharing). 105 | 106 | The directory structure should look like 107 | ``` 108 | food-101/ 109 | |–– images/ 110 | |–– license_agreement.txt 111 | |–– meta/ 112 | |–– README.txt 113 | |–– split_zhou_Food101.json 114 | ``` 115 | 116 | ### FGVCAircraft 117 | - Download the data from https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz. 118 | - Extract `fgvc-aircraft-2013b.tar.gz` and keep only `data/`. 119 | - Move `data/` to `$DATA` and rename the folder to `fgvc_aircraft/`. 120 | 121 | The directory structure should look like 122 | ``` 123 | fgvc_aircraft/ 124 | |–– images/ 125 | |–– ... # a bunch of .txt files 126 | ``` 127 | 128 | ### SUN397 129 | - Create a folder named `sun397/` under `$DATA`. 130 | - Download the images http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz. 131 | - Download the partitions https://vision.princeton.edu/projects/2010/SUN/download/Partitions.zip. 132 | - Extract these files under `$DATA/sun397/`. 133 | - Download `split_zhou_SUN397.json` from this [link](https://drive.google.com/file/d/1y2RD81BYuiyvebdN-JymPfyWYcd8_MUq/view?usp=sharing). 134 | 135 | The directory structure should look like 136 | ``` 137 | sun397/ 138 | |–– SUN397/ 139 | |–– split_zhou_SUN397.json 140 | |–– ... # a bunch of .txt files 141 | ``` 142 | 143 | ### DTD 144 | - Download the dataset from https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz and extract it to `$DATA`. This should lead to `$DATA/dtd/`. 145 | - Download `split_zhou_DescribableTextures.json` from this [link](https://drive.google.com/file/d/1u3_QfB467jqHgNXC00UIzbLZRQCg2S7x/view?usp=sharing). 146 | 147 | The directory structure should look like 148 | ``` 149 | dtd/ 150 | |–– images/ 151 | |–– imdb/ 152 | |–– labels/ 153 | |–– split_zhou_DescribableTextures.json 154 | ``` 155 | 156 | ### EuroSAT 157 | - Create a folder named `eurosat/` under `$DATA`. 158 | - Download the dataset from http://madm.dfki.de/files/sentinel/EuroSAT.zip and extract it to `$DATA/eurosat/`. 159 | - Download `split_zhou_EuroSAT.json` from [here](https://drive.google.com/file/d/1Ip7yaCWFi0eaOFUGga0lUdVi_DDQth1o/view?usp=sharing). 160 | 161 | The directory structure should look like 162 | ``` 163 | eurosat/ 164 | |–– 2750/ 165 | |–– split_zhou_EuroSAT.json 166 | ``` 167 | 168 | ### UCF101 169 | - Create a folder named `ucf101/` under `$DATA`. 170 | - Download the zip file `UCF-101-midframes.zip` from [here](https://drive.google.com/file/d/10Jqome3vtUA2keJkNanAiFpgbyC9Hc2O/view?usp=sharing) and extract it to `$DATA/ucf101/`. This zip file contains the extracted middle video frames. 171 | - Download `split_zhou_UCF101.json` from this [link](https://drive.google.com/file/d/1I0S0q91hJfsV9Gf4xDIjgDq4AqBNJb1y/view?usp=sharing). 172 | 173 | The directory structure should look like 174 | ``` 175 | ucf101/ 176 | |–– UCF-101-midframes/ 177 | |–– split_zhou_UCF101.json 178 | ``` 179 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Prompt Tuning based Adapter for Vision-Language Model Adaption 2 | 3 | 4 | ### Step 1: Installation 5 | Create a conda environment and install dependencies: 6 | ```bash 7 | conda create -y -n torch180 python=3.8 8 | conda activate torch180 9 | pip3 install torch==1.8.2 torchvision==0.9.2 torchaudio==0.8.2 --extra-index-url https://download.pytorch.org/whl/lts/1.8/cu111 10 | 11 | pip install -r requirements.txt 12 | 13 | ``` 14 | 15 | ### Step 2: Dataset 16 | Follow [DATASETS.md](https://github.com/gaopengcuhk/Tip-Adapter/blob/main/DATASET.md) to install the datasets used in the paper. Or run the following script(11 datasets, include ImageNet): 17 | ```bash 18 | bash scripts/data.sh 19 | ``` 20 | 21 | 22 | ### Step 3: Prompt Download 23 | Download the pretrained prompt from the [link](https://drive.google.com/file/d/1bfCXO9iE3ys3__xnOrC6bHAVXVcFXkyW/view?usp=share_link) 24 | And decompress it under the folder `prompt_adapter/prompt_tensor_init`. 25 | ```bash 26 | tar -xvf prompt_tensor_init.tar 27 | ``` 28 | 29 | 30 | ### Step 4: Change Configs 31 | 32 | The running configurations can be modified in `configs/dataset.yaml`, including shot numbers, visual encoders, and hyperparamters. 33 | 34 | For our evauation of 1shot, 2shots, 4shots, 8shots, 16shots, 20shots, YOU NEED to change the shots first and then running the follow script. 35 | 36 | Note that the default `load_cache` and `load_pre_feat` are `False` for the first running, which will store the cache model and val/test features in `configs/dataset/`. For later running, they can be set as `True` for faster hyperparamters tuning. 37 | 38 | 39 | ### Step 5: Running 40 | For ImageNet dataset: 41 | ```bash 42 | python main_imagenet.py --config configs/imagenet.yaml 43 | ``` 44 | For other 10 datasets: 45 | ```bash 46 | python main.py --config configs/oxford_pets.yaml 47 | ``` 48 | 49 | 50 | 51 | ## Acknowledgement 52 | This repo benefits from [Tip-Adapter](https://github.com/gaopengcuhk/Tip-Adapter) and [CoOp](https://github.com/KaiyangZhou/Dassl.pytorch). Thanks for their wonderful works. 53 | 54 | ## Citation 55 | ```bash 56 | @article{sun2023prompt, 57 | title={Prompt Tuning based Adapter for Vision-Language Model Adaption}, 58 | author={Sun, Jingchen and Qin, Jiayu and Lin, Zihao and Chen, Changyou}, 59 | journal={arXiv preprint arXiv:2303.15234}, 60 | year={2023} 61 | } 62 | ``` -------------------------------------------------------------------------------- /Results.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jingchensun/Prompt-Adapter/f9cd186a01198fbae6aa1bd90a7996410ac8b66a/Results.xlsx -------------------------------------------------------------------------------- /cache_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jingchensun/Prompt-Adapter/f9cd186a01198fbae6aa1bd90a7996410ac8b66a/cache_model.png -------------------------------------------------------------------------------- /calch_load.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 31, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "imagenet\n", 13 | "torch.Size([512, 16000])\n", 14 | "caltech101\n", 15 | "max: tensor(1000, device='cuda:0')\n", 16 | "torch.Size([512, 1600])\n", 17 | "food101\n", 18 | "max: tensor(1100, device='cuda:0')\n", 19 | "torch.Size([512, 1616])\n", 20 | "stanford_cars\n", 21 | "max: tensor(1201, device='cuda:0')\n", 22 | "torch.Size([512, 3136])\n", 23 | "oxford_pets\n", 24 | "max: tensor(1397, device='cuda:0')\n", 25 | "torch.Size([512, 592])\n", 26 | "oxford_flowers\n", 27 | "max: tensor(1434, device='cuda:0')\n", 28 | "torch.Size([512, 1632])\n", 29 | "fgvc\n", 30 | "max: tensor(1536, device='cuda:0')\n", 31 | "torch.Size([512, 1600])\n", 32 | "sun397\n", 33 | "max: tensor(1636, device='cuda:0')\n", 34 | "torch.Size([512, 6352])\n", 35 | "dtd\n", 36 | "max: tensor(2033, device='cuda:0')\n", 37 | "torch.Size([512, 752])\n", 38 | "eurosat\n", 39 | "max: tensor(2080, device='cuda:0')\n", 40 | "torch.Size([512, 160])\n", 41 | "ucf101\n", 42 | "max: tensor(2090, device='cuda:0')\n", 43 | "torch.Size([512, 1616])\n", 44 | "torch.Size([512, 35056]) torch.Size([35056])\n", 45 | "torch.Size([35056, 2191])\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "import torch\n", 51 | "import torch.nn.functional as F\n", 52 | "arr = ['imagenet','caltech101', 'food101', 'stanford_cars', 'oxford_pets', 'oxford_flowers', 'fgvc', 'sun397','dtd', 'eurosat', 'ucf101']\n", 53 | "#arr = ['dtd', 'eurosat', 'oxford_pets']\n", 54 | "#ImageNet,Caltech101,Food101,StanfordCars,OxfordPets,OxfordFlowers,FGVCAircraft,SUN397,DescribableTextures,EuroSAT,UCF101\n", 55 | "keys = []\n", 56 | "val = []\n", 57 | "l= 0\n", 58 | "max = 0\n", 59 | "\n", 60 | "for i in range(len(arr)):\n", 61 | " print(arr[i])\n", 62 | " if i == 0:\n", 63 | " max = 0\n", 64 | " cache_values = torch.load('caches2/' + arr[i] + '/values_' + str(16) + \"shots.pt\")\n", 65 | " cache_values += max\n", 66 | " val.append(cache_values)\n", 67 | " \n", 68 | " else:\n", 69 | " cache_values_p = torch.load('caches2/' + arr[i-1] + '/values_' + str(16) + \"shots.pt\")\n", 70 | " cache_values = torch.load('caches2/' + arr[i] + '/values_' + str(16) + \"shots.pt\")\n", 71 | " max = max + torch.max(cache_values_p)+1\n", 72 | " print('max:',max)\n", 73 | " cache_values += max\n", 74 | " val.append(cache_values)\n", 75 | " \n", 76 | " cache_keys = torch.load('caches2/' + arr[i] + '/keys_' + str(16) + \"shots.pt\")\n", 77 | " keys.append(cache_keys)\n", 78 | " print(cache_keys.size())\n", 79 | "\n", 80 | "x = torch.cat(keys, dim=1)\n", 81 | "y = torch.cat(val, dim=0)\n", 82 | "z = F.one_hot(y).half()\n", 83 | "torch.save(x, 'caches2/' + '/multi_task_keys_' + str(16) + \"shots.pt\")\n", 84 | "torch.save(z, 'caches2/' + '/multi_task_values_' + str(16) + \"shots.pt\")\n", 85 | "print(x.size(),y.size())\n", 86 | "print(z.size())" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [] 95 | } 96 | ], 97 | "metadata": { 98 | "kernelspec": { 99 | "display_name": "torch171", 100 | "language": "python", 101 | "name": "torch171" 102 | }, 103 | "language_info": { 104 | "codemirror_mode": { 105 | "name": "ipython", 106 | "version": 3 107 | }, 108 | "file_extension": ".py", 109 | "mimetype": "text/x-python", 110 | "name": "python", 111 | "nbconvert_exporter": "python", 112 | "pygments_lexer": "ipython3", 113 | "version": "3.8.13" 114 | } 115 | }, 116 | "nbformat": 4, 117 | "nbformat_minor": 4 118 | } 119 | -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jingchensun/Prompt-Adapter/f9cd186a01198fbae6aa1bd90a7996410ac8b66a/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | try: 16 | from torchvision.transforms import InterpolationMode 17 | BICUBIC = InterpolationMode.BICUBIC 18 | except ImportError: 19 | BICUBIC = Image.BICUBIC 20 | 21 | 22 | if torch.__version__.split(".") < ["1", "7", "1"]: 23 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 24 | 25 | 26 | __all__ = ["available_models", "load", "tokenize"] 27 | _tokenizer = _Tokenizer() 28 | 29 | _MODELS = { 30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 34 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 35 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 36 | } 37 | 38 | 39 | def _download(url: str, root: str): 40 | os.makedirs(root, exist_ok=True) 41 | filename = os.path.basename(url) 42 | 43 | expected_sha256 = url.split("/")[-2] 44 | download_target = os.path.join(root, filename) 45 | 46 | if os.path.exists(download_target) and not os.path.isfile(download_target): 47 | raise RuntimeError(f"{download_target} exists and is not a regular file") 48 | 49 | if os.path.isfile(download_target): 50 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 51 | return download_target 52 | else: 53 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 54 | 55 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 56 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 57 | while True: 58 | buffer = source.read(8192) 59 | if not buffer: 60 | break 61 | 62 | output.write(buffer) 63 | loop.update(len(buffer)) 64 | 65 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 66 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 67 | 68 | return download_target 69 | 70 | 71 | def _convert_image_to_rgb(image): 72 | return image.convert("RGB") 73 | 74 | 75 | def _transform(n_px): 76 | return Compose([ 77 | Resize(n_px, interpolation=BICUBIC), 78 | CenterCrop(n_px), 79 | _convert_image_to_rgb, 80 | ToTensor(), 81 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 82 | ]) 83 | 84 | 85 | def available_models() -> List[str]: 86 | """Returns the names of available CLIP models""" 87 | return list(_MODELS.keys()) 88 | 89 | 90 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 91 | """Load a CLIP model 92 | 93 | Parameters 94 | ---------- 95 | name : str 96 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 97 | 98 | device : Union[str, torch.device] 99 | The device to put the loaded model 100 | 101 | jit : bool 102 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 103 | 104 | download_root: str 105 | path to download the model files; by default, it uses "~/.cache/clip" 106 | 107 | Returns 108 | ------- 109 | model : torch.nn.Module 110 | The CLIP model 111 | 112 | preprocess : Callable[[PIL.Image], torch.Tensor] 113 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 114 | """ 115 | if name in _MODELS: 116 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 117 | elif os.path.isfile(name): 118 | model_path = name 119 | else: 120 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 121 | 122 | try: 123 | # loading JIT archive 124 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 125 | state_dict = None 126 | except RuntimeError: 127 | # loading saved state dict 128 | if jit: 129 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 130 | jit = False 131 | state_dict = torch.load(model_path, map_location="cpu") 132 | 133 | if not jit: 134 | model = build_model(state_dict or model.state_dict()).to(device) 135 | if str(device) == "cpu": 136 | model.float() 137 | return model, _transform(model.visual.input_resolution) 138 | 139 | # patch the device names 140 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 141 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 142 | 143 | def patch_device(module): 144 | try: 145 | graphs = [module.graph] if hasattr(module, "graph") else [] 146 | except RuntimeError: 147 | graphs = [] 148 | 149 | if hasattr(module, "forward1"): 150 | graphs.append(module.forward1.graph) 151 | 152 | for graph in graphs: 153 | for node in graph.findAllNodes("prim::Constant"): 154 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 155 | node.copyAttributes(device_node) 156 | 157 | model.apply(patch_device) 158 | patch_device(model.encode_image) 159 | patch_device(model.encode_text) 160 | 161 | # patch dtype to float32 on CPU 162 | if str(device) == "cpu": 163 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 164 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 165 | float_node = float_input.node() 166 | 167 | def patch_float(module): 168 | try: 169 | graphs = [module.graph] if hasattr(module, "graph") else [] 170 | except RuntimeError: 171 | graphs = [] 172 | 173 | if hasattr(module, "forward1"): 174 | graphs.append(module.forward1.graph) 175 | 176 | for graph in graphs: 177 | for node in graph.findAllNodes("aten::to"): 178 | inputs = list(node.inputs()) 179 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 180 | if inputs[i].node()["value"] == 5: 181 | inputs[i].node().copyAttributes(float_node) 182 | 183 | model.apply(patch_float) 184 | patch_float(model.encode_image) 185 | patch_float(model.encode_text) 186 | 187 | model.float() 188 | 189 | return model, _transform(model.input_resolution.item()) 190 | 191 | 192 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 193 | """ 194 | Returns the tokenized representation of given input string(s) 195 | 196 | Parameters 197 | ---------- 198 | texts : Union[str, List[str]] 199 | An input string or a list of input strings to tokenize 200 | 201 | context_length : int 202 | The context length to use; all CLIP models use 77 as the context length 203 | 204 | truncate: bool 205 | Whether to truncate the text in case its encoding is longer than the context length 206 | 207 | Returns 208 | ------- 209 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 210 | """ 211 | if isinstance(texts, str): 212 | texts = [texts] 213 | 214 | sot_token = _tokenizer.encoder["<|startoftext|>"] 215 | eot_token = _tokenizer.encoder["<|endoftext|>"] 216 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 217 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 218 | 219 | for i, tokens in enumerate(all_tokens): 220 | if len(tokens) > context_length: 221 | if truncate: 222 | tokens = tokens[:context_length] 223 | tokens[-1] = eot_token 224 | else: 225 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 226 | result[i, :len(tokens)] = torch.tensor(tokens) 227 | 228 | return result 229 | -------------------------------------------------------------------------------- /clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | 20 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 24 | 25 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 26 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 27 | 28 | self.relu = nn.ReLU(inplace=True) 29 | self.downsample = None 30 | self.stride = stride 31 | 32 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 33 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 34 | self.downsample = nn.Sequential(OrderedDict([ 35 | ("-1", nn.AvgPool2d(stride)), 36 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 37 | ("1", nn.BatchNorm2d(planes * self.expansion)) 38 | ])) 39 | 40 | def forward(self, x: torch.Tensor): 41 | identity = x 42 | 43 | out = self.relu(self.bn1(self.conv1(x))) 44 | out = self.relu(self.bn2(self.conv2(out))) 45 | out = self.avgpool(out) 46 | out = self.bn3(self.conv3(out)) 47 | 48 | if self.downsample is not None: 49 | identity = self.downsample(x) 50 | 51 | out += identity 52 | out = self.relu(out) 53 | return out 54 | 55 | 56 | class AttentionPool2d(nn.Module): 57 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 58 | super().__init__() 59 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 60 | self.k_proj = nn.Linear(embed_dim, embed_dim) 61 | self.q_proj = nn.Linear(embed_dim, embed_dim) 62 | self.v_proj = nn.Linear(embed_dim, embed_dim) 63 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 64 | self.num_heads = num_heads 65 | 66 | def forward(self, x): 67 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 68 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 69 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 70 | x, _ = F.multi_head_attention_forward( 71 | query=x, key=x, value=x, 72 | embed_dim_to_check=x.shape[-1], 73 | num_heads=self.num_heads, 74 | q_proj_weight=self.q_proj.weight, 75 | k_proj_weight=self.k_proj.weight, 76 | v_proj_weight=self.v_proj.weight, 77 | in_proj_weight=None, 78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 79 | bias_k=None, 80 | bias_v=None, 81 | add_zero_attn=False, 82 | dropout_p=0, 83 | out_proj_weight=self.c_proj.weight, 84 | out_proj_bias=self.c_proj.bias, 85 | use_separate_proj_weight=True, 86 | training=self.training, 87 | need_weights=False 88 | ) 89 | 90 | return x[0] 91 | 92 | 93 | class ModifiedResNet(nn.Module): 94 | """ 95 | A ResNet class that is similar to torchvision's but contains the following changes: 96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 98 | - The final pooling layer is a QKV attention instead of an average pool 99 | """ 100 | 101 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 102 | super().__init__() 103 | self.output_dim = output_dim 104 | self.input_resolution = input_resolution 105 | 106 | # the 3-layer stem 107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 108 | self.bn1 = nn.BatchNorm2d(width // 2) 109 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 110 | self.bn2 = nn.BatchNorm2d(width // 2) 111 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 112 | self.bn3 = nn.BatchNorm2d(width) 113 | self.avgpool = nn.AvgPool2d(2) 114 | self.relu = nn.ReLU(inplace=True) 115 | 116 | # residual layers 117 | self._inplanes = width # this is a *mutable* variable used during construction 118 | self.layer1 = self._make_layer(width, layers[0]) 119 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 120 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 121 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 122 | 123 | embed_dim = width * 32 # the ResNet feature dimension 124 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 125 | 126 | def _make_layer(self, planes, blocks, stride=1): 127 | layers = [Bottleneck(self._inplanes, planes, stride)] 128 | 129 | self._inplanes = planes * Bottleneck.expansion 130 | for _ in range(1, blocks): 131 | layers.append(Bottleneck(self._inplanes, planes)) 132 | 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | def stem(x): 137 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 138 | x = self.relu(bn(conv(x))) 139 | x = self.avgpool(x) 140 | return x 141 | 142 | x = x.type(self.conv1.weight.dtype) 143 | x = stem(x) 144 | x = self.layer1(x) 145 | x = self.layer2(x) 146 | x = self.layer3(x) 147 | x = self.layer4(x) 148 | x = self.attnpool(x) 149 | 150 | return x 151 | 152 | 153 | class LayerNorm(nn.LayerNorm): 154 | """Subclass torch's LayerNorm to handle fp16.""" 155 | 156 | def forward(self, x: torch.Tensor): 157 | orig_type = x.dtype 158 | ret = super().forward(x.type(torch.float32)) 159 | return ret.type(orig_type) 160 | 161 | 162 | class QuickGELU(nn.Module): 163 | def forward(self, x: torch.Tensor): 164 | return x * torch.sigmoid(1.702 * x) 165 | 166 | 167 | class ResidualAttentionBlock(nn.Module): 168 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 169 | super().__init__() 170 | 171 | self.attn = nn.MultiheadAttention(d_model, n_head) 172 | self.ln_1 = LayerNorm(d_model) 173 | self.mlp = nn.Sequential(OrderedDict([ 174 | ("c_fc", nn.Linear(d_model, d_model * 4)), 175 | ("gelu", QuickGELU()), 176 | ("c_proj", nn.Linear(d_model * 4, d_model)) 177 | ])) 178 | self.ln_2 = LayerNorm(d_model) 179 | self.attn_mask = attn_mask 180 | 181 | def attention(self, x: torch.Tensor): 182 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 183 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 184 | 185 | def forward(self, x: torch.Tensor): 186 | x = x + self.attention(self.ln_1(x)) 187 | x = x + self.mlp(self.ln_2(x)) 188 | return x 189 | 190 | 191 | class Transformer(nn.Module): 192 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 193 | super().__init__() 194 | self.width = width 195 | self.layers = layers 196 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 197 | 198 | def forward(self, x: torch.Tensor): 199 | return self.resblocks(x) 200 | 201 | 202 | class VisionTransformer(nn.Module): 203 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 204 | super().__init__() 205 | self.input_resolution = input_resolution 206 | self.output_dim = output_dim 207 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 208 | 209 | scale = width ** -0.5 210 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 211 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 212 | self.ln_pre = LayerNorm(width) 213 | 214 | self.transformer = Transformer(width, layers, heads) 215 | 216 | self.ln_post = LayerNorm(width) 217 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 218 | 219 | def forward(self, x: torch.Tensor): 220 | x = self.conv1(x) # shape = [*, width, grid, grid] 221 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 222 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 223 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 224 | x = x + self.positional_embedding.to(x.dtype) 225 | x = self.ln_pre(x) 226 | 227 | x = x.permute(1, 0, 2) # NLD -> LND 228 | x = self.transformer(x) 229 | x = x.permute(1, 0, 2) # LND -> NLD 230 | 231 | x = self.ln_post(x[:, 0, :]) 232 | 233 | if self.proj is not None: 234 | x = x @ self.proj 235 | 236 | return x 237 | 238 | 239 | class CLIP(nn.Module): 240 | def __init__(self, 241 | embed_dim: int, 242 | # vision 243 | image_resolution: int, 244 | vision_layers: Union[Tuple[int, int, int, int], int], 245 | vision_width: int, 246 | vision_patch_size: int, 247 | # text 248 | context_length: int, 249 | vocab_size: int, 250 | transformer_width: int, 251 | transformer_heads: int, 252 | transformer_layers: int 253 | ): 254 | super().__init__() 255 | 256 | self.context_length = context_length 257 | 258 | if isinstance(vision_layers, (tuple, list)): 259 | vision_heads = vision_width * 32 // 64 260 | self.visual = ModifiedResNet( 261 | layers=vision_layers, 262 | output_dim=embed_dim, 263 | heads=vision_heads, 264 | input_resolution=image_resolution, 265 | width=vision_width 266 | ) 267 | else: 268 | vision_heads = vision_width // 64 269 | self.visual = VisionTransformer( 270 | input_resolution=image_resolution, 271 | patch_size=vision_patch_size, 272 | width=vision_width, 273 | layers=vision_layers, 274 | heads=vision_heads, 275 | output_dim=embed_dim 276 | ) 277 | 278 | self.transformer = Transformer( 279 | width=transformer_width, 280 | layers=transformer_layers, 281 | heads=transformer_heads, 282 | attn_mask=self.build_attention_mask() 283 | ) 284 | 285 | self.vocab_size = vocab_size 286 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 287 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 288 | self.ln_final = LayerNorm(transformer_width) 289 | 290 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 291 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 292 | 293 | self.initialize_parameters() 294 | 295 | def initialize_parameters(self): 296 | nn.init.normal_(self.token_embedding.weight, std=0.02) 297 | nn.init.normal_(self.positional_embedding, std=0.01) 298 | 299 | if isinstance(self.visual, ModifiedResNet): 300 | if self.visual.attnpool is not None: 301 | std = self.visual.attnpool.c_proj.in_features ** -0.5 302 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 303 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 304 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 305 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 306 | 307 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 308 | for name, param in resnet_block.named_parameters(): 309 | if name.endswith("bn3.weight"): 310 | nn.init.zeros_(param) 311 | 312 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 313 | attn_std = self.transformer.width ** -0.5 314 | fc_std = (2 * self.transformer.width) ** -0.5 315 | for block in self.transformer.resblocks: 316 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 317 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 318 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 319 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 320 | 321 | if self.text_projection is not None: 322 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 323 | 324 | def build_attention_mask(self): 325 | # lazily create causal attention mask, with full attention between the vision tokens 326 | # pytorch uses additive attention mask; fill with -inf 327 | mask = torch.empty(self.context_length, self.context_length) 328 | mask.fill_(float("-inf")) 329 | mask.triu_(1) # zero out the lower diagonal 330 | return mask 331 | 332 | @property 333 | def dtype(self): 334 | return self.visual.conv1.weight.dtype 335 | 336 | def encode_image(self, image): 337 | return self.visual(image.type(self.dtype)) 338 | 339 | def encode_text(self, text): 340 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 341 | 342 | x = x + self.positional_embedding.type(self.dtype) 343 | x = x.permute(1, 0, 2) # NLD -> LND 344 | x = self.transformer(x) 345 | x = x.permute(1, 0, 2) # LND -> NLD 346 | x = self.ln_final(x).type(self.dtype) 347 | 348 | # x.shape = [batch_size, n_ctx, transformer.width] 349 | # take features from the eot embedding (eot_token is the highest number in each sequence) 350 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 351 | 352 | return x 353 | 354 | def forward(self, image, text): 355 | image_features = self.encode_image(image) 356 | text_features = self.encode_text(text) 357 | 358 | # normalized features 359 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 360 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 361 | 362 | # cosine similarity as logits 363 | logit_scale = self.logit_scale.exp() 364 | logits_per_image = logit_scale * image_features @ text_features.t() 365 | logits_per_text = logits_per_image.t() 366 | 367 | # shape = [global_batch_size, global_batch_size] 368 | return logits_per_image, logits_per_text 369 | 370 | 371 | def convert_weights(model: nn.Module): 372 | """Convert applicable model parameters to fp16""" 373 | 374 | def _convert_weights_to_fp16(l): 375 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 376 | l.weight.data = l.weight.data.half() 377 | if l.bias is not None: 378 | l.bias.data = l.bias.data.half() 379 | 380 | if isinstance(l, nn.MultiheadAttention): 381 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 382 | tensor = getattr(l, attr) 383 | if tensor is not None: 384 | tensor.data = tensor.data.half() 385 | 386 | for name in ["text_projection", "proj"]: 387 | if hasattr(l, name): 388 | attr = getattr(l, name) 389 | if attr is not None: 390 | attr.data = attr.data.half() 391 | 392 | model.apply(_convert_weights_to_fp16) 393 | 394 | 395 | def build_model(state_dict: dict): 396 | vit = "visual.proj" in state_dict 397 | 398 | if vit: 399 | vision_width = state_dict["visual.conv1.weight"].shape[0] 400 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 401 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 402 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 403 | image_resolution = vision_patch_size * grid_size 404 | else: 405 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 406 | vision_layers = tuple(counts) 407 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 408 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 409 | vision_patch_size = None 410 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 411 | image_resolution = output_width * 32 412 | 413 | embed_dim = state_dict["text_projection"].shape[1] 414 | context_length = state_dict["positional_embedding"].shape[0] 415 | vocab_size = state_dict["token_embedding.weight"].shape[0] 416 | transformer_width = state_dict["ln_final.weight"].shape[0] 417 | transformer_heads = transformer_width // 64 418 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 419 | 420 | model = CLIP( 421 | embed_dim, 422 | image_resolution, vision_layers, vision_width, vision_patch_size, 423 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 424 | ) 425 | 426 | for key in ["input_resolution", "context_length", "vocab_size"]: 427 | if key in state_dict: 428 | del state_dict[key] 429 | 430 | convert_weights(model) 431 | model.load_state_dict(state_dict) 432 | return model.eval() 433 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /configs/caltech101.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [12, 5] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 3 22 | 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'caltech101' 26 | shots: 20 27 | backbone: 'ViT-B/16' 28 | 29 | lr: 0.005 30 | augment_epoch: 10 31 | train_epoch: 20 -------------------------------------------------------------------------------- /configs/dtd.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [13, 13] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 2 22 | 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'dtd' 26 | shots: 20 27 | backbone: 'ViT-B/16' 28 | 29 | lr: 0.005 30 | augment_epoch: 10 31 | train_epoch: 20 -------------------------------------------------------------------------------- /configs/eurosat.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [20, 20] 18 | search_step: [200, 200] 19 | 20 | init_beta: 1 21 | init_alpha: 2 22 | 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'eurosat' 26 | shots: 20 27 | backbone: 'ViT-B/16' 28 | 29 | lr: 0.005 30 | augment_epoch: 10 31 | train_epoch: 100 -------------------------------------------------------------------------------- /configs/fgvc.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [30, 30] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 5 22 | 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'fgvc' 26 | shots: 20 27 | backbone: 'ViT-B/16' 28 | 29 | lr: 0.001 30 | augment_epoch: 10 31 | train_epoch: 20 -------------------------------------------------------------------------------- /configs/food101.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [50, 50] 18 | search_step: [200, 200] 19 | 20 | init_beta: 1 21 | init_alpha: 1.17 22 | 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'food101' 26 | shots: 20 27 | backbone: 'ViT-B/16' 28 | 29 | lr: 0.005 30 | augment_epoch: 10 31 | train_epoch: 20 -------------------------------------------------------------------------------- /configs/imagenet.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [7, 3] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 1.17 22 | 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'imagenet' 26 | shots: 20 27 | backbone: 'ViT-B/16' 28 | 29 | lr: 0.001 30 | augment_epoch: 10 31 | train_epoch: 20 -------------------------------------------------------------------------------- /configs/oxford_flowers.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [50, 50] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 10 22 | 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'oxford_flowers' 26 | shots: 20 27 | backbone: 'ViT-B/16' 28 | 29 | lr: 0.001 30 | augment_epoch: 10 31 | train_epoch: 20 -------------------------------------------------------------------------------- /configs/oxford_pets.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [100, 100] 18 | search_step: [200, 200] 19 | 20 | init_beta: 1 21 | init_alpha: 1.17 22 | 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'oxford_pets' 26 | shots: 20 27 | backbone: 'ViT-B/16' 28 | 29 | lr: 0.005 30 | augment_epoch: 10 31 | train_epoch: 20 -------------------------------------------------------------------------------- /configs/stanford_cars.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [20, 10] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 3 22 | 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'stanford_cars' 26 | shots: 20 27 | backbone: 'ViT-B/16' 28 | 29 | lr: 0.005 30 | augment_epoch: 10 31 | train_epoch: 20 -------------------------------------------------------------------------------- /configs/sun397.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [12, 10] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 1.17 22 | 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'sun397' 26 | shots: 20 27 | backbone: 'ViT-B/16' 28 | 29 | lr: 0.005 30 | augment_epoch: 10 31 | train_epoch: 20 -------------------------------------------------------------------------------- /configs/ucf101.yaml: -------------------------------------------------------------------------------- 1 | # ------ root_path/dataset_name ------ 2 | root_path: '' 3 | 4 | 5 | # ------ Load Cache and Features ------ 6 | load_cache: False 7 | load_pre_feat: False 8 | 9 | # load_cache: True 10 | # load_pre_feat: True 11 | 12 | 13 | # ------ Hyperparamters ------ 14 | search_hp: True 15 | # search_hp: False 16 | 17 | search_scale: [7, 3] 18 | search_step: [200, 20] 19 | 20 | init_beta: 1 21 | init_alpha: 3 22 | 23 | 24 | # ------ Basic Config ------ 25 | dataset: 'ucf101' 26 | shots: 20 27 | backbone: 'ViT-B/16' 28 | 29 | lr: 0.01 30 | augment_epoch: 10 31 | train_epoch: 20 -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .oxford_pets import OxfordPets 2 | from .eurosat import EuroSAT 3 | from .ucf101 import UCF101 4 | from .sun397 import SUN397 5 | from .caltech101 import Caltech101 6 | from .dtd import DescribableTextures 7 | from .fgvc import FGVCAircraft 8 | from .food101 import Food101 9 | from .oxford_flowers import OxfordFlowers 10 | from .stanford_cars import StanfordCars 11 | 12 | 13 | dataset_list = { 14 | "oxford_pets": OxfordPets, 15 | "eurosat": EuroSAT, 16 | "ucf101": UCF101, 17 | "sun397": SUN397, 18 | "caltech101": Caltech101, 19 | "dtd": DescribableTextures, 20 | "fgvc": FGVCAircraft, 21 | "food101": Food101, 22 | "oxford_flowers": OxfordFlowers, 23 | "stanford_cars": StanfordCars, 24 | } 25 | 26 | 27 | def build_dataset(dataset, root_path, shots): 28 | return dataset_list[dataset](root_path, shots) -------------------------------------------------------------------------------- /datasets/caltech101.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase 4 | from .oxford_pets import OxfordPets 5 | 6 | 7 | template = ['a photo of a {}.'] 8 | 9 | 10 | class Caltech101(DatasetBase): 11 | 12 | dataset_dir = 'data/caltech-101' 13 | 14 | def __init__(self, root, num_shots): 15 | self.dataset_dir = os.path.join(root, self.dataset_dir) 16 | self.image_dir = os.path.join(self.dataset_dir, '101_ObjectCategories') 17 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_Caltech101.json') 18 | 19 | self.template = template 20 | 21 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 22 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 23 | 24 | super().__init__(train_x=train, val=val, test=test) -------------------------------------------------------------------------------- /datasets/dtd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | from .utils import Datum, DatasetBase, listdir_nohidden 5 | from .oxford_pets import OxfordPets 6 | 7 | 8 | template = ['{} texture.'] 9 | 10 | 11 | class DescribableTextures(DatasetBase): 12 | 13 | dataset_dir = 'data/dtd' 14 | 15 | def __init__(self, root, num_shots): 16 | self.dataset_dir = os.path.join(root, self.dataset_dir) 17 | self.image_dir = os.path.join(self.dataset_dir, 'images') 18 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_DescribableTextures.json') 19 | 20 | self.template = template 21 | 22 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 23 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 24 | 25 | super().__init__(train_x=train, val=val, test=test) 26 | 27 | @staticmethod 28 | def read_and_split_data( 29 | image_dir, 30 | p_trn=0.5, 31 | p_val=0.2, 32 | ignored=[], 33 | new_cnames=None 34 | ): 35 | # The data are supposed to be organized into the following structure 36 | # ============= 37 | # images/ 38 | # dog/ 39 | # cat/ 40 | # horse/ 41 | # ============= 42 | categories = listdir_nohidden(image_dir) 43 | categories = [c for c in categories if c not in ignored] 44 | categories.sort() 45 | 46 | p_tst = 1 - p_trn - p_val 47 | print(f'Splitting into {p_trn:.0%} train, {p_val:.0%} val, and {p_tst:.0%} test') 48 | 49 | def _collate(ims, y, c): 50 | items = [] 51 | for im in ims: 52 | item = Datum( 53 | impath=im, 54 | label=y, # is already 0-based 55 | classname=c 56 | ) 57 | items.append(item) 58 | return items 59 | 60 | train, val, test = [], [], [] 61 | for label, category in enumerate(categories): 62 | category_dir = os.path.join(image_dir, category) 63 | images = listdir_nohidden(category_dir) 64 | images = [os.path.join(category_dir, im) for im in images] 65 | random.shuffle(images) 66 | n_total = len(images) 67 | n_train = round(n_total * p_trn) 68 | n_val = round(n_total * p_val) 69 | n_test = n_total - n_train - n_val 70 | assert n_train > 0 and n_val > 0 and n_test > 0 71 | 72 | if new_cnames is not None and category in new_cnames: 73 | category = new_cnames[category] 74 | 75 | train.extend(_collate(images[:n_train], label, category)) 76 | val.extend(_collate(images[n_train:n_train+n_val], label, category)) 77 | test.extend(_collate(images[n_train+n_val:], label, category)) 78 | 79 | return train, val, test 80 | -------------------------------------------------------------------------------- /datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 4 | from .oxford_pets import OxfordPets 5 | 6 | 7 | template = ['a centered satellite photo of {}.'] 8 | 9 | 10 | NEW_CNAMES = { 11 | 'AnnualCrop': 'Annual Crop Land', 12 | 'Forest': 'Forest', 13 | 'HerbaceousVegetation': 'Herbaceous Vegetation Land', 14 | 'Highway': 'Highway or Road', 15 | 'Industrial': 'Industrial Buildings', 16 | 'Pasture': 'Pasture Land', 17 | 'PermanentCrop': 'Permanent Crop Land', 18 | 'Residential': 'Residential Buildings', 19 | 'River': 'River', 20 | 'SeaLake': 'Sea or Lake' 21 | } 22 | 23 | 24 | class EuroSAT(DatasetBase): 25 | 26 | dataset_dir = 'data/eurosat' 27 | 28 | def __init__(self, root, num_shots): 29 | self.dataset_dir = os.path.join(root, self.dataset_dir) 30 | self.image_dir = os.path.join(self.dataset_dir, '2750') 31 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_EuroSAT.json') 32 | 33 | self.template = template 34 | 35 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 36 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 37 | 38 | super().__init__(train_x=train, val=val, test=test) 39 | 40 | def update_classname(self, dataset_old): 41 | dataset_new = [] 42 | for item_old in dataset_old: 43 | cname_old = item_old.classname 44 | cname_new = NEW_CLASSNAMES[cname_old] 45 | item_new = Datum( 46 | impath=item_old.impath, 47 | label=item_old.label, 48 | classname=cname_new 49 | ) 50 | dataset_new.append(item_new) 51 | return dataset_new 52 | -------------------------------------------------------------------------------- /datasets/fgvc.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 4 | 5 | 6 | template = ['a photo of a {}, a type of aircraft.'] 7 | 8 | 9 | class FGVCAircraft(DatasetBase): 10 | 11 | dataset_dir = 'data/fgvc_aircraft' 12 | 13 | def __init__(self, root, num_shots): 14 | 15 | self.dataset_dir = os.path.join(root, self.dataset_dir) 16 | self.image_dir = os.path.join(self.dataset_dir, 'images') 17 | 18 | self.template = template 19 | 20 | classnames = [] 21 | with open(os.path.join(self.dataset_dir, 'variants.txt'), 'r') as f: 22 | lines = f.readlines() 23 | for line in lines: 24 | classnames.append(line.strip()) 25 | cname2lab = {c: i for i, c in enumerate(classnames)} 26 | 27 | train = self.read_data(cname2lab, 'images_variant_train.txt') 28 | val = self.read_data(cname2lab, 'images_variant_val.txt') 29 | test = self.read_data(cname2lab, 'images_variant_test.txt') 30 | 31 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 32 | 33 | super().__init__(train_x=train, val=val, test=test) 34 | 35 | def read_data(self, cname2lab, split_file): 36 | filepath = os.path.join(self.dataset_dir, split_file) 37 | items = [] 38 | 39 | with open(filepath, 'r') as f: 40 | lines = f.readlines() 41 | for line in lines: 42 | line = line.strip().split(' ') 43 | imname = line[0] + '.jpg' 44 | classname = ' '.join(line[1:]) 45 | impath = os.path.join(self.image_dir, imname) 46 | label = cname2lab[classname] 47 | item = Datum( 48 | impath=impath, 49 | label=label, 50 | classname=classname 51 | ) 52 | items.append(item) 53 | 54 | return items -------------------------------------------------------------------------------- /datasets/food101.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 4 | from .oxford_pets import OxfordPets 5 | 6 | 7 | template = ['a photo of {}, a type of food.'] 8 | 9 | 10 | class Food101(DatasetBase): 11 | 12 | dataset_dir = 'data/food-101' 13 | 14 | def __init__(self, root, num_shots): 15 | self.dataset_dir = os.path.join(root, self.dataset_dir) 16 | self.image_dir = os.path.join(self.dataset_dir, 'images') 17 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_Food101.json') 18 | 19 | self.template = template 20 | 21 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 22 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 23 | 24 | super().__init__(train_x=train, val=val, test=test) -------------------------------------------------------------------------------- /datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import random 4 | from collections import defaultdict 5 | 6 | import torch 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | 10 | 11 | imagenet_classes = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", 12 | "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", 13 | "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", 14 | "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", 15 | "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", 16 | "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", 17 | "box turtle", "banded gecko", "green iguana", "Carolina anole", 18 | "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", 19 | "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", 20 | "American alligator", "triceratops", "worm snake", "ring-necked snake", 21 | "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", 22 | "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", 23 | "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", 24 | "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", 25 | "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", 26 | "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", 27 | "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", 28 | "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", 29 | "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", 30 | "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", 31 | "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", 32 | "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", 33 | "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", 34 | "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", 35 | "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", 36 | "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", 37 | "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", 38 | "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", 39 | "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", 40 | "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", 41 | "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", 42 | "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", 43 | "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", 44 | "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", 45 | "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", 46 | "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", 47 | "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", 48 | "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", 49 | "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", 50 | "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", 51 | "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", 52 | "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", 53 | "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", 54 | "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", 55 | "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", 56 | "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", 57 | "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", 58 | "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", 59 | "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", 60 | "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", 61 | "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", 62 | "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", 63 | "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", 64 | "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", 65 | "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", 66 | "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", 67 | "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", 68 | "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", 69 | "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", 70 | "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", 71 | "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", 72 | "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", 73 | "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", 74 | "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", 75 | "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", 76 | "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", 77 | "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", 78 | "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", 79 | "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", 80 | "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", 81 | "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", 82 | "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", 83 | "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", 84 | "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", 85 | "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", 86 | "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", 87 | "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", 88 | "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", 89 | "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", 90 | "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", 91 | "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", 92 | "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", 93 | "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", 94 | "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", 95 | "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", 96 | "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", 97 | "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", 98 | "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", 99 | "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", 100 | "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", 101 | "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", 102 | "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", 103 | "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", 104 | "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", 105 | "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", 106 | "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", 107 | "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", 108 | "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", 109 | "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", 110 | "freight car", "French horn", "frying pan", "fur coat", "garbage truck", 111 | "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", 112 | "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", 113 | "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", 114 | "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", 115 | "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", 116 | "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", 117 | "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", 118 | "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", 119 | "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", 120 | "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", 121 | "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", 122 | "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", 123 | "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", 124 | "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", 125 | "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", 126 | "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", 127 | "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", 128 | "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", 129 | "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", 130 | "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", 131 | "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", 132 | "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", 133 | "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", 134 | "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", 135 | "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", 136 | "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", 137 | "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", 138 | "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", 139 | "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", 140 | "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", 141 | "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", 142 | "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", 143 | "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", 144 | "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", 145 | "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", 146 | "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", 147 | "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", 148 | "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", 149 | "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", 150 | "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", 151 | "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", 152 | "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", 153 | "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", 154 | "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", 155 | "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", 156 | "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", 157 | "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", 158 | "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", 159 | "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", 160 | "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", 161 | "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", 162 | "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", 163 | "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", 164 | "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", 165 | "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", 166 | "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", 167 | "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", 168 | "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", 169 | "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", 170 | "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", 171 | "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", 172 | "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", 173 | "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", 174 | "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", 175 | "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"] 176 | 177 | imagenet_templates = ["itap of a {}.", 178 | "a bad photo of the {}.", 179 | "a origami {}.", 180 | "a photo of the large {}.", 181 | "a {} in a video game.", 182 | "art of the {}.", 183 | "a photo of the small {}."] 184 | 185 | 186 | class ImageNet(): 187 | 188 | dataset_dir = 'data/imagenet' 189 | 190 | def __init__(self, root, num_shots, preprocess): 191 | 192 | self.dataset_dir = os.path.join(root, self.dataset_dir) 193 | self.image_dir = os.path.join(self.dataset_dir, 'images') 194 | 195 | train_preprocess = transforms.Compose([ 196 | transforms.RandomResizedCrop(size=224, scale=(0.5, 1), interpolation=transforms.InterpolationMode.BICUBIC), 197 | transforms.RandomHorizontalFlip(p=0.5), 198 | transforms.ToTensor(), 199 | transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) 200 | ]) 201 | test_preprocess = preprocess 202 | 203 | self.train = torchvision.datasets.ImageNet(self.image_dir, split='train', transform=train_preprocess) 204 | self.val = torchvision.datasets.ImageNet(self.image_dir, split='val', transform=test_preprocess) 205 | self.test = torchvision.datasets.ImageNet(self.image_dir, split='val', transform=test_preprocess) 206 | 207 | self.template = imagenet_templates 208 | self.classnames = imagenet_classes 209 | 210 | split_by_label_dict = defaultdict(list) 211 | for i in range(len(self.train.imgs)): 212 | split_by_label_dict[self.train.targets[i]].append(self.train.imgs[i]) 213 | imgs = [] 214 | targets = [] 215 | 216 | for label, items in split_by_label_dict.items(): 217 | imgs = imgs + random.sample(items, num_shots) 218 | targets = targets + [label for i in range(num_shots)] 219 | self.train.imgs = imgs 220 | self.train.targets = targets 221 | self.train.samples = imgs -------------------------------------------------------------------------------- /datasets/oxford_flowers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from scipy.io import loadmat 4 | from collections import defaultdict 5 | 6 | from .oxford_pets import OxfordPets 7 | from .utils import Datum, DatasetBase, read_json 8 | 9 | 10 | template = ['a photo of a {}, a type of flower.'] 11 | 12 | 13 | class OxfordFlowers(DatasetBase): 14 | 15 | dataset_dir = 'data/oxford_flowers' 16 | 17 | def __init__(self, root, num_shots): 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, 'jpg') 20 | self.label_file = os.path.join(self.dataset_dir, 'imagelabels.mat') 21 | self.lab2cname_file = os.path.join(self.dataset_dir, 'cat_to_name.json') 22 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_OxfordFlowers.json') 23 | 24 | self.template = template 25 | 26 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 27 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 28 | 29 | super().__init__(train_x=train, val=val, test=test) 30 | 31 | def read_data(self): 32 | tracker = defaultdict(list) 33 | label_file = loadmat(self.label_file)['labels'][0] 34 | for i, label in enumerate(label_file): 35 | imname = f'image_{str(i + 1).zfill(5)}.jpg' 36 | impath = os.path.join(self.image_dir, imname) 37 | label = int(label) 38 | tracker[label].append(impath) 39 | 40 | print('Splitting data into 50% train, 20% val, and 30% test') 41 | 42 | def _collate(ims, y, c): 43 | items = [] 44 | for im in ims: 45 | item = Datum( 46 | impath=im, 47 | label=y-1, # convert to 0-based label 48 | classname=c 49 | ) 50 | items.append(item) 51 | return items 52 | 53 | lab2cname = read_json(self.lab2cname_file) 54 | train, val, test = [], [], [] 55 | for label, impaths in tracker.items(): 56 | random.shuffle(impaths) 57 | n_total = len(impaths) 58 | n_train = round(n_total * 0.5) 59 | n_val = round(n_total * 0.2) 60 | n_test = n_total - n_train - n_val 61 | assert n_train > 0 and n_val > 0 and n_test > 0 62 | cname = lab2cname[str(label)] 63 | train.extend(_collate(impaths[:n_train], label, cname)) 64 | val.extend(_collate(impaths[n_train:n_train+n_val], label, cname)) 65 | test.extend(_collate(impaths[n_train+n_val:], label, cname)) 66 | 67 | return train, val, test -------------------------------------------------------------------------------- /datasets/oxford_pets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import random 4 | from collections import defaultdict 5 | 6 | import torchvision.transforms as transforms 7 | 8 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 9 | 10 | 11 | template = ['a photo of a {}, a type of pet.'] 12 | 13 | 14 | class OxfordPets(DatasetBase): 15 | 16 | dataset_dir = 'data/oxford_pets' 17 | 18 | def __init__(self, root, num_shots): 19 | self.dataset_dir = os.path.join(root, self.dataset_dir) 20 | self.image_dir = os.path.join(self.dataset_dir, 'images') 21 | self.anno_dir = os.path.join(self.dataset_dir, 'annotations') 22 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_OxfordPets.json') 23 | 24 | self.template = template 25 | 26 | train, val, test = self.read_split(self.split_path, self.image_dir) 27 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 28 | 29 | super().__init__(train_x=train, val=val, test=test) 30 | 31 | def read_data(self, split_file): 32 | filepath = os.path.join(self.anno_dir, split_file) 33 | items = [] 34 | 35 | with open(filepath, 'r') as f: 36 | lines = f.readlines() 37 | for line in lines: 38 | line = line.strip() 39 | imname, label, species, _ = line.split(' ') 40 | breed = imname.split('_')[:-1] 41 | breed = '_'.join(breed) 42 | breed = breed.lower() 43 | imname += '.jpg' 44 | impath = os.path.join(self.image_dir, imname) 45 | label = int(label) - 1 # convert to 0-based index 46 | item = Datum( 47 | impath=impath, 48 | label=label, 49 | classname=breed 50 | ) 51 | items.append(item) 52 | 53 | return items 54 | 55 | @staticmethod 56 | def split_trainval(trainval, p_val=0.2): 57 | p_trn = 1 - p_val 58 | print(f'Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val') 59 | tracker = defaultdict(list) 60 | for idx, item in enumerate(trainval): 61 | label = item.label 62 | tracker[label].append(idx) 63 | 64 | train, val = [], [] 65 | for label, idxs in tracker.items(): 66 | n_val = round(len(idxs) * p_val) 67 | assert n_val > 0 68 | random.shuffle(idxs) 69 | for n, idx in enumerate(idxs): 70 | item = trainval[idx] 71 | if n < n_val: 72 | val.append(item) 73 | else: 74 | train.append(item) 75 | 76 | return train, val 77 | 78 | @staticmethod 79 | def save_split(train, val, test, filepath, path_prefix): 80 | def _extract(items): 81 | out = [] 82 | for item in items: 83 | impath = item.impath 84 | label = item.label 85 | classname = item.classname 86 | impath = impath.replace(path_prefix, '') 87 | if impath.startswith('/'): 88 | impath = impath[1:] 89 | out.append((impath, label, classname)) 90 | return out 91 | 92 | train = _extract(train) 93 | val = _extract(val) 94 | test = _extract(test) 95 | 96 | split = { 97 | 'train': train, 98 | 'val': val, 99 | 'test': test 100 | } 101 | 102 | write_json(split, filepath) 103 | print(f'Saved split to {filepath}') 104 | 105 | @staticmethod 106 | def read_split(filepath, path_prefix): 107 | def _convert(items): 108 | out = [] 109 | for impath, label, classname in items: 110 | impath = os.path.join(path_prefix, impath) 111 | item = Datum( 112 | impath=impath, 113 | label=int(label), 114 | classname=classname 115 | ) 116 | out.append(item) 117 | return out 118 | 119 | print(f'Reading split from {filepath}') 120 | split = read_json(filepath) 121 | train = _convert(split['train']) 122 | val = _convert(split['val']) 123 | test = _convert(split['test']) 124 | 125 | return train, val, test -------------------------------------------------------------------------------- /datasets/stanford_cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | from scipy.io import loadmat 3 | 4 | from .oxford_pets import OxfordPets 5 | from .utils import Datum, DatasetBase 6 | 7 | 8 | template = ['a photo of a {}.'] 9 | 10 | 11 | class StanfordCars(DatasetBase): 12 | 13 | dataset_dir = 'data/stanford_cars' 14 | 15 | def __init__(self, root, num_shots): 16 | self.dataset_dir = os.path.join(root, self.dataset_dir) 17 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_StanfordCars.json') 18 | 19 | self.template = template 20 | 21 | train, val, test = OxfordPets.read_split(self.split_path, self.dataset_dir) 22 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 23 | 24 | super().__init__(train_x=train, val=val, test=test) 25 | 26 | def read_data(self, image_dir, anno_file, meta_file): 27 | anno_file = loadmat(anno_file)['annotations'][0] 28 | meta_file = loadmat(meta_file)['class_names'][0] 29 | items = [] 30 | 31 | for i in range(len(anno_file)): 32 | imname = anno_file[i]['fname'][0] 33 | impath = os.path.join(self.dataset_dir, image_dir, imname) 34 | label = anno_file[i]['class'][0, 0] 35 | label = int(label) - 1 # convert to 0-based index 36 | classname = meta_file[label][0] 37 | names = classname.split(' ') 38 | year = names.pop(-1) 39 | names.insert(0, year) 40 | classname = ' '.join(names) 41 | item = Datum( 42 | impath=impath, 43 | label=label, 44 | classname=classname 45 | ) 46 | items.append(item) 47 | 48 | return items -------------------------------------------------------------------------------- /datasets/sun397.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 4 | 5 | from .oxford_pets import OxfordPets 6 | 7 | 8 | template = ['a photo of a {}.'] 9 | 10 | 11 | class SUN397(DatasetBase): 12 | 13 | dataset_dir = 'data/sun397' 14 | 15 | def __init__(self, root, num_shots): 16 | self.dataset_dir = os.path.join(root, self.dataset_dir) 17 | self.image_dir = os.path.join(self.dataset_dir, 'SUN397') 18 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_SUN397.json') 19 | 20 | self.template = template 21 | 22 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 23 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 24 | 25 | super().__init__(train_x=train, val=val, test=test) 26 | 27 | def read_data(self, cname2lab, text_file): 28 | text_file = os.path.join(self.dataset_dir, text_file) 29 | items = [] 30 | 31 | with open(text_file, 'r') as f: 32 | lines = f.readlines() 33 | for line in lines: 34 | imname = line.strip()[1:] # remove / 35 | classname = os.path.dirname(imname) 36 | label = cname2lab[classname] 37 | impath = os.path.join(self.image_dir, imname) 38 | 39 | names = classname.split('/')[1:] # remove 1st letter 40 | names = names[::-1] # put words like indoor/outdoor at first 41 | classname = ' '.join(names) 42 | 43 | item = Datum( 44 | impath=impath, 45 | label=label, 46 | classname=classname 47 | ) 48 | items.append(item) 49 | 50 | return items 51 | -------------------------------------------------------------------------------- /datasets/ucf101.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 4 | 5 | from .oxford_pets import OxfordPets 6 | 7 | 8 | template = ['a photo of a person doing {}.'] 9 | 10 | 11 | class UCF101(DatasetBase): 12 | 13 | dataset_dir = 'data/ucf101' 14 | 15 | def __init__(self, root, num_shots): 16 | self.dataset_dir = os.path.join(root, self.dataset_dir) 17 | self.image_dir = os.path.join(self.dataset_dir, 'UCF-101-midframes') 18 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_UCF101.json') 19 | 20 | self.template = template 21 | 22 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 23 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 24 | 25 | super().__init__(train_x=train, val=val, test=test) 26 | 27 | def read_data(self, cname2lab, text_file): 28 | text_file = os.path.join(self.dataset_dir, text_file) 29 | items = [] 30 | 31 | with open(text_file, 'r') as f: 32 | lines = f.readlines() 33 | for line in lines: 34 | line = line.strip().split(' ')[0] # trainlist: filename, label 35 | action, filename = line.split('/') 36 | label = cname2lab[action] 37 | 38 | elements = re.findall('[A-Z][^A-Z]*', action) 39 | renamed_action = '_'.join(elements) 40 | 41 | filename = filename.replace('.avi', '.jpg') 42 | impath = os.path.join(self.image_dir, renamed_action, filename) 43 | 44 | item = Datum( 45 | impath=impath, 46 | label=label, 47 | classname=renamed_action 48 | ) 49 | items.append(item) 50 | 51 | return items 52 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import os.path as osp 4 | import tarfile 5 | import zipfile 6 | from collections import defaultdict 7 | import gdown 8 | import json 9 | import torch 10 | from torch.utils.data import Dataset as TorchDataset 11 | import torchvision.transforms as T 12 | from PIL import Image 13 | 14 | 15 | def read_json(fpath): 16 | """Read json file from a path.""" 17 | with open(fpath, 'r') as f: 18 | obj = json.load(f) 19 | return obj 20 | 21 | 22 | def write_json(obj, fpath): 23 | """Writes to a json file.""" 24 | if not osp.exists(osp.dirname(fpath)): 25 | os.makedirs(osp.dirname(fpath)) 26 | with open(fpath, 'w') as f: 27 | json.dump(obj, f, indent=4, separators=(',', ': ')) 28 | 29 | 30 | def read_image(path): 31 | """Read image from path using ``PIL.Image``. 32 | 33 | Args: 34 | path (str): path to an image. 35 | 36 | Returns: 37 | PIL image 38 | """ 39 | if not osp.exists(path): 40 | raise IOError('No file exists at {}'.format(path)) 41 | 42 | while True: 43 | try: 44 | img = Image.open(path).convert('RGB') 45 | return img 46 | except IOError: 47 | print( 48 | 'Cannot read image from {}, ' 49 | 'probably due to heavy IO. Will re-try'.format(path) 50 | ) 51 | 52 | 53 | def listdir_nohidden(path, sort=False): 54 | """List non-hidden items in a directory. 55 | 56 | Args: 57 | path (str): directory path. 58 | sort (bool): sort the items. 59 | """ 60 | items = [f for f in os.listdir(path) if not f.startswith('.') and 'sh' not in f] 61 | if sort: 62 | items.sort() 63 | return items 64 | 65 | 66 | class Datum: 67 | """Data instance which defines the basic attributes. 68 | 69 | Args: 70 | impath (str): image path. 71 | label (int): class label. 72 | domain (int): domain label. 73 | classname (str): class name. 74 | """ 75 | 76 | def __init__(self, impath='', label=0, domain=-1, classname=''): 77 | assert isinstance(impath, str) 78 | assert isinstance(label, int) 79 | assert isinstance(domain, int) 80 | assert isinstance(classname, str) 81 | 82 | self._impath = impath 83 | self._label = label 84 | self._domain = domain 85 | self._classname = classname 86 | 87 | @property 88 | def impath(self): 89 | return self._impath 90 | 91 | @property 92 | def label(self): 93 | return self._label 94 | 95 | @property 96 | def domain(self): 97 | return self._domain 98 | 99 | @property 100 | def classname(self): 101 | return self._classname 102 | 103 | 104 | class DatasetBase: 105 | """A unified dataset class for 106 | 1) domain adaptation 107 | 2) domain generalization 108 | 3) semi-supervised learning 109 | """ 110 | dataset_dir = '' # the directory where the dataset is stored 111 | domains = [] # string names of all domains 112 | 113 | def __init__(self, train_x=None, train_u=None, val=None, test=None): 114 | self._train_x = train_x # labeled training data 115 | self._train_u = train_u # unlabeled training data (optional) 116 | self._val = val # validation data (optional) 117 | self._test = test # test data 118 | 119 | self._num_classes = self.get_num_classes(train_x) 120 | self._lab2cname, self._classnames = self.get_lab2cname(train_x) 121 | 122 | @property 123 | def train_x(self): 124 | return self._train_x 125 | 126 | @property 127 | def train_u(self): 128 | return self._train_u 129 | 130 | @property 131 | def val(self): 132 | return self._val 133 | 134 | @property 135 | def test(self): 136 | return self._test 137 | 138 | @property 139 | def lab2cname(self): 140 | return self._lab2cname 141 | 142 | @property 143 | def classnames(self): 144 | return self._classnames 145 | 146 | @property 147 | def num_classes(self): 148 | return self._num_classes 149 | 150 | def get_num_classes(self, data_source): 151 | """Count number of classes. 152 | 153 | Args: 154 | data_source (list): a list of Datum objects. 155 | """ 156 | label_set = set() 157 | for item in data_source: 158 | label_set.add(item.label) 159 | return max(label_set) + 1 160 | 161 | def get_lab2cname(self, data_source): 162 | """Get a label-to-classname mapping (dict). 163 | 164 | Args: 165 | data_source (list): a list of Datum objects. 166 | """ 167 | container = set() 168 | for item in data_source: 169 | container.add((item.label, item.classname)) 170 | mapping = {label: classname for label, classname in container} 171 | labels = list(mapping.keys()) 172 | labels.sort() 173 | classnames = [mapping[label] for label in labels] 174 | return mapping, classnames 175 | 176 | def check_input_domains(self, source_domains, target_domains): 177 | self.is_input_domain_valid(source_domains) 178 | self.is_input_domain_valid(target_domains) 179 | 180 | def is_input_domain_valid(self, input_domains): 181 | for domain in input_domains: 182 | if domain not in self.domains: 183 | raise ValueError( 184 | 'Input domain must belong to {}, ' 185 | 'but got [{}]'.format(self.domains, domain) 186 | ) 187 | 188 | def download_data(self, url, dst, from_gdrive=True): 189 | if not osp.exists(osp.dirname(dst)): 190 | os.makedirs(osp.dirname(dst)) 191 | 192 | if from_gdrive: 193 | gdown.download(url, dst, quiet=False) 194 | else: 195 | raise NotImplementedError 196 | 197 | print('Extracting file ...') 198 | 199 | try: 200 | tar = tarfile.open(dst) 201 | tar.extractall(path=osp.dirname(dst)) 202 | tar.close() 203 | except: 204 | zip_ref = zipfile.ZipFile(dst, 'r') 205 | zip_ref.extractall(osp.dirname(dst)) 206 | zip_ref.close() 207 | 208 | print('File extracted to {}'.format(osp.dirname(dst))) 209 | 210 | def generate_fewshot_dataset( 211 | self, *data_sources, num_shots=-1, repeat=True 212 | ): 213 | """Generate a few-shot dataset (typically for the training set). 214 | 215 | This function is useful when one wants to evaluate a model 216 | in a few-shot learning setting where each class only contains 217 | a few number of images. 218 | 219 | Args: 220 | data_sources: each individual is a list containing Datum objects. 221 | num_shots (int): number of instances per class to sample. 222 | repeat (bool): repeat images if needed. 223 | """ 224 | if num_shots < 1: 225 | if len(data_sources) == 1: 226 | return data_sources[0] 227 | return data_sources 228 | 229 | print(f'Creating a {num_shots}-shot dataset') 230 | 231 | output = [] 232 | 233 | for data_source in data_sources: 234 | tracker = self.split_dataset_by_label(data_source) 235 | dataset = [] 236 | 237 | for label, items in tracker.items(): 238 | if len(items) >= num_shots: 239 | sampled_items = random.sample(items, num_shots) 240 | else: 241 | if repeat: 242 | sampled_items = random.choices(items, k=num_shots) 243 | else: 244 | sampled_items = items 245 | dataset.extend(sampled_items) 246 | 247 | output.append(dataset) 248 | 249 | if len(output) == 1: 250 | return output[0] 251 | 252 | return output 253 | 254 | def split_dataset_by_label(self, data_source): 255 | """Split a dataset, i.e. a list of Datum objects, 256 | into class-specific groups stored in a dictionary. 257 | 258 | Args: 259 | data_source (list): a list of Datum objects. 260 | """ 261 | output = defaultdict(list) 262 | 263 | for item in data_source: 264 | output[item.label].append(item) 265 | 266 | return output 267 | 268 | def split_dataset_by_domain(self, data_source): 269 | """Split a dataset, i.e. a list of Datum objects, 270 | into domain-specific groups stored in a dictionary. 271 | 272 | Args: 273 | data_source (list): a list of Datum objects. 274 | """ 275 | output = defaultdict(list) 276 | 277 | for item in data_source: 278 | output[item.domain].append(item) 279 | 280 | return output 281 | 282 | 283 | class DatasetWrapper(TorchDataset): 284 | def __init__(self, data_source, input_size, transform=None, is_train=False, 285 | return_img0=False, k_tfm=1): 286 | self.data_source = data_source 287 | self.transform = transform # accept list (tuple) as input 288 | self.is_train = is_train 289 | # Augmenting an image K>1 times is only allowed during training 290 | self.k_tfm = k_tfm if is_train else 1 291 | self.return_img0 = return_img0 292 | 293 | if self.k_tfm > 1 and transform is None: 294 | raise ValueError( 295 | 'Cannot augment the image {} times ' 296 | 'because transform is None'.format(self.k_tfm) 297 | ) 298 | 299 | # Build transform that doesn't apply any data augmentation 300 | interp_mode = T.InterpolationMode.BICUBIC 301 | #interp_mode = T.BICUBIC 302 | to_tensor = [] 303 | to_tensor += [T.Resize(input_size, interpolation=interp_mode)] 304 | to_tensor += [T.ToTensor()] 305 | normalize = T.Normalize( 306 | mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) 307 | ) 308 | to_tensor += [normalize] 309 | self.to_tensor = T.Compose(to_tensor) 310 | 311 | def __len__(self): 312 | return len(self.data_source) 313 | 314 | def __getitem__(self, idx): 315 | item = self.data_source[idx] 316 | 317 | output = { 318 | 'label': item.label, 319 | 'domain': item.domain, 320 | 'impath': item.impath 321 | } 322 | 323 | img0 = read_image(item.impath) 324 | 325 | if self.transform is not None: 326 | if isinstance(self.transform, (list, tuple)): 327 | for i, tfm in enumerate(self.transform): 328 | img = self._transform_image(tfm, img0) 329 | keyname = 'img' 330 | if (i + 1) > 1: 331 | keyname += str(i + 1) 332 | output[keyname] = img 333 | else: 334 | img = self._transform_image(self.transform, img0) 335 | output['img'] = img 336 | 337 | if self.return_img0: 338 | output['img0'] = self.to_tensor(img0) 339 | 340 | return output['img'], output['label'] 341 | 342 | def _transform_image(self, tfm, img0): 343 | img_list = [] 344 | 345 | for k in range(self.k_tfm): 346 | img_list.append(tfm(img0)) 347 | 348 | img = img_list 349 | if len(img) == 1: 350 | img = img[0] 351 | 352 | return img 353 | 354 | 355 | def build_data_loader( 356 | data_source=None, 357 | batch_size=64, 358 | input_size=224, 359 | tfm=None, 360 | is_train=True, 361 | shuffle=False, 362 | dataset_wrapper=None 363 | ): 364 | 365 | if dataset_wrapper is None: 366 | dataset_wrapper = DatasetWrapper 367 | 368 | # Build data loader 369 | data_loader = torch.utils.data.DataLoader( 370 | dataset_wrapper(data_source, input_size=input_size, transform=tfm, is_train=is_train), 371 | batch_size=batch_size, 372 | num_workers=8, 373 | shuffle=shuffle, 374 | drop_last=False, 375 | pin_memory=(torch.cuda.is_available()) 376 | ) 377 | assert len(data_loader) > 0 378 | 379 | return data_loader 380 | -------------------------------------------------------------------------------- /draw_curves.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | save_dir = "main_curves" 8 | if not os.path.exists(save_dir): 9 | os.makedirs(save_dir) 10 | 11 | path = "Results.xlsx" # this is the excel file containing the results (like the one we released) 12 | file = pd.read_excel(path, sheet_name="imcls_fewshot") 13 | 14 | datasets = [ 15 | "OxfordPets", "Flowers102", "FGVCAircraft", "DTD", 16 | "EuroSAT", "StanfordCars", "Food101", "SUN397", 17 | "Caltech101", "UCF101", "ImageNet" 18 | ] 19 | 20 | shots = [1, 2, 4, 8, 16] 21 | 22 | COLORS = { 23 | "zs": "C4", 24 | "linear": "C4", 25 | "ours_v16_end": "C0", 26 | "ours_v16_mid": "C2", 27 | "ours_v16_end_csc": "C1", 28 | "ours_v16_mid_csc": "C3" 29 | } 30 | MS = 3 31 | ALPHA = 1 32 | plt.rcParams.update({"font.size": 12}) 33 | 34 | average = { 35 | "zs": 0., 36 | "ours_v16_end": np.array([0., 0., 0., 0., 0.]), 37 | "ours_v16_mid": np.array([0., 0., 0., 0., 0.]), 38 | "ours_v16_end_csc": np.array([0., 0., 0., 0., 0.]), 39 | "ours_v16_mid_csc": np.array([0., 0., 0., 0., 0.]), 40 | "linear": np.array([0., 0., 0., 0., 0.]) 41 | } 42 | 43 | for dataset in datasets: 44 | print(f"Processing {dataset} ...") 45 | 46 | zs = file[dataset][0] 47 | 48 | ours_v16_end = file[dataset][2:7] 49 | ours_v16_end = [float(num) for num in ours_v16_end] 50 | 51 | ours_v16_mid = file[dataset][7:12] 52 | ours_v16_mid = [float(num) for num in ours_v16_mid] 53 | 54 | ours_v16_end_csc = file[dataset][12:17] 55 | ours_v16_end_csc = [float(num) for num in ours_v16_end_csc] 56 | 57 | ours_v16_mid_csc = file[dataset][17:22] 58 | ours_v16_mid_csc = [float(num) for num in ours_v16_mid_csc] 59 | 60 | linear = file[dataset][22:27] 61 | linear = [float(num) for num in linear] 62 | 63 | average["zs"] += zs 64 | average["ours_v16_end"] += np.array(ours_v16_end) 65 | average["ours_v16_mid"] += np.array(ours_v16_mid) 66 | average["ours_v16_end_csc"] += np.array(ours_v16_end_csc) 67 | average["ours_v16_mid_csc"] += np.array(ours_v16_mid_csc) 68 | average["linear"] += np.array(linear) 69 | 70 | # Plot 71 | values = [zs] 72 | values += linear 73 | values += ours_v16_end 74 | values += ours_v16_mid 75 | values += ours_v16_end_csc 76 | values += ours_v16_mid_csc 77 | val_min, val_max = min(values), max(values) 78 | diff = val_max - val_min 79 | val_bot = val_min - diff*0.05 80 | val_top = val_max + diff*0.05 81 | 82 | fig, ax = plt.subplots() 83 | ax.set_facecolor("#EBEBEB") 84 | 85 | ax.set_xticks([0] + shots) 86 | ax.set_xticklabels([0] + shots) 87 | ax.set_xlabel("Number of labeled training examples per class") 88 | ax.set_ylabel("Score (%)") 89 | ax.grid(axis="x", color="white", linewidth=1) 90 | ax.axhline(zs, color="white", linewidth=1) 91 | ax.set_title(dataset) 92 | ax.set_ylim(val_bot, val_top) 93 | 94 | ax.plot( 95 | 0, zs, 96 | marker="*", 97 | markersize=MS*1.5, 98 | color=COLORS["zs"], 99 | alpha=ALPHA 100 | ) 101 | ax.plot( 102 | shots, ours_v16_end, 103 | marker="o", 104 | markersize=MS, 105 | color=COLORS["ours_v16_end"], 106 | label="CoOp", 107 | alpha=ALPHA 108 | ) 109 | ax.plot( 110 | shots, ours_v16_mid, 111 | marker="o", 112 | markersize=MS, 113 | color=COLORS["ours_v16_mid"], 114 | label="UPT", 115 | alpha=ALPHA 116 | ) 117 | ax.plot( 118 | shots, ours_v16_end_csc, 119 | marker="o", 120 | markersize=MS, 121 | color=COLORS["ours_v16_end_csc"], 122 | label="Tip-Adapter-F", 123 | alpha=ALPHA 124 | ) 125 | ax.plot( 126 | shots, ours_v16_mid_csc, 127 | marker="o", 128 | markersize=MS, 129 | color=COLORS["ours_v16_mid_csc"], 130 | label="Prompt-Adapter-F", 131 | alpha=ALPHA 132 | ) 133 | ax.plot( 134 | shots, linear, 135 | marker="o", 136 | markersize=MS, 137 | color=COLORS["linear"], 138 | label="Linear Probe CLIP", 139 | linestyle="dotted", 140 | alpha=ALPHA 141 | ) 142 | 143 | ax.text(-0.5, zs-diff*0.11, "Zero-shot\nCLIP", color=COLORS["zs"]) 144 | ax.legend(loc="lower right") 145 | 146 | fig.savefig(f"{save_dir}/{dataset}.pdf", bbox_inches="tight") 147 | 148 | 149 | # Plot 150 | average = {k: v/len(datasets) for k, v in average.items()} 151 | zs = average["zs"] 152 | linear = list(average["linear"]) 153 | ours_v16_end = list(average["ours_v16_end"]) 154 | ours_v16_mid = list(average["ours_v16_mid"]) 155 | ours_v16_end_csc = list(average["ours_v16_end_csc"]) 156 | ours_v16_mid_csc = list(average["ours_v16_mid_csc"]) 157 | 158 | values = [zs] 159 | values += linear 160 | values += ours_v16_end 161 | values += ours_v16_mid 162 | values += ours_v16_end_csc 163 | values += ours_v16_mid_csc 164 | val_min, val_max = min(values), max(values) 165 | diff = val_max - val_min 166 | val_bot = val_min - diff*0.05 167 | val_top = val_max + diff*0.05 168 | 169 | fig, ax = plt.subplots() 170 | ax.set_facecolor("#EBEBEB") 171 | 172 | ax.set_xticks([0] + shots) 173 | ax.set_xticklabels([0] + shots) 174 | ax.set_xlabel("Number of labeled training examples per class") 175 | ax.set_ylabel("Score (%)") 176 | ax.grid(axis="x", color="white", linewidth=1) 177 | ax.axhline(zs, color="white", linewidth=1) 178 | ax.set_title("Average over 11 datasets", fontweight="bold") 179 | ax.set_ylim(val_bot, val_top) 180 | 181 | ax.plot( 182 | 0, zs, 183 | marker="*", 184 | markersize=MS*1.5, 185 | color=COLORS["zs"], 186 | alpha=ALPHA 187 | ) 188 | ax.plot( 189 | shots, ours_v16_end, 190 | marker="o", 191 | markersize=MS, 192 | color=COLORS["ours_v16_end"], 193 | label="CoOp", 194 | alpha=ALPHA 195 | ) 196 | ax.plot( 197 | shots, ours_v16_mid, 198 | marker="o", 199 | markersize=MS, 200 | color=COLORS["ours_v16_mid"], 201 | label="UPT", 202 | alpha=ALPHA 203 | ) 204 | ax.plot( 205 | shots, ours_v16_end_csc, 206 | marker="o", 207 | markersize=MS, 208 | color=COLORS["ours_v16_end_csc"], 209 | label="Tip-Adapter-F", 210 | alpha=ALPHA 211 | ) 212 | ax.plot( 213 | shots, ours_v16_mid_csc, 214 | marker="o", 215 | markersize=MS, 216 | color=COLORS["ours_v16_mid_csc"], 217 | label="Prompt-Adapter-F", 218 | alpha=ALPHA 219 | ) 220 | ax.plot( 221 | shots, linear, 222 | marker="o", 223 | markersize=MS, 224 | color=COLORS["linear"], 225 | label="Linear Probe CLIP", 226 | linestyle="dotted", 227 | alpha=ALPHA 228 | ) 229 | 230 | ax.text(-0.5, zs-diff*0.11, "Zero-shot\nCLIP", color=COLORS["zs"]) 231 | ax.legend(loc="lower right") 232 | 233 | fig.savefig(f"{save_dir}/average.pdf", bbox_inches="tight") 234 | -------------------------------------------------------------------------------- /draw_curves_full.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | save_dir = "main_curves" 8 | if not os.path.exists(save_dir): 9 | os.makedirs(save_dir) 10 | 11 | path = "Results.xlsx" # this is the excel file containing the results (like the one we released) 12 | file = pd.read_excel(path, sheet_name="imcls_fewshot") 13 | 14 | datasets = [ 15 | "OxfordPets", "Flowers102", "FGVCAircraft", "DTD", 16 | "EuroSAT", "StanfordCars", "Food101", "SUN397", 17 | "Caltech101", "UCF101", "ImageNet" 18 | ] 19 | 20 | shots = [1, 2, 4, 8, 16] 21 | 22 | COLORS = { 23 | "zs": "C4", 24 | "linear": "C4", 25 | "CoOp": "C0", 26 | "UPT": "C2", 27 | "tip_adapter_f": "C1", 28 | "prompt_adapter_f": "C3", 29 | "tip_adapter": "C5", 30 | "prompt_adapter": "C6" 31 | } 32 | MS = 3 33 | ALPHA = 1 34 | plt.rcParams.update({"font.size": 12}) 35 | 36 | average = { 37 | "zs": 0., 38 | "CoOp": np.array([0., 0., 0., 0., 0.]), 39 | "UPT": np.array([0., 0., 0., 0., 0.]), 40 | "tip_adapter_f": np.array([0., 0., 0., 0., 0.]), 41 | "prompt_adapter_f": np.array([0., 0., 0., 0., 0.]), 42 | "linear": np.array([0., 0., 0., 0., 0.]), 43 | "tip_adapter": np.array([0., 0., 0., 0., 0.]), 44 | "prompt_adapter": np.array([0., 0., 0., 0., 0.]) 45 | 46 | } 47 | 48 | for dataset in datasets: 49 | print(f"Processing {dataset} ...") 50 | 51 | zs = file[dataset][0] 52 | 53 | CoOp = file[dataset][2:7] 54 | CoOp = [float(num) for num in CoOp] 55 | 56 | UPT = file[dataset][7:12] 57 | UPT = [float(num) for num in UPT] 58 | 59 | tip_adapter_f = file[dataset][12:17] 60 | tip_adapter_f = [float(num) for num in tip_adapter_f] 61 | 62 | prompt_adapter_f = file[dataset][17:22] 63 | prompt_adapter_f = [float(num) for num in prompt_adapter_f] 64 | 65 | linear = file[dataset][22:27] 66 | linear = [float(num) for num in linear] 67 | 68 | tip_adapter = file[dataset][27:32] 69 | tip_adapter = [float(num) for num in tip_adapter] 70 | 71 | prompt_adapter = file[dataset][32:37] 72 | prompt_adapter = [float(num) for num in prompt_adapter] 73 | 74 | average["zs"] += zs 75 | average["CoOp"] += np.array(CoOp) 76 | average["UPT"] += np.array(UPT) 77 | average["tip_adapter_f"] += np.array(tip_adapter_f) 78 | average["prompt_adapter_f"] += np.array(prompt_adapter_f) 79 | average["linear"] += np.array(linear) 80 | average["tip_adapter"] += np.array(tip_adapter) 81 | average["prompt_adapter"] += np.array(prompt_adapter) 82 | 83 | # Plot 84 | values = [zs] 85 | values += linear 86 | values += CoOp 87 | values += UPT 88 | values += tip_adapter_f 89 | values += prompt_adapter_f 90 | values += prompt_adapter 91 | values += tip_adapter 92 | 93 | 94 | val_min, val_max = min(values), max(values) 95 | diff = val_max - val_min 96 | val_bot = val_min - diff*0.05 97 | val_top = val_max + diff*0.05 98 | 99 | fig, ax = plt.subplots() 100 | ax.set_facecolor("#EBEBEB") 101 | 102 | ax.set_xticks([0] + shots) 103 | ax.set_xticklabels([0] + shots) 104 | ax.set_xlabel("Number of labeled training examples per class") 105 | ax.set_ylabel("Score (%)") 106 | ax.grid(axis="x", color="white", linewidth=1) 107 | ax.axhline(zs, color="white", linewidth=1) 108 | ax.set_title(dataset) 109 | ax.set_ylim(val_bot, val_top) 110 | 111 | ax.plot( 112 | 0, zs, 113 | marker="*", 114 | markersize=MS*1.5, 115 | color=COLORS["zs"], 116 | alpha=ALPHA 117 | ) 118 | ax.plot( 119 | shots, CoOp, 120 | marker="o", 121 | markersize=MS, 122 | color=COLORS["CoOp"], 123 | label="CoOp", 124 | alpha=ALPHA 125 | ) 126 | ax.plot( 127 | shots, UPT, 128 | marker="o", 129 | markersize=MS, 130 | color=COLORS["UPT"], 131 | label="UPT", 132 | alpha=ALPHA 133 | ) 134 | ax.plot( 135 | shots, tip_adapter, 136 | marker="o", 137 | markersize=MS, 138 | color=COLORS["tip_adapter"], 139 | label="Tip-Adapter", 140 | alpha=ALPHA 141 | ) 142 | 143 | ax.plot( 144 | shots, prompt_adapter, 145 | marker="o", 146 | markersize=MS, 147 | color=COLORS["prompt_adapter"], 148 | label="Prompt-Adapter", 149 | alpha=ALPHA 150 | ) 151 | ax.plot( 152 | shots, tip_adapter_f, 153 | marker="o", 154 | markersize=MS, 155 | color=COLORS["tip_adapter_f"], 156 | label="Tip-Adapter-F", 157 | alpha=ALPHA 158 | ) 159 | ax.plot( 160 | shots, prompt_adapter_f, 161 | marker="o", 162 | markersize=MS, 163 | color=COLORS["prompt_adapter_f"], 164 | label="Prompt-Adapter-F", 165 | alpha=ALPHA 166 | ) 167 | ax.plot( 168 | shots, linear, 169 | marker="o", 170 | markersize=MS, 171 | color=COLORS["linear"], 172 | label="Linear Probe CLIP", 173 | linestyle="dotted", 174 | alpha=ALPHA 175 | ) 176 | 177 | ax.text(-0.5, zs-diff*0.11, "Zero-shot\nCLIP", color=COLORS["zs"]) 178 | ax.legend(loc="lower right") 179 | 180 | fig.savefig(f"{save_dir}/{dataset}.pdf", bbox_inches="tight") 181 | 182 | 183 | # Plot 184 | average = {k: v/len(datasets) for k, v in average.items()} 185 | zs = average["zs"] 186 | linear = list(average["linear"]) 187 | CoOp = list(average["CoOp"]) 188 | UPT = list(average["UPT"]) 189 | tip_adapter_f = list(average["tip_adapter_f"]) 190 | prompt_adapter_f = list(average["prompt_adapter_f"]) 191 | tip_adapter = list(average["tip_adapter"]) 192 | prompt_adapter = list(average["prompt_adapter"]) 193 | 194 | 195 | values = [zs] 196 | values += linear 197 | values += CoOp 198 | values += UPT 199 | values += tip_adapter_f 200 | values += prompt_adapter_f 201 | values += tip_adapter 202 | values += prompt_adapter 203 | 204 | val_min, val_max = min(values), max(values) 205 | diff = val_max - val_min 206 | val_bot = val_min - diff*0.05 207 | val_top = val_max + diff*0.05 208 | 209 | fig, ax = plt.subplots() 210 | ax.set_facecolor("#EBEBEB") 211 | 212 | ax.set_xticks([0] + shots) 213 | ax.set_xticklabels([0] + shots) 214 | ax.set_xlabel("Number of labeled training examples per class") 215 | ax.set_ylabel("Score (%)") 216 | ax.grid(axis="x", color="white", linewidth=1) 217 | ax.axhline(zs, color="white", linewidth=1) 218 | ax.set_title("Average over 11 datasets", fontweight="bold") 219 | ax.set_ylim(val_bot, val_top) 220 | 221 | ax.plot( 222 | 0, zs, 223 | marker="*", 224 | markersize=MS*1.5, 225 | color=COLORS["zs"], 226 | alpha=ALPHA 227 | ) 228 | ax.plot( 229 | shots, CoOp, 230 | marker="o", 231 | markersize=MS, 232 | color=COLORS["CoOp"], 233 | label="CoOp", 234 | alpha=ALPHA 235 | ) 236 | ax.plot( 237 | shots, UPT, 238 | marker="o", 239 | markersize=MS, 240 | color=COLORS["UPT"], 241 | label="UPT", 242 | alpha=ALPHA 243 | ) 244 | ax.plot( 245 | shots, tip_adapter, 246 | marker="o", 247 | markersize=MS, 248 | color=COLORS["tip_adapter"], 249 | label="Tip-Adapter", 250 | alpha=ALPHA 251 | ) 252 | 253 | ax.plot( 254 | shots, prompt_adapter, 255 | marker="o", 256 | markersize=MS, 257 | color=COLORS["prompt_adapter"], 258 | label="Prompt-Adapter", 259 | alpha=ALPHA 260 | ) 261 | ax.plot( 262 | shots, tip_adapter_f, 263 | marker="o", 264 | markersize=MS, 265 | color=COLORS["tip_adapter_f"], 266 | label="Tip-Adapter-F", 267 | alpha=ALPHA 268 | ) 269 | ax.plot( 270 | shots, prompt_adapter_f, 271 | marker="o", 272 | markersize=MS, 273 | color=COLORS["prompt_adapter_f"], 274 | label="Prompt-Adapter-F", 275 | alpha=ALPHA 276 | ) 277 | ax.plot( 278 | shots, linear, 279 | marker="o", 280 | markersize=MS, 281 | color=COLORS["linear"], 282 | label="Linear Probe CLIP", 283 | linestyle="dotted", 284 | alpha=ALPHA 285 | ) 286 | 287 | ax.text(-0.5, zs-diff*0.11, "Zero-shot\nCLIP", color=COLORS["zs"]) 288 | ax.legend(loc="lower right") 289 | 290 | fig.savefig(f"{save_dir}/average.pdf", bbox_inches="tight") 291 | -------------------------------------------------------------------------------- /exp.log: -------------------------------------------------------------------------------- 1 | imagenet: 60.32 2 | 3 | 16 8 4 2 1 4 | Tip 62.01 61.45 60.98 60.96 60.70 5 | Tip-F 65.51 64.00 62.52 61.69 61.13 6 | 7 | 8 | sun397: 58.52 9 | 10 | 16 8 4 2 1 11 | Tip 66.85 65.62 64.15 62.70 61.30 12 | Tip-F 71.47 68.87 66.21 63.64 62.50 13 | 14 | 15 | oxford_pets: 85.83 16 | 17 | 16 8 4 2 1 18 | Tip 88.14 87.03 86.45 87.03 86.10 19 | Tip-F 89.70 88.09 87.54 87.03 87.00 20 | 21 | 22 | eurosat: 37.52 23 | 24 | 16 8 4 2 1 25 | Tip 70.54 67.95 65.32 61.68 54.38 26 | Tip-F 84.54 77.93 74.12 66.15 59.53 27 | 28 | 29 | caltech101: 85.92 30 | 31 | 16 8 4 2 1 32 | Tip 90.18 89.83 89.39 88.44 87.18 33 | Tip-F 92.86 91.44 90.56 89.74 89.33 34 | 35 | 36 | dtd: 42.20 37 | 38 | 16 8 4 2 1 39 | Tip 60.93 58.63 53.96 49.47 46.22 40 | Tip-F 66.55 62.71 57.39 53.72 49.65 41 | 42 | 43 | fgvc: 17.10 44 | 45 | 16 8 4 2 1 46 | Tip 29.76 25.59 22.41 21.21 19.05 47 | Tip-F 35.55 30.21 25.80 23.19 20.22 48 | 49 | 50 | food101: 77.32 51 | 52 | 16 8 4 2 1 53 | Tip 77.83 77.76 77.54 77.52 77.42 54 | Tip-F 79.43 78.64 78.24 77.81 77.51 55 | 56 | 57 | oxford_flowers: 66.02 58 | 59 | 16 8 4 2 1 60 | Tip 89.89 87.98 83.80 79.13 73.12 61 | Tip-F 94.80 91.51 88.83 82.30 79.98 62 | 63 | 64 | ucf101: 61.35 65 | 66 | 16 8 4 2 1 67 | Tip 70.58 68.68 66.46 64.74 62.60 68 | Tip-F 78.03 74.25 70.55 66.43 64.87 69 | 70 | 71 | stanford_cars: 55.74 72 | 73 | 16 8 4 2 1 74 | Tip 66.77 62.93 61.45 57.93 57.54 75 | Tip-F 75.74 69.25 64.57 61.50 58.86 -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pydoc import cli 3 | import random 4 | import argparse 5 | import yaml 6 | from tqdm import tqdm 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | import torch.nn as nn 11 | import torchvision.transforms as transforms 12 | 13 | from datasets import build_dataset 14 | from datasets.utils import build_data_loader 15 | import clip 16 | from utils import * 17 | 18 | 19 | def get_arguments(): 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--config', dest='config', help='settings of Tip-Adapter in yaml format') 23 | args = parser.parse_args() 24 | 25 | return args 26 | 27 | 28 | def run_tip_adapter(cfg, cache_keys, cache_values, val_features, val_labels, test_features, test_labels, clip_weights): 29 | 30 | print("\n-------- Searching hyperparameters on the val set. --------") 31 | 32 | # Zero-shot CLIP 33 | clip_logits = 100. * val_features @ clip_weights 34 | acc = cls_acc(clip_logits, val_labels) 35 | print("\n**** Zero-shot CLIP's val accuracy: {:.2f}. ****\n".format(acc)) 36 | 37 | # Tip-Adapter 38 | beta, alpha = cfg['init_beta'], cfg['init_alpha'] 39 | 40 | affinity = val_features @ cache_keys 41 | cache_logits = ((-1) * (beta - beta * affinity)).exp() @ cache_values 42 | 43 | tip_logits = clip_logits + cache_logits * alpha 44 | acc = cls_acc(tip_logits, val_labels) 45 | print("**** Prompt-Adapter's val accuracy: {:.2f}. ****\n".format(acc)) 46 | 47 | # Search Hyperparameters 48 | best_beta, best_alpha = search_hp(cfg, cache_keys, cache_values, val_features, val_labels, clip_weights) 49 | 50 | 51 | print("\n-------- Evaluating on the test set. --------") 52 | 53 | # Zero-shot CLIP 54 | clip_logits = 100. * test_features @ clip_weights 55 | acc = cls_acc(clip_logits, test_labels) 56 | print("\n**** Zero-shot CLIP's test accuracy: {:.2f}. ****\n".format(acc)) 57 | 58 | # Tip-Adapter 59 | affinity = test_features @ cache_keys 60 | cache_logits = ((-1) * (best_beta - best_beta * affinity)).exp() @ cache_values 61 | 62 | tip_logits = clip_logits + cache_logits * best_alpha 63 | acc = cls_acc(tip_logits, test_labels) 64 | print("**** Prompt-Adapter's test accuracy: {:.2f}. ****\n".format(acc)) 65 | 66 | 67 | def run_tip_adapter_F(cfg, cache_keys, cache_values, val_features, val_labels, test_features, test_labels, clip_weights, clip_model, train_loader_F): 68 | 69 | # Enable the cached keys to be learnable 70 | adapter = nn.Linear(cache_keys.shape[0], cache_keys.shape[1], bias=False).to(clip_model.dtype).cuda() 71 | adapter.weight = nn.Parameter(cache_keys.t()) 72 | 73 | optimizer = torch.optim.AdamW(adapter.parameters(), lr=cfg['lr'], eps=1e-4) 74 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, cfg['train_epoch'] * len(train_loader_F)) 75 | 76 | beta, alpha = cfg['init_beta'], cfg['init_alpha'] 77 | best_acc, best_epoch = 0.0, 0 78 | 79 | for train_idx in range(cfg['train_epoch']): 80 | # Train 81 | adapter.train() 82 | correct_samples, all_samples = 0, 0 83 | loss_list = [] 84 | print('Train Epoch: {:} / {:}'.format(train_idx, cfg['train_epoch'])) 85 | 86 | for i, (images, target) in enumerate(tqdm(train_loader_F)): 87 | images, target = images.cuda(), target.cuda() 88 | #print('images:',images.size())#([256, 3, 224, 224]) 89 | with torch.no_grad(): 90 | image_features = clip_model.encode_image(images) 91 | image_features /= image_features.norm(dim=-1, keepdim=True) 92 | 93 | affinity = adapter(image_features) 94 | cache_logits = ((-1) * (beta - beta * affinity)).exp() @ cache_values 95 | clip_logits = 100. * image_features @ clip_weights 96 | tip_logits = clip_logits + cache_logits * alpha 97 | 98 | loss = F.cross_entropy(tip_logits, target) 99 | #print('target:',target) # 100 | 101 | acc = cls_acc(tip_logits, target) 102 | correct_samples += acc / 100 * len(tip_logits) 103 | all_samples += len(tip_logits) 104 | loss_list.append(loss.item()) 105 | 106 | optimizer.zero_grad() 107 | loss.backward() 108 | optimizer.step() 109 | scheduler.step() 110 | 111 | current_lr = scheduler.get_last_lr()[0] 112 | print('LR: {:.6f}, Acc: {:.4f} ({:}/{:}), Loss: {:.4f}'.format(current_lr, correct_samples / all_samples, correct_samples, all_samples, sum(loss_list)/len(loss_list))) 113 | 114 | # Eval 115 | adapter.eval() 116 | 117 | affinity = adapter(test_features) 118 | cache_logits = ((-1) * (beta - beta * affinity)).exp() @ cache_values 119 | clip_logits = 100. * test_features @ clip_weights 120 | tip_logits = clip_logits + cache_logits * alpha 121 | acc = cls_acc(tip_logits, test_labels) 122 | 123 | print("**** Prompt-Adapter-F's test accuracy: {:.2f}. ****\n".format(acc)) 124 | if acc > best_acc: 125 | best_acc = acc 126 | best_epoch = train_idx 127 | torch.save(adapter.weight, cfg['cache_dir'] + "/best_F_" + str(cfg['shots']) + "shots.pt") 128 | 129 | adapter.weight = torch.load(cfg['cache_dir'] + "/best_F_" + str(cfg['shots']) + "shots.pt") 130 | print(f"**** After fine-tuning, Prompt-Adapter-F's best test accuracy: {best_acc:.2f}, at epoch: {best_epoch}. ****\n") 131 | 132 | print("\n-------- Searching hyperparameters on the val set. --------") 133 | 134 | # Search Hyperparameters 135 | best_beta, best_alpha = search_hp(cfg, cache_keys, cache_values, val_features, val_labels, clip_weights, adapter=adapter) 136 | 137 | #best_beta, best_alpha = search_hp(cfg, cache_keys, cache_values, test_features, test_labels, clip_weights, adapter=adapter) 138 | 139 | print("\n-------- Evaluating on the test set. --------") 140 | 141 | affinity = adapter(test_features) 142 | cache_logits = ((-1) * (best_beta - best_beta * affinity)).exp() @ cache_values 143 | 144 | tip_logits = clip_logits + cache_logits * best_alpha 145 | acc = cls_acc(tip_logits, test_labels) 146 | print("**** Prompt-Adapter-F's test accuracy: {:.2f}. ****\n".format(max(best_acc, acc))) 147 | 148 | 149 | def main(): 150 | 151 | # Load config file 152 | args = get_arguments() 153 | assert (os.path.exists(args.config)) 154 | 155 | cfg = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 156 | 157 | cache_dir = os.path.join('./caches', cfg['dataset']) 158 | os.makedirs(cache_dir, exist_ok=True) 159 | cfg['cache_dir'] = cache_dir 160 | 161 | print("\nRunning configs.") 162 | print(cfg, "\n") 163 | 164 | # CLIP 165 | clip_model, preprocess = clip.load(cfg['backbone']) 166 | clip_model.eval() 167 | 168 | # Prepare dataset 169 | random.seed(1) 170 | torch.manual_seed(1) 171 | 172 | print("Preparing dataset.") 173 | dataset = build_dataset(cfg['dataset'], cfg['root_path'], cfg['shots']) 174 | 175 | val_loader = build_data_loader(data_source=dataset.val, batch_size=64, is_train=False, tfm=preprocess, shuffle=False) 176 | test_loader = build_data_loader(data_source=dataset.test, batch_size=64, is_train=False, tfm=preprocess, shuffle=False) 177 | 178 | train_tranform = transforms.Compose([ 179 | transforms.RandomResizedCrop(size=224, scale=(0.5, 1), interpolation=transforms.InterpolationMode.BICUBIC), 180 | transforms.RandomHorizontalFlip(p=0.5), 181 | transforms.ToTensor(), 182 | transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) 183 | ]) 184 | 185 | train_loader_cache = build_data_loader(data_source=dataset.train_x, batch_size=256, tfm=train_tranform, is_train=True, shuffle=False) 186 | train_loader_F = build_data_loader(data_source=dataset.train_x, batch_size=256, tfm=train_tranform, is_train=True, shuffle=True) 187 | 188 | # Textual features 189 | print("\nGetting textual features as CLIP's classifier.") 190 | #clip_weights = clip_classifier(dataset.classnames, dataset.template, clip_model) 191 | 192 | path = str('./prompt_tensor_init/' + cfg['dataset'] +'_vit16.pt') 193 | clip_weights = torch.load(path, map_location='cuda') 194 | clip_weights = clip_weights.permute(1, 0) 195 | 196 | # Construct the cache model by few-shot training set 197 | print("\nConstructing cache model by few-shot visual features and labels.") 198 | cache_keys, cache_values = build_cache_model(cfg, clip_model, train_loader_cache) 199 | 200 | # Pre-load val features 201 | print("\nLoading visual features and labels from val set.") 202 | val_features, val_labels = pre_load_features(cfg, "val", clip_model, val_loader) 203 | # print('val_features:',val_features.size()) 204 | # print('val_labels:',val_labels.size()) 205 | 206 | # Pre-load test features 207 | print("\nLoading visual features and labels from test set.") 208 | test_features, test_labels = pre_load_features(cfg, "test", clip_model, test_loader) 209 | # print('test_features:',test_features.size()) 210 | # print('test_labels:',test_labels.size()) 211 | 212 | 213 | # ------------------------------------------ Tip-Adapter ------------------------------------------ 214 | run_tip_adapter(cfg, cache_keys, cache_values, val_features, val_labels, test_features, test_labels, clip_weights) 215 | 216 | # ------------------------------------------ Tip-Adapter-F ------------------------------------------ 217 | run_tip_adapter_F(cfg, cache_keys, cache_values, val_features, val_labels, test_features, test_labels, clip_weights, clip_model, train_loader_F) 218 | 219 | 220 | if __name__ == '__main__': 221 | main() -------------------------------------------------------------------------------- /main_curves/Caltech101.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jingchensun/Prompt-Adapter/f9cd186a01198fbae6aa1bd90a7996410ac8b66a/main_curves/Caltech101.pdf -------------------------------------------------------------------------------- /main_curves/DTD.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jingchensun/Prompt-Adapter/f9cd186a01198fbae6aa1bd90a7996410ac8b66a/main_curves/DTD.pdf -------------------------------------------------------------------------------- /main_curves/EuroSAT.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jingchensun/Prompt-Adapter/f9cd186a01198fbae6aa1bd90a7996410ac8b66a/main_curves/EuroSAT.pdf -------------------------------------------------------------------------------- /main_curves/FGVCAircraft.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jingchensun/Prompt-Adapter/f9cd186a01198fbae6aa1bd90a7996410ac8b66a/main_curves/FGVCAircraft.pdf -------------------------------------------------------------------------------- /main_curves/Flowers102.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jingchensun/Prompt-Adapter/f9cd186a01198fbae6aa1bd90a7996410ac8b66a/main_curves/Flowers102.pdf -------------------------------------------------------------------------------- /main_curves/Food101.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jingchensun/Prompt-Adapter/f9cd186a01198fbae6aa1bd90a7996410ac8b66a/main_curves/Food101.pdf -------------------------------------------------------------------------------- /main_curves/ImageNet.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jingchensun/Prompt-Adapter/f9cd186a01198fbae6aa1bd90a7996410ac8b66a/main_curves/ImageNet.pdf -------------------------------------------------------------------------------- /main_curves/OxfordPets.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jingchensun/Prompt-Adapter/f9cd186a01198fbae6aa1bd90a7996410ac8b66a/main_curves/OxfordPets.pdf -------------------------------------------------------------------------------- /main_curves/SUN397.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jingchensun/Prompt-Adapter/f9cd186a01198fbae6aa1bd90a7996410ac8b66a/main_curves/SUN397.pdf -------------------------------------------------------------------------------- /main_curves/StanfordCars.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jingchensun/Prompt-Adapter/f9cd186a01198fbae6aa1bd90a7996410ac8b66a/main_curves/StanfordCars.pdf -------------------------------------------------------------------------------- /main_curves/UCF101.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jingchensun/Prompt-Adapter/f9cd186a01198fbae6aa1bd90a7996410ac8b66a/main_curves/UCF101.pdf -------------------------------------------------------------------------------- /main_curves/average.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jingchensun/Prompt-Adapter/f9cd186a01198fbae6aa1bd90a7996410ac8b66a/main_curves/average.pdf -------------------------------------------------------------------------------- /main_curves_full/Caltech101.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jingchensun/Prompt-Adapter/f9cd186a01198fbae6aa1bd90a7996410ac8b66a/main_curves_full/Caltech101.pdf -------------------------------------------------------------------------------- /main_curves_full/DTD.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jingchensun/Prompt-Adapter/f9cd186a01198fbae6aa1bd90a7996410ac8b66a/main_curves_full/DTD.pdf -------------------------------------------------------------------------------- /main_curves_full/EuroSAT.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jingchensun/Prompt-Adapter/f9cd186a01198fbae6aa1bd90a7996410ac8b66a/main_curves_full/EuroSAT.pdf -------------------------------------------------------------------------------- /main_curves_full/FGVCAircraft.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jingchensun/Prompt-Adapter/f9cd186a01198fbae6aa1bd90a7996410ac8b66a/main_curves_full/FGVCAircraft.pdf -------------------------------------------------------------------------------- /main_curves_full/Flowers102.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jingchensun/Prompt-Adapter/f9cd186a01198fbae6aa1bd90a7996410ac8b66a/main_curves_full/Flowers102.pdf -------------------------------------------------------------------------------- /main_curves_full/Food101.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jingchensun/Prompt-Adapter/f9cd186a01198fbae6aa1bd90a7996410ac8b66a/main_curves_full/Food101.pdf -------------------------------------------------------------------------------- /main_curves_full/ImageNet.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jingchensun/Prompt-Adapter/f9cd186a01198fbae6aa1bd90a7996410ac8b66a/main_curves_full/ImageNet.pdf -------------------------------------------------------------------------------- /main_curves_full/OxfordPets.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jingchensun/Prompt-Adapter/f9cd186a01198fbae6aa1bd90a7996410ac8b66a/main_curves_full/OxfordPets.pdf -------------------------------------------------------------------------------- /main_curves_full/SUN397.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jingchensun/Prompt-Adapter/f9cd186a01198fbae6aa1bd90a7996410ac8b66a/main_curves_full/SUN397.pdf -------------------------------------------------------------------------------- /main_curves_full/StanfordCars.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jingchensun/Prompt-Adapter/f9cd186a01198fbae6aa1bd90a7996410ac8b66a/main_curves_full/StanfordCars.pdf -------------------------------------------------------------------------------- /main_curves_full/UCF101.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jingchensun/Prompt-Adapter/f9cd186a01198fbae6aa1bd90a7996410ac8b66a/main_curves_full/UCF101.pdf -------------------------------------------------------------------------------- /main_curves_full/average.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jingchensun/Prompt-Adapter/f9cd186a01198fbae6aa1bd90a7996410ac8b66a/main_curves_full/average.pdf -------------------------------------------------------------------------------- /main_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import argparse 4 | import yaml 5 | from tqdm import tqdm 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | 11 | from datasets.imagenet import ImageNet 12 | import clip 13 | from utils import * 14 | 15 | 16 | def get_arguments(): 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--config', dest='config', help='settings of Tip-Adapter in yaml format') 20 | args = parser.parse_args() 21 | 22 | return args 23 | 24 | 25 | def run_tip_adapter(cfg, cache_keys, cache_values, test_features, test_labels, clip_weights): 26 | 27 | # Zero-shot CLIP 28 | clip_logits = 100. * test_features @ clip_weights 29 | acc = cls_acc(clip_logits, test_labels) 30 | print("\n**** Zero-shot CLIP's test accuracy: {:.2f}. ****\n".format(acc)) 31 | 32 | # Tip-Adapter 33 | beta, alpha = cfg['init_beta'], cfg['init_alpha'] 34 | 35 | affinity = test_features @ cache_keys 36 | cache_logits = ((-1) * (beta - beta * affinity)).exp() @ cache_values 37 | 38 | tip_logits = clip_logits + cache_logits * alpha 39 | acc = cls_acc(tip_logits, test_labels) 40 | print("**** Prompt-Adapter's test accuracy: {:.2f}. ****\n".format(acc)) 41 | 42 | # Search Hyperparameters 43 | _ = search_hp(cfg, cache_keys, cache_values, test_features, test_labels, clip_weights) 44 | 45 | 46 | def run_tip_adapter_F(cfg, cache_keys, cache_values, test_features, test_labels, clip_weights, clip_model, train_loader_F): 47 | 48 | # Enable the cached keys to be learnable 49 | adapter = nn.Linear(cache_keys.shape[0], cache_keys.shape[1], bias=False).to(clip_model.dtype).cuda() 50 | adapter.weight = nn.Parameter(cache_keys.t()) 51 | 52 | optimizer = torch.optim.AdamW(adapter.parameters(), lr=cfg['lr'], eps=1e-4) 53 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, cfg['train_epoch'] * len(train_loader_F)) 54 | 55 | beta, alpha = cfg['init_beta'], cfg['init_alpha'] 56 | best_acc, best_epoch = 0.0, 0 57 | 58 | for train_idx in range(cfg['train_epoch']): 59 | # Train 60 | adapter.train() 61 | correct_samples, all_samples = 0, 0 62 | loss_list = [] 63 | print('Train Epoch: {:} / {:}'.format(train_idx, cfg['train_epoch'])) 64 | 65 | for i, (images, target) in enumerate(tqdm(train_loader_F)): 66 | images, target = images.cuda(), target.cuda() 67 | with torch.no_grad(): 68 | image_features = clip_model.encode_image(images) 69 | image_features /= image_features.norm(dim=-1, keepdim=True) 70 | 71 | affinity = adapter(image_features) 72 | cache_logits = ((-1) * (beta - beta * affinity)).exp() @ cache_values 73 | clip_logits = 100. * image_features @ clip_weights 74 | tip_logits = clip_logits + cache_logits * alpha 75 | 76 | loss = F.cross_entropy(tip_logits, target) 77 | 78 | acc = cls_acc(tip_logits, target) 79 | correct_samples += acc / 100 * len(tip_logits) 80 | all_samples += len(tip_logits) 81 | loss_list.append(loss.item()) 82 | 83 | optimizer.zero_grad() 84 | loss.backward() 85 | optimizer.step() 86 | scheduler.step() 87 | 88 | current_lr = scheduler.get_last_lr()[0] 89 | print('LR: {:.6f}, Acc: {:.4f} ({:}/{:}), Loss: {:.4f}'.format(current_lr, correct_samples / all_samples, correct_samples, all_samples, sum(loss_list)/len(loss_list))) 90 | 91 | # Eval 92 | adapter.eval() 93 | 94 | affinity = adapter(test_features) 95 | cache_logits = ((-1) * (beta - beta * affinity)).exp() @ cache_values 96 | clip_logits = 100. * test_features @ clip_weights 97 | tip_logits = clip_logits + cache_logits * alpha 98 | acc = cls_acc(tip_logits, test_labels) 99 | 100 | print("**** Prompt-Adapter-F's test accuracy: {:.2f}. ****\n".format(acc)) 101 | if acc > best_acc: 102 | best_acc = acc 103 | best_epoch = train_idx 104 | torch.save(adapter.weight, cfg['cache_dir'] + "/best_F_" + str(cfg['shots']) + "shots.pt") 105 | 106 | adapter.weight = torch.load(cfg['cache_dir'] + "/best_F_" + str(cfg['shots']) + "shots.pt") 107 | print(f"**** After fine-tuning, Prompt-Adapter-F's best test accuracy: {best_acc:.2f}, at epoch: {best_epoch}. ****\n") 108 | 109 | # Search Hyperparameters 110 | _ = search_hp(cfg, affinity, cache_values, test_features, test_labels, clip_weights, adapter=adapter) 111 | 112 | 113 | def main(): 114 | 115 | # Load config file 116 | args = get_arguments() 117 | assert (os.path.exists(args.config)) 118 | 119 | cfg = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 120 | 121 | cache_dir = os.path.join('./caches2', cfg['dataset']) 122 | os.makedirs(cache_dir, exist_ok=True) 123 | cfg['cache_dir'] = cache_dir 124 | 125 | print("\nRunning configs.") 126 | print(cfg, "\n") 127 | 128 | # CLIP 129 | clip_model, preprocess = clip.load(cfg['backbone']) 130 | clip_model.eval() 131 | 132 | # ImageNet dataset 133 | random.seed(1) 134 | torch.manual_seed(1) 135 | 136 | print("Preparing ImageNet dataset.") 137 | imagenet = ImageNet(cfg['root_path'], cfg['shots'], preprocess) 138 | 139 | # test_loader = torch.utils.data.DataLoader(imagenet.test, batch_size=64, num_workers=8, shuffle=False) 140 | # train_loader_cache = torch.utils.data.DataLoader(imagenet.train, batch_size=256, num_workers=8, shuffle=False) 141 | # train_loader_F = torch.utils.data.DataLoader(imagenet.train, batch_size=256, num_workers=8, shuffle=True) 142 | 143 | test_loader = torch.utils.data.DataLoader(imagenet.test, batch_size=100, num_workers=8, shuffle=False) 144 | train_loader_cache = torch.utils.data.DataLoader(imagenet.train, batch_size=64, num_workers=8, shuffle=False) 145 | train_loader_F = torch.utils.data.DataLoader(imagenet.train, batch_size=64, num_workers=8, shuffle=True) 146 | 147 | # Textual features 148 | print("Getting textual features as CLIP's classifier.") 149 | #clip_weights = clip_classifier(imagenet.classnames, imagenet.template, clip_model) 150 | 151 | path = str('./prompt_tensor_init/' + cfg['dataset'] +'_vit16.pt') 152 | #print('path:', path) 153 | clip_weights = torch.load(path, map_location='cuda') 154 | clip_weights = clip_weights.permute(1, 0) 155 | 156 | # Construct the cache model by few-shot training set 157 | print("\nConstructing cache model by few-shot visual features and labels.") 158 | cache_keys, cache_values = build_cache_model(cfg, clip_model, train_loader_cache) 159 | 160 | # Pre-load test features 161 | print("\nLoading visual features and labels from test set.") 162 | test_features, test_labels = pre_load_features(cfg, "test", clip_model, test_loader) 163 | print('test_featurestest_features',test_features.size()) #[50000, 1024]) 164 | 165 | # ------------------------------------------ Tip-Adapter ------------------------------------------ 166 | run_tip_adapter(cfg, cache_keys, cache_values, test_features, test_labels, clip_weights) 167 | 168 | # ------------------------------------------ Tip-Adapter-F ------------------------------------------ 169 | run_tip_adapter_F(cfg, cache_keys, cache_values, test_features, test_labels, clip_weights, clip_model, train_loader_F) 170 | 171 | 172 | if __name__ == '__main__': 173 | main() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | flake8==3.7.9 2 | yapf==0.29.0 3 | isort==4.3.21 4 | yacs 5 | gdown 6 | tb-nightly 7 | future 8 | scipy 9 | scikit-learn 10 | tqdm 11 | ftfy 12 | regex 13 | wilds==1.2.2 14 | tabulate -------------------------------------------------------------------------------- /scripts/data.sh: -------------------------------------------------------------------------------- 1 | pip install gdown 2 | 3 | DATA=coop_data/ 4 | mkdir $DATA 5 | # DATA=/work/tianjun/few-shot-learning/prompt-moe/CoOp/data/ 6 | cd $DATA 7 | 8 | # pip install gdown 9 | 10 | mkdir -p caltech-101 11 | cd caltech-101 12 | # wget http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz 13 | wget https://data.caltech.edu/records/mzrjq-6wc02/files/caltech-101.zip 14 | unzip caltech-101.zip 15 | mv caltech-101/101_ObjectCategories.tar.gz . 16 | gdown 1hyarUivQE36mY6jSomru6Fjd-JzwcCzN 17 | tar -xvf 101_ObjectCategories.tar.gz 18 | cd $DATA 19 | 20 | mkdir -p oxford_pets 21 | cd oxford_pets 22 | wget https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz 23 | wget https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz 24 | gdown 1501r8Ber4nNKvmlFVQZ8SeUHTcdTTEqs 25 | tar -xvf images.tar.gz 26 | tar -xvf annotations.tar.gz 27 | cd $DATA 28 | 29 | mkdir -p stanford_cars 30 | cd stanford_cars 31 | wget http://ai.stanford.edu/~jkrause/car196/cars_train.tgz 32 | wget http://ai.stanford.edu/~jkrause/car196/cars_test.tgz 33 | wget https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz 34 | wget http://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat 35 | gdown 1ObCFbaAgVu0I-k_Au-gIUcefirdAuizT 36 | tar -xvf cars_train.tgz 37 | tar -xvf cars_test.tgz 38 | tar -xvf car_devkit.tgz 39 | cd $DATA 40 | 41 | mkdir -p oxford_flowers 42 | cd oxford_flowers 43 | wget https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz 44 | wget https://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat 45 | gdown 1AkcxCXeK_RCGCEC_GvmWxjcjaNhu-at0 46 | gdown 1Pp0sRXzZFZq15zVOzKjKBu4A9i01nozT 47 | tar -xvf 102flowers.tgz 48 | cd $DATA 49 | 50 | 51 | wget http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz 52 | tar -xvf food-101.tar.gz 53 | cd food-101 54 | gdown 1QK0tGi096I0Ba6kggatX1ee6dJFIcEJl 55 | cd $DATA 56 | 57 | wget https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz 58 | tar -xvf fgvc-aircraft-2013b.tar.gz 59 | mv fgvc-aircraft-2013b/data fgvc_aircraft 60 | cd $DATA 61 | 62 | mkdir -p sun397 63 | cd sun397 64 | wget http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz 65 | wget https://vision.princeton.edu/projects/2010/SUN/download/Partitions.zip 66 | gdown 1y2RD81BYuiyvebdN-JymPfyWYcd8_MUq 67 | tar -xvf SUN397.tar.gz 68 | unzip Partitions.zip 69 | cd $DATA 70 | 71 | 72 | wget https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz 73 | tar -xvf dtd-r1.0.1.tar.gz 74 | cd dtd 75 | gdown 1u3_QfB467jqHgNXC00UIzbLZRQCg2S7x 76 | cd $DATA 77 | 78 | mkdir -p eurosat 79 | cd eurosat 80 | wget http://madm.dfki.de/files/sentinel/EuroSAT.zip 81 | unzip EuroSAT.zip 82 | gdown 1Ip7yaCWFi0eaOFUGga0lUdVi_DDQth1o 83 | cd $DATA 84 | 85 | mkdir -p ucf101 86 | cd ucf101 87 | gdown 10Jqome3vtUA2keJkNanAiFpgbyC9Hc2O 88 | unzip UCF-101-midframes.zip 89 | gdown 1I0S0q91hJfsV9Gf4xDIjgDq4AqBNJb1y 90 | cd $DATA 91 | 92 | 93 | # mkdir -p imagenet/images 94 | # cd imagenet/images 95 | # ##1. Download the data 96 | # #get ILSVRC2012_img_val.tar (about 6.3 GB). MD5: 29b22e2961454d5413ddabcf34fc5622 97 | # wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar 98 | # #get ILSVRC2012_img_train.tar (about 138 GB). MD5: 1d675b47d978889d74fa0da5fadfb00e 99 | # wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar 100 | 101 | # ## 2. Extract the training data: 102 | # mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train 103 | # tar -xvf ILSVRC2012_img_train.tar && mv ILSVRC2012_img_train.tar ../ 104 | # find . -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xvf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done 105 | # cd .. 106 | 107 | # ## 3. Extract the validation data and move images to subfolders: 108 | # mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xvf ILSVRC2012_img_val.tar && mv ILSVRC2012_img_val.tar ../ 109 | # wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash 110 | # cd $DATA 111 | 112 | # ## 4. Move the classname.txt to /imagenet/images 113 | # cd ../scripts/ 114 | # mv classnames.txt ../coop_data/imagenet/images -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | 7 | import clip 8 | 9 | 10 | def cls_acc(output, target, topk=1): 11 | pred = output.topk(topk, 1, True, True)[1].t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | acc = float(correct[: topk].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) 14 | acc = 100 * acc / target.shape[0] 15 | return acc 16 | 17 | 18 | def clip_classifier(classnames, template, clip_model): 19 | with torch.no_grad(): 20 | clip_weights = [] 21 | 22 | for classname in classnames: 23 | # Tokenize the prompts 24 | classname = classname.replace('_', ' ') 25 | texts = [t.format(classname) for t in template] 26 | #print('texts:',texts) #['a photo of a face.'] 27 | texts = clip.tokenize(texts).cuda() 28 | #print('texts:',texts.size()) #torch.Size([1, 77]) 29 | # prompt ensemble for ImageNet 30 | class_embeddings = clip_model.encode_text(texts) 31 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) 32 | class_embedding = class_embeddings.mean(dim=0) 33 | class_embedding /= class_embedding.norm() 34 | clip_weights.append(class_embedding) 35 | print('clip_weights',clip_weights[0].size()) #torch.Size([1024]) 36 | clip_weights = torch.stack(clip_weights, dim=1).cuda() 37 | print('clip_weights',clip_weights.size()) #torch.Size([1024, 100]) 38 | return clip_weights 39 | 40 | 41 | def build_cache_model(cfg, clip_model, train_loader_cache): 42 | 43 | if cfg['load_cache'] == False: 44 | cache_keys = [] 45 | cache_values = [] 46 | 47 | with torch.no_grad(): 48 | # Data augmentation for the cache model 49 | for augment_idx in range(cfg['augment_epoch']): 50 | train_features = [] 51 | 52 | print('Augment Epoch: {:} / {:}'.format(augment_idx, cfg['augment_epoch'])) 53 | for i, (images, target) in enumerate(tqdm(train_loader_cache)): 54 | images = images.cuda() 55 | image_features = clip_model.encode_image(images) 56 | #print('image_features :',image_features.size()) #torch.Size([256, 512]) 57 | train_features.append(image_features) 58 | if augment_idx == 0: 59 | target = target.cuda() 60 | #print(target.size()) #torch.Size([256]) torch.Size([256]) torch.Size([240]) 61 | cache_values.append(target) 62 | #print('train_features:',len(train_features)) #3 63 | cache_keys.append(torch.cat(train_features, dim=0).unsqueeze(0)) 64 | 65 | #print('cache_keys1:',len(cache_keys)) #10 66 | #print('cache_keys1:',torch.cat(cache_keys, dim=0).size()) #torch.Size([10, 752, 512]) 67 | cache_keys = torch.cat(cache_keys, dim=0).mean(dim=0) 68 | #print('cache_keys2:',cache_keys.size()) #torch.Size([752, 512]) 69 | cache_keys /= cache_keys.norm(dim=-1, keepdim=True) 70 | #print('cache_keys3:',cache_keys.size()) #torch.Size([752, 512]) 71 | cache_keys = cache_keys.permute(1, 0) 72 | #print('cache_keys4:',cache_keys.size()) #torch.Size([512, 752]) 73 | #print('cache_keys:',cache_keys.size()) 74 | #print('cache_values1:',torch.cat(cache_values, dim=0))#rch.Size([752]) 75 | cache_values = F.one_hot(torch.cat(cache_values, dim=0)).half() 76 | #cache_values = torch.cat(cache_values, dim=0) 77 | #print('cache_values2:',cache_values) #torch.Size([752, 47]) 78 | 79 | #assert(1==0) 80 | # torch.save(cache_keys, cfg['cache_dir'] + '/keys_' + str(cfg['shots']) + "shots.pt") 81 | # torch.save(cache_values, cfg['cache_dir'] + '/values_' + str(cfg['shots']) + "shots.pt") 82 | torch.save(cache_keys, cfg['cache_dir'] + '/keys_' + str(cfg['shots']) + "shots.pt") 83 | torch.save(cache_values, cfg['cache_dir'] + '/values_' + str(cfg['shots']) + "shots.pt") 84 | 85 | else: 86 | cache_keys = torch.load(cfg['cache_dir'] + '/keys_' + str(cfg['shots']) + "shots.pt") 87 | cache_values = torch.load(cfg['cache_dir'] + '/values_' + str(cfg['shots']) + "shots.pt") 88 | 89 | return cache_keys, cache_values 90 | 91 | 92 | def pre_load_features(cfg, split, clip_model, loader): 93 | 94 | if cfg['load_pre_feat'] == False: 95 | features, labels = [], [] 96 | 97 | with torch.no_grad(): 98 | for i, (images, target) in enumerate(tqdm(loader)): 99 | images, target = images.cuda(), target.cuda() 100 | image_features = clip_model.encode_image(images) 101 | image_features /= image_features.norm(dim=-1, keepdim=True) 102 | features.append(image_features) 103 | labels.append(target) 104 | 105 | features, labels = torch.cat(features), torch.cat(labels) 106 | 107 | torch.save(features, cfg['cache_dir'] + "/" + split + "_f.pt") 108 | torch.save(labels, cfg['cache_dir'] + "/" + split + "_l.pt") 109 | 110 | else: 111 | features = torch.load(cfg['cache_dir'] + "/" + split + "_f.pt") 112 | labels = torch.load(cfg['cache_dir'] + "/" + split + "_l.pt") 113 | 114 | return features, labels 115 | 116 | 117 | def search_hp(cfg, cache_keys, cache_values, features, labels, clip_weights, adapter=None): 118 | 119 | if cfg['search_hp'] == True: 120 | 121 | beta_list = [i * (cfg['search_scale'][0] - 0.1) / cfg['search_step'][0] + 0.1 for i in range(cfg['search_step'][0])] 122 | print('beta_list',beta_list) 123 | alpha_list = [i * (cfg['search_scale'][1] - 0.1) / cfg['search_step'][1] + 0.1 for i in range(cfg['search_step'][1])] 124 | print('alpha_list',alpha_list) 125 | 126 | best_acc = 0 127 | best_beta, best_alpha = 0, 0 128 | 129 | for beta in beta_list: 130 | for alpha in alpha_list: 131 | if adapter: 132 | affinity = adapter(features) 133 | else: 134 | affinity = features @ cache_keys 135 | 136 | cache_logits = ((-1) * (beta - beta * affinity)).exp() @ cache_values 137 | clip_logits = 100. * features @ clip_weights 138 | tip_logits = clip_logits + cache_logits * alpha 139 | acc = cls_acc(tip_logits, labels) 140 | 141 | if acc > best_acc: 142 | print("New best setting, beta: {:.2f}, alpha: {:.2f}; accuracy: {:.2f}".format(beta, alpha, acc)) 143 | best_acc = acc 144 | best_beta = beta 145 | best_alpha = alpha 146 | 147 | print("\nAfter searching, the best accuarcy: {:.2f}.\n".format(best_acc)) 148 | 149 | return best_beta, best_alpha 150 | --------------------------------------------------------------------------------