├── DATASETS.md ├── LICENSE ├── README.md ├── VPT.md ├── assets └── vpt.png ├── clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── clip.py ├── model.py └── simple_tokenizer.py ├── configs ├── datasets │ ├── caltech101.yaml │ ├── dtd.yaml │ ├── eurosat.yaml │ ├── fgvc_aircraft.yaml │ ├── food101.yaml │ ├── imagenet.yaml │ ├── imagenet_a.yaml │ ├── imagenet_r.yaml │ ├── imagenet_sketch.yaml │ ├── imagenetv2.yaml │ ├── oxford_flowers.yaml │ ├── oxford_pets.yaml │ ├── stanford_cars.yaml │ ├── sun397.yaml │ └── ucf101.yaml └── trainers │ ├── CoOp │ ├── rn101.yaml │ ├── rn101_ep50.yaml │ ├── rn50.yaml │ ├── rn50_ctxv1.yaml │ ├── rn50_ep100.yaml │ ├── rn50_ep50.yaml │ ├── rn50_ep50_ctxv1.yaml │ ├── rn50_val.yaml │ ├── vit_b16.yaml │ ├── vit_b16_ctxv1.yaml │ ├── vit_b16_ep100.yaml │ ├── vit_b16_ep100_ctxv1.yaml │ ├── vit_b16_ep50.yaml │ ├── vit_b16_ep50_ctxv1.yaml │ ├── vit_b32.yaml │ └── vit_b32_ep50.yaml │ └── VPT │ ├── vit_b16_c16_ep10_batch1.yaml │ ├── vit_b16_c4_ep10_batch1.yaml │ ├── vit_b16_c4_ep10_batch1_ctxv1.yaml │ └── vit_b16_c8_ep10_batch1.yaml ├── datasets ├── caltech101.py ├── dtd.py ├── eurosat.py ├── fgvc_aircraft.py ├── food101.py ├── imagenet.py ├── imagenet_a.py ├── imagenet_r.py ├── imagenet_sketch.py ├── imagenetv2.py ├── oxford_flowers.py ├── oxford_pets.py ├── stanford_cars.py ├── sun397.py └── ucf101.py ├── draw_curves.py ├── eval.sh ├── interpret_prompt.py ├── lpclip ├── README.md ├── feat_extractor.py ├── feat_extractor.sh ├── linear_probe.py └── linear_probe.sh ├── parse_test_res.py ├── requirements.txt ├── scripts ├── eval.sh ├── main.sh ├── vpt │ ├── base2new_test.sh │ ├── base2new_test_ablation.sh │ ├── base2new_test_arch.sh │ ├── base2new_test_imagenet.sh │ ├── base2new_test_imagenet_copy.sh │ ├── base2new_train.sh │ ├── base2new_train_ablation.sh │ ├── base2new_train_arch.sh │ ├── base2new_train_imagenet.sh │ ├── ctx_ablations.py │ ├── eval_cross_dataset.py │ ├── shots_ablation.py │ ├── train_cross_dataset.py │ ├── train_eval_caltech.py │ ├── train_eval_cars.py │ ├── train_eval_dtd.py │ ├── train_eval_eurosat.py │ ├── train_eval_fgvc.py │ ├── train_eval_flowers.py │ ├── train_eval_food101.py │ ├── train_eval_imagenet.py │ ├── train_eval_pets.py │ ├── train_eval_sun.py │ ├── train_eval_ucf101.py │ ├── vision_backbone_ablations.py │ ├── xd_test.sh │ └── xd_train.sh └── zeroshot.sh ├── train.py └── trainers ├── coop.py ├── imagenet_templates.py ├── vpt.py └── zsclip.py /DATASETS.md: -------------------------------------------------------------------------------- 1 | # How to install datasets 2 | 3 | We suggest putting all datasets under the same folder (say `$DATA`) to ease management and following the instructions below to organize datasets to avoid modifying the source code. The file structure looks like 4 | 5 | ``` 6 | $DATA/ 7 | |–– imagenet/ 8 | |–– caltech-101/ 9 | |–– oxford_pets/ 10 | |–– stanford_cars/ 11 | ``` 12 | 13 | If you have some datasets already installed somewhere else, you can create symbolic links in `$DATA/dataset_name` that point to the original data to avoid duplicate download. 14 | 15 | Datasets list: 16 | - [ImageNet](#imagenet) 17 | - [Caltech101](#caltech101) 18 | - [OxfordPets](#oxfordpets) 19 | - [StanfordCars](#stanfordcars) 20 | - [Flowers102](#flowers102) 21 | - [Food101](#food101) 22 | - [FGVCAircraft](#fgvcaircraft) 23 | - [SUN397](#sun397) 24 | - [DTD](#dtd) 25 | - [EuroSAT](#eurosat) 26 | - [UCF101](#ucf101) 27 | - [ImageNetV2](#imagenetv2) 28 | - [ImageNet-Sketch](#imagenet-sketch) 29 | - [ImageNet-A](#imagenet-a) 30 | - [ImageNet-R](#imagenet-r) 31 | 32 | The instructions to prepare each dataset are detailed below. To ensure reproducibility and fair comparison for future work, we provide fixed train/val/test splits for all datasets except ImageNet where the validation set is used as test set. The fixed splits are either from the original datasets (if available) or created by us. 33 | 34 | ### ImageNet 35 | - Create a folder named `imagenet/` under `$DATA`. 36 | - Create `images/` under `imagenet/`. 37 | - Download the dataset from the [official website](https://image-net.org/index.php) and extract the training and validation sets to `$DATA/imagenet/images`. The directory structure should look like 38 | ``` 39 | imagenet/ 40 | |–– images/ 41 | | |–– train/ # contains 1,000 folders like n01440764, n01443537, etc. 42 | | |–– val/ 43 | ``` 44 | - If you had downloaded the ImageNet dataset before, you can create symbolic links to map the training and validation sets to `$DATA/imagenet/images`. 45 | - Download the `classnames.txt` to `$DATA/imagenet/` from this [link](https://drive.google.com/file/d/1-61f_ol79pViBFDG_IDlUQSwoLcn2XXF/view?usp=sharing). The class names are copied from [CLIP](https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb). 46 | 47 | ### Caltech101 48 | - Create a folder named `caltech-101/` under `$DATA`. 49 | - Download `101_ObjectCategories.tar.gz` from http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz and extract the file under `$DATA/caltech-101`. 50 | - Download `split_zhou_Caltech101.json` from this [link](https://drive.google.com/file/d/1hyarUivQE36mY6jSomru6Fjd-JzwcCzN/view?usp=sharing) and put it under `$DATA/caltech-101`. 51 | 52 | The directory structure should look like 53 | ``` 54 | caltech-101/ 55 | |–– 101_ObjectCategories/ 56 | |–– split_zhou_Caltech101.json 57 | ``` 58 | 59 | ### OxfordPets 60 | - Create a folder named `oxford_pets/` under `$DATA`. 61 | - Download the images from https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz. 62 | - Download the annotations from https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz. 63 | - Download `split_zhou_OxfordPets.json` from this [link](https://drive.google.com/file/d/1501r8Ber4nNKvmlFVQZ8SeUHTcdTTEqs/view?usp=sharing). 64 | 65 | The directory structure should look like 66 | ``` 67 | oxford_pets/ 68 | |–– images/ 69 | |–– annotations/ 70 | |–– split_zhou_OxfordPets.json 71 | ``` 72 | 73 | ### StanfordCars 74 | - Create a folder named `stanford_cars/` under `$DATA`. 75 | - Download the train images http://ai.stanford.edu/~jkrause/car196/cars_train.tgz. 76 | - Download the test images http://ai.stanford.edu/~jkrause/car196/cars_test.tgz. 77 | - Download the train labels https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz. 78 | - Download the test labels http://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat. 79 | - Download `split_zhou_StanfordCars.json` from this [link](https://drive.google.com/file/d/1ObCFbaAgVu0I-k_Au-gIUcefirdAuizT/view?usp=sharing). 80 | 81 | The directory structure should look like 82 | ``` 83 | stanford_cars/ 84 | |–– cars_test\ 85 | |–– cars_test_annos_withlabels.mat 86 | |–– cars_train\ 87 | |–– devkit\ 88 | |–– split_zhou_StanfordCars.json 89 | ``` 90 | 91 | ### Flowers102 92 | - Create a folder named `oxford_flowers/` under `$DATA`. 93 | - Download the images and labels from https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz and https://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat respectively. 94 | - Download `cat_to_name.json` from [here](https://drive.google.com/file/d/1AkcxCXeK_RCGCEC_GvmWxjcjaNhu-at0/view?usp=sharing). 95 | - Download `split_zhou_OxfordFlowers.json` from [here](https://drive.google.com/file/d/1Pp0sRXzZFZq15zVOzKjKBu4A9i01nozT/view?usp=sharing). 96 | 97 | The directory structure should look like 98 | ``` 99 | oxford_flowers/ 100 | |–– cat_to_name.json 101 | |–– imagelabels.mat 102 | |–– jpg/ 103 | |–– split_zhou_OxfordFlowers.json 104 | ``` 105 | 106 | ### Food101 107 | - Download the dataset from https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/ and extract the file `food-101.tar.gz` under `$DATA`, resulting in a folder named `$DATA/food-101/`. 108 | - Download `split_zhou_Food101.json` from [here](https://drive.google.com/file/d/1QK0tGi096I0Ba6kggatX1ee6dJFIcEJl/view?usp=sharing). 109 | 110 | The directory structure should look like 111 | ``` 112 | food-101/ 113 | |–– images/ 114 | |–– license_agreement.txt 115 | |–– meta/ 116 | |–– README.txt 117 | |–– split_zhou_Food101.json 118 | ``` 119 | 120 | ### FGVCAircraft 121 | - Download the data from https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz. 122 | - Extract `fgvc-aircraft-2013b.tar.gz` and keep only `data/`. 123 | - Move `data/` to `$DATA` and rename the folder to `fgvc_aircraft/`. 124 | 125 | The directory structure should look like 126 | ``` 127 | fgvc_aircraft/ 128 | |–– images/ 129 | |–– ... # a bunch of .txt files 130 | ``` 131 | 132 | ### SUN397 133 | - Create a folder named `sun397/` under `$DATA`. 134 | - Download the images http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz. 135 | - Download the partitions https://vision.princeton.edu/projects/2010/SUN/download/Partitions.zip. 136 | - Extract these files under `$DATA/sun397/`. 137 | - Download `split_zhou_SUN397.json` from this [link](https://drive.google.com/file/d/1y2RD81BYuiyvebdN-JymPfyWYcd8_MUq/view?usp=sharing). 138 | 139 | The directory structure should look like 140 | ``` 141 | sun397/ 142 | |–– SUN397/ 143 | |–– split_zhou_SUN397.json 144 | |–– ... # a bunch of .txt files 145 | ``` 146 | 147 | ### DTD 148 | - Download the dataset from https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz and extract it to `$DATA`. This should lead to `$DATA/dtd/`. 149 | - Download `split_zhou_DescribableTextures.json` from this [link](https://drive.google.com/file/d/1u3_QfB467jqHgNXC00UIzbLZRQCg2S7x/view?usp=sharing). 150 | 151 | The directory structure should look like 152 | ``` 153 | dtd/ 154 | |–– images/ 155 | |–– imdb/ 156 | |–– labels/ 157 | |–– split_zhou_DescribableTextures.json 158 | ``` 159 | 160 | ### EuroSAT 161 | - Create a folder named `eurosat/` under `$DATA`. 162 | - Download the dataset from http://madm.dfki.de/files/sentinel/EuroSAT.zip and extract it to `$DATA/eurosat/`. 163 | - Download `split_zhou_EuroSAT.json` from [here](https://drive.google.com/file/d/1Ip7yaCWFi0eaOFUGga0lUdVi_DDQth1o/view?usp=sharing). 164 | 165 | The directory structure should look like 166 | ``` 167 | eurosat/ 168 | |–– 2750/ 169 | |–– split_zhou_EuroSAT.json 170 | ``` 171 | 172 | ### UCF101 173 | - Create a folder named `ucf101/` under `$DATA`. 174 | - Download the zip file `UCF-101-midframes.zip` from [here](https://drive.google.com/file/d/10Jqome3vtUA2keJkNanAiFpgbyC9Hc2O/view?usp=sharing) and extract it to `$DATA/ucf101/`. This zip file contains the extracted middle video frames. 175 | - Download `split_zhou_UCF101.json` from this [link](https://drive.google.com/file/d/1I0S0q91hJfsV9Gf4xDIjgDq4AqBNJb1y/view?usp=sharing). 176 | 177 | The directory structure should look like 178 | ``` 179 | ucf101/ 180 | |–– UCF-101-midframes/ 181 | |–– split_zhou_UCF101.json 182 | ``` 183 | 184 | ### ImageNetV2 185 | - Create a folder named `imagenetv2/` under `$DATA`. 186 | - Go to this github repo https://github.com/modestyachts/ImageNetV2. 187 | - Download the matched-frequency dataset from https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-matched-frequency.tar.gz and extract it to `$DATA/imagenetv2/`. 188 | - Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenetv2/`. 189 | 190 | The directory structure should look like 191 | ``` 192 | imagenetv2/ 193 | |–– imagenetv2-matched-frequency-format-val/ 194 | |–– classnames.txt 195 | ``` 196 | 197 | ### ImageNet-Sketch 198 | - Download the dataset from https://github.com/HaohanWang/ImageNet-Sketch. 199 | - Extract the dataset to `$DATA/imagenet-sketch`. 200 | - Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenet-sketch/`. 201 | 202 | The directory structure should look like 203 | ``` 204 | imagenet-sketch/ 205 | |–– images/ # contains 1,000 folders whose names have the format of n* 206 | |–– classnames.txt 207 | ``` 208 | 209 | ### ImageNet-A 210 | - Create a folder named `imagenet-adversarial/` under `$DATA`. 211 | - Download the dataset from https://github.com/hendrycks/natural-adv-examples and extract it to `$DATA/imagenet-adversarial/`. 212 | - Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenet-adversarial/`. 213 | 214 | The directory structure should look like 215 | ``` 216 | imagenet-adversarial/ 217 | |–– imagenet-a/ # contains 200 folders whose names have the format of n* 218 | |–– classnames.txt 219 | ``` 220 | 221 | ### ImageNet-R 222 | - Create a folder named `imagenet-rendition/` under `$DATA`. 223 | - Download the dataset from https://github.com/hendrycks/imagenet-r and extract it to `$DATA/imagenet-rendition/`. 224 | - Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenet-rendition/`. 225 | 226 | The directory structure should look like 227 | ``` 228 | imagenet-rendition/ 229 | |–– imagenet-r/ # contains 200 folders whose names have the format of n* 230 | |–– classnames.txt 231 | ``` -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Mohammad Mahdi Derakhshani 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Bayesian Prompt Learning for Image-Language Model Generalization 3 | ![Variational Prompt Learning](assets/vpt.png "VPT") 4 | 5 | This repo contains the codebase of the ICCV'23 paper [Bayesian Prompt Learning for Image-Language Model Generalization](https://arxiv.org/abs/2210.02390). 6 | 7 | ## How to Install 8 | This code is built on top of the awesome toolbox [Dassl.pytorch](https://github.com/KaiyangZhou/Dassl.pytorch) so you need to install the `dassl` environment first. Simply follow the instructions described [here](https://github.com/KaiyangZhou/Dassl.pytorch#installation) to install `dassl` as well as PyTorch. After that, run `pip install -r requirements.txt` under `VPT/` to install a few more packages required by [CLIP](https://github.com/openai/CLIP) (this should be done when `dassl` is activated). Then, you are ready to go. 9 | 10 | ## Dataset Preparation 11 | 12 | In this paper, we follow [DATASETS.md](DATASETS.md) to install the datasets. The task definition and few-shot learning setting are similar to following papers for fair comparison: 13 | * [Conditional Prompt Learning for Vision-Language Models](https://arxiv.org/abs/2203.05557), in CVPR, 2022. 14 | * [Learning to Prompt for Vision-Language Models](https://arxiv.org/abs/2109.01134), in IJCV, 2022. 15 | 16 | ## How to Run 17 | 18 | Click a paper below to see the detailed instructions on how to run the code to reproduce the results. 19 | 20 | * [Bayesian Prompt Learning for Image-Language Model Generalization](VPT.md) 21 | 22 | ## Results 23 | 24 | The raw numerical results can be found at this [google drive link](https://docs.google.com/spreadsheets/d/e/2PACX-1vSI_8GjWG7gbu_SjqVYfipeDP2ytVaSQqkINU1yEdgFB8gF27FwLXn2E_6c9N7hNWb-o2oB617vifh5/pubhtml). 25 | 26 | ## Citation 27 | If you use this code in your research, please kindly cite the following papers 28 | 29 | ``` 30 | @article{derakhshani2023variational, 31 | title={Bayesian Prompt Learning for Image-Language Model Generalization}, 32 | author={Derakhshani, Mohammad Mahdi and Sanchez, Enrique and Bulat, Adrian and da Costa, Victor Guilherme Turrisi and Snoek, Cees GM and Tzimiropoulos, Georgios and Martinez, Brais}, 33 | journal={ICCV}, 34 | year={2023} 35 | } 36 | ``` 37 | -------------------------------------------------------------------------------- /VPT.md: -------------------------------------------------------------------------------- 1 | ## How to Run 2 | 3 | The running scripts are provided in `scripts/vpt/`. Make sure you change the path in `DATA` and run the commands under `VPT/scripts/vpt/`. 4 | 5 | ### Generalization From Base to New Classes 6 | 7 | This corresponds to the experiments in Section 4.1, i.e., Table 1 and Figure 2. 8 | 9 | You will need both `scripts/vpt/base2new_train.sh` and `scripts/vpt/base2new_test.sh`. The former trains a model on base classes while the latter evaluates the trained model on new classes. Both scripts have file input arguments, i.e.: 10 | * `DATASET` (takes as input a dataset name, like `imagenet` or `caltech101`. The valid names are the files' names in `VPT/configs/datasets/`.) 11 | * `SEED`(Seed number) 12 | * `GPUIDS` (List of gpu ids, should be provided as a sequence of number, separated by ",") 13 | * `L` (Number of Monte Carlo samples) 14 | * `EPOCHS` (Number of training epochs) 15 | 16 | To reduce the possibility of doing mistakes for reproduction, for each dataset, we provide a python script taking `GPUIDS`, `L`, and `EPOCHS` as input arguments. 17 | 18 | Below we provide an example on how to train and evaluate the model on all datasets. 19 | 20 | ```bash 21 | # Caltech dataset 22 | python train_eval_caltech.py --gpuids 0,1,2,3,4,5,6,7 --l 20 --epochs 20 23 | 24 | # OxfordPets dataset 25 | python train_eval_pets.py --gpuids 0,1,2,3,4,5,6,7 --l 40 --epochs 20 26 | 27 | # StanfordCars dataset 28 | python train_eval_cars.py --gpuids 0,1,2,3,4,5,6,7 --l 20 --epochs 40 29 | 30 | # OxfordFlowers dataset 31 | python train_eval_flowers.py --gpuids 0,1,2,3,4,5,6,7 --l 10 --epochs 40 32 | 33 | # Food101 dataset 34 | python train_eval_food101.py --gpuids 0,1,2,3,4,5,6,7 --l 20 --epochs 20 35 | 36 | # FGVCAircraft dataset 37 | python train_eval_fgvc.py --gpuids 0,1,2,3,4,5,6,7 --l 10 --epochs 10 38 | 39 | # SUN397 dataset 40 | python train_eval_sun.py --gpuids 0,1,2,3,4,5,6,7 --l 20 --epochs 10 41 | 42 | # DTD dataset 43 | python train_eval_dtd.py --gpuids 0,1,2,3,4,5,6,7 --l 40 --epochs 10 44 | 45 | # EuroSAT dataset 46 | python train_eval_eurosat.py --gpuids 0,1,2,3,4,5,6,7 --l 20 --epochs 60 47 | 48 | # UCF101 dataset 49 | python train_eval_ucf101.py --gpuids 0,1,2,3,4,5,6,7 --l 5 --epochs 20 50 | 51 | # Imagenet dataset 52 | python train_eval_imagenet.py --gpuids 0,1,2,3,4,5,6,7 --l 10 --epochs 10 53 | ``` 54 | 55 | When the evaluation is done, you can use `parse_test_res.py` to automatically calculate the average results. For instance, after you finish the evaluation (including `base2new_train.sh` and `base2new_test.sh`) on ImageNet using the aforementioned commands, you would get 56 | 57 | ``` 58 | output 59 | |–– base2new/ 60 | | |–– test_new/ 61 | | | |–– imagenet/ 62 | | | | |–– shots_16/ 63 | | | | | |–– VPT/ 64 | | | | | | |–– vit_b16_c4_ep10_batch1_ctxv1/ 65 | | | | | | | |–– seed1/ 66 | | | | | | | |–– seed2/ 67 | | | | | | | |–– seed3/ 68 | | |–– train_base/ 69 | | | |–– imagenet/ 70 | | | | |–– shots_16/ 71 | | | | | |–– VPT/ 72 | | | | | | |–– vit_b16_c4_ep10_batch1_ctxv1/ 73 | | | | | | | |–– seed1/ 74 | | | | | | | |–– seed2/ 75 | | | | | | | |–– seed3/ 76 | ``` 77 | 78 | Then, to get the average performance on the base classes, run 79 | 80 | ```bash 81 | python parse_test_res.py output/base2new/train_base/imagenet/shots_16/VPT/vit_b16_c4_ep10_batch1_ctxv1 82 | ``` 83 | 84 | To get the average performance on the new classes, run 85 | 86 | ```bash 87 | python parse_test_res.py output/base2new/test_new/imagenet/shots_16/VPT/vit_b16_c4_ep10_batch1_ctxv1 --test-log 88 | ``` 89 | 90 | ### Cross-Dataset Transfer 91 | 92 | This corresponds to the experiments in Section 4.2, i.e., Table 2. 93 | 94 | The relevant scripts are `scripts/VPT/xd_train.sh` and `scripts/VPT/xd_test.sh` where the `DATASET` variable is set to the default, namely `imagenet`. To train the model, run 95 | 96 | ```bash 97 | python train_cross_dataset.py --gpuids 0,1,2,3,4,5,6,7 --l 10 --epochs 10 98 | ``` 99 | 100 | Then, you evaluate the model on other datasets, e.g., 101 | 102 | ```bash 103 | python train_cross_dataset.py --dname caltech101 --gpuids 0,1,2,3,4,5,6,7 --l 10 --epochs 10 104 | python train_cross_dataset.py --dname dtd --gpuids 0,1,2,3,4,5,6,7 --l 10 --epochs 10 105 | python train_cross_dataset.py --dname eurosat --gpuids 0,1,2,3,4,5,6,7 --l 10 --epochs 10 106 | python train_cross_dataset.py --dname fgvc_aircraft --gpuids 0,1,2,3,4,5,6,7 --l 10 --epochs 10 107 | python train_cross_dataset.py --dname food101 --gpuids 0,1,2,3,4,5,6,7 --l 10 --epochs 10 108 | python train_cross_dataset.py --dname oxford_flowers --gpuids 0,1,2,3,4,5,6,7 --l 10 --epochs 10 109 | python train_cross_dataset.py --dname oxford_pets --gpuids 0,1,2,3,4,5,6,7 --l 10 --epochs 10 110 | python train_cross_dataset.py --dname stanford_cars --gpuids 0,1,2,3,4,5,6,7 --l 10 --epochs 10 111 | python train_cross_dataset.py --dname sun397 --gpuids 0,1,2,3,4,5,6,7 --l 10 --epochs 10 112 | python train_cross_dataset.py --dname ucf101 --gpuids 0,1,2,3,4,5,6,7 --l 10 --epochs 10 113 | 114 | ``` 115 | 116 | ### Domain Generalization 117 | 118 | This corresponds to the experiments in Section 4.3, i.e., Table 3. 119 | 120 | The steps are similar to those discussed in "Cross-Dataset Transfer" except you evaluate the model on the variants of ImageNet, i.e., `imagenetv2`, `imagenet_sketch`, `imagenet_a` and `imagenet_r`. 121 | 122 | ```bash 123 | python train_cross_dataset.py --dname imagenetv2 --gpuids 0,1,2,3,4,5,6,7 --l 10 --epochs 10 124 | python train_cross_dataset.py --dname imagenet_sketch --gpuids 0,1,2,3,4,5,6,7 --l 10 --epochs 10 125 | python train_cross_dataset.py --dname imagenet_a --gpuids 0,1,2,3,4,5,6,7 --l 10 --epochs 10 126 | python train_cross_dataset.py --dname imagenet_r --gpuids 0,1,2,3,4,5,6,7 --l 10 --epochs 10 127 | 128 | ``` -------------------------------------------------------------------------------- /assets/vpt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saic-fi/Bayesian-Prompt-Learning/8a73fff156327e96f5a3b129aa52dc4dde426c27/assets/vpt.png -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saic-fi/Bayesian-Prompt-Learning/8a73fff156327e96f5a3b129aa52dc4dde426c27/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | try: 16 | from torchvision.transforms import InterpolationMode 17 | BICUBIC = InterpolationMode.BICUBIC 18 | except ImportError: 19 | BICUBIC = Image.BICUBIC 20 | 21 | 22 | if torch.__version__.split(".") < ["1", "7", "1"]: 23 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 24 | 25 | 26 | __all__ = ["available_models", "load", "tokenize"] 27 | _tokenizer = _Tokenizer() 28 | 29 | _MODELS = { 30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 34 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 35 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 36 | } 37 | 38 | 39 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 40 | os.makedirs(root, exist_ok=True) 41 | filename = os.path.basename(url) 42 | 43 | expected_sha256 = url.split("/")[-2] 44 | download_target = os.path.join(root, filename) 45 | 46 | if os.path.exists(download_target) and not os.path.isfile(download_target): 47 | raise RuntimeError(f"{download_target} exists and is not a regular file") 48 | 49 | if os.path.isfile(download_target): 50 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 51 | return download_target 52 | else: 53 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 54 | 55 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 56 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 57 | while True: 58 | buffer = source.read(8192) 59 | if not buffer: 60 | break 61 | 62 | output.write(buffer) 63 | loop.update(len(buffer)) 64 | 65 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 66 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 67 | 68 | return download_target 69 | 70 | 71 | def _transform(n_px): 72 | return Compose([ 73 | Resize(n_px, interpolation=BICUBIC), 74 | CenterCrop(n_px), 75 | lambda image: image.convert("RGB"), 76 | ToTensor(), 77 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 78 | ]) 79 | 80 | 81 | def available_models() -> List[str]: 82 | """Returns the names of available CLIP models""" 83 | return list(_MODELS.keys()) 84 | 85 | 86 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False): 87 | """Load a CLIP model 88 | 89 | Parameters 90 | ---------- 91 | name : str 92 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 93 | 94 | device : Union[str, torch.device] 95 | The device to put the loaded model 96 | 97 | jit : bool 98 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 99 | 100 | Returns 101 | ------- 102 | model : torch.nn.Module 103 | The CLIP model 104 | 105 | preprocess : Callable[[PIL.Image], torch.Tensor] 106 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 107 | """ 108 | if name in _MODELS: 109 | model_path = _download(_MODELS[name]) 110 | elif os.path.isfile(name): 111 | model_path = name 112 | else: 113 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 114 | 115 | try: 116 | # loading JIT archive 117 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 118 | state_dict = None 119 | except RuntimeError: 120 | # loading saved state dict 121 | if jit: 122 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 123 | jit = False 124 | state_dict = torch.load(model_path, map_location="cpu") 125 | 126 | if not jit: 127 | model = build_model(state_dict or model.state_dict()).to(device) 128 | if str(device) == "cpu": 129 | model.float() 130 | return model, _transform(model.visual.input_resolution) 131 | 132 | # patch the device names 133 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 134 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 135 | 136 | def patch_device(module): 137 | try: 138 | graphs = [module.graph] if hasattr(module, "graph") else [] 139 | except RuntimeError: 140 | graphs = [] 141 | 142 | if hasattr(module, "forward1"): 143 | graphs.append(module.forward1.graph) 144 | 145 | for graph in graphs: 146 | for node in graph.findAllNodes("prim::Constant"): 147 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 148 | node.copyAttributes(device_node) 149 | 150 | model.apply(patch_device) 151 | patch_device(model.encode_image) 152 | patch_device(model.encode_text) 153 | 154 | # patch dtype to float32 on CPU 155 | if str(device) == "cpu": 156 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 157 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 158 | float_node = float_input.node() 159 | 160 | def patch_float(module): 161 | try: 162 | graphs = [module.graph] if hasattr(module, "graph") else [] 163 | except RuntimeError: 164 | graphs = [] 165 | 166 | if hasattr(module, "forward1"): 167 | graphs.append(module.forward1.graph) 168 | 169 | for graph in graphs: 170 | for node in graph.findAllNodes("aten::to"): 171 | inputs = list(node.inputs()) 172 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 173 | if inputs[i].node()["value"] == 5: 174 | inputs[i].node().copyAttributes(float_node) 175 | 176 | model.apply(patch_float) 177 | patch_float(model.encode_image) 178 | patch_float(model.encode_text) 179 | 180 | model.float() 181 | 182 | return model, _transform(model.input_resolution.item()) 183 | 184 | 185 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 186 | """ 187 | Returns the tokenized representation of given input string(s) 188 | 189 | Parameters 190 | ---------- 191 | texts : Union[str, List[str]] 192 | An input string or a list of input strings to tokenize 193 | 194 | context_length : int 195 | The context length to use; all CLIP models use 77 as the context length 196 | 197 | truncate: bool 198 | Whether to truncate the text in case its encoding is longer than the context length 199 | 200 | Returns 201 | ------- 202 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 203 | """ 204 | if isinstance(texts, str): 205 | texts = [texts] 206 | 207 | sot_token = _tokenizer.encoder["<|startoftext|>"] 208 | eot_token = _tokenizer.encoder["<|endoftext|>"] 209 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 210 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 211 | 212 | for i, tokens in enumerate(all_tokens): 213 | if len(tokens) > context_length: 214 | if truncate: 215 | tokens = tokens[:context_length] 216 | tokens[-1] = eot_token 217 | else: 218 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 219 | result[i, :len(tokens)] = torch.tensor(tokens) 220 | 221 | return result 222 | -------------------------------------------------------------------------------- /clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | 20 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 24 | 25 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 26 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 27 | 28 | self.relu = nn.ReLU(inplace=True) 29 | self.downsample = None 30 | self.stride = stride 31 | 32 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 33 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 34 | self.downsample = nn.Sequential(OrderedDict([ 35 | ("-1", nn.AvgPool2d(stride)), 36 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 37 | ("1", nn.BatchNorm2d(planes * self.expansion)) 38 | ])) 39 | 40 | def forward(self, x: torch.Tensor): 41 | identity = x 42 | 43 | out = self.relu(self.bn1(self.conv1(x))) 44 | out = self.relu(self.bn2(self.conv2(out))) 45 | out = self.avgpool(out) 46 | out = self.bn3(self.conv3(out)) 47 | 48 | if self.downsample is not None: 49 | identity = self.downsample(x) 50 | 51 | out += identity 52 | out = self.relu(out) 53 | return out 54 | 55 | 56 | class AttentionPool2d(nn.Module): 57 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 58 | super().__init__() 59 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 60 | self.k_proj = nn.Linear(embed_dim, embed_dim) 61 | self.q_proj = nn.Linear(embed_dim, embed_dim) 62 | self.v_proj = nn.Linear(embed_dim, embed_dim) 63 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 64 | self.num_heads = num_heads 65 | 66 | def forward(self, x): 67 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 68 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 69 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 70 | x, _ = F.multi_head_attention_forward( 71 | query=x, key=x, value=x, 72 | embed_dim_to_check=x.shape[-1], 73 | num_heads=self.num_heads, 74 | q_proj_weight=self.q_proj.weight, 75 | k_proj_weight=self.k_proj.weight, 76 | v_proj_weight=self.v_proj.weight, 77 | in_proj_weight=None, 78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 79 | bias_k=None, 80 | bias_v=None, 81 | add_zero_attn=False, 82 | dropout_p=0, 83 | out_proj_weight=self.c_proj.weight, 84 | out_proj_bias=self.c_proj.bias, 85 | use_separate_proj_weight=True, 86 | training=self.training, 87 | need_weights=False 88 | ) 89 | 90 | return x[0] 91 | 92 | 93 | class ModifiedResNet(nn.Module): 94 | """ 95 | A ResNet class that is similar to torchvision's but contains the following changes: 96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 98 | - The final pooling layer is a QKV attention instead of an average pool 99 | """ 100 | 101 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 102 | super().__init__() 103 | self.output_dim = output_dim 104 | self.input_resolution = input_resolution 105 | 106 | # the 3-layer stem 107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 108 | self.bn1 = nn.BatchNorm2d(width // 2) 109 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 110 | self.bn2 = nn.BatchNorm2d(width // 2) 111 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 112 | self.bn3 = nn.BatchNorm2d(width) 113 | self.avgpool = nn.AvgPool2d(2) 114 | self.relu = nn.ReLU(inplace=True) 115 | 116 | # residual layers 117 | self._inplanes = width # this is a *mutable* variable used during construction 118 | self.layer1 = self._make_layer(width, layers[0]) 119 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 120 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 121 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 122 | 123 | embed_dim = width * 32 # the ResNet feature dimension 124 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 125 | 126 | def _make_layer(self, planes, blocks, stride=1): 127 | layers = [Bottleneck(self._inplanes, planes, stride)] 128 | 129 | self._inplanes = planes * Bottleneck.expansion 130 | for _ in range(1, blocks): 131 | layers.append(Bottleneck(self._inplanes, planes)) 132 | 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | def stem(x): 137 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 138 | x = self.relu(bn(conv(x))) 139 | x = self.avgpool(x) 140 | return x 141 | 142 | x = x.type(self.conv1.weight.dtype) 143 | x = stem(x) 144 | x = self.layer1(x) 145 | x = self.layer2(x) 146 | x = self.layer3(x) 147 | x = self.layer4(x) 148 | x = self.attnpool(x) 149 | 150 | return x 151 | 152 | 153 | class LayerNorm(nn.LayerNorm): 154 | """Subclass torch's LayerNorm to handle fp16.""" 155 | 156 | def forward(self, x: torch.Tensor): 157 | orig_type = x.dtype 158 | ret = super().forward(x.type(torch.float32)) 159 | return ret.type(orig_type) 160 | 161 | 162 | class QuickGELU(nn.Module): 163 | def forward(self, x: torch.Tensor): 164 | return x * torch.sigmoid(1.702 * x) 165 | 166 | 167 | class ResidualAttentionBlock(nn.Module): 168 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 169 | super().__init__() 170 | 171 | self.attn = nn.MultiheadAttention(d_model, n_head) 172 | self.ln_1 = LayerNorm(d_model) 173 | self.mlp = nn.Sequential(OrderedDict([ 174 | ("c_fc", nn.Linear(d_model, d_model * 4)), 175 | ("gelu", QuickGELU()), 176 | ("c_proj", nn.Linear(d_model * 4, d_model)) 177 | ])) 178 | self.ln_2 = LayerNorm(d_model) 179 | self.attn_mask = attn_mask 180 | 181 | def attention(self, x: torch.Tensor): 182 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 183 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 184 | 185 | def forward(self, x: torch.Tensor): 186 | x = x + self.attention(self.ln_1(x)) 187 | x = x + self.mlp(self.ln_2(x)) 188 | return x 189 | 190 | 191 | class Transformer(nn.Module): 192 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 193 | super().__init__() 194 | self.width = width 195 | self.layers = layers 196 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 197 | 198 | def forward(self, x: torch.Tensor): 199 | return self.resblocks(x) 200 | 201 | 202 | class VisionTransformer(nn.Module): 203 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 204 | super().__init__() 205 | self.input_resolution = input_resolution 206 | self.output_dim = output_dim 207 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 208 | 209 | scale = width ** -0.5 210 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 211 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 212 | self.ln_pre = LayerNorm(width) 213 | 214 | self.transformer = Transformer(width, layers, heads) 215 | 216 | self.ln_post = LayerNorm(width) 217 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 218 | 219 | def forward(self, x: torch.Tensor): 220 | x = self.conv1(x) # shape = [*, width, grid, grid] 221 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 222 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 223 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 224 | x = x + self.positional_embedding.to(x.dtype) 225 | x = self.ln_pre(x) 226 | 227 | x = x.permute(1, 0, 2) # NLD -> LND 228 | x = self.transformer(x) 229 | x = x.permute(1, 0, 2) # LND -> NLD 230 | 231 | x = self.ln_post(x[:, 0, :]) 232 | 233 | if self.proj is not None: 234 | x = x @ self.proj 235 | 236 | return x 237 | 238 | 239 | class CLIP(nn.Module): 240 | def __init__(self, 241 | embed_dim: int, 242 | # vision 243 | image_resolution: int, 244 | vision_layers: Union[Tuple[int, int, int, int], int], 245 | vision_width: int, 246 | vision_patch_size: int, 247 | # text 248 | context_length: int, 249 | vocab_size: int, 250 | transformer_width: int, 251 | transformer_heads: int, 252 | transformer_layers: int 253 | ): 254 | super().__init__() 255 | 256 | self.context_length = context_length 257 | 258 | if isinstance(vision_layers, (tuple, list)): 259 | vision_heads = vision_width * 32 // 64 260 | self.visual = ModifiedResNet( 261 | layers=vision_layers, 262 | output_dim=embed_dim, 263 | heads=vision_heads, 264 | input_resolution=image_resolution, 265 | width=vision_width 266 | ) 267 | else: 268 | vision_heads = vision_width // 64 269 | self.visual = VisionTransformer( 270 | input_resolution=image_resolution, 271 | patch_size=vision_patch_size, 272 | width=vision_width, 273 | layers=vision_layers, 274 | heads=vision_heads, 275 | output_dim=embed_dim 276 | ) 277 | 278 | self.transformer = Transformer( 279 | width=transformer_width, 280 | layers=transformer_layers, 281 | heads=transformer_heads, 282 | attn_mask=self.build_attention_mask() 283 | ) 284 | 285 | self.vocab_size = vocab_size 286 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 287 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 288 | self.ln_final = LayerNorm(transformer_width) 289 | 290 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 291 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 292 | 293 | self.initialize_parameters() 294 | 295 | def initialize_parameters(self): 296 | nn.init.normal_(self.token_embedding.weight, std=0.02) 297 | nn.init.normal_(self.positional_embedding, std=0.01) 298 | 299 | if isinstance(self.visual, ModifiedResNet): 300 | if self.visual.attnpool is not None: 301 | std = self.visual.attnpool.c_proj.in_features ** -0.5 302 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 303 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 304 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 305 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 306 | 307 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 308 | for name, param in resnet_block.named_parameters(): 309 | if name.endswith("bn3.weight"): 310 | nn.init.zeros_(param) 311 | 312 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 313 | attn_std = self.transformer.width ** -0.5 314 | fc_std = (2 * self.transformer.width) ** -0.5 315 | for block in self.transformer.resblocks: 316 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 317 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 318 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 319 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 320 | 321 | if self.text_projection is not None: 322 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 323 | 324 | def build_attention_mask(self): 325 | # lazily create causal attention mask, with full attention between the vision tokens 326 | # pytorch uses additive attention mask; fill with -inf 327 | mask = torch.empty(self.context_length, self.context_length) 328 | mask.fill_(float("-inf")) 329 | mask.triu_(1) # zero out the lower diagonal 330 | return mask 331 | 332 | @property 333 | def dtype(self): 334 | return self.visual.conv1.weight.dtype 335 | 336 | def encode_image(self, image): 337 | return self.visual(image.type(self.dtype)) 338 | 339 | def encode_text(self, text): 340 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 341 | 342 | x = x + self.positional_embedding.type(self.dtype) 343 | x = x.permute(1, 0, 2) # NLD -> LND 344 | x = self.transformer(x) 345 | x = x.permute(1, 0, 2) # LND -> NLD 346 | x = self.ln_final(x).type(self.dtype) 347 | 348 | # x.shape = [batch_size, n_ctx, transformer.width] 349 | # take features from the eot embedding (eot_token is the highest number in each sequence) 350 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 351 | 352 | return x 353 | 354 | def forward(self, image, text): 355 | image_features = self.encode_image(image) 356 | text_features = self.encode_text(text) 357 | 358 | # normalized features 359 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 360 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 361 | 362 | # cosine similarity as logits 363 | logit_scale = self.logit_scale.exp() 364 | logits_per_image = logit_scale * image_features @ text_features.t() 365 | logits_per_text = logit_scale * text_features @ image_features.t() 366 | 367 | # shape = [global_batch_size, global_batch_size] 368 | return logits_per_image, logits_per_text 369 | 370 | 371 | def convert_weights(model: nn.Module): 372 | """Convert applicable model parameters to fp16""" 373 | 374 | def _convert_weights_to_fp16(l): 375 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 376 | l.weight.data = l.weight.data.half() 377 | if l.bias is not None: 378 | l.bias.data = l.bias.data.half() 379 | 380 | if isinstance(l, nn.MultiheadAttention): 381 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 382 | tensor = getattr(l, attr) 383 | if tensor is not None: 384 | tensor.data = tensor.data.half() 385 | 386 | for name in ["text_projection", "proj"]: 387 | if hasattr(l, name): 388 | attr = getattr(l, name) 389 | if attr is not None: 390 | attr.data = attr.data.half() 391 | 392 | model.apply(_convert_weights_to_fp16) 393 | 394 | 395 | def build_model(state_dict: dict): 396 | vit = "visual.proj" in state_dict 397 | 398 | if vit: 399 | vision_width = state_dict["visual.conv1.weight"].shape[0] 400 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 401 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 402 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 403 | image_resolution = vision_patch_size * grid_size 404 | else: 405 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 406 | vision_layers = tuple(counts) 407 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 408 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 409 | vision_patch_size = None 410 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 411 | image_resolution = output_width * 32 412 | 413 | embed_dim = state_dict["text_projection"].shape[1] 414 | context_length = state_dict["positional_embedding"].shape[0] 415 | vocab_size = state_dict["token_embedding.weight"].shape[0] 416 | transformer_width = state_dict["ln_final.weight"].shape[0] 417 | transformer_heads = transformer_width // 64 418 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 419 | 420 | model = CLIP( 421 | embed_dim, 422 | image_resolution, vision_layers, vision_width, vision_patch_size, 423 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 424 | ) 425 | 426 | for key in ["input_resolution", "context_length", "vocab_size"]: 427 | if key in state_dict: 428 | del state_dict[key] 429 | 430 | convert_weights(model) 431 | model.load_state_dict(state_dict) 432 | return model.eval() 433 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /configs/datasets/caltech101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "Caltech101" 3 | -------------------------------------------------------------------------------- /configs/datasets/dtd.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "DescribableTextures" 3 | -------------------------------------------------------------------------------- /configs/datasets/eurosat.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "EuroSAT" 3 | -------------------------------------------------------------------------------- /configs/datasets/fgvc_aircraft.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "FGVCAircraft" 3 | -------------------------------------------------------------------------------- /configs/datasets/food101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "Food101" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNet" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenet_a.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetA" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenet_r.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetR" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenet_sketch.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetSketch" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenetv2.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetV2" 3 | -------------------------------------------------------------------------------- /configs/datasets/oxford_flowers.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "OxfordFlowers" -------------------------------------------------------------------------------- /configs/datasets/oxford_pets.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "OxfordPets" -------------------------------------------------------------------------------- /configs/datasets/stanford_cars.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "StanfordCars" 3 | -------------------------------------------------------------------------------- /configs/datasets/sun397.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "SUN397" 3 | -------------------------------------------------------------------------------- /configs/datasets/ucf101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "UCF101" 3 | -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn101.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN101" -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn101_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN101" -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" 34 | -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50_ep100.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50_ep50_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50_val.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 200 4 | TEST: 5 | BATCH_SIZE: 200 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | MODEL: 16 | BACKBONE: 17 | NAME: "RN50" -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b16.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b16_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" 34 | -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b16_ep100.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b16_ep100_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" 34 | -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b16_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b16_ep50_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" 34 | -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b32.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/32" -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b32_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/32" -------------------------------------------------------------------------------- /configs/trainers/VPT/vit_b16_c16_ep10_batch1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 16 34 | CTX_INIT: "" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/VPT/vit_b16_c4_ep10_batch1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 4 34 | CTX_INIT: "" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/VPT/vit_b16_c4_ep10_batch1_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 1 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 40 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | VPT: 33 | N_CTX: 4 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" 36 | L: 10 37 | VPT_TYPE: "coopvpt" -------------------------------------------------------------------------------- /configs/trainers/VPT/vit_b16_c8_ep10_batch1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 8 34 | CTX_INIT: "" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /datasets/caltech101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | from .dtd import DescribableTextures as DTD 9 | 10 | IGNORED = ["BACKGROUND_Google", "Faces_easy"] 11 | NEW_CNAMES = { 12 | "airplanes": "airplane", 13 | "Faces": "face", 14 | "Leopards": "leopard", 15 | "Motorbikes": "motorbike", 16 | } 17 | 18 | 19 | @DATASET_REGISTRY.register() 20 | class Caltech101(DatasetBase): 21 | 22 | dataset_dir = "caltech-101" 23 | 24 | def __init__(self, cfg): 25 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 26 | self.dataset_dir = os.path.join(root, self.dataset_dir) 27 | self.image_dir = os.path.join(self.dataset_dir, "101_ObjectCategories") 28 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_Caltech101.json") 29 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 30 | mkdir_if_missing(self.split_fewshot_dir) 31 | 32 | if os.path.exists(self.split_path): 33 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 34 | else: 35 | train, val, test = DTD.read_and_split_data(self.image_dir, ignored=IGNORED, new_cnames=NEW_CNAMES) 36 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 37 | 38 | num_shots = cfg.DATASET.NUM_SHOTS 39 | if num_shots >= 1: 40 | seed = cfg.SEED 41 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 42 | 43 | if os.path.exists(preprocessed): 44 | print(f"Loading preprocessed few-shot data from {preprocessed}") 45 | with open(preprocessed, "rb") as file: 46 | data = pickle.load(file) 47 | train, val = data["train"], data["val"] 48 | else: 49 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 50 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 51 | data = {"train": train, "val": val} 52 | print(f"Saving preprocessed few-shot data to {preprocessed}") 53 | with open(preprocessed, "wb") as file: 54 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 55 | 56 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 57 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 58 | 59 | super().__init__(train_x=train, val=val, test=test) 60 | -------------------------------------------------------------------------------- /datasets/dtd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | 5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 6 | from dassl.utils import listdir_nohidden, mkdir_if_missing 7 | 8 | from .oxford_pets import OxfordPets 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class DescribableTextures(DatasetBase): 13 | 14 | dataset_dir = "dtd" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "images") 20 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_DescribableTextures.json") 21 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 22 | mkdir_if_missing(self.split_fewshot_dir) 23 | 24 | if os.path.exists(self.split_path): 25 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 26 | else: 27 | train, val, test = self.read_and_split_data(self.image_dir) 28 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 29 | 30 | num_shots = cfg.DATASET.NUM_SHOTS 31 | if num_shots >= 1: 32 | seed = cfg.SEED 33 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 34 | 35 | if os.path.exists(preprocessed): 36 | print(f"Loading preprocessed few-shot data from {preprocessed}") 37 | with open(preprocessed, "rb") as file: 38 | data = pickle.load(file) 39 | train, val = data["train"], data["val"] 40 | else: 41 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 42 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 43 | data = {"train": train, "val": val} 44 | print(f"Saving preprocessed few-shot data to {preprocessed}") 45 | with open(preprocessed, "wb") as file: 46 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 47 | 48 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 49 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 50 | 51 | super().__init__(train_x=train, val=val, test=test) 52 | 53 | @staticmethod 54 | def read_and_split_data(image_dir, p_trn=0.5, p_val=0.2, ignored=[], new_cnames=None): 55 | # The data are supposed to be organized into the following structure 56 | # ============= 57 | # images/ 58 | # dog/ 59 | # cat/ 60 | # horse/ 61 | # ============= 62 | categories = listdir_nohidden(image_dir) 63 | categories = [c for c in categories if c not in ignored] 64 | categories.sort() 65 | 66 | p_tst = 1 - p_trn - p_val 67 | print(f"Splitting into {p_trn:.0%} train, {p_val:.0%} val, and {p_tst:.0%} test") 68 | 69 | def _collate(ims, y, c): 70 | items = [] 71 | for im in ims: 72 | item = Datum(impath=im, label=y, classname=c) # is already 0-based 73 | items.append(item) 74 | return items 75 | 76 | train, val, test = [], [], [] 77 | for label, category in enumerate(categories): 78 | category_dir = os.path.join(image_dir, category) 79 | images = listdir_nohidden(category_dir) 80 | images = [os.path.join(category_dir, im) for im in images] 81 | random.shuffle(images) 82 | n_total = len(images) 83 | n_train = round(n_total * p_trn) 84 | n_val = round(n_total * p_val) 85 | n_test = n_total - n_train - n_val 86 | assert n_train > 0 and n_val > 0 and n_test > 0 87 | 88 | if new_cnames is not None and category in new_cnames: 89 | category = new_cnames[category] 90 | 91 | train.extend(_collate(images[:n_train], label, category)) 92 | val.extend(_collate(images[n_train : n_train + n_val], label, category)) 93 | test.extend(_collate(images[n_train + n_val :], label, category)) 94 | 95 | return train, val, test 96 | -------------------------------------------------------------------------------- /datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | from .dtd import DescribableTextures as DTD 9 | 10 | NEW_CNAMES = { 11 | "AnnualCrop": "Annual Crop Land", 12 | "Forest": "Forest", 13 | "HerbaceousVegetation": "Herbaceous Vegetation Land", 14 | "Highway": "Highway or Road", 15 | "Industrial": "Industrial Buildings", 16 | "Pasture": "Pasture Land", 17 | "PermanentCrop": "Permanent Crop Land", 18 | "Residential": "Residential Buildings", 19 | "River": "River", 20 | "SeaLake": "Sea or Lake", 21 | } 22 | 23 | 24 | @DATASET_REGISTRY.register() 25 | class EuroSAT(DatasetBase): 26 | 27 | dataset_dir = "eurosat" 28 | 29 | def __init__(self, cfg): 30 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 31 | self.dataset_dir = os.path.join(root, self.dataset_dir) 32 | self.image_dir = os.path.join(self.dataset_dir, "2750") 33 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_EuroSAT.json") 34 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 35 | mkdir_if_missing(self.split_fewshot_dir) 36 | 37 | if os.path.exists(self.split_path): 38 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 39 | else: 40 | train, val, test = DTD.read_and_split_data(self.image_dir, new_cnames=NEW_CNAMES) 41 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 42 | 43 | num_shots = cfg.DATASET.NUM_SHOTS 44 | if num_shots >= 1: 45 | seed = cfg.SEED 46 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 47 | 48 | if os.path.exists(preprocessed): 49 | print(f"Loading preprocessed few-shot data from {preprocessed}") 50 | with open(preprocessed, "rb") as file: 51 | data = pickle.load(file) 52 | train, val = data["train"], data["val"] 53 | else: 54 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 55 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 56 | data = {"train": train, "val": val} 57 | print(f"Saving preprocessed few-shot data to {preprocessed}") 58 | with open(preprocessed, "wb") as file: 59 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 60 | 61 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 62 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 63 | 64 | super().__init__(train_x=train, val=val, test=test) 65 | 66 | def update_classname(self, dataset_old): 67 | dataset_new = [] 68 | for item_old in dataset_old: 69 | cname_old = item_old.classname 70 | cname_new = NEW_CLASSNAMES[cname_old] 71 | item_new = Datum(impath=item_old.impath, label=item_old.label, classname=cname_new) 72 | dataset_new.append(item_new) 73 | return dataset_new 74 | -------------------------------------------------------------------------------- /datasets/fgvc_aircraft.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class FGVCAircraft(DatasetBase): 12 | 13 | dataset_dir = "fgvc_aircraft" 14 | 15 | def __init__(self, cfg): 16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 17 | self.dataset_dir = os.path.join(root, self.dataset_dir) 18 | self.image_dir = os.path.join(self.dataset_dir, "images") 19 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 20 | mkdir_if_missing(self.split_fewshot_dir) 21 | 22 | classnames = [] 23 | with open(os.path.join(self.dataset_dir, "variants.txt"), "r") as f: 24 | lines = f.readlines() 25 | for line in lines: 26 | classnames.append(line.strip()) 27 | cname2lab = {c: i for i, c in enumerate(classnames)} 28 | 29 | train = self.read_data(cname2lab, "images_variant_train.txt") 30 | val = self.read_data(cname2lab, "images_variant_val.txt") 31 | test = self.read_data(cname2lab, "images_variant_test.txt") 32 | 33 | num_shots = cfg.DATASET.NUM_SHOTS 34 | if num_shots >= 1: 35 | seed = cfg.SEED 36 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 37 | 38 | if os.path.exists(preprocessed): 39 | print(f"Loading preprocessed few-shot data from {preprocessed}") 40 | with open(preprocessed, "rb") as file: 41 | data = pickle.load(file) 42 | train, val = data["train"], data["val"] 43 | else: 44 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 45 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 46 | data = {"train": train, "val": val} 47 | print(f"Saving preprocessed few-shot data to {preprocessed}") 48 | with open(preprocessed, "wb") as file: 49 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 50 | 51 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 52 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 53 | 54 | super().__init__(train_x=train, val=val, test=test) 55 | 56 | def read_data(self, cname2lab, split_file): 57 | filepath = os.path.join(self.dataset_dir, split_file) 58 | items = [] 59 | 60 | with open(filepath, "r") as f: 61 | lines = f.readlines() 62 | for line in lines: 63 | line = line.strip().split(" ") 64 | imname = line[0] + ".jpg" 65 | classname = " ".join(line[1:]) 66 | impath = os.path.join(self.image_dir, imname) 67 | label = cname2lab[classname] 68 | item = Datum(impath=impath, label=label, classname=classname) 69 | items.append(item) 70 | 71 | return items 72 | -------------------------------------------------------------------------------- /datasets/food101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | from .dtd import DescribableTextures as DTD 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class Food101(DatasetBase): 13 | 14 | dataset_dir = "food-101" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "images") 20 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_Food101.json") 21 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 22 | mkdir_if_missing(self.split_fewshot_dir) 23 | 24 | if os.path.exists(self.split_path): 25 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 26 | else: 27 | train, val, test = DTD.read_and_split_data(self.image_dir) 28 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 29 | 30 | num_shots = cfg.DATASET.NUM_SHOTS 31 | if num_shots >= 1: 32 | seed = cfg.SEED 33 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 34 | 35 | if os.path.exists(preprocessed): 36 | print(f"Loading preprocessed few-shot data from {preprocessed}") 37 | with open(preprocessed, "rb") as file: 38 | data = pickle.load(file) 39 | train, val = data["train"], data["val"] 40 | else: 41 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 42 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 43 | data = {"train": train, "val": val} 44 | print(f"Saving preprocessed few-shot data to {preprocessed}") 45 | with open(preprocessed, "wb") as file: 46 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 47 | 48 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 49 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 50 | 51 | super().__init__(train_x=train, val=val, test=test) 52 | -------------------------------------------------------------------------------- /datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from collections import OrderedDict 4 | 5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 6 | from dassl.utils import listdir_nohidden, mkdir_if_missing 7 | 8 | from .oxford_pets import OxfordPets 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class ImageNet(DatasetBase): 13 | 14 | dataset_dir = "imagenet" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "images") 20 | self.preprocessed = os.path.join(self.dataset_dir, "preprocessed.pkl") 21 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 22 | mkdir_if_missing(self.split_fewshot_dir) 23 | 24 | if os.path.exists(self.preprocessed): 25 | with open(self.preprocessed, "rb") as f: 26 | preprocessed = pickle.load(f) 27 | train = preprocessed["train"] 28 | test = preprocessed["test"] 29 | else: 30 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 31 | classnames = self.read_classnames(text_file) 32 | train = self.read_data(classnames, "train") 33 | # Follow standard practice to perform evaluation on the val set 34 | # Also used as the val set (so evaluate the last-step model) 35 | test = self.read_data(classnames, "val") 36 | 37 | preprocessed = {"train": train, "test": test} 38 | with open(self.preprocessed, "wb") as f: 39 | pickle.dump(preprocessed, f, protocol=pickle.HIGHEST_PROTOCOL) 40 | 41 | num_shots = cfg.DATASET.NUM_SHOTS 42 | if num_shots >= 1: 43 | seed = cfg.SEED 44 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 45 | 46 | if os.path.exists(preprocessed): 47 | print(f"Loading preprocessed few-shot data from {preprocessed}") 48 | with open(preprocessed, "rb") as file: 49 | data = pickle.load(file) 50 | train = data["train"] 51 | else: 52 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 53 | data = {"train": train} 54 | print(f"Saving preprocessed few-shot data to {preprocessed}") 55 | with open(preprocessed, "wb") as file: 56 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 57 | 58 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 59 | train, test = OxfordPets.subsample_classes(train, test, subsample=subsample) 60 | 61 | super().__init__(train_x=train, val=test, test=test) 62 | 63 | @staticmethod 64 | def read_classnames(text_file): 65 | """Return a dictionary containing 66 | key-value pairs of : . 67 | """ 68 | classnames = OrderedDict() 69 | with open(text_file, "r") as f: 70 | lines = f.readlines() 71 | for line in lines: 72 | line = line.strip().split(" ") 73 | folder = line[0] 74 | classname = " ".join(line[1:]) 75 | classnames[folder] = classname 76 | return classnames 77 | 78 | def read_data(self, classnames, split_dir): 79 | split_dir = os.path.join(self.image_dir, split_dir) 80 | folders = sorted(f.name for f in os.scandir(split_dir) if f.is_dir()) 81 | items = [] 82 | 83 | for label, folder in enumerate(folders): 84 | imnames = listdir_nohidden(os.path.join(split_dir, folder)) 85 | classname = classnames[folder] 86 | for imname in imnames: 87 | impath = os.path.join(split_dir, folder, imname) 88 | item = Datum(impath=impath, label=label, classname=classname) 89 | items.append(item) 90 | 91 | return items 92 | -------------------------------------------------------------------------------- /datasets/imagenet_a.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | TO_BE_IGNORED = ["README.txt"] 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class ImageNetA(DatasetBase): 13 | """ImageNet-A(dversarial). 14 | 15 | This dataset is used for testing only. 16 | """ 17 | 18 | dataset_dir = "imagenet-adversarial" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.image_dir = os.path.join(self.dataset_dir, "imagenet-a") 24 | 25 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 26 | classnames = ImageNet.read_classnames(text_file) 27 | 28 | data = self.read_data(classnames) 29 | 30 | super().__init__(train_x=data, test=data) 31 | 32 | def read_data(self, classnames): 33 | image_dir = self.image_dir 34 | folders = listdir_nohidden(image_dir, sort=True) 35 | folders = [f for f in folders if f not in TO_BE_IGNORED] 36 | items = [] 37 | 38 | for label, folder in enumerate(folders): 39 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 40 | classname = classnames[folder] 41 | for imname in imnames: 42 | impath = os.path.join(image_dir, folder, imname) 43 | item = Datum(impath=impath, label=label, classname=classname) 44 | items.append(item) 45 | 46 | return items 47 | -------------------------------------------------------------------------------- /datasets/imagenet_r.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | TO_BE_IGNORED = ["README.txt"] 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class ImageNetR(DatasetBase): 13 | """ImageNet-R(endition). 14 | 15 | This dataset is used for testing only. 16 | """ 17 | 18 | dataset_dir = "imagenet-rendition" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.image_dir = os.path.join(self.dataset_dir, "imagenet-r") 24 | 25 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 26 | classnames = ImageNet.read_classnames(text_file) 27 | 28 | data = self.read_data(classnames) 29 | 30 | super().__init__(train_x=data, test=data) 31 | 32 | def read_data(self, classnames): 33 | image_dir = self.image_dir 34 | folders = listdir_nohidden(image_dir, sort=True) 35 | folders = [f for f in folders if f not in TO_BE_IGNORED] 36 | items = [] 37 | 38 | for label, folder in enumerate(folders): 39 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 40 | classname = classnames[folder] 41 | for imname in imnames: 42 | impath = os.path.join(image_dir, folder, imname) 43 | item = Datum(impath=impath, label=label, classname=classname) 44 | items.append(item) 45 | 46 | return items 47 | -------------------------------------------------------------------------------- /datasets/imagenet_sketch.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class ImageNetSketch(DatasetBase): 11 | """ImageNet-Sketch. 12 | 13 | This dataset is used for testing only. 14 | """ 15 | 16 | dataset_dir = "imagenet-sketch" 17 | 18 | def __init__(self, cfg): 19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 20 | self.dataset_dir = os.path.join(root, self.dataset_dir) 21 | self.image_dir = os.path.join(self.dataset_dir, "images") 22 | 23 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 24 | classnames = ImageNet.read_classnames(text_file) 25 | 26 | data = self.read_data(classnames) 27 | 28 | super().__init__(train_x=data, test=data) 29 | 30 | def read_data(self, classnames): 31 | image_dir = self.image_dir 32 | folders = listdir_nohidden(image_dir, sort=True) 33 | items = [] 34 | 35 | for label, folder in enumerate(folders): 36 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 37 | classname = classnames[folder] 38 | for imname in imnames: 39 | impath = os.path.join(image_dir, folder, imname) 40 | item = Datum(impath=impath, label=label, classname=classname) 41 | items.append(item) 42 | 43 | return items 44 | -------------------------------------------------------------------------------- /datasets/imagenetv2.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class ImageNetV2(DatasetBase): 11 | """ImageNetV2. 12 | 13 | This dataset is used for testing only. 14 | """ 15 | 16 | dataset_dir = "imagenetv2" 17 | 18 | def __init__(self, cfg): 19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 20 | self.dataset_dir = os.path.join(root, self.dataset_dir) 21 | image_dir = "imagenetv2-matched-frequency-format-val" 22 | self.image_dir = os.path.join(self.dataset_dir, image_dir) 23 | 24 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 25 | classnames = ImageNet.read_classnames(text_file) 26 | 27 | data = self.read_data(classnames) 28 | 29 | super().__init__(train_x=data, test=data) 30 | 31 | def read_data(self, classnames): 32 | image_dir = self.image_dir 33 | folders = list(classnames.keys()) 34 | items = [] 35 | 36 | for label in range(1000): 37 | class_dir = os.path.join(image_dir, str(label)) 38 | imnames = listdir_nohidden(class_dir) 39 | folder = folders[label] 40 | classname = classnames[folder] 41 | for imname in imnames: 42 | impath = os.path.join(class_dir, imname) 43 | item = Datum(impath=impath, label=label, classname=classname) 44 | items.append(item) 45 | 46 | return items 47 | -------------------------------------------------------------------------------- /datasets/oxford_flowers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | from scipy.io import loadmat 5 | from collections import defaultdict 6 | 7 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 8 | from dassl.utils import read_json, mkdir_if_missing 9 | 10 | from .oxford_pets import OxfordPets 11 | 12 | 13 | @DATASET_REGISTRY.register() 14 | class OxfordFlowers(DatasetBase): 15 | 16 | dataset_dir = "oxford_flowers" 17 | 18 | def __init__(self, cfg): 19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 20 | self.dataset_dir = os.path.join(root, self.dataset_dir) 21 | self.image_dir = os.path.join(self.dataset_dir, "jpg") 22 | self.label_file = os.path.join(self.dataset_dir, "imagelabels.mat") 23 | self.lab2cname_file = os.path.join(self.dataset_dir, "cat_to_name.json") 24 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordFlowers.json") 25 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 26 | mkdir_if_missing(self.split_fewshot_dir) 27 | 28 | if os.path.exists(self.split_path): 29 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 30 | else: 31 | train, val, test = self.read_data() 32 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 33 | 34 | num_shots = cfg.DATASET.NUM_SHOTS 35 | if num_shots >= 1: 36 | seed = cfg.SEED 37 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 38 | 39 | if os.path.exists(preprocessed): 40 | print(f"Loading preprocessed few-shot data from {preprocessed}") 41 | with open(preprocessed, "rb") as file: 42 | data = pickle.load(file) 43 | train, val = data["train"], data["val"] 44 | else: 45 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 46 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 47 | data = {"train": train, "val": val} 48 | print(f"Saving preprocessed few-shot data to {preprocessed}") 49 | with open(preprocessed, "wb") as file: 50 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 51 | 52 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 53 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 54 | 55 | super().__init__(train_x=train, val=val, test=test) 56 | 57 | def read_data(self): 58 | tracker = defaultdict(list) 59 | label_file = loadmat(self.label_file)["labels"][0] 60 | for i, label in enumerate(label_file): 61 | imname = f"image_{str(i + 1).zfill(5)}.jpg" 62 | impath = os.path.join(self.image_dir, imname) 63 | label = int(label) 64 | tracker[label].append(impath) 65 | 66 | print("Splitting data into 50% train, 20% val, and 30% test") 67 | 68 | def _collate(ims, y, c): 69 | items = [] 70 | for im in ims: 71 | item = Datum(impath=im, label=y - 1, classname=c) # convert to 0-based label 72 | items.append(item) 73 | return items 74 | 75 | lab2cname = read_json(self.lab2cname_file) 76 | train, val, test = [], [], [] 77 | for label, impaths in tracker.items(): 78 | random.shuffle(impaths) 79 | n_total = len(impaths) 80 | n_train = round(n_total * 0.5) 81 | n_val = round(n_total * 0.2) 82 | n_test = n_total - n_train - n_val 83 | assert n_train > 0 and n_val > 0 and n_test > 0 84 | cname = lab2cname[str(label)] 85 | train.extend(_collate(impaths[:n_train], label, cname)) 86 | val.extend(_collate(impaths[n_train : n_train + n_val], label, cname)) 87 | test.extend(_collate(impaths[n_train + n_val :], label, cname)) 88 | 89 | return train, val, test 90 | -------------------------------------------------------------------------------- /datasets/oxford_pets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import math 4 | import random 5 | from collections import defaultdict 6 | 7 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 8 | from dassl.utils import read_json, write_json, mkdir_if_missing 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class OxfordPets(DatasetBase): 13 | 14 | dataset_dir = "oxford_pets" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "images") 20 | self.anno_dir = os.path.join(self.dataset_dir, "annotations") 21 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordPets.json") 22 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 23 | mkdir_if_missing(self.split_fewshot_dir) 24 | 25 | if os.path.exists(self.split_path): 26 | train, val, test = self.read_split(self.split_path, self.image_dir) 27 | else: 28 | trainval = self.read_data(split_file="trainval.txt") 29 | test = self.read_data(split_file="test.txt") 30 | train, val = self.split_trainval(trainval) 31 | self.save_split(train, val, test, self.split_path, self.image_dir) 32 | 33 | num_shots = cfg.DATASET.NUM_SHOTS 34 | if num_shots >= 1: 35 | seed = cfg.SEED 36 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 37 | 38 | if os.path.exists(preprocessed): 39 | print(f"Loading preprocessed few-shot data from {preprocessed}") 40 | with open(preprocessed, "rb") as file: 41 | data = pickle.load(file) 42 | train, val = data["train"], data["val"] 43 | else: 44 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 45 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 46 | data = {"train": train, "val": val} 47 | print(f"Saving preprocessed few-shot data to {preprocessed}") 48 | with open(preprocessed, "wb") as file: 49 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 50 | 51 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 52 | train, val, test = self.subsample_classes(train, val, test, subsample=subsample) 53 | 54 | super().__init__(train_x=train, val=val, test=test) 55 | 56 | def read_data(self, split_file): 57 | filepath = os.path.join(self.anno_dir, split_file) 58 | items = [] 59 | 60 | with open(filepath, "r") as f: 61 | lines = f.readlines() 62 | for line in lines: 63 | line = line.strip() 64 | imname, label, species, _ = line.split(" ") 65 | breed = imname.split("_")[:-1] 66 | breed = "_".join(breed) 67 | breed = breed.lower() 68 | imname += ".jpg" 69 | impath = os.path.join(self.image_dir, imname) 70 | label = int(label) - 1 # convert to 0-based index 71 | item = Datum(impath=impath, label=label, classname=breed) 72 | items.append(item) 73 | 74 | return items 75 | 76 | @staticmethod 77 | def split_trainval(trainval, p_val=0.2): 78 | p_trn = 1 - p_val 79 | print(f"Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val") 80 | tracker = defaultdict(list) 81 | for idx, item in enumerate(trainval): 82 | label = item.label 83 | tracker[label].append(idx) 84 | 85 | train, val = [], [] 86 | for label, idxs in tracker.items(): 87 | n_val = round(len(idxs) * p_val) 88 | assert n_val > 0 89 | random.shuffle(idxs) 90 | for n, idx in enumerate(idxs): 91 | item = trainval[idx] 92 | if n < n_val: 93 | val.append(item) 94 | else: 95 | train.append(item) 96 | 97 | return train, val 98 | 99 | @staticmethod 100 | def save_split(train, val, test, filepath, path_prefix): 101 | def _extract(items): 102 | out = [] 103 | for item in items: 104 | impath = item.impath 105 | label = item.label 106 | classname = item.classname 107 | impath = impath.replace(path_prefix, "") 108 | if impath.startswith("/"): 109 | impath = impath[1:] 110 | out.append((impath, label, classname)) 111 | return out 112 | 113 | train = _extract(train) 114 | val = _extract(val) 115 | test = _extract(test) 116 | 117 | split = {"train": train, "val": val, "test": test} 118 | 119 | write_json(split, filepath) 120 | print(f"Saved split to {filepath}") 121 | 122 | @staticmethod 123 | def read_split(filepath, path_prefix): 124 | def _convert(items): 125 | out = [] 126 | for impath, label, classname in items: 127 | impath = os.path.join(path_prefix, impath) 128 | item = Datum(impath=impath, label=int(label), classname=classname) 129 | out.append(item) 130 | return out 131 | 132 | print(f"Reading split from {filepath}") 133 | split = read_json(filepath) 134 | train = _convert(split["train"]) 135 | val = _convert(split["val"]) 136 | test = _convert(split["test"]) 137 | 138 | return train, val, test 139 | 140 | @staticmethod 141 | def subsample_classes(*args, subsample="all"): 142 | """Divide classes into two groups. The first group 143 | represents base classes while the second group represents 144 | new classes. 145 | 146 | Args: 147 | args: a list of datasets, e.g. train, val and test. 148 | subsample (str): what classes to subsample. 149 | """ 150 | assert subsample in ["all", "base", "new"] 151 | 152 | if subsample == "all": 153 | return args 154 | 155 | dataset = args[0] 156 | labels = set() 157 | for item in dataset: 158 | labels.add(item.label) 159 | labels = list(labels) 160 | labels.sort() 161 | n = len(labels) 162 | # Divide classes into two halves 163 | m = math.ceil(n / 2) 164 | 165 | print(f"SUBSAMPLE {subsample.upper()} CLASSES!") 166 | if subsample == "base": 167 | selected = labels[:m] # take the first half 168 | else: 169 | selected = labels[m:] # take the second half 170 | relabeler = {y: y_new for y_new, y in enumerate(selected)} 171 | 172 | output = [] 173 | for dataset in args: 174 | dataset_new = [] 175 | for item in dataset: 176 | if item.label not in selected: 177 | continue 178 | item_new = Datum( 179 | impath=item.impath, 180 | label=relabeler[item.label], 181 | classname=item.classname 182 | ) 183 | dataset_new.append(item_new) 184 | output.append(dataset_new) 185 | 186 | return output 187 | -------------------------------------------------------------------------------- /datasets/stanford_cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from scipy.io import loadmat 4 | 5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 6 | from dassl.utils import mkdir_if_missing 7 | 8 | from .oxford_pets import OxfordPets 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class StanfordCars(DatasetBase): 13 | 14 | dataset_dir = "stanford_cars" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_StanfordCars.json") 20 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 21 | mkdir_if_missing(self.split_fewshot_dir) 22 | 23 | if os.path.exists(self.split_path): 24 | train, val, test = OxfordPets.read_split(self.split_path, self.dataset_dir) 25 | else: 26 | trainval_file = os.path.join(self.dataset_dir, "devkit", "cars_train_annos.mat") 27 | test_file = os.path.join(self.dataset_dir, "cars_test_annos_withlabels.mat") 28 | meta_file = os.path.join(self.dataset_dir, "devkit", "cars_meta.mat") 29 | trainval = self.read_data("cars_train", trainval_file, meta_file) 30 | test = self.read_data("cars_test", test_file, meta_file) 31 | train, val = OxfordPets.split_trainval(trainval) 32 | OxfordPets.save_split(train, val, test, self.split_path, self.dataset_dir) 33 | 34 | num_shots = cfg.DATASET.NUM_SHOTS 35 | if num_shots >= 1: 36 | seed = cfg.SEED 37 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 38 | 39 | if os.path.exists(preprocessed): 40 | print(f"Loading preprocessed few-shot data from {preprocessed}") 41 | with open(preprocessed, "rb") as file: 42 | data = pickle.load(file) 43 | train, val = data["train"], data["val"] 44 | else: 45 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 46 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 47 | data = {"train": train, "val": val} 48 | print(f"Saving preprocessed few-shot data to {preprocessed}") 49 | with open(preprocessed, "wb") as file: 50 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 51 | 52 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 53 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 54 | 55 | super().__init__(train_x=train, val=val, test=test) 56 | 57 | def read_data(self, image_dir, anno_file, meta_file): 58 | anno_file = loadmat(anno_file)["annotations"][0] 59 | meta_file = loadmat(meta_file)["class_names"][0] 60 | items = [] 61 | 62 | for i in range(len(anno_file)): 63 | imname = anno_file[i]["fname"][0] 64 | impath = os.path.join(self.dataset_dir, image_dir, imname) 65 | label = anno_file[i]["class"][0, 0] 66 | label = int(label) - 1 # convert to 0-based index 67 | classname = meta_file[label][0] 68 | names = classname.split(" ") 69 | year = names.pop(-1) 70 | names.insert(0, year) 71 | classname = " ".join(names) 72 | item = Datum(impath=impath, label=label, classname=classname) 73 | items.append(item) 74 | 75 | return items 76 | -------------------------------------------------------------------------------- /datasets/sun397.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class SUN397(DatasetBase): 12 | 13 | dataset_dir = "sun397" 14 | 15 | def __init__(self, cfg): 16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 17 | self.dataset_dir = os.path.join(root, self.dataset_dir) 18 | self.image_dir = os.path.join(self.dataset_dir, "SUN397") 19 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_SUN397.json") 20 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 21 | mkdir_if_missing(self.split_fewshot_dir) 22 | 23 | if os.path.exists(self.split_path): 24 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 25 | else: 26 | classnames = [] 27 | with open(os.path.join(self.dataset_dir, "ClassName.txt"), "r") as f: 28 | lines = f.readlines() 29 | for line in lines: 30 | line = line.strip()[1:] # remove / 31 | classnames.append(line) 32 | cname2lab = {c: i for i, c in enumerate(classnames)} 33 | trainval = self.read_data(cname2lab, "Training_01.txt") 34 | test = self.read_data(cname2lab, "Testing_01.txt") 35 | train, val = OxfordPets.split_trainval(trainval) 36 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 37 | 38 | num_shots = cfg.DATASET.NUM_SHOTS 39 | if num_shots >= 1: 40 | seed = cfg.SEED 41 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 42 | 43 | if os.path.exists(preprocessed): 44 | print(f"Loading preprocessed few-shot data from {preprocessed}") 45 | with open(preprocessed, "rb") as file: 46 | data = pickle.load(file) 47 | train, val = data["train"], data["val"] 48 | else: 49 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 50 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 51 | data = {"train": train, "val": val} 52 | print(f"Saving preprocessed few-shot data to {preprocessed}") 53 | with open(preprocessed, "wb") as file: 54 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 55 | 56 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 57 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 58 | 59 | super().__init__(train_x=train, val=val, test=test) 60 | 61 | def read_data(self, cname2lab, text_file): 62 | text_file = os.path.join(self.dataset_dir, text_file) 63 | items = [] 64 | 65 | with open(text_file, "r") as f: 66 | lines = f.readlines() 67 | for line in lines: 68 | imname = line.strip()[1:] # remove / 69 | classname = os.path.dirname(imname) 70 | label = cname2lab[classname] 71 | impath = os.path.join(self.image_dir, imname) 72 | 73 | names = classname.split("/")[1:] # remove 1st letter 74 | names = names[::-1] # put words like indoor/outdoor at first 75 | classname = " ".join(names) 76 | 77 | item = Datum(impath=impath, label=label, classname=classname) 78 | items.append(item) 79 | 80 | return items 81 | -------------------------------------------------------------------------------- /datasets/ucf101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import re 4 | 5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 6 | from dassl.utils import mkdir_if_missing 7 | 8 | from .oxford_pets import OxfordPets 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class UCF101(DatasetBase): 13 | 14 | dataset_dir = "ucf101" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "UCF-101-midframes") 20 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_UCF101.json") 21 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 22 | mkdir_if_missing(self.split_fewshot_dir) 23 | 24 | if os.path.exists(self.split_path): 25 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 26 | else: 27 | cname2lab = {} 28 | filepath = os.path.join(self.dataset_dir, "ucfTrainTestlist/classInd.txt") 29 | with open(filepath, "r") as f: 30 | lines = f.readlines() 31 | for line in lines: 32 | label, classname = line.strip().split(" ") 33 | label = int(label) - 1 # conver to 0-based index 34 | cname2lab[classname] = label 35 | 36 | trainval = self.read_data(cname2lab, "ucfTrainTestlist/trainlist01.txt") 37 | test = self.read_data(cname2lab, "ucfTrainTestlist/testlist01.txt") 38 | train, val = OxfordPets.split_trainval(trainval) 39 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 40 | 41 | num_shots = cfg.DATASET.NUM_SHOTS 42 | if num_shots >= 1: 43 | seed = cfg.SEED 44 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 45 | 46 | if os.path.exists(preprocessed): 47 | print(f"Loading preprocessed few-shot data from {preprocessed}") 48 | with open(preprocessed, "rb") as file: 49 | data = pickle.load(file) 50 | train, val = data["train"], data["val"] 51 | else: 52 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 53 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 54 | data = {"train": train, "val": val} 55 | print(f"Saving preprocessed few-shot data to {preprocessed}") 56 | with open(preprocessed, "wb") as file: 57 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 58 | 59 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 60 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 61 | 62 | super().__init__(train_x=train, val=val, test=test) 63 | 64 | def read_data(self, cname2lab, text_file): 65 | text_file = os.path.join(self.dataset_dir, text_file) 66 | items = [] 67 | 68 | with open(text_file, "r") as f: 69 | lines = f.readlines() 70 | for line in lines: 71 | line = line.strip().split(" ")[0] # trainlist: filename, label 72 | action, filename = line.split("/") 73 | label = cname2lab[action] 74 | 75 | elements = re.findall("[A-Z][^A-Z]*", action) 76 | renamed_action = "_".join(elements) 77 | 78 | filename = filename.replace(".avi", ".jpg") 79 | impath = os.path.join(self.image_dir, renamed_action, filename) 80 | 81 | item = Datum(impath=impath, label=label, classname=renamed_action) 82 | items.append(item) 83 | 84 | return items 85 | -------------------------------------------------------------------------------- /draw_curves.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | save_dir = "main_curves" 8 | if not os.path.exists(save_dir): 9 | os.makedirs(save_dir) 10 | 11 | path = "Results.xlsx" # this is the excel file containing the results (like the one we released) 12 | file = pd.read_excel(path, sheet_name="imcls_fewshot") 13 | 14 | datasets = [ 15 | "OxfordPets", "Flowers102", "FGVCAircraft", "DTD", 16 | "EuroSAT", "StanfordCars", "Food101", "SUN397", 17 | "Caltech101", "UCF101", "ImageNet" 18 | ] 19 | 20 | shots = [1, 2, 4, 8, 16] 21 | 22 | COLORS = { 23 | "zs": "C4", 24 | "linear": "C4", 25 | "ours_v16_end": "C0", 26 | "ours_v16_mid": "C2", 27 | "ours_v16_end_csc": "C1", 28 | "ours_v16_mid_csc": "C3" 29 | } 30 | MS = 3 31 | ALPHA = 1 32 | plt.rcParams.update({"font.size": 12}) 33 | 34 | average = { 35 | "zs": 0., 36 | "ours_v16_end": np.array([0., 0., 0., 0., 0.]), 37 | "ours_v16_mid": np.array([0., 0., 0., 0., 0.]), 38 | "ours_v16_end_csc": np.array([0., 0., 0., 0., 0.]), 39 | "ours_v16_mid_csc": np.array([0., 0., 0., 0., 0.]), 40 | "linear": np.array([0., 0., 0., 0., 0.]) 41 | } 42 | 43 | for dataset in datasets: 44 | print(f"Processing {dataset} ...") 45 | 46 | zs = file[dataset][0] 47 | 48 | ours_v16_end = file[dataset][2:7] 49 | ours_v16_end = [float(num) for num in ours_v16_end] 50 | 51 | ours_v16_mid = file[dataset][7:12] 52 | ours_v16_mid = [float(num) for num in ours_v16_mid] 53 | 54 | ours_v16_end_csc = file[dataset][12:17] 55 | ours_v16_end_csc = [float(num) for num in ours_v16_end_csc] 56 | 57 | ours_v16_mid_csc = file[dataset][17:22] 58 | ours_v16_mid_csc = [float(num) for num in ours_v16_mid_csc] 59 | 60 | linear = file[dataset][22:27] 61 | linear = [float(num) for num in linear] 62 | 63 | average["zs"] += zs 64 | average["ours_v16_end"] += np.array(ours_v16_end) 65 | average["ours_v16_mid"] += np.array(ours_v16_mid) 66 | average["ours_v16_end_csc"] += np.array(ours_v16_end_csc) 67 | average["ours_v16_mid_csc"] += np.array(ours_v16_mid_csc) 68 | average["linear"] += np.array(linear) 69 | 70 | # Plot 71 | values = [zs] 72 | values += linear 73 | values += ours_v16_end 74 | values += ours_v16_mid 75 | values += ours_v16_end_csc 76 | values += ours_v16_mid_csc 77 | val_min, val_max = min(values), max(values) 78 | diff = val_max - val_min 79 | val_bot = val_min - diff*0.05 80 | val_top = val_max + diff*0.05 81 | 82 | fig, ax = plt.subplots() 83 | ax.set_facecolor("#EBEBEB") 84 | 85 | ax.set_xticks([0] + shots) 86 | ax.set_xticklabels([0] + shots) 87 | ax.set_xlabel("Number of labeled training examples per class") 88 | ax.set_ylabel("Score (%)") 89 | ax.grid(axis="x", color="white", linewidth=1) 90 | ax.axhline(zs, color="white", linewidth=1) 91 | ax.set_title(dataset) 92 | ax.set_ylim(val_bot, val_top) 93 | 94 | ax.plot( 95 | 0, zs, 96 | marker="*", 97 | markersize=MS*1.5, 98 | color=COLORS["zs"], 99 | alpha=ALPHA 100 | ) 101 | ax.plot( 102 | shots, ours_v16_end, 103 | marker="o", 104 | markersize=MS, 105 | color=COLORS["ours_v16_end"], 106 | label="CLIP + CoOp ($M\!=\!16$, end)", 107 | alpha=ALPHA 108 | ) 109 | ax.plot( 110 | shots, ours_v16_mid, 111 | marker="o", 112 | markersize=MS, 113 | color=COLORS["ours_v16_mid"], 114 | label="CLIP + CoOp ($M\!=\!16$, mid)", 115 | alpha=ALPHA 116 | ) 117 | ax.plot( 118 | shots, ours_v16_end_csc, 119 | marker="o", 120 | markersize=MS, 121 | color=COLORS["ours_v16_end_csc"], 122 | label="CLIP + CoOp ($M\!=\!16$, end, CSC)", 123 | alpha=ALPHA 124 | ) 125 | ax.plot( 126 | shots, ours_v16_mid_csc, 127 | marker="o", 128 | markersize=MS, 129 | color=COLORS["ours_v16_mid_csc"], 130 | label="CLIP + CoOp ($M\!=\!16$, mid, CSC)", 131 | alpha=ALPHA 132 | ) 133 | ax.plot( 134 | shots, linear, 135 | marker="o", 136 | markersize=MS, 137 | color=COLORS["linear"], 138 | label="Linear probe CLIP", 139 | linestyle="dotted", 140 | alpha=ALPHA 141 | ) 142 | 143 | ax.text(-0.5, zs-diff*0.11, "Zero-shot\nCLIP", color=COLORS["zs"]) 144 | ax.legend(loc="lower right") 145 | 146 | fig.savefig(f"{save_dir}/{dataset}.pdf", bbox_inches="tight") 147 | 148 | 149 | # Plot 150 | average = {k: v/len(datasets) for k, v in average.items()} 151 | zs = average["zs"] 152 | linear = list(average["linear"]) 153 | ours_v16_end = list(average["ours_v16_end"]) 154 | ours_v16_mid = list(average["ours_v16_mid"]) 155 | ours_v16_end_csc = list(average["ours_v16_end_csc"]) 156 | ours_v16_mid_csc = list(average["ours_v16_mid_csc"]) 157 | 158 | values = [zs] 159 | values += linear 160 | values += ours_v16_end 161 | values += ours_v16_mid 162 | values += ours_v16_end_csc 163 | values += ours_v16_mid_csc 164 | val_min, val_max = min(values), max(values) 165 | diff = val_max - val_min 166 | val_bot = val_min - diff*0.05 167 | val_top = val_max + diff*0.05 168 | 169 | fig, ax = plt.subplots() 170 | ax.set_facecolor("#EBEBEB") 171 | 172 | ax.set_xticks([0] + shots) 173 | ax.set_xticklabels([0] + shots) 174 | ax.set_xlabel("Number of labeled training examples per class") 175 | ax.set_ylabel("Score (%)") 176 | ax.grid(axis="x", color="white", linewidth=1) 177 | ax.axhline(zs, color="white", linewidth=1) 178 | ax.set_title("Average over 11 datasets", fontweight="bold") 179 | ax.set_ylim(val_bot, val_top) 180 | 181 | ax.plot( 182 | 0, zs, 183 | marker="*", 184 | markersize=MS*1.5, 185 | color=COLORS["zs"], 186 | alpha=ALPHA 187 | ) 188 | ax.plot( 189 | shots, ours_v16_end, 190 | marker="o", 191 | markersize=MS, 192 | color=COLORS["ours_v16_end"], 193 | label="CLIP + CoOp ($M\!=\!16$, end)", 194 | alpha=ALPHA 195 | ) 196 | ax.plot( 197 | shots, ours_v16_mid, 198 | marker="o", 199 | markersize=MS, 200 | color=COLORS["ours_v16_mid"], 201 | label="CLIP + CoOp ($M\!=\!16$, mid)", 202 | alpha=ALPHA 203 | ) 204 | ax.plot( 205 | shots, ours_v16_end_csc, 206 | marker="o", 207 | markersize=MS, 208 | color=COLORS["ours_v16_end_csc"], 209 | label="CLIP + CoOp ($M\!=\!16$, end, CSC)", 210 | alpha=ALPHA 211 | ) 212 | ax.plot( 213 | shots, ours_v16_mid_csc, 214 | marker="o", 215 | markersize=MS, 216 | color=COLORS["ours_v16_mid_csc"], 217 | label="CLIP + CoOp ($M\!=\!16$, mid, CSC)", 218 | alpha=ALPHA 219 | ) 220 | ax.plot( 221 | shots, linear, 222 | marker="o", 223 | markersize=MS, 224 | color=COLORS["linear"], 225 | label="Linear probe CLIP", 226 | linestyle="dotted", 227 | alpha=ALPHA 228 | ) 229 | 230 | ax.text(-0.5, zs-diff*0.11, "Zero-shot\nCLIP", color=COLORS["zs"]) 231 | ax.legend(loc="lower right") 232 | 233 | fig.savefig(f"{save_dir}/average.pdf", bbox_inches="tight") 234 | -------------------------------------------------------------------------------- /eval.sh: -------------------------------------------------------------------------------- 1 | dataset=$1 2 | mcmc=$2 3 | epochs=$3 4 | 5 | python parse_test_res.py ./ls/base2new/train_base/${dataset}/mcmc_${mcmc}_epochs_${epochs}/shots_16/CoCoOp/vit_b16_c4_ep10_batch1_ctxv1/ 6 | python parse_test_res.py ./ls/base2new/test_new/${dataset}/mcmc_${mcmc}_epochs_${epochs}/shots_16/CoCoOp/vit_b16_c4_ep10_batch1_ctxv1/ --test-log 7 | -------------------------------------------------------------------------------- /interpret_prompt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import torch 5 | 6 | from clip.simple_tokenizer import SimpleTokenizer 7 | from clip import clip 8 | 9 | 10 | def load_clip_to_cpu(backbone_name="RN50"): 11 | url = clip._MODELS[backbone_name] 12 | model_path = clip._download(url) 13 | 14 | try: 15 | # loading JIT archive 16 | model = torch.jit.load(model_path, map_location="cpu").eval() 17 | state_dict = None 18 | 19 | except RuntimeError: 20 | state_dict = torch.load(model_path, map_location="cpu") 21 | 22 | model = clip.build_model(state_dict or model.state_dict()) 23 | 24 | return model 25 | 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument("fpath", type=str, help="Path to the learned prompt") 29 | parser.add_argument("topk", type=int, help="Select top-k similar words") 30 | args = parser.parse_args() 31 | 32 | fpath = args.fpath 33 | topk = args.topk 34 | 35 | assert os.path.exists(fpath) 36 | 37 | print(f"Return the top-{topk} matched words") 38 | 39 | tokenizer = SimpleTokenizer() 40 | clip_model = load_clip_to_cpu() 41 | token_embedding = clip_model.token_embedding.weight 42 | print(f"Size of token embedding: {token_embedding.shape}") 43 | 44 | prompt_learner = torch.load(fpath, map_location="cpu")["state_dict"] 45 | ctx = prompt_learner["ctx"] 46 | ctx = ctx.float() 47 | print(f"Size of context: {ctx.shape}") 48 | 49 | if ctx.dim() == 2: 50 | # Generic context 51 | distance = torch.cdist(ctx, token_embedding) 52 | print(f"Size of distance matrix: {distance.shape}") 53 | sorted_idxs = torch.argsort(distance, dim=1) 54 | sorted_idxs = sorted_idxs[:, :topk] 55 | 56 | for m, idxs in enumerate(sorted_idxs): 57 | words = [tokenizer.decoder[idx.item()] for idx in idxs] 58 | dist = [f"{distance[m, idx].item():.4f}" for idx in idxs] 59 | print(f"{m+1}: {words} {dist}") 60 | 61 | elif ctx.dim() == 3: 62 | # Class-specific context 63 | raise NotImplementedError 64 | -------------------------------------------------------------------------------- /lpclip/README.md: -------------------------------------------------------------------------------- 1 | # Linear Probe CLIP 2 | 3 | To run linear probe baselines, make sure that your current working directory is `lpclip/`. 4 | 5 | Step 1: Extract Features using the CLIP Image Encoder 6 | ```bash 7 | sh feat_extractor.sh 8 | ``` 9 | 10 | Step 2: Train few-shot linear probe 11 | ```bash 12 | sh linear_probe.sh 13 | ``` 14 | 15 | We follow the instructions stated in the Appendix A3 (pp.38) of [the original CLIP paper](https://arxiv.org/pdf/2103.00020.pdf), with a careful hyperparameter sweep. 16 | 17 | Note: please pull the latest Dassl (version >= `606a2c6`). 18 | -------------------------------------------------------------------------------- /lpclip/feat_extractor.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | import numpy as np 3 | import torch 4 | import sys 5 | 6 | sys.path.append(os.path.abspath("..")) 7 | 8 | from datasets.oxford_pets import OxfordPets 9 | from datasets.oxford_flowers import OxfordFlowers 10 | from datasets.fgvc_aircraft import FGVCAircraft 11 | from datasets.dtd import DescribableTextures 12 | from datasets.eurosat import EuroSAT 13 | from datasets.stanford_cars import StanfordCars 14 | from datasets.food101 import Food101 15 | from datasets.sun397 import SUN397 16 | from datasets.caltech101 import Caltech101 17 | from datasets.ucf101 import UCF101 18 | from datasets.imagenet import ImageNet 19 | from datasets.imagenetv2 import ImageNetV2 20 | from datasets.imagenet_sketch import ImageNetSketch 21 | from datasets.imagenet_a import ImageNetA 22 | from datasets.imagenet_r import ImageNetR 23 | 24 | from dassl.utils import setup_logger, set_random_seed, collect_env_info 25 | from dassl.config import get_cfg_default 26 | from dassl.data.transforms import build_transform 27 | from dassl.data import DatasetWrapper 28 | 29 | import clip 30 | 31 | # import pdb; pdb.set_trace() 32 | 33 | 34 | def print_args(args, cfg): 35 | print("***************") 36 | print("** Arguments **") 37 | print("***************") 38 | optkeys = list(args.__dict__.keys()) 39 | optkeys.sort() 40 | for key in optkeys: 41 | print("{}: {}".format(key, args.__dict__[key])) 42 | print("************") 43 | print("** Config **") 44 | print("************") 45 | print(cfg) 46 | 47 | 48 | def reset_cfg(cfg, args): 49 | if args.root: 50 | cfg.DATASET.ROOT = args.root 51 | 52 | if args.output_dir: 53 | cfg.OUTPUT_DIR = args.output_dir 54 | 55 | if args.trainer: 56 | cfg.TRAINER.NAME = args.trainer 57 | 58 | if args.backbone: 59 | cfg.MODEL.BACKBONE.NAME = args.backbone 60 | 61 | if args.head: 62 | cfg.MODEL.HEAD.NAME = args.head 63 | 64 | 65 | def extend_cfg(cfg): 66 | """ 67 | Add new config variables. 68 | 69 | E.g. 70 | from yacs.config import CfgNode as CN 71 | cfg.TRAINER.MY_MODEL = CN() 72 | cfg.TRAINER.MY_MODEL.PARAM_A = 1. 73 | cfg.TRAINER.MY_MODEL.PARAM_B = 0.5 74 | cfg.TRAINER.MY_MODEL.PARAM_C = False 75 | """ 76 | from yacs.config import CfgNode as CN 77 | 78 | cfg.TRAINER.OURS = CN() 79 | cfg.TRAINER.OURS.N_CTX = 10 # number of context vectors 80 | cfg.TRAINER.OURS.CSC = False # class-specific context 81 | cfg.TRAINER.OURS.CTX_INIT = "" # initialize context vectors with given words 82 | cfg.TRAINER.OURS.WEIGHT_U = 0.1 # weight for the unsupervised loss 83 | 84 | 85 | def setup_cfg(args): 86 | cfg = get_cfg_default() 87 | extend_cfg(cfg) 88 | 89 | # 1. From the dataset config file 90 | if args.dataset_config_file: 91 | cfg.merge_from_file(args.dataset_config_file) 92 | 93 | # 2. From the method config file 94 | if args.config_file: 95 | cfg.merge_from_file(args.config_file) 96 | 97 | # 3. From input arguments 98 | reset_cfg(cfg, args) 99 | 100 | cfg.freeze() 101 | 102 | return cfg 103 | 104 | 105 | def main(args): 106 | cfg = setup_cfg(args) 107 | if cfg.SEED >= 0: 108 | print("Setting fixed seed: {}".format(cfg.SEED)) 109 | set_random_seed(cfg.SEED) 110 | setup_logger(cfg.OUTPUT_DIR) 111 | 112 | if torch.cuda.is_available() and cfg.USE_CUDA: 113 | torch.backends.cudnn.benchmark = True 114 | 115 | print_args(args, cfg) 116 | print("Collecting env info ...") 117 | print("** System info **\n{}\n".format(collect_env_info())) 118 | 119 | ###################################### 120 | # Setup DataLoader 121 | ###################################### 122 | dataset = eval(cfg.DATASET.NAME)(cfg) 123 | 124 | if args.split == "train": 125 | dataset_input = dataset.train_x 126 | elif args.split == "val": 127 | dataset_input = dataset.val 128 | else: 129 | dataset_input = dataset.test 130 | 131 | tfm_train = build_transform(cfg, is_train=False) 132 | data_loader = torch.utils.data.DataLoader( 133 | DatasetWrapper(cfg, dataset_input, transform=tfm_train, is_train=False), 134 | batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE, 135 | sampler=None, 136 | shuffle=False, 137 | num_workers=cfg.DATALOADER.NUM_WORKERS, 138 | drop_last=False, 139 | pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA), 140 | ) 141 | 142 | ######################################## 143 | # Setup Network 144 | ######################################## 145 | clip_model, _ = clip.load("RN50", "cuda", jit=False) 146 | clip_model.eval() 147 | ################################################################################################################### 148 | # Start Feature Extractor 149 | feature_list = [] 150 | label_list = [] 151 | train_dataiter = iter(data_loader) 152 | for train_step in range(1, len(train_dataiter) + 1): 153 | batch = next(train_dataiter) 154 | data = batch["img"].cuda() 155 | feature = clip_model.visual(data) 156 | feature = feature.cpu() 157 | for idx in range(len(data)): 158 | feature_list.append(feature[idx].tolist()) 159 | label_list.extend(batch["label"].tolist()) 160 | save_dir = os.path.join(cfg.OUTPUT_DIR, cfg.DATASET.NAME) 161 | os.makedirs(save_dir, exist_ok=True) 162 | save_filename = f"{args.split}" 163 | np.savez( 164 | os.path.join(save_dir, save_filename), 165 | feature_list=feature_list, 166 | label_list=label_list, 167 | ) 168 | 169 | 170 | if __name__ == "__main__": 171 | parser = argparse.ArgumentParser() 172 | parser.add_argument("--root", type=str, default="", help="path to dataset") 173 | parser.add_argument("--output-dir", type=str, default="", help="output directory") 174 | parser.add_argument("--config-file", type=str, default="", help="path to config file") 175 | parser.add_argument( 176 | "--dataset-config-file", 177 | type=str, 178 | default="", 179 | help="path to config file for dataset setup", 180 | ) 181 | parser.add_argument("--num-shot", type=int, default=1, help="number of shots") 182 | parser.add_argument("--split", type=str, choices=["train", "val", "test"], help="which split") 183 | parser.add_argument("--trainer", type=str, default="", help="name of trainer") 184 | parser.add_argument("--backbone", type=str, default="", help="name of CNN backbone") 185 | parser.add_argument("--head", type=str, default="", help="name of head") 186 | parser.add_argument("--seed", type=int, default=-1, help="only positive value enables a fixed seed") 187 | parser.add_argument("--eval-only", action="store_true", help="evaluation only") 188 | args = parser.parse_args() 189 | main(args) 190 | -------------------------------------------------------------------------------- /lpclip/feat_extractor.sh: -------------------------------------------------------------------------------- 1 | # sh feat_extractor.sh 2 | DATA=/path/to/datasets 3 | OUTPUT='./clip_feat/' 4 | SEED=1 5 | 6 | # oxford_pets oxford_flowers fgvc_aircraft dtd eurosat stanford_cars food101 sun397 caltech101 ucf101 imagenet 7 | for DATASET in oxford_pets 8 | do 9 | for SPLIT in train val test 10 | do 11 | python feat_extractor.py \ 12 | --split ${SPLIT} \ 13 | --root ${DATA} \ 14 | --seed ${SEED} \ 15 | --dataset-config-file ../configs/datasets/${DATASET}.yaml \ 16 | --config-file ../configs/trainers/CoOp/rn50_val.yaml \ 17 | --output-dir ${OUTPUT} \ 18 | --eval-only 19 | done 20 | done 21 | -------------------------------------------------------------------------------- /lpclip/linear_probe.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from sklearn.linear_model import LogisticRegression 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--dataset", type=str, default="", help="path to dataset") 8 | parser.add_argument("--num_step", type=int, default=8, help="number of steps") 9 | parser.add_argument("--num_run", type=int, default=10, help="number of runs") 10 | parser.add_argument("--feature_dir", type=str, default="clip_feat", help="feature dir path") 11 | args = parser.parse_args() 12 | 13 | dataset = args.dataset 14 | dataset_path = os.path.join(f"{args.feature_dir}", dataset) 15 | 16 | train_file = np.load(os.path.join(dataset_path, "train.npz")) 17 | train_feature, train_label = train_file["feature_list"], train_file["label_list"] 18 | val_file = np.load(os.path.join(dataset_path, "val.npz")) 19 | val_feature, val_label = val_file["feature_list"], val_file["label_list"] 20 | test_file = np.load(os.path.join(dataset_path, "test.npz")) 21 | test_feature, test_label = test_file["feature_list"], test_file["label_list"] 22 | 23 | os.makedirs("report", exist_ok=True) 24 | 25 | val_shot_list = {1: 1, 2: 2, 4: 4, 8: 4, 16: 4} 26 | 27 | for num_shot in [1, 2, 4, 8, 16]: 28 | test_acc_step_list = np.zeros([args.num_run, args.num_step]) 29 | for seed in range(1, args.num_run + 1): 30 | np.random.seed(seed) 31 | print(f"-- Seed: {seed} --------------------------------------------------------------") 32 | # Sampling 33 | all_label_list = np.unique(train_label) 34 | selected_idx_list = [] 35 | for label in all_label_list: 36 | label_collection = np.where(train_label == label)[0] 37 | selected_idx = np.random.choice(label_collection, size=num_shot, replace=False) 38 | selected_idx_list.extend(selected_idx) 39 | 40 | fewshot_train_feature = train_feature[selected_idx_list] 41 | fewshot_train_label = train_label[selected_idx_list] 42 | 43 | val_num_shot = val_shot_list[num_shot] 44 | val_selected_idx_list = [] 45 | for label in all_label_list: 46 | label_collection = np.where(val_label == label)[0] 47 | selected_idx = np.random.choice(label_collection, size=val_num_shot, replace=False) 48 | val_selected_idx_list.extend(selected_idx) 49 | 50 | fewshot_val_feature = val_feature[val_selected_idx_list] 51 | fewshot_val_label = val_label[val_selected_idx_list] 52 | 53 | # search initialization 54 | search_list = [1e6, 1e4, 1e2, 1, 1e-2, 1e-4, 1e-6] 55 | acc_list = [] 56 | for c_weight in search_list: 57 | clf = LogisticRegression(solver="lbfgs", max_iter=1000, penalty="l2", C=c_weight).fit(fewshot_train_feature, fewshot_train_label) 58 | pred = clf.predict(fewshot_val_feature) 59 | acc_val = sum(pred == fewshot_val_label) / len(fewshot_val_label) 60 | acc_list.append(acc_val) 61 | 62 | print(acc_list, flush=True) 63 | 64 | # binary search 65 | peak_idx = np.argmax(acc_list) 66 | c_peak = search_list[peak_idx] 67 | c_left, c_right = 1e-1 * c_peak, 1e1 * c_peak 68 | 69 | def binary_search(c_left, c_right, seed, step, test_acc_step_list): 70 | clf_left = LogisticRegression(solver="lbfgs", max_iter=1000, penalty="l2", C=c_left).fit(fewshot_train_feature, fewshot_train_label) 71 | pred_left = clf_left.predict(fewshot_val_feature) 72 | acc_left = sum(pred_left == fewshot_val_label) / len(fewshot_val_label) 73 | print("Val accuracy (Left): {:.2f}".format(100 * acc_left), flush=True) 74 | 75 | clf_right = LogisticRegression(solver="lbfgs", max_iter=1000, penalty="l2", C=c_right).fit(fewshot_train_feature, fewshot_train_label) 76 | pred_right = clf_right.predict(fewshot_val_feature) 77 | acc_right = sum(pred_right == fewshot_val_label) / len(fewshot_val_label) 78 | print("Val accuracy (Right): {:.2f}".format(100 * acc_right), flush=True) 79 | 80 | # find maximum and update ranges 81 | if acc_left < acc_right: 82 | c_final = c_right 83 | clf_final = clf_right 84 | # range for the next step 85 | c_left = 0.5 * (np.log10(c_right) + np.log10(c_left)) 86 | c_right = np.log10(c_right) 87 | else: 88 | c_final = c_left 89 | clf_final = clf_left 90 | # range for the next step 91 | c_right = 0.5 * (np.log10(c_right) + np.log10(c_left)) 92 | c_left = np.log10(c_left) 93 | 94 | pred = clf_final.predict(test_feature) 95 | test_acc = 100 * sum(pred == test_label) / len(pred) 96 | print("Test Accuracy: {:.2f}".format(test_acc), flush=True) 97 | test_acc_step_list[seed - 1, step] = test_acc 98 | 99 | saveline = "{}, seed {}, {} shot, weight {}, test_acc {:.2f}\n".format(dataset, seed, num_shot, c_final, test_acc) 100 | with open( 101 | "./report/{}_s{}r{}_details.txt".format(args.feature_dir, args.num_step, args.num_run), 102 | "a+", 103 | ) as writer: 104 | writer.write(saveline) 105 | return ( 106 | np.power(10, c_left), 107 | np.power(10, c_right), 108 | seed, 109 | step, 110 | test_acc_step_list, 111 | ) 112 | 113 | for step in range(args.num_step): 114 | print( 115 | f"{dataset}, {num_shot} Shot, Round {step}: {c_left}/{c_right}", 116 | flush=True, 117 | ) 118 | c_left, c_right, seed, step, test_acc_step_list = binary_search(c_left, c_right, seed, step, test_acc_step_list) 119 | # save results of last step 120 | test_acc_list = test_acc_step_list[:, -1] 121 | acc_mean = np.mean(test_acc_list) 122 | acc_std = np.std(test_acc_list) 123 | save_line = "{}, {} Shot, Test acc stat: {:.2f} ({:.2f})\n".format(dataset, num_shot, acc_mean, acc_std) 124 | print(save_line, flush=True) 125 | with open( 126 | "./report/{}_s{}r{}.txt".format(args.feature_dir, args.num_step, args.num_run), 127 | "a+", 128 | ) as writer: 129 | writer.write(save_line) 130 | -------------------------------------------------------------------------------- /lpclip/linear_probe.sh: -------------------------------------------------------------------------------- 1 | feature_dir=clip_feat 2 | 3 | for DATASET in OxfordPets 4 | do 5 | python linear_probe.py \ 6 | --dataset ${DATASET} \ 7 | --feature_dir ${feature_dir} \ 8 | --num_step 8 \ 9 | --num_run 3 10 | done 11 | -------------------------------------------------------------------------------- /parse_test_res.py: -------------------------------------------------------------------------------- 1 | """ 2 | Goal 3 | --- 4 | 1. Read test results from log.txt files 5 | 2. Compute mean and std across different folders (seeds) 6 | 7 | Usage 8 | --- 9 | Assume the output files are saved under output/my_experiment, 10 | which contains results of different seeds, e.g., 11 | 12 | my_experiment/ 13 | seed1/ 14 | log.txt 15 | seed2/ 16 | log.txt 17 | seed3/ 18 | log.txt 19 | 20 | Run the following command from the root directory: 21 | 22 | $ python tools/parse_test_res.py output/my_experiment 23 | 24 | Add --ci95 to the argument if you wanna get 95% confidence 25 | interval instead of standard deviation: 26 | 27 | $ python tools/parse_test_res.py output/my_experiment --ci95 28 | 29 | If my_experiment/ has the following structure, 30 | 31 | my_experiment/ 32 | exp-1/ 33 | seed1/ 34 | log.txt 35 | ... 36 | seed2/ 37 | log.txt 38 | ... 39 | seed3/ 40 | log.txt 41 | ... 42 | exp-2/ 43 | ... 44 | exp-3/ 45 | ... 46 | 47 | Run 48 | 49 | $ python tools/parse_test_res.py output/my_experiment --multi-exp 50 | """ 51 | import re 52 | import numpy as np 53 | import os.path as osp 54 | import argparse 55 | from collections import OrderedDict, defaultdict 56 | 57 | from dassl.utils import check_isfile, listdir_nohidden 58 | 59 | 60 | def compute_ci95(res): 61 | return 1.96 * np.std(res) / np.sqrt(len(res)) 62 | 63 | 64 | def parse_function(*metrics, directory="", args=None, end_signal=None): 65 | print(f"Parsing files in {directory}") 66 | subdirs = listdir_nohidden(directory, sort=True) 67 | 68 | outputs = [] 69 | 70 | for subdir in subdirs: 71 | fpath = osp.join(directory, subdir, "log.txt") 72 | assert check_isfile(fpath) 73 | good_to_go = False 74 | output = OrderedDict() 75 | 76 | with open(fpath, "r") as f: 77 | lines = f.readlines() 78 | 79 | for line in lines: 80 | line = line.strip() 81 | if line == end_signal: 82 | good_to_go = True 83 | 84 | for metric in metrics: 85 | match = metric["regex"].search(line) 86 | if match and good_to_go: 87 | if "file" not in output: 88 | output["file"] = fpath 89 | num = float(match.group(1)) 90 | name = metric["name"] 91 | output[name] = num 92 | 93 | if output: 94 | outputs.append(output) 95 | assert len(outputs) > 0, f"Nothing found in {directory}" 96 | 97 | metrics_results = defaultdict(list) 98 | 99 | for output in outputs: 100 | msg = "" 101 | for key, value in output.items(): 102 | if isinstance(value, float): 103 | msg += f"{key}: {value:.2f}%. " 104 | else: 105 | msg += f"{key}: {value}. " 106 | if key != "file": 107 | metrics_results[key].append(value) 108 | print(msg) 109 | 110 | output_results = OrderedDict() 111 | 112 | print("===") 113 | print(f"Summary of directory: {directory}") 114 | for key, values in metrics_results.items(): 115 | avg = np.mean(values) 116 | std = compute_ci95(values) if args.ci95 else np.std(values) 117 | print(f"* {key}: {avg:.2f}% +- {std:.2f}%") 118 | output_results[key] = avg 119 | print("===") 120 | 121 | return output_results 122 | 123 | 124 | def main(args, end_signal): 125 | metric = { 126 | "name": args.keyword, 127 | "regex": re.compile(fr"\* {args.keyword}: ([\.\deE+-]+)%"), 128 | } 129 | 130 | if args.multi_exp: 131 | final_results = defaultdict(list) 132 | 133 | for directory in listdir_nohidden(args.directory, sort=True): 134 | directory = osp.join(args.directory, directory) 135 | results = parse_function( 136 | metric, directory=directory, args=args, end_signal=end_signal 137 | ) 138 | 139 | for key, value in results.items(): 140 | final_results[key].append(value) 141 | 142 | print("Average performance") 143 | for key, values in final_results.items(): 144 | avg = np.mean(values) 145 | print(f"* {key}: {avg:.2f}%") 146 | 147 | else: 148 | parse_function( 149 | metric, directory=args.directory, args=args, end_signal=end_signal 150 | ) 151 | 152 | 153 | if __name__ == "__main__": 154 | parser = argparse.ArgumentParser() 155 | parser.add_argument("directory", type=str, help="path to directory") 156 | parser.add_argument( 157 | "--ci95", action="store_true", help=r"compute 95\% confidence interval" 158 | ) 159 | parser.add_argument("--test-log", action="store_true", help="parse test-only logs") 160 | parser.add_argument( 161 | "--multi-exp", action="store_true", help="parse multiple experiments" 162 | ) 163 | parser.add_argument( 164 | "--keyword", default="accuracy", type=str, help="which keyword to extract" 165 | ) 166 | args = parser.parse_args() 167 | 168 | end_signal = "Finish training" 169 | if args.test_log: 170 | end_signal = "=> result" 171 | 172 | main(args, end_signal) 173 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy 2 | regex 3 | tqdm 4 | -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd .. 4 | 5 | # custom config 6 | DATA=/path/to/datasets 7 | TRAINER=CoOp 8 | SHOTS=16 9 | NCTX=16 10 | CSC=False 11 | CTP=end 12 | 13 | DATASET=$1 14 | CFG=$2 15 | 16 | for SEED in 1 2 3 17 | do 18 | python train.py \ 19 | --root ${DATA} \ 20 | --seed ${SEED} \ 21 | --trainer ${TRAINER} \ 22 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 23 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 24 | --output-dir output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/${DATASET}/seed${SEED} \ 25 | --model-dir output/imagenet/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED} \ 26 | --load-epoch 50 \ 27 | --eval-only \ 28 | TRAINER.COOP.N_CTX ${NCTX} \ 29 | TRAINER.COOP.CSC ${CSC} \ 30 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} 31 | done -------------------------------------------------------------------------------- /scripts/main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd .. 4 | 5 | # custom config 6 | DATA=/path/to/datasets 7 | TRAINER=CoOp 8 | 9 | DATASET=$1 10 | CFG=$2 # config file 11 | CTP=$3 # class token position (end or middle) 12 | NCTX=$4 # number of context tokens 13 | SHOTS=$5 # number of shots (1, 2, 4, 8, 16) 14 | CSC=$6 # class-specific context (False or True) 15 | 16 | for SEED in 1 2 3 17 | do 18 | DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED} 19 | if [ -d "$DIR" ]; then 20 | echo "Oops! The results exist at ${DIR} (so skip this job)" 21 | else 22 | python train.py \ 23 | --root ${DATA} \ 24 | --seed ${SEED} \ 25 | --trainer ${TRAINER} \ 26 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 28 | --output-dir ${DIR} \ 29 | TRAINER.COOP.N_CTX ${NCTX} \ 30 | TRAINER.COOP.CSC ${CSC} \ 31 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ 32 | DATASET.NUM_SHOTS ${SHOTS} 33 | fi 34 | done -------------------------------------------------------------------------------- /scripts/vpt/base2new_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=~/projects/prob-cocoop/data 7 | TRAINER=VPT 8 | 9 | DATASET=$1 10 | SEED=$2 11 | GPUIDS=$3 12 | L=$4 13 | EPOCHS=$5 14 | 15 | CFG=vit_b16_c4_ep10_batch1_ctxv1 16 | SHOTS=16 17 | LOADEP=$5 18 | SUB=new 19 | 20 | 21 | COMMON_DIR=${DATASET}/mcmc_${L}_epochs_${EPOCHS}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 22 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR} 23 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 24 | if [ -d "$DIR" ]; then 25 | echo "Oops! The results exist at ${DIR} (so skip this job)" 26 | else 27 | CUDA_VISIBLE_DEVICES=${GPUIDS} python train.py \ 28 | --root ${DATA} \ 29 | --seed ${SEED} \ 30 | --trainer ${TRAINER} \ 31 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 32 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 33 | --output-dir ${DIR} \ 34 | --model-dir ${MODEL_DIR} \ 35 | --load-epoch ${LOADEP} \ 36 | --eval-only \ 37 | DATASET.NUM_SHOTS ${SHOTS} \ 38 | DATASET.SUBSAMPLE_CLASSES ${SUB} \ 39 | OPTIM.MAX_EPOCH ${EPOCHS} \ 40 | TRAINER.VPT.L ${L} 41 | fi 42 | -------------------------------------------------------------------------------- /scripts/vpt/base2new_test_ablation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=~/projects/prob-cocoop/data 7 | TRAINER=VPT 8 | 9 | DATASET=$1 10 | SEED=$2 11 | GPU=$3 12 | 13 | CFG=vit_b16_c4_ep10_batch1_ctxv1 14 | SHOTS=$4 15 | CTX=$5 16 | 17 | L=$6 18 | EPOCHS=$7 19 | 20 | LOADEP=$7 21 | SUB=new 22 | 23 | 24 | COMMON_DIR=${DATASET}/shots_${SHOTS}_ctx_${CTX}/${TRAINER}/${CFG}/seed${SEED} 25 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR} 26 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 27 | if [ -d "$DIR" ]; then 28 | echo "Oops! The results exist at ${DIR} (so skip this job)" 29 | else 30 | CUDA_VISIBLE_DEVICES=${GPU} python train.py \ 31 | --root ${DATA} \ 32 | --seed ${SEED} \ 33 | --trainer ${TRAINER} \ 34 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 35 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 36 | --output-dir ${DIR} \ 37 | --model-dir ${MODEL_DIR} \ 38 | --load-epoch ${LOADEP} \ 39 | --eval-only \ 40 | DATASET.NUM_SHOTS ${SHOTS} \ 41 | DATASET.SUBSAMPLE_CLASSES ${SUB} \ 42 | TRAINER.COCOOP.N_CTX ${CTX} \ 43 | OPTIM.MAX_EPOCH ${EPOCHS} \ 44 | TRAINER.COCOOP.L ${L} 45 | 46 | fi 47 | 48 | -------------------------------------------------------------------------------- /scripts/vpt/base2new_test_arch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=~/projects/prob-cocoop/data 7 | TRAINER=VPT 8 | 9 | DATASET=$1 10 | SEED=$2 11 | GPU=$3 12 | 13 | CFG=vit_b16_c4_ep10_batch1_ctxv1 14 | SHOTS=$4 15 | CTX=$5 16 | 17 | L=$6 18 | EPOCHS=$7 19 | ARCH=$8 20 | 21 | LOADEP=$7 22 | SUB=new 23 | 24 | 25 | COMMON_DIR=${DATASET}/shots_${SHOTS}_ctx_${CTX}_arch_${ARCH}/${TRAINER}/${CFG}/seed${SEED} 26 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR} 27 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 28 | if [ -d "$DIR" ]; then 29 | echo "Oops! The results exist at ${DIR} (so skip this job)" 30 | else 31 | CUDA_VISIBLE_DEVICES=${GPU} python train.py \ 32 | --root ${DATA} \ 33 | --seed ${SEED} \ 34 | --trainer ${TRAINER} \ 35 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 36 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 37 | --output-dir ${DIR} \ 38 | --model-dir ${MODEL_DIR} \ 39 | --load-epoch ${LOADEP} \ 40 | --eval-only \ 41 | DATASET.NUM_SHOTS ${SHOTS} \ 42 | DATASET.SUBSAMPLE_CLASSES ${SUB} \ 43 | TRAINER.COCOOP.N_CTX ${CTX} \ 44 | OPTIM.MAX_EPOCH ${EPOCHS} \ 45 | TRAINER.COCOOP.L ${L} \ 46 | MODEL.BACKBONE.NAME ${ARCH} 47 | 48 | fi 49 | 50 | -------------------------------------------------------------------------------- /scripts/vpt/base2new_test_imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=~/projects/prob-cocoop/data 7 | TRAINER=VPT 8 | 9 | DATASET=imagenet 10 | SEED=$1 11 | GPUIDS=$2 12 | L=$3 13 | EPOCHS=$4 14 | 15 | CFG=vit_b16_c4_ep10_batch1_ctxv1 16 | SHOTS=16 17 | LOADEP=$4 18 | SUB=new 19 | 20 | 21 | COMMON_DIR=${DATASET}/mcmc_${L}_epochs_${EPOCHS}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 22 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR} 23 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 24 | if [ -d "$DIR" ]; then 25 | echo "Oops! The results exist at ${DIR} (so skip this job)" 26 | else 27 | CUDA_VISIBLE_DEVICES=${GPUIDS} python train.py \ 28 | --root ${DATA} \ 29 | --seed ${SEED} \ 30 | --trainer ${TRAINER} \ 31 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 32 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 33 | --output-dir ${DIR} \ 34 | --model-dir ${MODEL_DIR} \ 35 | --load-epoch ${LOADEP} \ 36 | --eval-only \ 37 | DATASET.NUM_SHOTS ${SHOTS} \ 38 | DATASET.SUBSAMPLE_CLASSES ${SUB} \ 39 | OPTIM.MAX_EPOCH ${EPOCHS} \ 40 | TRAINER.VPT.L ${L} 41 | fi 42 | -------------------------------------------------------------------------------- /scripts/vpt/base2new_test_imagenet_copy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=~/projects/prob-cocoop/data 7 | TRAINER=VPT 8 | 9 | DATASET=imagenet 10 | SEED=$1 11 | GPUIDS=$2 12 | L=$3 13 | EPOCHS=$4 14 | 15 | CFG=vit_b16_c4_ep10_batch1_ctxv1 16 | SHOTS=16 17 | LOADEP=$4 18 | SUB=base 19 | 20 | 21 | COMMON_DIR=${DATASET}/mcmc_${L}_epochs_${EPOCHS}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 22 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR} 23 | DIR=output/base2new/train_${SUB}_eval/${COMMON_DIR} 24 | if [ -d "$DIR" ]; then 25 | echo "Oops! The results exist at ${DIR} (so skip this job)" 26 | else 27 | CUDA_VISIBLE_DEVICES=${GPUIDS} python train.py \ 28 | --root ${DATA} \ 29 | --seed ${SEED} \ 30 | --trainer ${TRAINER} \ 31 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 32 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 33 | --output-dir ${DIR} \ 34 | --model-dir ${MODEL_DIR} \ 35 | --load-epoch ${LOADEP} \ 36 | --eval-only \ 37 | DATASET.NUM_SHOTS ${SHOTS} \ 38 | DATASET.SUBSAMPLE_CLASSES ${SUB} \ 39 | OPTIM.MAX_EPOCH ${EPOCHS} \ 40 | TRAINER.COCOOP.L ${L} 41 | fi 42 | -------------------------------------------------------------------------------- /scripts/vpt/base2new_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=~/projects/prob-cocoop/data 7 | TRAINER=VPT 8 | 9 | DATASET=$1 10 | SEED=$2 11 | GPUIDS=$3 12 | L=$4 13 | EPOCHS=$5 14 | 15 | CFG=vit_b16_c4_ep10_batch1_ctxv1 16 | SHOTS=16 17 | 18 | 19 | 20 | DIR=output/base2new/train_base/${DATASET}/mcmc_${L}_epochs_${EPOCHS}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 21 | if [ -d "$DIR" ]; then 22 | echo "Oops! The results exist at ${DIR} (so skip this job)" 23 | else 24 | CUDA_VISIBLE_DEVICES=${GPUIDS} python train.py \ 25 | --root ${DATA} \ 26 | --seed ${SEED} \ 27 | --trainer ${TRAINER} \ 28 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 29 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 30 | --output-dir ${DIR} \ 31 | DATASET.NUM_SHOTS ${SHOTS} \ 32 | DATASET.SUBSAMPLE_CLASSES base \ 33 | OPTIM.MAX_EPOCH ${EPOCHS} \ 34 | TRAINER.VPT.L ${L} 35 | fi 36 | -------------------------------------------------------------------------------- /scripts/vpt/base2new_train_ablation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=~/projects/prob-cocoop/data 7 | TRAINER=VPT 8 | 9 | DATASET=$1 10 | SEED=$2 11 | GPU=$3 12 | 13 | CFG=vit_b16_c4_ep10_batch1_ctxv1 14 | SHOTS=$4 15 | CTX=$5 16 | L=$6 17 | EPOCHS=$7 18 | 19 | 20 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}_ctx_${CTX}/${TRAINER}/${CFG}/seed${SEED} 21 | if [ -d "$DIR" ]; then 22 | echo "Oops! The results exist at ${DIR} (so skip this job)" 23 | else 24 | CUDA_VISIBLE_DEVICES=$GPU python train.py \ 25 | --root ${DATA} \ 26 | --seed ${SEED} \ 27 | --trainer ${TRAINER} \ 28 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 29 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 30 | --output-dir ${DIR} \ 31 | DATASET.NUM_SHOTS ${SHOTS} \ 32 | DATASET.SUBSAMPLE_CLASSES base \ 33 | TRAINER.COCOOP.N_CTX ${CTX} \ 34 | OPTIM.MAX_EPOCH ${EPOCHS} \ 35 | TRAINER.COCOOP.L ${L} 36 | fi 37 | -------------------------------------------------------------------------------- /scripts/vpt/base2new_train_arch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=~/projects/prob-cocoop/data 7 | TRAINER=VPT 8 | 9 | DATASET=$1 10 | SEED=$2 11 | GPU=$3 12 | 13 | CFG=vit_b16_c4_ep10_batch1_ctxv1 14 | SHOTS=$4 15 | CTX=$5 16 | L=$6 17 | EPOCHS=$7 18 | ARCH=$8 19 | 20 | 21 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}_ctx_${CTX}_arch_${ARCH}/${TRAINER}/${CFG}/seed${SEED} 22 | if [ -d "$DIR" ]; then 23 | echo "Oops! The results exist at ${DIR} (so skip this job)" 24 | else 25 | CUDA_VISIBLE_DEVICES=$GPU python train.py \ 26 | --root ${DATA} \ 27 | --seed ${SEED} \ 28 | --trainer ${TRAINER} \ 29 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 30 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 31 | --output-dir ${DIR} \ 32 | DATASET.NUM_SHOTS ${SHOTS} \ 33 | DATASET.SUBSAMPLE_CLASSES base \ 34 | TRAINER.COCOOP.N_CTX ${CTX} \ 35 | OPTIM.MAX_EPOCH ${EPOCHS} \ 36 | TRAINER.COCOOP.L ${L} \ 37 | MODEL.BACKBONE.NAME ${ARCH} 38 | fi 39 | -------------------------------------------------------------------------------- /scripts/vpt/base2new_train_imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=~/projects/prob-cocoop/data 7 | TRAINER=VPT 8 | 9 | DATASET=imagenet 10 | SEED=$1 11 | GPUIDS=$2 12 | L=$3 13 | EPOCHS=$4 14 | 15 | CFG=vit_b16_c4_ep10_batch1_ctxv1 16 | SHOTS=16 17 | 18 | 19 | 20 | DIR=output/base2new/train_base/${DATASET}/mcmc_${L}_epochs_${EPOCHS}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 21 | if [ -d "$DIR" ]; then 22 | echo "Oops! The results exist at ${DIR} (so skip this job)" 23 | else 24 | CUDA_VISIBLE_DEVICES=${GPUIDS} python train.py \ 25 | --root ${DATA} \ 26 | --seed ${SEED} \ 27 | --trainer ${TRAINER} \ 28 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 29 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 30 | --output-dir ${DIR} \ 31 | DATASET.NUM_SHOTS ${SHOTS} \ 32 | DATASET.SUBSAMPLE_CLASSES base \ 33 | OPTIM.MAX_EPOCH ${EPOCHS} \ 34 | TRAINER.VPT.L ${L} 35 | fi 36 | -------------------------------------------------------------------------------- /scripts/vpt/ctx_ablations.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | datasets = [ 5 | "caltech101", 6 | "dtd", 7 | "eurosat", 8 | "fgvc_aircraft", 9 | "food101", 10 | "oxford_flowers", 11 | "oxford_pets", 12 | "stanford_cars", 13 | "sun397", 14 | "ucf101", 15 | ] 16 | mc = [20, 40, 20, 10, 20, 10, 40, 20, 10, 5] 17 | epochs = [20, 10, 60, 10, 20, 40, 20, 40, 10, 20] 18 | seeds = [1] 19 | shots = [16] 20 | ctxs = [1, 2, 4, 8] 21 | GPOUIDS = "0,1,2,3,4,5,6" 22 | 23 | for dataset, l, epoch in zip(datasets, mc, epochs): 24 | for seed in seeds: 25 | for shot in shots: 26 | for ctx in ctxs: 27 | os.system( 28 | f"bash base2new_train_ablation.sh {dataset} {seed} {GPOUIDS} {shot} {ctx} {l} {epoch}" 29 | ) 30 | os.system( 31 | f"bash base2new_test_ablation.sh {dataset} {seed} {GPOUIDS} {shot} {ctx} {l} {epoch}" 32 | ) 33 | -------------------------------------------------------------------------------- /scripts/vpt/eval_cross_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--gpuids", default="0,1,2,3,4,5,6,7", help="GPU ids to train model on") 8 | parser.add_argument("--l", help="Number of monte carlo samples") 9 | parser.add_argument("--epochs", help="Number of training epochs") 10 | parser.add_argument("--dname", help="dataset name") 11 | args = parser.parse_args() 12 | 13 | for seed in [1, 2, 3]: 14 | os.system(f"bash xd_test.sh {args.dname} {seed} {args.l} {args.epochs} {args.gpuids}") 15 | -------------------------------------------------------------------------------- /scripts/vpt/shots_ablation.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | datasets = ["caltech101", "dtd", "eurosat", "fgvc_aircraft", "food101", "oxford_flowers", "oxford_pets", "stanford_cars", "sun397", "ucf101"] 5 | mc = [20,40,20,10,20,10,40,20,10,5] 6 | epochs = [20,10,60,10,20,40,20,40,10,20] 7 | 8 | 9 | seeds = [1] 10 | shots = [1,4,8] 11 | GPOUIDS = "0,1,2,3,4,5,6,7" 12 | 13 | for dataset, l, epoch in zip(datasets, mc, epochs): 14 | for seed in seeds: 15 | for shot in shots: 16 | os.system(f"bash base2new_train_ablation.sh {dataset} {seed} {GPOUIDS} {shot} 4 {l} {epoch}") 17 | os.system(f"bash base2new_test_ablation.sh {dataset} {seed} {GPOUIDS} {shot} 4 {l} {epoch}") 18 | -------------------------------------------------------------------------------- /scripts/vpt/train_cross_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--gpuids", default="0,1,2,3,4,5,6,7", help="GPU ids to train model on") 8 | parser.add_argument("--l", help="Number of monte carlo samples") 9 | parser.add_argument("--epochs", help="Number of training epochs") 10 | args = parser.parse_args() 11 | 12 | for seed in [1, 2, 3]: 13 | os.system(f"bash xd_train.sh {seed} {args.l} {args.epochs}") 14 | -------------------------------------------------------------------------------- /scripts/vpt/train_eval_caltech.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--gpuids", default="0,1,2,3,4,5,6,7", help="GPU ids to train model on") 8 | parser.add_argument( 9 | "--l", help="Number of monte carlo samples" 10 | ) 11 | parser.add_argument( 12 | "--epochs", help="Number of training epochs" 13 | ) 14 | args = parser.parse_args() 15 | 16 | for seed in [1, 2, 3]: 17 | os.system(f"bash base2new_train.sh caltech101 {seed} {args.gpuids} {args.l} {args.epochs}") 18 | os.system(f"bash base2new_test.sh caltech101 {seed} {args.gpuids} {args.l} {args.epochs}") 19 | -------------------------------------------------------------------------------- /scripts/vpt/train_eval_cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--gpuids", default="0,1,2,3,4,5,6,7", help="GPU ids to train model on") 8 | parser.add_argument( 9 | "--l", help="Number of monte carlo samples" 10 | ) 11 | parser.add_argument( 12 | "--epochs", help="Number of training epochs" 13 | ) 14 | args = parser.parse_args() 15 | 16 | for seed in [1, 2, 3]: 17 | os.system(f"bash base2new_train.sh stanford_cars {seed} {args.gpuids} {args.l} {args.epochs}") 18 | os.system(f"bash base2new_test.sh stanford_cars {seed} {args.gpuids} {args.l} {args.epochs}") 19 | -------------------------------------------------------------------------------- /scripts/vpt/train_eval_dtd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--gpuids", default="0,1,2,3,4,5,6,7", help="GPU ids to train model on") 8 | parser.add_argument( 9 | "--l", help="Number of monte carlo samples" 10 | ) 11 | parser.add_argument( 12 | "--epochs", help="Number of training epochs" 13 | ) 14 | args = parser.parse_args() 15 | 16 | for seed in [1, 2, 3]: 17 | os.system(f"bash base2new_train.sh dtd {seed} {args.gpuids} {args.l} {args.epochs}") 18 | os.system(f"bash base2new_test.sh dtd {seed} {args.gpuids} {args.l} {args.epochs}") 19 | -------------------------------------------------------------------------------- /scripts/vpt/train_eval_eurosat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--gpuids", default="0,1,2,3,4,5,6,7", help="GPU ids to train model on") 8 | parser.add_argument( 9 | "--l", help="Number of monte carlo samples" 10 | ) 11 | parser.add_argument( 12 | "--epochs", help="Number of training epochs" 13 | ) 14 | args = parser.parse_args() 15 | 16 | for seed in [1, 2, 3]: 17 | os.system(f"bash base2new_train.sh eurosat {seed} {args.gpuids} {args.l} {args.epochs}") 18 | os.system(f"bash base2new_test.sh eurosat {seed} {args.gpuids} {args.l} {args.epochs}") 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /scripts/vpt/train_eval_fgvc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--gpuids", default="0,1,2,3,4,5,6,7", help="GPU ids to train model on") 8 | parser.add_argument( 9 | "--l", help="Number of monte carlo samples" 10 | ) 11 | parser.add_argument( 12 | "--epochs", help="Number of training epochs" 13 | ) 14 | args = parser.parse_args() 15 | 16 | for seed in [1, 2, 3]: 17 | os.system(f"bash base2new_train.sh fgvc_aircraft {seed} {args.gpuids} {args.l} {args.epochs}") 18 | os.system(f"bash base2new_test.sh fgvc_aircraft {seed} {args.gpuids} {args.l} {args.epochs}") 19 | -------------------------------------------------------------------------------- /scripts/vpt/train_eval_flowers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--gpuids", default="0,1,2,3,4,5,6,7", help="GPU ids to train model on") 8 | parser.add_argument( 9 | "--l", help="Number of monte carlo samples" 10 | ) 11 | parser.add_argument( 12 | "--epochs", help="Number of training epochs" 13 | ) 14 | args = parser.parse_args() 15 | 16 | for seed in [1, 2, 3]: 17 | os.system(f"bash base2new_train.sh oxford_flowers {seed} {args.gpuids} {args.l} {args.epochs}") 18 | os.system(f"bash base2new_test.sh oxford_flowers {seed} {args.gpuids} {args.l} {args.epochs}") 19 | 20 | -------------------------------------------------------------------------------- /scripts/vpt/train_eval_food101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | if __name__ == "__main__": 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--gpuids", default="0,1,2,3,4,5,6,7", help="GPU ids to train model on") 7 | parser.add_argument( 8 | "--l", help="Number of monte carlo samples" 9 | ) 10 | parser.add_argument( 11 | "--epochs", help="Number of training epochs" 12 | ) 13 | args = parser.parse_args() 14 | 15 | for seed in [1, 2, 3]: 16 | os.system(f"bash base2new_train.sh food101 {seed} {args.gpuids} {args.l} {args.epochs}") 17 | os.system(f"bash base2new_test.sh food101 {seed} {args.gpuids} {args.l} {args.epochs}") 18 | -------------------------------------------------------------------------------- /scripts/vpt/train_eval_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | if __name__ == "__main__": 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--gpuids", default="0,1,2,3,4,5,6,7", help="GPU ids to train model on") 7 | parser.add_argument( 8 | "--l", help="Number of monte carlo samples" 9 | ) 10 | parser.add_argument( 11 | "--epochs", help="Number of training epochs" 12 | ) 13 | args = parser.parse_args() 14 | 15 | for seed in [1, 2, 3]: 16 | os.system(f"bash base2new_train_imagenet.sh {seed} {args.gpuids} {args.l} {args.epoch}") 17 | os.system(f"bash base2new_test_imagenet.sh {seed} {args.gpuids} {args.l} {args.epoch}") -------------------------------------------------------------------------------- /scripts/vpt/train_eval_pets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--gpuids", default="0,1,2,3,4,5,6,7", help="GPU ids to train model on") 8 | parser.add_argument( 9 | "--l", help="Number of monte carlo samples" 10 | ) 11 | parser.add_argument( 12 | "--epochs", help="Number of training epochs" 13 | ) 14 | args = parser.parse_args() 15 | 16 | for seed in [1, 2, 3]: 17 | os.system(f"bash base2new_train.sh oxford_pets {seed} {args.gpuids} {args.l} {args.epochs}") 18 | os.system(f"bash base2new_test.sh oxford_pets {seed} {args.gpuids} {args.l} {args.epochs}") 19 | 20 | -------------------------------------------------------------------------------- /scripts/vpt/train_eval_sun.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | if __name__ == "__main__": 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--gpuids", default="0,1,2,3,4,5,6,7", help="GPU ids to train model on") 7 | parser.add_argument( 8 | "--l", help="Number of monte carlo samples" 9 | ) 10 | parser.add_argument( 11 | "--epochs", help="Number of training epochs" 12 | ) 13 | args = parser.parse_args() 14 | 15 | for seed in [1, 2, 3]: 16 | os.system(f"bash base2new_train.sh sun397 {seed} {args.gpuids} {args.l} {args.epochs}") 17 | os.system(f"bash base2new_test.sh sun397 {seed} {args.gpuids} {args.l} {args.epochs}") 18 | -------------------------------------------------------------------------------- /scripts/vpt/train_eval_ucf101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--gpuids", default="0,1,2,3,4,5,6,7", help="GPU ids to train model on") 8 | parser.add_argument( 9 | "--l", help="Number of monte carlo samples" 10 | ) 11 | parser.add_argument( 12 | "--epochs", help="Number of training epochs" 13 | ) 14 | args = parser.parse_args() 15 | 16 | for seed in [1, 2, 3]: 17 | os.system(f"bash base2new_train.sh ucf101 {seed} {args.gpuids} {args.l} {args.epochs}") 18 | os.system(f"bash base2new_test.sh ucf101 {seed} {args.gpuids} {args.l} {args.epochs}") 19 | -------------------------------------------------------------------------------- /scripts/vpt/vision_backbone_ablations.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | datasets = ["caltech101", "dtd", "eurosat", "fgvc_aircraft", "food101", "oxford_flowers", "oxford_pets", "stanford_cars", "sun397", "ucf101"] 5 | mc = [20,40,20,10,20,10,40,20,10,5] 6 | epochs = [20,10,60,10,20,40,20,40,10,20] 7 | seeds = [1] 8 | shots = [16] 9 | GPOUIDS = "0,1,2,3,4,5,6,7" 10 | backbones = ['RN50', 'RN101'] 11 | 12 | for dataset, l, epoch in zip(datasets, mc, epochs): 13 | for backbone in backbones: 14 | os.system(f"bash base2new_train_arch.sh {dataset} 1 {GPOUIDS} 16 4 {l} {epoch} {backbone}") 15 | os.system(f"bash base2new_test_arch.sh {dataset} 1 {GPOUIDS} 16 4 {l} {epoch} {backbone}") 16 | -------------------------------------------------------------------------------- /scripts/vpt/xd_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=~/projects/prob-cocoop/data 7 | TRAINER=VPT 8 | 9 | DATASET=$1 10 | SEED=$2 11 | L=$3 12 | EPOCHS=$4 13 | LOADEP=$4 14 | GPUIDS=$5 15 | 16 | CFG=vit_b16_c4_ep10_batch1_ctxv1 17 | SHOTS=16 18 | 19 | 20 | DIR=output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/${DATASET}/seed${SEED} 21 | if [ -d "$DIR" ]; then 22 | echo "Oops! The results exist at ${DIR} (so skip this job)" 23 | else 24 | CUDA_VISIBLE_DEVICES=${GPUIDS} python train.py \ 25 | --root ${DATA} \ 26 | --seed ${SEED} \ 27 | --trainer ${TRAINER} \ 28 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 29 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 30 | --output-dir ${DIR} \ 31 | --model-dir output/imagenet/mcmc_${L}_epochs_${EPOCHS}/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} \ 32 | --load-epoch ${LOADEP} \ 33 | --eval-only \ 34 | OPTIM.MAX_EPOCH ${EPOCHS} \ 35 | TRAINER.VPT.L ${L} 36 | fi 37 | -------------------------------------------------------------------------------- /scripts/vpt/xd_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA=~/projects/prob-cocoop/data 7 | TRAINER=VPT 8 | 9 | DATASET=imagenet 10 | SEED=$1 11 | L=$2 12 | EPOCHS=$3 13 | 14 | CFG=vit_b16_c4_ep10_batch1_ctxv1 15 | SHOTS=16 16 | 17 | 18 | DIR=output/${DATASET}/mcmc_${L}_epochs_${EPOCHS}/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} 19 | if [ -d "$DIR" ]; then 20 | echo "Oops! The results exist at ${DIR} (so skip this job)" 21 | else 22 | python train.py \ 23 | --root ${DATA} \ 24 | --seed ${SEED} \ 25 | --trainer ${TRAINER} \ 26 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 28 | --output-dir ${DIR} \ 29 | DATASET.NUM_SHOTS ${SHOTS} \ 30 | OPTIM.MAX_EPOCH ${EPOCHS} \ 31 | TRAINER.VPT.L ${L} 32 | fi 33 | -------------------------------------------------------------------------------- /scripts/zeroshot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd .. 4 | 5 | # custom config 6 | DATA=/path/to/datasets 7 | TRAINER=ZeroshotCLIP 8 | DATASET=$1 9 | CFG=$2 # rn50, rn101, vit_b32 or vit_b16 10 | 11 | python train.py \ 12 | --root ${DATA} \ 13 | --trainer ${TRAINER} \ 14 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 15 | --config-file configs/trainers/CoOp/${CFG}.yaml \ 16 | --output-dir output/${TRAINER}/${CFG}/${DATASET} \ 17 | --eval-only -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from dassl.utils import setup_logger, set_random_seed, collect_env_info 5 | from dassl.config import get_cfg_default 6 | from dassl.engine import build_trainer 7 | 8 | # custom 9 | import datasets.oxford_pets 10 | import datasets.oxford_flowers 11 | import datasets.fgvc_aircraft 12 | import datasets.dtd 13 | import datasets.eurosat 14 | import datasets.stanford_cars 15 | import datasets.food101 16 | import datasets.sun397 17 | import datasets.caltech101 18 | import datasets.ucf101 19 | import datasets.imagenet 20 | 21 | import datasets.imagenet_sketch 22 | import datasets.imagenetv2 23 | import datasets.imagenet_a 24 | import datasets.imagenet_r 25 | 26 | import trainers.coop 27 | import trainers.vpt 28 | import trainers.zsclip 29 | 30 | import pdb 31 | 32 | 33 | def print_args(args, cfg): 34 | print("***************") 35 | print("** Arguments **") 36 | print("***************") 37 | optkeys = list(args.__dict__.keys()) 38 | optkeys.sort() 39 | for key in optkeys: 40 | print("{}: {}".format(key, args.__dict__[key])) 41 | print("************") 42 | print("** Config **") 43 | print("************") 44 | print(cfg) 45 | 46 | 47 | def reset_cfg(cfg, args): 48 | if args.root: 49 | cfg.DATASET.ROOT = args.root 50 | 51 | if args.output_dir: 52 | cfg.OUTPUT_DIR = args.output_dir 53 | 54 | if args.resume: 55 | cfg.RESUME = args.resume 56 | 57 | if args.seed: 58 | cfg.SEED = args.seed 59 | 60 | if args.source_domains: 61 | cfg.DATASET.SOURCE_DOMAINS = args.source_domains 62 | 63 | if args.target_domains: 64 | cfg.DATASET.TARGET_DOMAINS = args.target_domains 65 | 66 | if args.transforms: 67 | cfg.INPUT.TRANSFORMS = args.transforms 68 | 69 | if args.trainer: 70 | cfg.TRAINER.NAME = args.trainer 71 | 72 | if args.backbone: 73 | cfg.MODEL.BACKBONE.NAME = args.backbone 74 | 75 | if args.head: 76 | cfg.MODEL.HEAD.NAME = args.head 77 | 78 | 79 | def extend_cfg(cfg): 80 | """ 81 | Add new config variables. 82 | 83 | E.g. 84 | from yacs.config import CfgNode as CN 85 | cfg.TRAINER.MY_MODEL = CN() 86 | cfg.TRAINER.MY_MODEL.PARAM_A = 1. 87 | cfg.TRAINER.MY_MODEL.PARAM_B = 0.5 88 | cfg.TRAINER.MY_MODEL.PARAM_C = False 89 | """ 90 | from yacs.config import CfgNode as CN 91 | 92 | cfg.TRAINER.COOP = CN() 93 | cfg.TRAINER.COOP.N_CTX = 16 # number of context vectors 94 | cfg.TRAINER.COOP.CSC = False # class-specific context 95 | cfg.TRAINER.COOP.CTX_INIT = "" # initialization words 96 | cfg.TRAINER.COOP.PREC = "fp16" # fp16, fp32, amp 97 | cfg.TRAINER.COOP.CLASS_TOKEN_POSITION = "end" # 'middle' or 'end' or 'front' 98 | 99 | cfg.TRAINER.VPT = CN() 100 | cfg.TRAINER.VPT.N_CTX = 16 # number of context vectors 101 | cfg.TRAINER.VPT.L = 10 # number of monte carlo samples 102 | cfg.TRAINER.VPT.CTX_INIT = "" # initialization words 103 | cfg.TRAINER.VPT.PREC = "fp16" # fp16, fp32, amp 104 | cfg.TRAINER.VPT.VPT_TYPE = "cocoopvpt" 105 | 106 | cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new 107 | 108 | 109 | def setup_cfg(args): 110 | cfg = get_cfg_default() 111 | extend_cfg(cfg) 112 | 113 | # 1. From the dataset config file 114 | if args.dataset_config_file: 115 | cfg.merge_from_file(args.dataset_config_file) 116 | 117 | # 2. From the method config file 118 | if args.config_file: 119 | cfg.merge_from_file(args.config_file) 120 | 121 | # 3. From input arguments 122 | reset_cfg(cfg, args) 123 | 124 | # 4. From optional input arguments 125 | cfg.merge_from_list(args.opts) 126 | 127 | cfg.freeze() 128 | 129 | return cfg 130 | 131 | 132 | def main(args): 133 | import pdb 134 | 135 | cfg = setup_cfg(args) 136 | if cfg.SEED >= 0: 137 | print("Setting fixed seed: {}".format(cfg.SEED)) 138 | set_random_seed(cfg.SEED) 139 | setup_logger(cfg.OUTPUT_DIR) 140 | 141 | if torch.cuda.is_available() and cfg.USE_CUDA: 142 | torch.backends.cudnn.benchmark = True 143 | 144 | print_args(args, cfg) 145 | print("Collecting env info ...") 146 | print("** System info **\n{}\n".format(collect_env_info())) 147 | 148 | trainer = build_trainer(cfg) 149 | pdb.set_trace() 150 | 151 | if args.eval_only: 152 | trainer.load_model(args.model_dir, epoch=args.load_epoch) 153 | trainer.test() 154 | return 155 | 156 | if not args.no_train: 157 | trainer.train() 158 | 159 | 160 | if __name__ == "__main__": 161 | parser = argparse.ArgumentParser() 162 | parser.add_argument("--root", type=str, default="", help="path to dataset") 163 | parser.add_argument("--output-dir", type=str, default="", help="output directory") 164 | parser.add_argument( 165 | "--resume", 166 | type=str, 167 | default="", 168 | help="checkpoint directory (from which the training resumes)", 169 | ) 170 | parser.add_argument( 171 | "--seed", type=int, default=-1, help="only positive value enables a fixed seed" 172 | ) 173 | parser.add_argument("--source-domains", type=str, nargs="+", help="source domains for DA/DG") 174 | parser.add_argument("--target-domains", type=str, nargs="+", help="target domains for DA/DG") 175 | parser.add_argument("--transforms", type=str, nargs="+", help="data augmentation methods") 176 | parser.add_argument("--config-file", type=str, default="", help="path to config file") 177 | parser.add_argument( 178 | "--dataset-config-file", 179 | type=str, 180 | default="", 181 | help="path to config file for dataset setup", 182 | ) 183 | parser.add_argument("--trainer", type=str, default="", help="name of trainer") 184 | parser.add_argument("--backbone", type=str, default="", help="name of CNN backbone") 185 | parser.add_argument("--head", type=str, default="", help="name of head") 186 | parser.add_argument("--eval-only", action="store_true", help="evaluation only") 187 | parser.add_argument( 188 | "--model-dir", 189 | type=str, 190 | default="", 191 | help="load model from this directory for eval-only mode", 192 | ) 193 | parser.add_argument( 194 | "--load-epoch", type=int, help="load model weights at this epoch for evaluation" 195 | ) 196 | parser.add_argument("--no-train", action="store_true", help="do not call trainer.train()") 197 | parser.add_argument( 198 | "opts", 199 | default=None, 200 | nargs=argparse.REMAINDER, 201 | help="modify config options using the command-line", 202 | ) 203 | args = parser.parse_args() 204 | main(args) 205 | -------------------------------------------------------------------------------- /trainers/coop.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from torch.cuda.amp import GradScaler, autocast 7 | 8 | from dassl.engine import TRAINER_REGISTRY, TrainerX 9 | from dassl.metrics import compute_accuracy 10 | from dassl.utils import load_pretrained_weights, load_checkpoint 11 | from dassl.optim import build_optimizer, build_lr_scheduler 12 | 13 | from clip import clip 14 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | _tokenizer = _Tokenizer() 17 | 18 | 19 | def load_clip_to_cpu(cfg): 20 | backbone_name = cfg.MODEL.BACKBONE.NAME 21 | url = clip._MODELS[backbone_name] 22 | model_path = clip._download(url) 23 | 24 | try: 25 | # loading JIT archive 26 | model = torch.jit.load(model_path, map_location="cpu").eval() 27 | state_dict = None 28 | 29 | except RuntimeError: 30 | state_dict = torch.load(model_path, map_location="cpu") 31 | 32 | model = clip.build_model(state_dict or model.state_dict()) 33 | 34 | return model 35 | 36 | 37 | class TextEncoder(nn.Module): 38 | def __init__(self, clip_model): 39 | super().__init__() 40 | self.transformer = clip_model.transformer 41 | self.positional_embedding = clip_model.positional_embedding 42 | self.ln_final = clip_model.ln_final 43 | self.text_projection = clip_model.text_projection 44 | self.dtype = clip_model.dtype 45 | 46 | def forward(self, prompts, tokenized_prompts): 47 | x = prompts + self.positional_embedding.type(self.dtype) 48 | x = x.permute(1, 0, 2) # NLD -> LND 49 | x = self.transformer(x) 50 | x = x.permute(1, 0, 2) # LND -> NLD 51 | x = self.ln_final(x).type(self.dtype) 52 | 53 | # x.shape = [batch_size, n_ctx, transformer.width] 54 | # take features from the eot embedding (eot_token is the highest number in each sequence) 55 | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection 56 | 57 | return x 58 | 59 | 60 | class PromptLearner(nn.Module): 61 | def __init__(self, cfg, classnames, clip_model): 62 | super().__init__() 63 | n_cls = len(classnames) 64 | n_ctx = cfg.TRAINER.COOP.N_CTX 65 | ctx_init = cfg.TRAINER.COOP.CTX_INIT 66 | dtype = clip_model.dtype 67 | ctx_dim = clip_model.ln_final.weight.shape[0] 68 | clip_imsize = clip_model.visual.input_resolution 69 | cfg_imsize = cfg.INPUT.SIZE[0] 70 | assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})" 71 | 72 | if ctx_init: 73 | # use given words to initialize context vectors 74 | ctx_init = ctx_init.replace("_", " ") 75 | n_ctx = len(ctx_init.split(" ")) 76 | prompt = clip.tokenize(ctx_init) 77 | with torch.no_grad(): 78 | embedding = clip_model.token_embedding(prompt).type(dtype) 79 | ctx_vectors = embedding[0, 1 : 1 + n_ctx, :] 80 | prompt_prefix = ctx_init 81 | 82 | else: 83 | # random initialization 84 | if cfg.TRAINER.COOP.CSC: 85 | print("Initializing class-specific contexts") 86 | ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype) 87 | else: 88 | print("Initializing a generic context") 89 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) 90 | nn.init.normal_(ctx_vectors, std=0.02) 91 | prompt_prefix = " ".join(["X"] * n_ctx) 92 | 93 | print(f'Initial context: "{prompt_prefix}"') 94 | print(f"Number of context words (tokens): {n_ctx}") 95 | 96 | self.ctx = nn.Parameter(ctx_vectors) # to be optimized 97 | 98 | classnames = [name.replace("_", " ") for name in classnames] 99 | name_lens = [len(_tokenizer.encode(name)) for name in classnames] 100 | prompts = [prompt_prefix + " " + name + "." for name in classnames] 101 | 102 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) 103 | with torch.no_grad(): 104 | embedding = clip_model.token_embedding(tokenized_prompts).type(dtype) 105 | 106 | # These token vectors will be saved when in save_model(), 107 | # but they should be ignored in load_model() as we want to use 108 | # those computed using the current class names 109 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS 110 | self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS 111 | 112 | self.n_cls = n_cls 113 | self.n_ctx = n_ctx 114 | self.tokenized_prompts = tokenized_prompts # torch.Tensor 115 | self.name_lens = name_lens 116 | self.class_token_position = cfg.TRAINER.COOP.CLASS_TOKEN_POSITION 117 | 118 | def forward(self): 119 | ctx = self.ctx 120 | if ctx.dim() == 2: 121 | ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) 122 | 123 | prefix = self.token_prefix 124 | suffix = self.token_suffix 125 | 126 | if self.class_token_position == "end": 127 | prompts = torch.cat( 128 | [ 129 | prefix, # (n_cls, 1, dim) 130 | ctx, # (n_cls, n_ctx, dim) 131 | suffix, # (n_cls, *, dim) 132 | ], 133 | dim=1, 134 | ) 135 | 136 | elif self.class_token_position == "middle": 137 | half_n_ctx = self.n_ctx // 2 138 | prompts = [] 139 | for i in range(self.n_cls): 140 | name_len = self.name_lens[i] 141 | prefix_i = prefix[i : i + 1, :, :] 142 | class_i = suffix[i : i + 1, :name_len, :] 143 | suffix_i = suffix[i : i + 1, name_len:, :] 144 | ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :] 145 | ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :] 146 | prompt = torch.cat( 147 | [ 148 | prefix_i, # (1, 1, dim) 149 | ctx_i_half1, # (1, n_ctx//2, dim) 150 | class_i, # (1, name_len, dim) 151 | ctx_i_half2, # (1, n_ctx//2, dim) 152 | suffix_i, # (1, *, dim) 153 | ], 154 | dim=1, 155 | ) 156 | prompts.append(prompt) 157 | prompts = torch.cat(prompts, dim=0) 158 | 159 | elif self.class_token_position == "front": 160 | prompts = [] 161 | for i in range(self.n_cls): 162 | name_len = self.name_lens[i] 163 | prefix_i = prefix[i : i + 1, :, :] 164 | class_i = suffix[i : i + 1, :name_len, :] 165 | suffix_i = suffix[i : i + 1, name_len:, :] 166 | ctx_i = ctx[i : i + 1, :, :] 167 | prompt = torch.cat( 168 | [ 169 | prefix_i, # (1, 1, dim) 170 | class_i, # (1, name_len, dim) 171 | ctx_i, # (1, n_ctx, dim) 172 | suffix_i, # (1, *, dim) 173 | ], 174 | dim=1, 175 | ) 176 | prompts.append(prompt) 177 | prompts = torch.cat(prompts, dim=0) 178 | 179 | else: 180 | raise ValueError 181 | 182 | return prompts 183 | 184 | 185 | class CustomCLIP(nn.Module): 186 | def __init__(self, cfg, classnames, clip_model): 187 | super().__init__() 188 | self.prompt_learner = PromptLearner(cfg, classnames, clip_model) 189 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts 190 | self.image_encoder = clip_model.visual 191 | self.text_encoder = TextEncoder(clip_model) 192 | self.logit_scale = clip_model.logit_scale 193 | self.dtype = clip_model.dtype 194 | 195 | def forward(self, image): 196 | image_features = self.image_encoder(image.type(self.dtype)) 197 | 198 | prompts = self.prompt_learner() 199 | tokenized_prompts = self.tokenized_prompts 200 | text_features = self.text_encoder(prompts, tokenized_prompts) 201 | 202 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 203 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 204 | 205 | logit_scale = self.logit_scale.exp() 206 | logits = logit_scale * image_features @ text_features.t() 207 | 208 | return logits 209 | 210 | 211 | @TRAINER_REGISTRY.register() 212 | class CoOp(TrainerX): 213 | """Context Optimization (CoOp). 214 | 215 | Learning to Prompt for Vision-Language Models 216 | https://arxiv.org/abs/2109.01134 217 | """ 218 | 219 | def check_cfg(self, cfg): 220 | assert cfg.TRAINER.COOP.PREC in ["fp16", "fp32", "amp"] 221 | 222 | def build_model(self): 223 | cfg = self.cfg 224 | classnames = self.dm.dataset.classnames 225 | 226 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 227 | clip_model = load_clip_to_cpu(cfg) 228 | 229 | if cfg.TRAINER.COOP.PREC == "fp32" or cfg.TRAINER.COOP.PREC == "amp": 230 | # CLIP's default precision is fp16 231 | clip_model.float() 232 | 233 | print("Building custom CLIP") 234 | self.model = CustomCLIP(cfg, classnames, clip_model) 235 | 236 | print("Turning off gradients in both the image and the text encoder") 237 | for name, param in self.model.named_parameters(): 238 | if "prompt_learner" not in name: 239 | param.requires_grad_(False) 240 | 241 | if cfg.MODEL.INIT_WEIGHTS: 242 | load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS) 243 | 244 | self.model.to(self.device) 245 | # NOTE: only give prompt_learner to the optimizer 246 | self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM) 247 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) 248 | self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched) 249 | 250 | self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == "amp" else None 251 | 252 | # Note that multi-gpu training could be slow because CLIP's size is 253 | # big, which slows down the copy operation in DataParallel 254 | device_count = torch.cuda.device_count() 255 | if device_count > 1: 256 | print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!") 257 | self.model = nn.DataParallel(self.model) 258 | 259 | def forward_backward(self, batch): 260 | image, label = self.parse_batch_train(batch) 261 | 262 | prec = self.cfg.TRAINER.COOP.PREC 263 | if prec == "amp": 264 | with autocast(): 265 | output = self.model(image) 266 | loss = F.cross_entropy(output, label) 267 | self.optim.zero_grad() 268 | self.scaler.scale(loss).backward() 269 | self.scaler.step(self.optim) 270 | self.scaler.update() 271 | else: 272 | output = self.model(image) 273 | loss = F.cross_entropy(output, label) 274 | self.model_backward_and_update(loss) 275 | 276 | loss_summary = { 277 | "loss": loss.item(), 278 | "acc": compute_accuracy(output, label)[0].item(), 279 | } 280 | 281 | if (self.batch_idx + 1) == self.num_batches: 282 | self.update_lr() 283 | 284 | return loss_summary 285 | 286 | def parse_batch_train(self, batch): 287 | input = batch["img"] 288 | label = batch["label"] 289 | input = input.to(self.device) 290 | label = label.to(self.device) 291 | return input, label 292 | 293 | def load_model(self, directory, epoch=None): 294 | if not directory: 295 | print("Note that load_model() is skipped as no pretrained model is given") 296 | return 297 | 298 | names = self.get_model_names() 299 | 300 | # By default, the best model is loaded 301 | model_file = "model-best.pth.tar" 302 | 303 | if epoch is not None: 304 | model_file = "model.pth.tar-" + str(epoch) 305 | 306 | for name in names: 307 | model_path = osp.join(directory, name, model_file) 308 | 309 | if not osp.exists(model_path): 310 | raise FileNotFoundError('Model not found at "{}"'.format(model_path)) 311 | 312 | checkpoint = load_checkpoint(model_path) 313 | state_dict = checkpoint["state_dict"] 314 | epoch = checkpoint["epoch"] 315 | 316 | # Ignore fixed token vectors 317 | if "token_prefix" in state_dict: 318 | del state_dict["token_prefix"] 319 | 320 | if "token_suffix" in state_dict: 321 | del state_dict["token_suffix"] 322 | 323 | print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch)) 324 | # set strict=False 325 | self._models[name].load_state_dict(state_dict, strict=False) 326 | -------------------------------------------------------------------------------- /trainers/imagenet_templates.py: -------------------------------------------------------------------------------- 1 | # source: https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb 2 | 3 | IMAGENET_TEMPLATES = [ 4 | "a bad photo of a {}.", 5 | "a photo of many {}.", 6 | "a sculpture of a {}.", 7 | "a photo of the hard to see {}.", 8 | "a low resolution photo of the {}.", 9 | "a rendering of a {}.", 10 | "graffiti of a {}.", 11 | "a bad photo of the {}.", 12 | "a cropped photo of the {}.", 13 | "a tattoo of a {}.", 14 | "the embroidered {}.", 15 | "a photo of a hard to see {}.", 16 | "a bright photo of a {}.", 17 | "a photo of a clean {}.", 18 | "a photo of a dirty {}.", 19 | "a dark photo of the {}.", 20 | "a drawing of a {}.", 21 | "a photo of my {}.", 22 | "the plastic {}.", 23 | "a photo of the cool {}.", 24 | "a close-up photo of a {}.", 25 | "a black and white photo of the {}.", 26 | "a painting of the {}.", 27 | "a painting of a {}.", 28 | "a pixelated photo of the {}.", 29 | "a sculpture of the {}.", 30 | "a bright photo of the {}.", 31 | "a cropped photo of a {}.", 32 | "a plastic {}.", 33 | "a photo of the dirty {}.", 34 | "a jpeg corrupted photo of a {}.", 35 | "a blurry photo of the {}.", 36 | "a photo of the {}.", 37 | "a good photo of the {}.", 38 | "a rendering of the {}.", 39 | "a {} in a video game.", 40 | "a photo of one {}.", 41 | "a doodle of a {}.", 42 | "a close-up photo of the {}.", 43 | "a photo of a {}.", 44 | "the origami {}.", 45 | "the {} in a video game.", 46 | "a sketch of a {}.", 47 | "a doodle of the {}.", 48 | "a origami {}.", 49 | "a low resolution photo of a {}.", 50 | "the toy {}.", 51 | "a rendition of the {}.", 52 | "a photo of the clean {}.", 53 | "a photo of a large {}.", 54 | "a rendition of a {}.", 55 | "a photo of a nice {}.", 56 | "a photo of a weird {}.", 57 | "a blurry photo of a {}.", 58 | "a cartoon {}.", 59 | "art of a {}.", 60 | "a sketch of the {}.", 61 | "a embroidered {}.", 62 | "a pixelated photo of a {}.", 63 | "itap of the {}.", 64 | "a jpeg corrupted photo of the {}.", 65 | "a good photo of a {}.", 66 | "a plushie {}.", 67 | "a photo of the nice {}.", 68 | "a photo of the small {}.", 69 | "a photo of the weird {}.", 70 | "the cartoon {}.", 71 | "art of the {}.", 72 | "a drawing of the {}.", 73 | "a photo of the large {}.", 74 | "a black and white photo of a {}.", 75 | "the plushie {}.", 76 | "a dark photo of a {}.", 77 | "itap of a {}.", 78 | "graffiti of the {}.", 79 | "a toy {}.", 80 | "itap of my {}.", 81 | "a photo of a cool {}.", 82 | "a photo of a small {}.", 83 | "a tattoo of the {}.", 84 | ] 85 | 86 | IMAGENET_TEMPLATES_SELECT = [ 87 | "itap of a {}.", 88 | "a bad photo of the {}.", 89 | "a origami {}.", 90 | "a photo of the large {}.", 91 | "a {} in a video game.", 92 | "art of the {}.", 93 | "a photo of the small {}.", 94 | ] 95 | -------------------------------------------------------------------------------- /trainers/vpt.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from collections import OrderedDict 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import functional as F 8 | from torch.cuda.amp import GradScaler, autocast 9 | 10 | from dassl.engine import TRAINER_REGISTRY, TrainerX 11 | from dassl.metrics import compute_accuracy 12 | from dassl.utils import load_pretrained_weights, load_checkpoint 13 | from dassl.optim import build_optimizer, build_lr_scheduler 14 | 15 | from clip import clip 16 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 17 | 18 | import pdb 19 | 20 | _tokenizer = _Tokenizer() 21 | 22 | 23 | class InferenceBlock(nn.Module): 24 | def __init__(self, input_units, d_theta, output_units): 25 | """ 26 | :param d_theta: dimensionality of the intermediate hidden layers. 27 | :param output_units: dimensionality of the output. 28 | :return: batch of outputs. 29 | """ 30 | super(InferenceBlock, self).__init__() 31 | self.module = nn.Sequential( 32 | nn.Linear(input_units, d_theta, bias=True), 33 | nn.ELU(inplace=True), 34 | nn.Linear(d_theta, d_theta, bias=True), 35 | nn.ELU(inplace=True), 36 | nn.Linear(d_theta, output_units, bias=True), 37 | ) 38 | 39 | def forward(self, inps): 40 | out = self.module(inps) 41 | return out 42 | 43 | 44 | class Amortized(nn.Module): 45 | def __init__(self, input_units=400, d_theta=400, output_units=400): 46 | super(Amortized, self).__init__() 47 | self.output_units = output_units 48 | self.weight_mean = InferenceBlock(input_units, d_theta, output_units) 49 | self.weight_log_variance = InferenceBlock(input_units, d_theta, output_units) 50 | 51 | def forward(self, inps): 52 | weight_mean = self.weight_mean(inps) 53 | weight_log_variance = self.weight_log_variance(inps) 54 | return weight_mean, weight_log_variance 55 | 56 | 57 | def load_clip_to_cpu(cfg): 58 | backbone_name = cfg.MODEL.BACKBONE.NAME 59 | url = clip._MODELS[backbone_name] 60 | model_path = clip._download(url) 61 | 62 | try: 63 | # loading JIT archive 64 | model = torch.jit.load(model_path, map_location="cpu").eval() 65 | state_dict = None 66 | 67 | except RuntimeError: 68 | state_dict = torch.load(model_path, map_location="cpu") 69 | 70 | model = clip.build_model(state_dict or model.state_dict()) 71 | 72 | return model 73 | 74 | 75 | class TextEncoder(nn.Module): 76 | def __init__(self, clip_model): 77 | super().__init__() 78 | self.transformer = clip_model.transformer 79 | self.positional_embedding = clip_model.positional_embedding 80 | self.ln_final = clip_model.ln_final # maybe layer normalization 81 | self.text_projection = clip_model.text_projection 82 | self.dtype = clip_model.dtype 83 | 84 | def forward(self, prompts, tokenized_prompts): 85 | # print(prompts.shape, tokenized_prompts.shape) 86 | x = prompts + self.positional_embedding.type(self.dtype) 87 | x = x.permute(1, 0, 2) # NLD -> LND 88 | x = self.transformer(x) 89 | x = x.permute(1, 0, 2) # LND -> NLD 90 | x = self.ln_final(x).type(self.dtype) 91 | 92 | # x.shape = [batch_size, n_ctx, transformer.width] 93 | # take features from the eot embedding (eot_token is the highest number in each sequence) 94 | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection 95 | 96 | return x 97 | 98 | 99 | class PromptLearner(nn.Module): 100 | def __init__(self, cfg, classnames, clip_model): 101 | super().__init__() 102 | n_cls = len(classnames) 103 | n_ctx = cfg.TRAINER.VPT.N_CTX 104 | ctx_init = cfg.TRAINER.VPT.CTX_INIT 105 | self.L = cfg.TRAINER.VPT.L 106 | self.vpt_type = cfg.TRAINER.VPT.VPT_TYPE 107 | dtype = clip_model.dtype 108 | ctx_dim = clip_model.ln_final.weight.shape[0] 109 | vis_dim = clip_model.visual.output_dim 110 | clip_imsize = clip_model.visual.input_resolution 111 | cfg_imsize = cfg.INPUT.SIZE[0] 112 | assert ( 113 | cfg_imsize == clip_imsize 114 | ), f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})" 115 | 116 | if ctx_init: 117 | # use given words to initialize context vectors 118 | ctx_init = ctx_init.replace("_", " ") 119 | n_ctx = len(ctx_init.split(" ")) 120 | prompt = clip.tokenize( 121 | ctx_init 122 | ) # returns a vector with dim 1 x 77 where 77 is the maximum length of the prompt, intialized the prompt with the context and pad it with zeros. 123 | with torch.no_grad(): 124 | embedding = clip_model.token_embedding(prompt).type(dtype) # 1 x 77 x 512 125 | ctx_vectors = embedding[0, 1 : 1 + n_ctx, :] # 1 x (n_ctx) x 512 126 | prompt_prefix = ctx_init 127 | else: 128 | # random initialization 129 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) 130 | nn.init.normal_(ctx_vectors, std=0.02) 131 | prompt_prefix = " ".join(["X"] * n_ctx) 132 | 133 | print(f'Initial context: "{prompt_prefix}"') 134 | print(f"Number of context words (tokens): {n_ctx}") 135 | 136 | self.ctx = nn.Parameter( 137 | ctx_vectors 138 | ) # ctx intialized with a embedding of the prompt "a photo of cat" 139 | 140 | if self.vpt_type == "cocoopvpt": 141 | self.meta_net = Amortized( 142 | input_units=vis_dim, d_theta=vis_dim // 2, output_units=ctx_dim 143 | ) 144 | if cfg.TRAINER.VPT.PREC == "fp16": 145 | self.meta_net.half() 146 | elif self.vpt_type == "coopvpt": 147 | self.mean_posterior = nn.Parameter(torch.zeros(1, ctx_dim, dtype=dtype)) 148 | self.std_posterior = nn.Parameter(torch.rand(1, ctx_dim, dtype=dtype)) 149 | else: 150 | raise ValueError(f"Type {cfg.vpt_type} is not supported.") 151 | 152 | classnames = [name.replace("_", " ") for name in classnames] # remove any available _ 153 | name_lens = [ 154 | len(_tokenizer.encode(name)) for name in classnames 155 | ] # tokenize each class name, tokenizer might generate multiple token for each class even if the classname only have one character. 156 | prompts = [prompt_prefix + " " + name + "." for name in classnames] 157 | 158 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) # (n_cls, n_tkn) 159 | with torch.no_grad(): 160 | embedding = clip_model.token_embedding(tokenized_prompts).type(dtype) 161 | 162 | # These token vectors will be saved when in save_model(), 163 | # but they should be ignored in load_model() as we want to use 164 | # those computed using the current class names 165 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS 166 | self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS 167 | 168 | self.n_cls = n_cls 169 | self.n_ctx = n_ctx 170 | self.tokenized_prompts = tokenized_prompts # torch.Tensor 171 | self.name_lens = name_lens 172 | 173 | def construct_prompts(self, ctx, prefix, suffix, label=None): 174 | # dim0 is either batch_size (during training) or n_cls (during testing) 175 | # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim) 176 | # prefix: the sos token, with shape of (n_cls, 1, ctx_dim) 177 | # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim) 178 | 179 | if label is not None: 180 | prefix = prefix[label] 181 | suffix = suffix[label] 182 | 183 | prompts = torch.cat( 184 | [ 185 | prefix, # (dim0, 1, dim) 186 | ctx, # (dim0, n_ctx, dim) 187 | suffix, # (dim0, *, dim) 188 | ], 189 | dim=1, 190 | ) 191 | 192 | return prompts 193 | 194 | def sample(self, mu, logvar, L): 195 | shape = (L,) + mu.size() 196 | eps = torch.randn(shape).type_as(mu) 197 | bias = mu.unsqueeze(0) + eps * logvar.exp().sqrt().unsqueeze(0) 198 | return bias 199 | 200 | def forward(self, im_features): 201 | prefix = self.token_prefix 202 | suffix = self.token_suffix 203 | ctx = self.ctx # (n_ctx, ctx_dim) 204 | 205 | if self.vpt_type == "cocoopvpt": 206 | bias_mu, bias_logvar = self.meta_net(im_features) # (1, ctx_dim) 207 | elif self.vpt_type == "coopvpt": 208 | bias_mu, bias_logvar = self.mean_posterior, self.std_posterior # (1, ctx_dim) 209 | else: 210 | raise ValueError(f"Type {self.vpt_type} is not supported.") 211 | 212 | bias = self.sample(bias_mu, bias_logvar, self.L) # (L, 1, ctx_dim) 213 | ctx = ctx.unsqueeze(0) # (1, n_ctx, ctx_dim) 214 | ctx_shifted = ctx + bias # (L, n_ctx, ctx_dim) 215 | 216 | # Use instance-conditioned context tokens for all classes 217 | prompts = [] 218 | for ctx_shifted_i in ctx_shifted: 219 | ctx_i = ctx_shifted_i.unsqueeze(0).expand(self.n_cls, -1, -1) 220 | pts_i = self.construct_prompts(ctx_i, prefix, suffix) # (n_cls, n_tkn, ctx_dim) 221 | prompts.append(pts_i) 222 | prompts = torch.stack(prompts) 223 | 224 | return prompts, bias_mu, bias_logvar 225 | 226 | 227 | class CustomCLIP(nn.Module): 228 | def __init__(self, cfg, classnames, clip_model): 229 | super().__init__() 230 | self.prompt_learner = PromptLearner(cfg, classnames, clip_model) 231 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts 232 | self.L = self.prompt_learner.L 233 | self.image_encoder = clip_model.visual 234 | device_count = torch.cuda.device_count() 235 | # pdb.set_trace() 236 | if device_count > 1: 237 | print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!") 238 | self.text_encoder = nn.DataParallel(TextEncoder(clip_model)) 239 | else: 240 | self.text_encoder = TextEncoder(clip_model) 241 | self.logit_scale = clip_model.logit_scale 242 | self.dtype = clip_model.dtype 243 | 244 | def forward(self, image, label=None): 245 | tokenized_prompts = self.tokenized_prompts 246 | tokenized_prompts = torch.tile(tokenized_prompts, (self.L, 1)) 247 | logit_scale = self.logit_scale.exp() 248 | 249 | # pdb.set_trace() 250 | 251 | image_features = self.image_encoder(image.type(self.dtype)) # 1 x 512 252 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 253 | 254 | prompts, mu, logvar = self.prompt_learner(image_features) # L x NumClass x Length x DIM 255 | _, NumClass, Length, dim = prompts.shape 256 | prompts = prompts.view(-1, Length, dim) # (L * NumClass) x Length x DIM 257 | text_features = self.text_encoder(prompts, tokenized_prompts) 258 | 259 | image_features = image_features.unsqueeze(0).expand((self.L, -1, -1)) 260 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 261 | text_features = text_features.view( 262 | -1, NumClass, text_features.shape[-1] 263 | ) # L * NumClass x DIM 264 | 265 | logits = logit_scale * torch.einsum("LBD,LCD->LBC", image_features, text_features) 266 | 267 | # pdb.set_trace() 268 | 269 | # logits = torch.stack(logits) # L x batch x num_class 270 | log_p_y = torch.log_softmax(logits, dim=-1) 271 | 272 | if self.prompt_learner.training: 273 | 274 | label_one_hot = torch.nn.functional.one_hot(label, num_classes=logits.shape[-1]) 275 | tile_label = torch.tile(label_one_hot.unsqueeze(0), (self.L, 1, 1)) 276 | 277 | task_log_py = self.nll(log_p_y, tile_label) 278 | task_score = torch.logsumexp(task_log_py, dim=0) - torch.log( 279 | torch.Tensor([self.L]).type_as(logits) 280 | ) 281 | task_loss = -task_score.mean(dim=-1) 282 | 283 | return task_loss + 0.001 * self.kl_divergence(mu, logvar) 284 | else: 285 | average_prediction = torch.logsumexp(log_p_y, dim=0) - torch.log( 286 | torch.Tensor([self.L]).type_as(logits) 287 | ) 288 | return average_prediction 289 | 290 | def kl_divergence(self, mu, logvar): 291 | prior_mu = torch.zeros_like(mu) 292 | prior_std = torch.ones_like(logvar) 293 | 294 | prior = torch.distributions.Normal(loc=prior_mu, scale=prior_std) 295 | post = torch.distributions.Normal(loc=mu, scale=logvar.exp().sqrt()) 296 | 297 | dist = torch.distributions.kl_divergence(post, prior).mean(dim=-1) 298 | return dist 299 | 300 | def nll(self, logits, targets): 301 | task_log_py = (logits * targets).sum(dim=-1) 302 | return task_log_py 303 | 304 | 305 | @TRAINER_REGISTRY.register() 306 | class VPT(TrainerX): 307 | def check_cfg(self, cfg): 308 | assert cfg.TRAINER.VPT.PREC in ["fp16", "fp32", "amp"] 309 | 310 | def build_model(self): 311 | cfg = self.cfg 312 | classnames = self.dm.dataset.classnames # List of class names for each benchmark. 313 | 314 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 315 | clip_model = load_clip_to_cpu(cfg) 316 | 317 | if cfg.TRAINER.VPT.PREC == "fp32" or cfg.TRAINER.VPT.PREC == "amp": 318 | # CLIP's default precision is fp16 319 | clip_model.float() 320 | 321 | print("Building custom CLIP") 322 | self.model = CustomCLIP(cfg, classnames, clip_model) 323 | 324 | print("Turning off gradients in both the image and the text encoder") 325 | name_to_update = "prompt_learner" 326 | 327 | for name, param in self.model.named_parameters(): 328 | if name_to_update not in name: 329 | param.requires_grad_(False) 330 | 331 | # Double check 332 | enabled = set() 333 | for name, param in self.model.named_parameters(): 334 | if param.requires_grad: 335 | enabled.add(name) 336 | print(f"Parameters to be updated: {enabled}") 337 | 338 | if cfg.MODEL.INIT_WEIGHTS: 339 | load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS) 340 | 341 | self.model.to(self.device) 342 | # NOTE: only give prompt_learner to the optimizer 343 | self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM) 344 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) 345 | self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched) 346 | 347 | self.scaler = GradScaler() if cfg.TRAINER.VPT.PREC == "amp" else None 348 | 349 | # Note that multi-gpu training could be slow because CLIP's size is 350 | # big, which slows down the copy operation in DataParallel 351 | # device_count = torch.cuda.device_count() 352 | # # pdb.set_trace() 353 | # if device_count > 1: 354 | # print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!") 355 | # self.model = nn.DataParallel(self.model) 356 | 357 | def forward_backward(self, batch): 358 | # pdb.set_trace() 359 | image, label = self.parse_batch_train(batch) 360 | 361 | model = self.model 362 | optim = self.optim 363 | scaler = self.scaler 364 | 365 | prec = self.cfg.TRAINER.VPT.PREC 366 | if prec == "amp": 367 | with autocast(): 368 | loss = model(image, label) 369 | optim.zero_grad() 370 | scaler.scale(loss).backward() 371 | scaler.step(optim) 372 | scaler.update() 373 | else: 374 | loss = model(image, label) 375 | optim.zero_grad() 376 | loss.backward() 377 | optim.step() 378 | 379 | loss_summary = {"loss": loss.item()} 380 | 381 | if (self.batch_idx + 1) == self.num_batches: 382 | self.update_lr() 383 | 384 | return loss_summary 385 | 386 | def parse_batch_train(self, batch): 387 | input = batch["img"] 388 | label = batch["label"] 389 | input = input.to(self.device) 390 | label = label.to(self.device) 391 | return input, label 392 | 393 | def load_model(self, directory, epoch=None): 394 | if not directory: 395 | print("Note that load_model() is skipped as no pretrained model is given") 396 | return 397 | 398 | names = self.get_model_names() 399 | 400 | # By default, the best model is loaded 401 | model_file = "model-best.pth.tar" 402 | 403 | if epoch is not None: 404 | model_file = "model.pth.tar-" + str(epoch) 405 | 406 | for name in names: 407 | model_path = osp.join(directory, name, model_file) 408 | 409 | if not osp.exists(model_path): 410 | raise FileNotFoundError('Model not found at "{}"'.format(model_path)) 411 | 412 | checkpoint = load_checkpoint(model_path) 413 | state_dict = checkpoint["state_dict"] 414 | epoch = checkpoint["epoch"] 415 | 416 | # Ignore fixed token vectors 417 | if "token_prefix" in state_dict: 418 | del state_dict["token_prefix"] 419 | 420 | if "token_suffix" in state_dict: 421 | del state_dict["token_suffix"] 422 | 423 | print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch)) 424 | # set strict=False 425 | self._models[name].load_state_dict(state_dict, strict=False) 426 | -------------------------------------------------------------------------------- /trainers/zsclip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from dassl.engine import TRAINER_REGISTRY, TrainerX 5 | from dassl.optim import build_optimizer, build_lr_scheduler 6 | 7 | from clip import clip 8 | from clip.model import convert_weights 9 | 10 | from .coop import load_clip_to_cpu 11 | from .imagenet_templates import IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT 12 | 13 | CUSTOM_TEMPLATES = { 14 | "OxfordPets": "a photo of a {}, a type of pet.", 15 | "OxfordFlowers": "a photo of a {}, a type of flower.", 16 | "FGVCAircraft": "a photo of a {}, a type of aircraft.", 17 | "DescribableTextures": "{} texture.", 18 | "EuroSAT": "a centered satellite photo of {}.", 19 | "StanfordCars": "a photo of a {}.", 20 | "Food101": "a photo of {}, a type of food.", 21 | "SUN397": "a photo of a {}.", 22 | "Caltech101": "a photo of a {}.", 23 | "UCF101": "a photo of a person doing {}.", 24 | "ImageNet": "a photo of a {}.", 25 | "ImageNetSketch": "a photo of a {}.", 26 | "ImageNetV2": "a photo of a {}.", 27 | "ImageNetA": "a photo of a {}.", 28 | "ImageNetR": "a photo of a {}.", 29 | } 30 | 31 | 32 | @TRAINER_REGISTRY.register() 33 | class ZeroshotCLIP(TrainerX): 34 | def build_model(self): 35 | cfg = self.cfg 36 | classnames = self.dm.dataset.classnames 37 | 38 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 39 | clip_model = load_clip_to_cpu(cfg) 40 | clip_model.to(self.device) 41 | 42 | temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME] 43 | prompts = [temp.format(c.replace("_", " ")) for c in classnames] 44 | print(f"Prompts: {prompts}") 45 | prompts = torch.cat([clip.tokenize(p) for p in prompts]) 46 | prompts = prompts.to(self.device) 47 | 48 | with torch.no_grad(): 49 | text_features = clip_model.encode_text(prompts) 50 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 51 | 52 | self.text_features = text_features 53 | self.clip_model = clip_model 54 | 55 | def model_inference(self, image): 56 | image_features = self.clip_model.encode_image(image) 57 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 58 | logit_scale = self.clip_model.logit_scale.exp() 59 | logits = logit_scale * image_features @ self.text_features.t() 60 | return logits 61 | 62 | 63 | @TRAINER_REGISTRY.register() 64 | class ZeroshotCLIP2(ZeroshotCLIP): 65 | """Prompt ensembling.""" 66 | 67 | # templates = IMAGENET_TEMPLATES 68 | templates = IMAGENET_TEMPLATES_SELECT 69 | 70 | def build_model(self): 71 | cfg = self.cfg 72 | classnames = self.dm.dataset.classnames 73 | 74 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 75 | clip_model = load_clip_to_cpu(cfg) 76 | clip_model.to(self.device) 77 | 78 | for params in clip_model.parameters(): 79 | params.requires_grad_(False) 80 | 81 | # add custom-made prompt 82 | if cfg.DATASET.NAME != "ImageNet": 83 | self.templates += [CUSTOM_TEMPLATES[cfg.DATASET.NAME]] 84 | 85 | num_temp = len(self.templates) 86 | print(f"Prompt ensembling (n={num_temp})") 87 | 88 | mean_text_features = 0 89 | for i, temp in enumerate(self.templates): 90 | prompts = [temp.format(c.replace("_", " ")) for c in classnames] 91 | prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(self.device) 92 | text_features = clip_model.encode_text(prompts) 93 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 94 | mean_text_features = mean_text_features + text_features 95 | mean_text_features = mean_text_features / num_temp 96 | mean_text_features = mean_text_features / mean_text_features.norm(dim=-1, keepdim=True) 97 | 98 | self.text_features = mean_text_features 99 | self.clip_model = clip_model 100 | --------------------------------------------------------------------------------