├── DATASETS.md ├── LICENSE ├── README.md ├── clip ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── clip.cpython-37.pyc │ ├── clip.cpython-38.pyc │ ├── model.cpython-37.pyc │ ├── model.cpython-38.pyc │ ├── simple_tokenizer.cpython-37.pyc │ └── simple_tokenizer.cpython-38.pyc ├── bpe_simple_vocab_16e6.txt.gz ├── clip.py ├── model.py └── simple_tokenizer.py ├── configs ├── datasets │ ├── caltech101.yaml │ ├── dtd.yaml │ ├── eurosat.yaml │ ├── fgvc_aircraft.yaml │ ├── food101.yaml │ ├── imagenet.yaml │ ├── imagenet_a.yaml │ ├── imagenet_r.yaml │ ├── imagenet_sketch.yaml │ ├── imagenetv2.yaml │ ├── oxford_flowers.yaml │ ├── oxford_pets.yaml │ ├── stanford_cars.yaml │ ├── sun397.yaml │ └── ucf101.yaml └── trainers │ ├── CoCoOp │ ├── rn50_c4_ep10_batch1_ctxv1.yaml │ ├── rn50_ep100_init.yaml │ ├── rn50_ep50.yaml │ ├── vit_b16_c16_ep10_batch1.yaml │ ├── vit_b16_c4_ep10_batch1.yaml │ ├── vit_b16_c4_ep10_batch1_ctxv1.yaml │ └── vit_b16_c8_ep10_batch1.yaml │ ├── CoOp │ ├── rn50.yaml │ ├── rn50_ep100.yaml │ ├── rn50_ep50.yaml │ └── rn50_val.yaml │ ├── KgCoOp │ ├── .ipynb_checkpoints │ │ └── vit_b16_ep100_ctxv1-checkpoint.yaml │ ├── rn50_ep100.yaml │ ├── rn50_ep100_b16.yaml │ ├── vit_b16_ep100_ctxv1.yaml │ ├── vit_b16_ep100_ctxv1_b128.yaml │ ├── vit_b16_ep100_ctxv1_b16.yaml │ └── vit_b16_ep100_ctxv1_b8.yaml │ ├── ProGrad │ ├── rn50.yaml │ ├── rn50_ep100.yaml │ └── rn50_ep50.yaml │ └── TCP │ └── vit_b16_ep100_ctxv1.yaml ├── datasets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── caltech101.cpython-37.pyc │ ├── caltech101.cpython-38.pyc │ ├── dtd.cpython-37.pyc │ ├── dtd.cpython-38.pyc │ ├── eurosat.cpython-37.pyc │ ├── eurosat.cpython-38.pyc │ ├── fgvc_aircraft.cpython-37.pyc │ ├── fgvc_aircraft.cpython-38.pyc │ ├── food101.cpython-37.pyc │ ├── food101.cpython-38.pyc │ ├── imagenet.cpython-37.pyc │ ├── imagenet.cpython-38.pyc │ ├── imagenet_a.cpython-37.pyc │ ├── imagenet_a.cpython-38.pyc │ ├── imagenet_r.cpython-37.pyc │ ├── imagenet_r.cpython-38.pyc │ ├── imagenet_sketch.cpython-37.pyc │ ├── imagenet_sketch.cpython-38.pyc │ ├── imagenetv2.cpython-37.pyc │ ├── imagenetv2.cpython-38.pyc │ ├── oxford_flowers.cpython-37.pyc │ ├── oxford_flowers.cpython-38.pyc │ ├── oxford_pets.cpython-37.pyc │ ├── oxford_pets.cpython-38.pyc │ ├── stanford_cars.cpython-37.pyc │ ├── stanford_cars.cpython-38.pyc │ ├── sun397.cpython-37.pyc │ ├── sun397.cpython-38.pyc │ ├── ucf101.cpython-37.pyc │ └── ucf101.cpython-38.pyc ├── caltech101.py ├── dtd.py ├── eurosat.py ├── fgvc_aircraft.py ├── food101.py ├── imagenet.py ├── imagenet_a.py ├── imagenet_r.py ├── imagenet_sketch.py ├── imagenetv2.py ├── oxford_flowers.py ├── oxford_pets.py ├── stanford_cars.py ├── sun397.py └── ucf101.py ├── eval.sh ├── interpret_prompt.py ├── lpclip ├── README.md ├── feat_extractor.py ├── feat_extractor.sh ├── linear_probe.py ├── linear_probe.sh └── linear_probe_transfer.py ├── parse_test_res.py ├── requirements.txt ├── scripts └── base2new_train.sh ├── train.py └── trainers ├── .ipynb_checkpoints ├── kgcoop-checkpoint.py └── kgcoop_bk-checkpoint.py ├── __init__.py ├── __pycache__ ├── __init__.cpython-37.pyc ├── __init__.cpython-38.pyc ├── cocoop.cpython-37.pyc ├── cocoop.cpython-38.pyc ├── coop.cpython-37.pyc ├── coop.cpython-38.pyc ├── imagenet_templates.cpython-37.pyc ├── imagenet_templates.cpython-38.pyc ├── kgcoop.cpython-37.pyc ├── kgcoop.cpython-38.pyc ├── prograd.cpython-37.pyc ├── prograd.cpython-38.pyc ├── tcp.cpython-38.pyc ├── zsclip.cpython-37.pyc └── zsclip.cpython-38.pyc ├── clip_text ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── clip.cpython-37.pyc │ ├── clip.cpython-38.pyc │ ├── model.cpython-37.pyc │ ├── model.cpython-38.pyc │ ├── simple_tokenizer.cpython-37.pyc │ └── simple_tokenizer.cpython-38.pyc ├── bpe_simple_vocab_16e6.txt.gz ├── clip.py ├── model.py └── simple_tokenizer.py ├── cocoop.py ├── coop.py ├── imagenet_templates.py ├── prograd.py ├── tcp.py └── zsclip.py /DATASETS.md: -------------------------------------------------------------------------------- 1 | # How to install datasets 2 | 3 | We suggest putting all datasets under the same folder (say `$DATA`) to ease management and following the instructions below to organize datasets to avoid modifying the source code. The file structure looks like 4 | 5 | ``` 6 | $DATA/ 7 | |–– imagenet/ 8 | |–– caltech-101/ 9 | |–– oxford_pets/ 10 | |–– stanford_cars/ 11 | ``` 12 | 13 | If you have some datasets already installed somewhere else, you can create symbolic links in `$DATA/dataset_name` that point to the original data to avoid duplicate download. 14 | 15 | Datasets list: 16 | - [ImageNet](#imagenet) 17 | - [Caltech101](#caltech101) 18 | - [OxfordPets](#oxfordpets) 19 | - [StanfordCars](#stanfordcars) 20 | - [Flowers102](#flowers102) 21 | - [Food101](#food101) 22 | - [FGVCAircraft](#fgvcaircraft) 23 | - [SUN397](#sun397) 24 | - [DTD](#dtd) 25 | - [EuroSAT](#eurosat) 26 | - [UCF101](#ucf101) 27 | - [ImageNetV2](#imagenetv2) 28 | - [ImageNet-Sketch](#imagenet-sketch) 29 | - [ImageNet-A](#imagenet-a) 30 | - [ImageNet-R](#imagenet-r) 31 | 32 | The instructions to prepare each dataset are detailed below. To ensure reproducibility and fair comparison for future work, we provide fixed train/val/test splits for all datasets except ImageNet where the validation set is used as test set. The fixed splits are either from the original datasets (if available) or created by us. 33 | 34 | ### ImageNet 35 | - Create a folder named `imagenet/` under `$DATA`. 36 | - Create `images/` under `imagenet/`. 37 | - Download the dataset from the [official website](https://image-net.org/index.php) and extract the training and validation sets to `$DATA/imagenet/images`. The directory structure should look like 38 | ``` 39 | imagenet/ 40 | |–– images/ 41 | | |–– train/ # contains 1,000 folders like n01440764, n01443537, etc. 42 | | |–– val/ 43 | ``` 44 | - If you had downloaded the ImageNet dataset before, you can create symbolic links to map the training and validation sets to `$DATA/imagenet/images`. 45 | - Download the `classnames.txt` to `$DATA/imagenet/` from this [link](https://drive.google.com/file/d/1-61f_ol79pViBFDG_IDlUQSwoLcn2XXF/view?usp=sharing). The class names are copied from [CLIP](https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb). 46 | 47 | ### Caltech101 48 | - Create a folder named `caltech-101/` under `$DATA`. 49 | - Download `101_ObjectCategories.tar.gz` from http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz and extract the file under `$DATA/caltech-101`. 50 | - Download `split_zhou_Caltech101.json` from this [link](https://drive.google.com/file/d/1hyarUivQE36mY6jSomru6Fjd-JzwcCzN/view?usp=sharing) and put it under `$DATA/caltech-101`. 51 | 52 | The directory structure should look like 53 | ``` 54 | caltech-101/ 55 | |–– 101_ObjectCategories/ 56 | |–– split_zhou_Caltech101.json 57 | ``` 58 | 59 | ### OxfordPets 60 | - Create a folder named `oxford_pets/` under `$DATA`. 61 | - Download the images from https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz. 62 | - Download the annotations from https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz. 63 | - Download `split_zhou_OxfordPets.json` from this [link](https://drive.google.com/file/d/1501r8Ber4nNKvmlFVQZ8SeUHTcdTTEqs/view?usp=sharing). 64 | 65 | The directory structure should look like 66 | ``` 67 | oxford_pets/ 68 | |–– images/ 69 | |–– annotations/ 70 | |–– split_zhou_OxfordPets.json 71 | ``` 72 | 73 | ### StanfordCars 74 | - Create a folder named `stanford_cars/` under `$DATA`. 75 | - Download the train images http://ai.stanford.edu/~jkrause/car196/cars_train.tgz. 76 | - Download the test images http://ai.stanford.edu/~jkrause/car196/cars_test.tgz. 77 | - Download the train labels https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz. 78 | - Download the test labels http://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat. 79 | - Download `split_zhou_StanfordCars.json` from this [link](https://drive.google.com/file/d/1ObCFbaAgVu0I-k_Au-gIUcefirdAuizT/view?usp=sharing). 80 | 81 | The directory structure should look like 82 | ``` 83 | stanford_cars/ 84 | |–– cars_test\ 85 | |–– cars_test_annos_withlabels.mat 86 | |–– cars_train\ 87 | |–– devkit\ 88 | |–– split_zhou_StanfordCars.json 89 | ``` 90 | 91 | ### Flowers102 92 | - Create a folder named `oxford_flowers/` under `$DATA`. 93 | - Download the images and labels from https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz and https://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat respectively. 94 | - Download `cat_to_name.json` from [here](https://drive.google.com/file/d/1AkcxCXeK_RCGCEC_GvmWxjcjaNhu-at0/view?usp=sharing). 95 | - Download `split_zhou_OxfordFlowers.json` from [here](https://drive.google.com/file/d/1Pp0sRXzZFZq15zVOzKjKBu4A9i01nozT/view?usp=sharing). 96 | 97 | The directory structure should look like 98 | ``` 99 | oxford_flowers/ 100 | |–– cat_to_name.json 101 | |–– imagelabels.mat 102 | |–– jpg/ 103 | |–– split_zhou_OxfordFlowers.json 104 | ``` 105 | 106 | ### Food101 107 | - Download the dataset from https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/ and extract the file `food-101.tar.gz` under `$DATA`, resulting in a folder named `$DATA/food-101/`. 108 | - Download `split_zhou_Food101.json` from [here](https://drive.google.com/file/d/1QK0tGi096I0Ba6kggatX1ee6dJFIcEJl/view?usp=sharing). 109 | 110 | The directory structure should look like 111 | ``` 112 | food-101/ 113 | |–– images/ 114 | |–– license_agreement.txt 115 | |–– meta/ 116 | |–– README.txt 117 | |–– split_zhou_Food101.json 118 | ``` 119 | 120 | ### FGVCAircraft 121 | - Download the data from https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz. 122 | - Extract `fgvc-aircraft-2013b.tar.gz` and keep only `data/`. 123 | - Move `data/` to `$DATA` and rename the folder to `fgvc_aircraft/`. 124 | 125 | The directory structure should look like 126 | ``` 127 | fgvc_aircraft/ 128 | |–– images/ 129 | |–– ... # a bunch of .txt files 130 | ``` 131 | 132 | ### SUN397 133 | - Create a folder named `sun397/` under `$DATA`. 134 | - Download the images http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz. 135 | - Download the partitions https://vision.princeton.edu/projects/2010/SUN/download/Partitions.zip. 136 | - Extract these files under `$DATA/sun397/`. 137 | - Download `split_zhou_SUN397.json` from this [link](https://drive.google.com/file/d/1y2RD81BYuiyvebdN-JymPfyWYcd8_MUq/view?usp=sharing). 138 | 139 | The directory structure should look like 140 | ``` 141 | sun397/ 142 | |–– SUN397/ 143 | |–– split_zhou_SUN397.json 144 | |–– ... # a bunch of .txt files 145 | ``` 146 | 147 | ### DTD 148 | - Download the dataset from https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz and extract it to `$DATA`. This should lead to `$DATA/dtd/`. 149 | - Download `split_zhou_DescribableTextures.json` from this [link](https://drive.google.com/file/d/1u3_QfB467jqHgNXC00UIzbLZRQCg2S7x/view?usp=sharing). 150 | 151 | The directory structure should look like 152 | ``` 153 | dtd/ 154 | |–– images/ 155 | |–– imdb/ 156 | |–– labels/ 157 | |–– split_zhou_DescribableTextures.json 158 | ``` 159 | 160 | ### EuroSAT 161 | - Create a folder named `eurosat/` under `$DATA`. 162 | - Download the dataset from http://madm.dfki.de/files/sentinel/EuroSAT.zip and extract it to `$DATA/eurosat/`. 163 | - Download `split_zhou_EuroSAT.json` from [here](https://drive.google.com/file/d/1Ip7yaCWFi0eaOFUGga0lUdVi_DDQth1o/view?usp=sharing). 164 | 165 | The directory structure should look like 166 | ``` 167 | eurosat/ 168 | |–– 2750/ 169 | |–– split_zhou_EuroSAT.json 170 | ``` 171 | 172 | ### UCF101 173 | - Create a folder named `ucf101/` under `$DATA`. 174 | - Download the zip file `UCF-101-midframes.zip` from [here](https://drive.google.com/file/d/10Jqome3vtUA2keJkNanAiFpgbyC9Hc2O/view?usp=sharing) and extract it to `$DATA/ucf101/`. This zip file contains the extracted middle video frames. 175 | - Download `split_zhou_UCF101.json` from this [link](https://drive.google.com/file/d/1I0S0q91hJfsV9Gf4xDIjgDq4AqBNJb1y/view?usp=sharing). 176 | 177 | The directory structure should look like 178 | ``` 179 | ucf101/ 180 | |–– UCF-101-midframes/ 181 | |–– split_zhou_UCF101.json 182 | ``` 183 | 184 | ### ImageNetV2 185 | - Create a folder named `imagenetv2/` under `$DATA`. 186 | - Go to this github repo https://github.com/modestyachts/ImageNetV2. 187 | - Download the matched-frequency dataset from https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-matched-frequency.tar.gz and extract it to `$DATA/imagenetv2/`. 188 | - Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenetv2/`. 189 | 190 | The directory structure should look like 191 | ``` 192 | imagenetv2/ 193 | |–– imagenetv2-matched-frequency-format-val/ 194 | |–– classnames.txt 195 | ``` 196 | 197 | ### ImageNet-Sketch 198 | - Download the dataset from https://github.com/HaohanWang/ImageNet-Sketch. 199 | - Extract the dataset to `$DATA/imagenet-sketch`. 200 | - Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenet-sketch/`. 201 | 202 | The directory structure should look like 203 | ``` 204 | imagenet-sketch/ 205 | |–– images/ # contains 1,000 folders whose names have the format of n* 206 | |–– classnames.txt 207 | ``` 208 | 209 | ### ImageNet-A 210 | - Create a folder named `imagenet-adversarial/` under `$DATA`. 211 | - Download the dataset from https://github.com/hendrycks/natural-adv-examples and extract it to `$DATA/imagenet-adversarial/`. 212 | - Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenet-adversarial/`. 213 | 214 | The directory structure should look like 215 | ``` 216 | imagenet-adversarial/ 217 | |–– imagenet-a/ # contains 200 folders whose names have the format of n* 218 | |–– classnames.txt 219 | ``` 220 | 221 | ### ImageNet-R 222 | - Create a folder named `imagenet-rendition/` under `$DATA`. 223 | - Download the dataset from https://github.com/hendrycks/imagenet-r and extract it to `$DATA/imagenet-rendition/`. 224 | - Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenet-rendition/`. 225 | 226 | The directory structure should look like 227 | ``` 228 | imagenet-rendition/ 229 | |–– imagenet-r/ # contains 200 folders whose names have the format of n* 230 | |–– classnames.txt 231 | ``` -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Kaiyang Zhou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## TCP: Textual-based Class-aware Prompt tuning for Visual-Language Model[CVPR24] 2 | 3 | > [**TCP: Textual-based Class-aware Prompt tuning for Visual-Language Model**](https://arxiv.org/abs/2311.18231)
4 | > Hantao Yao, Rui Zhang, Changsheng Xu 5 | 6 | ## How to Install 7 | This code is built on top of the toolbox [Dassl](https://github.com/KaiyangZhou/Dassl.pytorch). You can prepare the environment as follows: 8 | 9 | ``` 10 | # Create a conda environment 11 | conda create -n dassl python=3.7 12 | 13 | # Activate the environment 14 | conda activate dassl 15 | 16 | # Install dependencies 17 | pip install -r requirements.txt 18 | 19 | # Install torch (version >= 1.7.1) and torchvision 20 | # Please make sure you have installed the gpu version due to the speed. 21 | # For example: 22 | conda install pytorch torchvision cudatoolkit=10.1 -c pytorch 23 | 24 | # Install this library (no need to re-build if the source code is modified) 25 | python setup.py develop 26 | ``` 27 | 28 | After that, run `pip install -r requirements.txt` under `Textual-based_Class-aware_prompt_tuning/` to install a few more packages required by [CLIP](https://github.com/openai/CLIP) (this should be done when `dassl` is activated). Then, you are ready to go. 29 | 30 | Follow [DATASETS.md](DATASETS.md) to install the datasets. 31 | 32 | ## [Importantly]Adjust `EPS` in Adam optimzier 33 | Since using the standard AdaW on the fp16 data will produce NaN loss, we thus set the EPS in AdaW as 1e-3. The discussion can also be see https://discuss.pytorch.org/t/adam-half-precision-nans/1765. 34 | 35 | Line 80: ./Dassl.pytorch/dassl/optim/optimizer.py 36 | 37 | ``` 38 | if optim == "adam": 39 | optimizer = torch.optim.Adam( 40 | param_groups, 41 | lr=lr, 42 | weight_decay=weight_decay, 43 | betas=(adam_beta1, adam_beta2), 44 | eps=1e-3, 45 | ) 46 | ``` 47 | 48 | 49 | 50 | 51 | ## Generalization From Base to New Classes 52 | 53 | You will need `base2new_train_main.sh`. The scripts with the prefix `base2new_train` train a model on base classes while the ones with the prefix `base2new_test` evaluate the trained model on new classes. Both kinds of scripts have only one input argument, i.e., `DATASET`. `DATASET` takes as input a dataset name, like `imagenet` or `caltech101`. The valid names are the files' names in `CoOp/configs/datasets/`. 54 | 55 | Below we provide an example on how to evaluate the model on ImageNet. 56 | 57 | ```bash 58 | bash base2new_train.sh 59 | ``` 60 | 61 | When the evaluation is done, you can use `parse_test_res.py` to automatically calculate the average results. For instance, after you finish the evaluation using the aforementioned commands, you would get 62 | 63 | 64 | Then, to get the average performance on the base classes, run 65 | 66 | ```bash 67 | python parse_test_res.py output/base2new/train_base/stanford_cars/shots_16/CoCoOp/rn50_ep100 68 | ``` 69 | 70 | To get the average performance on the new classes, run 71 | 72 | ```bash 73 | python parse_test_res.py output/base2new/test_new/stanford_cars/shots_16/CoCoOp/rn50_ep100 --test-log 74 | ``` 75 | 76 | ## Citation 77 | If you use our work, please consider citing: 78 | ```bibtex 79 | @inproceedings{TCP24, 80 | title={TCP: Textual-based Class-aware Prompt tuning for Visual-Language Model}, 81 | author={Hantao Yao, Rui Zhang, Changsheng Xu}, 82 | booktitle={The IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 83 | year={2024} 84 | } 85 | ``` 86 | 87 | -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /clip/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/clip/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /clip/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/clip/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /clip/__pycache__/clip.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/clip/__pycache__/clip.cpython-37.pyc -------------------------------------------------------------------------------- /clip/__pycache__/clip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/clip/__pycache__/clip.cpython-38.pyc -------------------------------------------------------------------------------- /clip/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/clip/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /clip/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/clip/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /clip/__pycache__/simple_tokenizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/clip/__pycache__/simple_tokenizer.cpython-37.pyc -------------------------------------------------------------------------------- /clip/__pycache__/simple_tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/clip/__pycache__/simple_tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 39 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 40 | } 41 | 42 | 43 | #def _download(url: str, root: str): 44 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 45 | os.makedirs(root, exist_ok=True) 46 | filename = os.path.basename(url) 47 | 48 | expected_sha256 = url.split("/")[-2] 49 | download_target = os.path.join(root, filename) 50 | 51 | if os.path.exists(download_target) and not os.path.isfile(download_target): 52 | raise RuntimeError(f"{download_target} exists and is not a regular file") 53 | 54 | if os.path.isfile(download_target): 55 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 56 | return download_target 57 | else: 58 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 59 | 60 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 61 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 62 | while True: 63 | buffer = source.read(8192) 64 | if not buffer: 65 | break 66 | 67 | output.write(buffer) 68 | loop.update(len(buffer)) 69 | 70 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 71 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") 72 | 73 | return download_target 74 | 75 | 76 | def _convert_image_to_rgb(image): 77 | return image.convert("RGB") 78 | 79 | 80 | def _transform(n_px): 81 | return Compose([ 82 | Resize(n_px, interpolation=BICUBIC), 83 | CenterCrop(n_px), 84 | _convert_image_to_rgb, 85 | ToTensor(), 86 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 87 | ]) 88 | 89 | 90 | def available_models() -> List[str]: 91 | """Returns the names of available CLIP models""" 92 | return list(_MODELS.keys()) 93 | 94 | 95 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 96 | """Load a CLIP model 97 | 98 | Parameters 99 | ---------- 100 | name : str 101 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 102 | 103 | device : Union[str, torch.device] 104 | The device to put the loaded model 105 | 106 | jit : bool 107 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 108 | 109 | download_root: str 110 | path to download the model files; by default, it uses "~/.cache/clip" 111 | 112 | Returns 113 | ------- 114 | model : torch.nn.Module 115 | The CLIP model 116 | 117 | preprocess : Callable[[PIL.Image], torch.Tensor] 118 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 119 | """ 120 | if name in _MODELS: 121 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 122 | elif os.path.isfile(name): 123 | model_path = name 124 | else: 125 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 126 | 127 | with open(model_path, 'rb') as opened_file: 128 | try: 129 | # loading JIT archive 130 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 131 | state_dict = None 132 | except RuntimeError: 133 | # loading saved state dict 134 | if jit: 135 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 136 | jit = False 137 | state_dict = torch.load(opened_file, map_location="cpu") 138 | 139 | if not jit: 140 | model = build_model(state_dict or model.state_dict()).to(device) 141 | if str(device) == "cpu": 142 | model.float() 143 | return model, _transform(model.visual.input_resolution) 144 | 145 | # patch the device names 146 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 147 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 148 | 149 | def _node_get(node: torch._C.Node, key: str): 150 | """Gets attributes of a node which is polymorphic over return type. 151 | 152 | From https://github.com/pytorch/pytorch/pull/82628 153 | """ 154 | sel = node.kindOf(key) 155 | return getattr(node, sel)(key) 156 | 157 | def patch_device(module): 158 | try: 159 | graphs = [module.graph] if hasattr(module, "graph") else [] 160 | except RuntimeError: 161 | graphs = [] 162 | 163 | if hasattr(module, "forward1"): 164 | graphs.append(module.forward1.graph) 165 | 166 | for graph in graphs: 167 | for node in graph.findAllNodes("prim::Constant"): 168 | if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"): 169 | node.copyAttributes(device_node) 170 | 171 | model.apply(patch_device) 172 | patch_device(model.encode_image) 173 | patch_device(model.encode_text) 174 | 175 | # patch dtype to float32 on CPU 176 | if str(device) == "cpu": 177 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 178 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 179 | float_node = float_input.node() 180 | 181 | def patch_float(module): 182 | try: 183 | graphs = [module.graph] if hasattr(module, "graph") else [] 184 | except RuntimeError: 185 | graphs = [] 186 | 187 | if hasattr(module, "forward1"): 188 | graphs.append(module.forward1.graph) 189 | 190 | for graph in graphs: 191 | for node in graph.findAllNodes("aten::to"): 192 | inputs = list(node.inputs()) 193 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 194 | if _node_get(inputs[i].node(), "value") == 5: 195 | inputs[i].node().copyAttributes(float_node) 196 | 197 | model.apply(patch_float) 198 | patch_float(model.encode_image) 199 | patch_float(model.encode_text) 200 | 201 | model.float() 202 | 203 | return model, _transform(model.input_resolution.item()) 204 | 205 | 206 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: 207 | """ 208 | Returns the tokenized representation of given input string(s) 209 | 210 | Parameters 211 | ---------- 212 | texts : Union[str, List[str]] 213 | An input string or a list of input strings to tokenize 214 | 215 | context_length : int 216 | The context length to use; all CLIP models use 77 as the context length 217 | 218 | truncate: bool 219 | Whether to truncate the text in case its encoding is longer than the context length 220 | 221 | Returns 222 | ------- 223 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 224 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 225 | """ 226 | if isinstance(texts, str): 227 | texts = [texts] 228 | 229 | sot_token = _tokenizer.encoder["<|startoftext|>"] 230 | eot_token = _tokenizer.encoder["<|endoftext|>"] 231 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 232 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 233 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 234 | else: 235 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 236 | 237 | for i, tokens in enumerate(all_tokens): 238 | if len(tokens) > context_length: 239 | if truncate: 240 | tokens = tokens[:context_length] 241 | tokens[-1] = eot_token 242 | else: 243 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 244 | result[i, :len(tokens)] = torch.tensor(tokens) 245 | 246 | return result 247 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /configs/datasets/caltech101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "Caltech101" 3 | -------------------------------------------------------------------------------- /configs/datasets/dtd.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "DescribableTextures" 3 | -------------------------------------------------------------------------------- /configs/datasets/eurosat.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "EuroSAT" 3 | -------------------------------------------------------------------------------- /configs/datasets/fgvc_aircraft.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "FGVCAircraft" 3 | -------------------------------------------------------------------------------- /configs/datasets/food101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "Food101" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNet" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenet_a.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetA" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenet_r.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetR" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenet_sketch.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetSketch" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenetv2.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetV2" 3 | -------------------------------------------------------------------------------- /configs/datasets/oxford_flowers.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "OxfordFlowers" -------------------------------------------------------------------------------- /configs/datasets/oxford_pets.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "OxfordPets" -------------------------------------------------------------------------------- /configs/datasets/stanford_cars.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "StanfordCars" 3 | -------------------------------------------------------------------------------- /configs/datasets/sun397.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "SUN397" 3 | -------------------------------------------------------------------------------- /configs/datasets/ucf101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "UCF101" 3 | -------------------------------------------------------------------------------- /configs/trainers/CoCoOp/rn50_c4_ep10_batch1_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 4 34 | CTX_INIT: True 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/CoCoOp/rn50_ep100_init.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 16 34 | CTX_INIT: True 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/CoCoOp/rn50_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 16 34 | CTX_INIT: True 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/CoCoOp/vit_b16_c16_ep10_batch1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 16 34 | CTX_INIT: "" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/CoCoOp/vit_b16_c4_ep10_batch1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 4 34 | CTX_INIT: "" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/CoCoOp/vit_b16_c4_ep10_batch1_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 16 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 4 34 | #CTX_INIT: "a photo of a" 35 | PREC: "fp16" 36 | -------------------------------------------------------------------------------- /configs/trainers/CoCoOp/vit_b16_c8_ep10_batch1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 8 34 | CTX_INIT: "" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50_ep100.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50_val.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 32 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | MODEL: 16 | BACKBONE: 17 | NAME: "RN50" -------------------------------------------------------------------------------- /configs/trainers/KgCoOp/.ipynb_checkpoints/vit_b16_ep100_ctxv1-checkpoint.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "adam" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: False 34 | -------------------------------------------------------------------------------- /configs/trainers/KgCoOp/rn50_ep100.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 256 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "adam" #"sgd" 17 | LR: 0.002 #0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: False 34 | -------------------------------------------------------------------------------- /configs/trainers/KgCoOp/rn50_ep100_b16.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 16 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "adam" #"sgd" 17 | LR: 0.002 #0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: False 34 | -------------------------------------------------------------------------------- /configs/trainers/KgCoOp/vit_b16_ep100_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 500 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "adam" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | #EPS: 1e-3 24 | 25 | TRAIN: 26 | PRINT_FREQ: 5 27 | 28 | MODEL: 29 | BACKBONE: 30 | NAME: "ViT-B/16" 31 | 32 | TRAINER: 33 | COOP: 34 | CTX_INIT: False 35 | -------------------------------------------------------------------------------- /configs/trainers/KgCoOp/vit_b16_ep100_ctxv1_b128.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 128 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "adam" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: True 34 | -------------------------------------------------------------------------------- /configs/trainers/KgCoOp/vit_b16_ep100_ctxv1_b16.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 16 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "adam" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: False 34 | -------------------------------------------------------------------------------- /configs/trainers/KgCoOp/vit_b16_ep100_ctxv1_b8.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 8 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "adam" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 10 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: False 34 | -------------------------------------------------------------------------------- /configs/trainers/ProGrad/rn50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | LOSS: 25 | NAME: "prograd" 26 | T: 1.0 27 | 28 | TRAIN: 29 | PRINT_FREQ: 5 30 | 31 | MODEL: 32 | BACKBONE: 33 | NAME: "RN50" 34 | 35 | TRAINER: 36 | COOP: 37 | CTX_INIT: True 38 | -------------------------------------------------------------------------------- /configs/trainers/ProGrad/rn50_ep100.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | LOSS: 25 | NAME: "prograd" 26 | T: 1.0 27 | 28 | TRAIN: 29 | PRINT_FREQ: 5 30 | 31 | MODEL: 32 | BACKBONE: 33 | NAME: "RN50" 34 | 35 | TRAINER: 36 | COOP: 37 | CTX_INIT: True 38 | -------------------------------------------------------------------------------- /configs/trainers/ProGrad/rn50_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | LOSS: 25 | NAME: "prograd" 26 | T: 1.0 27 | 28 | TRAIN: 29 | PRINT_FREQ: 5 30 | 31 | MODEL: 32 | BACKBONE: 33 | NAME: "RN50" 34 | 35 | TRAINER: 36 | COOP: 37 | CTX_INIT: True -------------------------------------------------------------------------------- /configs/trainers/TCP/vit_b16_ep100_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 500 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "adam" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: False 34 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/caltech101.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/caltech101.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/caltech101.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/caltech101.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dtd.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/dtd.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dtd.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/dtd.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/eurosat.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/eurosat.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/eurosat.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/eurosat.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/fgvc_aircraft.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/fgvc_aircraft.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/fgvc_aircraft.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/fgvc_aircraft.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/food101.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/food101.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/food101.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/food101.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/imagenet.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/imagenet.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet_a.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/imagenet_a.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet_a.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/imagenet_a.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet_r.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/imagenet_r.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet_r.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/imagenet_r.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet_sketch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/imagenet_sketch.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet_sketch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/imagenet_sketch.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenetv2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/imagenetv2.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenetv2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/imagenetv2.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/oxford_flowers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/oxford_flowers.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/oxford_flowers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/oxford_flowers.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/oxford_pets.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/oxford_pets.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/oxford_pets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/oxford_pets.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/stanford_cars.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/stanford_cars.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/stanford_cars.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/stanford_cars.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sun397.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/sun397.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sun397.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/sun397.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/ucf101.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/ucf101.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/ucf101.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/ucf101.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/caltech101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | from .dtd import DescribableTextures as DTD 9 | 10 | IGNORED = ["BACKGROUND_Google", "Faces_easy"] 11 | NEW_CNAMES = { 12 | "airplanes": "airplane", 13 | "Faces": "face", 14 | "Leopards": "leopard", 15 | "Motorbikes": "motorbike", 16 | } 17 | 18 | 19 | @DATASET_REGISTRY.register() 20 | class Caltech101(DatasetBase): 21 | 22 | dataset_dir = "caltech-101" 23 | 24 | def __init__(self, cfg): 25 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 26 | self.dataset_dir = os.path.join(root, self.dataset_dir) 27 | self.image_dir = os.path.join(self.dataset_dir, "101_ObjectCategories") 28 | self.split_path = os.path.join(self.dataset_dir, 29 | "split_zhou_Caltech101.json") 30 | self.split_fewshot_dir = os.path.join(self.dataset_dir, 31 | "split_fewshot") 32 | mkdir_if_missing(self.split_fewshot_dir) 33 | 34 | if os.path.exists(self.split_path): 35 | train, val, test = OxfordPets.read_split(self.split_path, 36 | self.image_dir) 37 | else: 38 | train, val, test = DTD.read_and_split_data(self.image_dir, 39 | ignored=IGNORED, 40 | new_cnames=NEW_CNAMES) 41 | OxfordPets.save_split(train, val, test, self.split_path, 42 | self.image_dir) 43 | 44 | num_shots = cfg.DATASET.NUM_SHOTS 45 | if num_shots >= 1: 46 | seed = cfg.SEED 47 | preprocessed = os.path.join(self.split_fewshot_dir, 48 | f"shot_{num_shots}-seed_{seed}.pkl") 49 | 50 | if os.path.exists(preprocessed): 51 | print( 52 | f"Loading preprocessed few-shot data from {preprocessed}") 53 | with open(preprocessed, "rb") as file: 54 | data = pickle.load(file) 55 | train, val = data["train"], data["val"] 56 | else: 57 | train = self.generate_fewshot_dataset(train, 58 | num_shots=num_shots) 59 | val = self.generate_fewshot_dataset(val, 60 | num_shots=min( 61 | num_shots, 4)) 62 | data = {"train": train, "val": val} 63 | print(f"Saving preprocessed few-shot data to {preprocessed}") 64 | with open(preprocessed, "wb") as file: 65 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 66 | 67 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 68 | train, val, test = OxfordPets.subsample_classes(train, 69 | val, 70 | test, 71 | subsample=subsample) 72 | 73 | super().__init__(train_x=train, val=val, test=test) 74 | -------------------------------------------------------------------------------- /datasets/dtd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | 5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 6 | from dassl.utils import listdir_nohidden, mkdir_if_missing 7 | 8 | from .oxford_pets import OxfordPets 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class DescribableTextures(DatasetBase): 13 | 14 | dataset_dir = "dtd" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "images") 20 | self.split_path = os.path.join(self.dataset_dir, 21 | "split_zhou_DescribableTextures.json") 22 | self.split_fewshot_dir = os.path.join(self.dataset_dir, 23 | "split_fewshot") 24 | mkdir_if_missing(self.split_fewshot_dir) 25 | 26 | if os.path.exists(self.split_path): 27 | train, val, test = OxfordPets.read_split(self.split_path, 28 | self.image_dir) 29 | else: 30 | train, val, test = self.read_and_split_data(self.image_dir) 31 | OxfordPets.save_split(train, val, test, self.split_path, 32 | self.image_dir) 33 | 34 | num_shots = cfg.DATASET.NUM_SHOTS 35 | if num_shots >= 1: 36 | seed = cfg.SEED 37 | preprocessed = os.path.join(self.split_fewshot_dir, 38 | f"shot_{num_shots}-seed_{seed}.pkl") 39 | 40 | if os.path.exists(preprocessed): 41 | print( 42 | f"Loading preprocessed few-shot data from {preprocessed}") 43 | with open(preprocessed, "rb") as file: 44 | data = pickle.load(file) 45 | train, val = data["train"], data["val"] 46 | else: 47 | train = self.generate_fewshot_dataset(train, 48 | num_shots=num_shots) 49 | val = self.generate_fewshot_dataset(val, 50 | num_shots=min( 51 | num_shots, 4)) 52 | data = {"train": train, "val": val} 53 | print(f"Saving preprocessed few-shot data to {preprocessed}") 54 | with open(preprocessed, "wb") as file: 55 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 56 | 57 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 58 | train, val, test = OxfordPets.subsample_classes(train, 59 | val, 60 | test, 61 | subsample=subsample) 62 | 63 | super().__init__(train_x=train, val=val, test=test) 64 | 65 | @staticmethod 66 | def read_and_split_data(image_dir, 67 | p_trn=0.5, 68 | p_val=0.2, 69 | ignored=[], 70 | new_cnames=None): 71 | # The data are supposed to be organized into the following structure 72 | # ============= 73 | # images/ 74 | # dog/ 75 | # cat/ 76 | # horse/ 77 | # ============= 78 | categories = listdir_nohidden(image_dir) 79 | categories = [c for c in categories if c not in ignored] 80 | categories.sort() 81 | 82 | p_tst = 1 - p_trn - p_val 83 | print( 84 | f"Splitting into {p_trn:.0%} train, {p_val:.0%} val, and {p_tst:.0%} test" 85 | ) 86 | 87 | def _collate(ims, y, c): 88 | items = [] 89 | for im in ims: 90 | item = Datum(impath=im, label=y, 91 | classname=c) # is already 0-based 92 | items.append(item) 93 | return items 94 | 95 | train, val, test = [], [], [] 96 | for label, category in enumerate(categories): 97 | category_dir = os.path.join(image_dir, category) 98 | images = listdir_nohidden(category_dir) 99 | images = [os.path.join(category_dir, im) for im in images] 100 | random.shuffle(images) 101 | n_total = len(images) 102 | n_train = round(n_total * p_trn) 103 | n_val = round(n_total * p_val) 104 | n_test = n_total - n_train - n_val 105 | assert n_train > 0 and n_val > 0 and n_test > 0 106 | 107 | if new_cnames is not None and category in new_cnames: 108 | category = new_cnames[category] 109 | 110 | train.extend(_collate(images[:n_train], label, category)) 111 | val.extend( 112 | _collate(images[n_train:n_train + n_val], label, category)) 113 | test.extend(_collate(images[n_train + n_val:], label, category)) 114 | 115 | return train, val, test 116 | -------------------------------------------------------------------------------- /datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | from .dtd import DescribableTextures as DTD 9 | 10 | NEW_CNAMES = { 11 | "AnnualCrop": "Annual Crop Land", 12 | "Forest": "Forest", 13 | "HerbaceousVegetation": "Herbaceous Vegetation Land", 14 | "Highway": "Highway or Road", 15 | "Industrial": "Industrial Buildings", 16 | "Pasture": "Pasture Land", 17 | "PermanentCrop": "Permanent Crop Land", 18 | "Residential": "Residential Buildings", 19 | "River": "River", 20 | "SeaLake": "Sea or Lake", 21 | } 22 | 23 | 24 | @DATASET_REGISTRY.register() 25 | class EuroSAT(DatasetBase): 26 | 27 | dataset_dir = "eurosat" 28 | 29 | def __init__(self, cfg): 30 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 31 | self.dataset_dir = os.path.join(root, self.dataset_dir) 32 | self.image_dir = os.path.join(self.dataset_dir, "2750") 33 | self.split_path = os.path.join(self.dataset_dir, 34 | "split_zhou_EuroSAT.json") 35 | self.split_fewshot_dir = os.path.join(self.dataset_dir, 36 | "split_fewshot") 37 | mkdir_if_missing(self.split_fewshot_dir) 38 | 39 | if os.path.exists(self.split_path): 40 | train, val, test = OxfordPets.read_split(self.split_path, 41 | self.image_dir) 42 | else: 43 | train, val, test = DTD.read_and_split_data(self.image_dir, 44 | new_cnames=NEW_CNAMES) 45 | OxfordPets.save_split(train, val, test, self.split_path, 46 | self.image_dir) 47 | 48 | num_shots = cfg.DATASET.NUM_SHOTS 49 | if num_shots >= 1: 50 | seed = cfg.SEED 51 | preprocessed = os.path.join(self.split_fewshot_dir, 52 | f"shot_{num_shots}-seed_{seed}.pkl") 53 | 54 | if os.path.exists(preprocessed): 55 | print( 56 | f"Loading preprocessed few-shot data from {preprocessed}") 57 | with open(preprocessed, "rb") as file: 58 | data = pickle.load(file) 59 | train, val = data["train"], data["val"] 60 | else: 61 | train = self.generate_fewshot_dataset(train, 62 | num_shots=num_shots) 63 | val = self.generate_fewshot_dataset(val, 64 | num_shots=min( 65 | num_shots, 4)) 66 | data = {"train": train, "val": val} 67 | print(f"Saving preprocessed few-shot data to {preprocessed}") 68 | with open(preprocessed, "wb") as file: 69 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 70 | 71 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 72 | train, val, test = OxfordPets.subsample_classes(train, 73 | val, 74 | test, 75 | subsample=subsample) 76 | 77 | super().__init__(train_x=train, val=val, test=test) 78 | 79 | def update_classname(self, dataset_old): 80 | dataset_new = [] 81 | for item_old in dataset_old: 82 | cname_old = item_old.classname 83 | cname_new = NEW_CLASSNAMES[cname_old] 84 | item_new = Datum(impath=item_old.impath, 85 | label=item_old.label, 86 | classname=cname_new) 87 | dataset_new.append(item_new) 88 | return dataset_new 89 | -------------------------------------------------------------------------------- /datasets/fgvc_aircraft.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class FGVCAircraft(DatasetBase): 12 | 13 | dataset_dir = "fgvc_aircraft" 14 | 15 | def __init__(self, cfg): 16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 17 | self.dataset_dir = os.path.join(root, self.dataset_dir) 18 | self.image_dir = os.path.join(self.dataset_dir, "images") 19 | self.split_fewshot_dir = os.path.join(self.dataset_dir, 20 | "split_fewshot") 21 | mkdir_if_missing(self.split_fewshot_dir) 22 | 23 | classnames = [] 24 | with open(os.path.join(self.dataset_dir, "variants.txt"), "r") as f: 25 | lines = f.readlines() 26 | for line in lines: 27 | classnames.append(line.strip()) 28 | cname2lab = {c: i for i, c in enumerate(classnames)} 29 | 30 | train = self.read_data(cname2lab, "images_variant_train.txt") 31 | val = self.read_data(cname2lab, "images_variant_val.txt") 32 | test = self.read_data(cname2lab, "images_variant_test.txt") 33 | 34 | num_shots = cfg.DATASET.NUM_SHOTS 35 | if num_shots >= 1: 36 | seed = cfg.SEED 37 | preprocessed = os.path.join(self.split_fewshot_dir, 38 | f"shot_{num_shots}-seed_{seed}.pkl") 39 | 40 | if os.path.exists(preprocessed): 41 | print( 42 | f"Loading preprocessed few-shot data from {preprocessed}") 43 | with open(preprocessed, "rb") as file: 44 | data = pickle.load(file) 45 | train, val = data["train"], data["val"] 46 | else: 47 | train = self.generate_fewshot_dataset(train, 48 | num_shots=num_shots) 49 | val = self.generate_fewshot_dataset(val, 50 | num_shots=min( 51 | num_shots, 4)) 52 | data = {"train": train, "val": val} 53 | print(f"Saving preprocessed few-shot data to {preprocessed}") 54 | with open(preprocessed, "wb") as file: 55 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 56 | 57 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 58 | train, val, test = OxfordPets.subsample_classes(train, 59 | val, 60 | test, 61 | subsample=subsample) 62 | 63 | super().__init__(train_x=train, val=val, test=test) 64 | 65 | def read_data(self, cname2lab, split_file): 66 | filepath = os.path.join(self.dataset_dir, split_file) 67 | items = [] 68 | 69 | with open(filepath, "r") as f: 70 | lines = f.readlines() 71 | for line in lines: 72 | line = line.strip().split(" ") 73 | imname = line[0] + ".jpg" 74 | classname = " ".join(line[1:]) 75 | impath = os.path.join(self.image_dir, imname) 76 | label = cname2lab[classname] 77 | item = Datum(impath=impath, label=label, classname=classname) 78 | items.append(item) 79 | 80 | return items 81 | -------------------------------------------------------------------------------- /datasets/food101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | from .dtd import DescribableTextures as DTD 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class Food101(DatasetBase): 13 | 14 | dataset_dir = "food-101" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "images") 20 | self.split_path = os.path.join(self.dataset_dir, 21 | "split_zhou_Food101.json") 22 | self.split_fewshot_dir = os.path.join(self.dataset_dir, 23 | "split_fewshot") 24 | mkdir_if_missing(self.split_fewshot_dir) 25 | 26 | if os.path.exists(self.split_path): 27 | train, val, test = OxfordPets.read_split(self.split_path, 28 | self.image_dir) 29 | else: 30 | train, val, test = DTD.read_and_split_data(self.image_dir) 31 | OxfordPets.save_split(train, val, test, self.split_path, 32 | self.image_dir) 33 | 34 | num_shots = cfg.DATASET.NUM_SHOTS 35 | if num_shots >= 1: 36 | seed = cfg.SEED 37 | preprocessed = os.path.join(self.split_fewshot_dir, 38 | f"shot_{num_shots}-seed_{seed}.pkl") 39 | 40 | if os.path.exists(preprocessed): 41 | print( 42 | f"Loading preprocessed few-shot data from {preprocessed}") 43 | with open(preprocessed, "rb") as file: 44 | data = pickle.load(file) 45 | train, val = data["train"], data["val"] 46 | else: 47 | train = self.generate_fewshot_dataset(train, 48 | num_shots=num_shots) 49 | val = self.generate_fewshot_dataset(val, 50 | num_shots=min( 51 | num_shots, 4)) 52 | data = {"train": train, "val": val} 53 | print(f"Saving preprocessed few-shot data to {preprocessed}") 54 | with open(preprocessed, "wb") as file: 55 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 56 | 57 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 58 | train, val, test = OxfordPets.subsample_classes(train, 59 | val, 60 | test, 61 | subsample=subsample) 62 | 63 | super().__init__(train_x=train, val=val, test=test) 64 | -------------------------------------------------------------------------------- /datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from collections import OrderedDict 4 | 5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 6 | from dassl.utils import listdir_nohidden, mkdir_if_missing 7 | 8 | from .oxford_pets import OxfordPets 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class ImageNet(DatasetBase): 13 | 14 | dataset_dir = "imagenet" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "images") 20 | self.preprocessed = os.path.join(self.dataset_dir, "preprocessed.pkl") 21 | self.split_fewshot_dir = os.path.join(self.dataset_dir, 22 | "split_fewshot") 23 | mkdir_if_missing(self.split_fewshot_dir) 24 | 25 | if os.path.exists(self.preprocessed): 26 | with open(self.preprocessed, "rb") as f: 27 | preprocessed = pickle.load(f) 28 | train = preprocessed["train"] 29 | test = preprocessed["test"] 30 | else: 31 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 32 | classnames = self.read_classnames(text_file) 33 | train = self.read_data(classnames, "train") 34 | # Follow standard practice to perform evaluation on the val set 35 | # Also used as the val set (so evaluate the last-step model) 36 | test = self.read_data(classnames, "val") 37 | 38 | preprocessed = {"train": train, "test": test} 39 | with open(self.preprocessed, "wb") as f: 40 | pickle.dump(preprocessed, f, protocol=pickle.HIGHEST_PROTOCOL) 41 | 42 | num_shots = cfg.DATASET.NUM_SHOTS 43 | if num_shots >= 1: 44 | seed = cfg.SEED 45 | preprocessed = os.path.join(self.split_fewshot_dir, 46 | f"shot_{num_shots}-seed_{seed}.pkl") 47 | 48 | if os.path.exists(preprocessed): 49 | print( 50 | f"Loading preprocessed few-shot data from {preprocessed}") 51 | with open(preprocessed, "rb") as file: 52 | data = pickle.load(file) 53 | train = data["train"] 54 | else: 55 | train = self.generate_fewshot_dataset(train, 56 | num_shots=num_shots) 57 | data = {"train": train} 58 | print(f"Saving preprocessed few-shot data to {preprocessed}") 59 | with open(preprocessed, "wb") as file: 60 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 61 | 62 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 63 | train, test = OxfordPets.subsample_classes(train, 64 | test, 65 | subsample=subsample) 66 | 67 | super().__init__(train_x=train, val=test, test=test) 68 | 69 | @staticmethod 70 | def read_classnames(text_file): 71 | """Return a dictionary containing 72 | key-value pairs of : . 73 | """ 74 | classnames = OrderedDict() 75 | with open(text_file, "r") as f: 76 | lines = f.readlines() 77 | for line in lines: 78 | line = line.strip().split(" ") 79 | folder = line[0] 80 | classname = " ".join(line[1:]) 81 | classnames[folder] = classname 82 | return classnames 83 | 84 | def read_data(self, classnames, split_dir): 85 | split_dir = os.path.join(self.image_dir, split_dir) 86 | folders = sorted(f.name for f in os.scandir(split_dir) if f.is_dir()) 87 | items = [] 88 | 89 | for label, folder in enumerate(folders): 90 | imnames = listdir_nohidden(os.path.join(split_dir, folder)) 91 | classname = classnames[folder] 92 | for imname in imnames: 93 | impath = os.path.join(split_dir, folder, imname) 94 | item = Datum(impath=impath, label=label, classname=classname) 95 | items.append(item) 96 | 97 | return items 98 | -------------------------------------------------------------------------------- /datasets/imagenet_a.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | TO_BE_IGNORED = ["README.txt"] 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class ImageNetA(DatasetBase): 13 | """ImageNet-A(dversarial). 14 | 15 | This dataset is used for testing only. 16 | """ 17 | 18 | dataset_dir = "imagenet-adversarial" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.image_dir = os.path.join(self.dataset_dir, "imagenet-a") 24 | 25 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 26 | classnames = ImageNet.read_classnames(text_file) 27 | 28 | data = self.read_data(classnames) 29 | 30 | super().__init__(train_x=data, test=data) 31 | 32 | def read_data(self, classnames): 33 | image_dir = self.image_dir 34 | folders = listdir_nohidden(image_dir, sort=True) 35 | folders = [f for f in folders if f not in TO_BE_IGNORED] 36 | items = [] 37 | 38 | for label, folder in enumerate(folders): 39 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 40 | classname = classnames[folder] 41 | for imname in imnames: 42 | impath = os.path.join(image_dir, folder, imname) 43 | item = Datum(impath=impath, label=label, classname=classname) 44 | items.append(item) 45 | 46 | return items 47 | -------------------------------------------------------------------------------- /datasets/imagenet_r.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | TO_BE_IGNORED = ["README.txt"] 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class ImageNetR(DatasetBase): 13 | """ImageNet-R(endition). 14 | 15 | This dataset is used for testing only. 16 | """ 17 | 18 | dataset_dir = "imagenet-rendition" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.image_dir = os.path.join(self.dataset_dir, "imagenet-r") 24 | 25 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 26 | classnames = ImageNet.read_classnames(text_file) 27 | 28 | data = self.read_data(classnames) 29 | 30 | super().__init__(train_x=data, test=data) 31 | 32 | def read_data(self, classnames): 33 | image_dir = self.image_dir 34 | folders = listdir_nohidden(image_dir, sort=True) 35 | folders = [f for f in folders if f not in TO_BE_IGNORED] 36 | items = [] 37 | 38 | for label, folder in enumerate(folders): 39 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 40 | classname = classnames[folder] 41 | for imname in imnames: 42 | impath = os.path.join(image_dir, folder, imname) 43 | item = Datum(impath=impath, label=label, classname=classname) 44 | items.append(item) 45 | 46 | return items 47 | -------------------------------------------------------------------------------- /datasets/imagenet_sketch.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class ImageNetSketch(DatasetBase): 11 | """ImageNet-Sketch. 12 | 13 | This dataset is used for testing only. 14 | """ 15 | 16 | dataset_dir = "imagenet-sketch" 17 | 18 | def __init__(self, cfg): 19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 20 | self.dataset_dir = os.path.join(root, self.dataset_dir) 21 | self.image_dir = os.path.join(self.dataset_dir, "images") 22 | 23 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 24 | classnames = ImageNet.read_classnames(text_file) 25 | 26 | data = self.read_data(classnames) 27 | 28 | super().__init__(train_x=data, test=data) 29 | 30 | def read_data(self, classnames): 31 | image_dir = self.image_dir 32 | folders = listdir_nohidden(image_dir, sort=True) 33 | items = [] 34 | 35 | for label, folder in enumerate(folders): 36 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 37 | classname = classnames[folder] 38 | for imname in imnames: 39 | impath = os.path.join(image_dir, folder, imname) 40 | item = Datum(impath=impath, label=label, classname=classname) 41 | items.append(item) 42 | 43 | return items 44 | -------------------------------------------------------------------------------- /datasets/imagenetv2.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class ImageNetV2(DatasetBase): 11 | """ImageNetV2. 12 | 13 | This dataset is used for testing only. 14 | """ 15 | 16 | dataset_dir = "imagenetv2" 17 | 18 | def __init__(self, cfg): 19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 20 | self.dataset_dir = os.path.join(root, self.dataset_dir) 21 | image_dir = "imagenetv2-matched-frequency-format-val" 22 | self.image_dir = os.path.join(self.dataset_dir, image_dir) 23 | 24 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 25 | classnames = ImageNet.read_classnames(text_file) 26 | 27 | data = self.read_data(classnames) 28 | 29 | super().__init__(train_x=data, test=data) 30 | 31 | def read_data(self, classnames): 32 | image_dir = self.image_dir 33 | folders = list(classnames.keys()) 34 | items = [] 35 | 36 | for label in range(1000): 37 | class_dir = os.path.join(image_dir, str(label)) 38 | imnames = listdir_nohidden(class_dir) 39 | folder = folders[label] 40 | classname = classnames[folder] 41 | for imname in imnames: 42 | impath = os.path.join(class_dir, imname) 43 | item = Datum(impath=impath, label=label, classname=classname) 44 | items.append(item) 45 | 46 | return items 47 | -------------------------------------------------------------------------------- /datasets/oxford_flowers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | from scipy.io import loadmat 5 | from collections import defaultdict 6 | 7 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 8 | from dassl.utils import read_json, mkdir_if_missing 9 | 10 | from .oxford_pets import OxfordPets 11 | 12 | 13 | @DATASET_REGISTRY.register() 14 | class OxfordFlowers(DatasetBase): 15 | 16 | dataset_dir = "oxford_flowers" 17 | 18 | def __init__(self, cfg): 19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 20 | self.dataset_dir = os.path.join(root, self.dataset_dir) 21 | self.image_dir = os.path.join(self.dataset_dir, "jpg") 22 | self.label_file = os.path.join(self.dataset_dir, "imagelabels.mat") 23 | self.lab2cname_file = os.path.join(self.dataset_dir, 24 | "cat_to_name.json") 25 | self.split_path = os.path.join(self.dataset_dir, 26 | "split_zhou_OxfordFlowers.json") 27 | self.split_fewshot_dir = os.path.join(self.dataset_dir, 28 | "split_fewshot") 29 | mkdir_if_missing(self.split_fewshot_dir) 30 | 31 | if os.path.exists(self.split_path): 32 | train, val, test = OxfordPets.read_split(self.split_path, 33 | self.image_dir) 34 | else: 35 | train, val, test = self.read_data() 36 | OxfordPets.save_split(train, val, test, self.split_path, 37 | self.image_dir) 38 | 39 | num_shots = cfg.DATASET.NUM_SHOTS 40 | if num_shots >= 1: 41 | seed = cfg.SEED 42 | preprocessed = os.path.join(self.split_fewshot_dir, 43 | f"shot_{num_shots}-seed_{seed}.pkl") 44 | 45 | if os.path.exists(preprocessed): 46 | print( 47 | f"Loading preprocessed few-shot data from {preprocessed}") 48 | with open(preprocessed, "rb") as file: 49 | data = pickle.load(file) 50 | train, val = data["train"], data["val"] 51 | else: 52 | train = self.generate_fewshot_dataset(train, 53 | num_shots=num_shots) 54 | val = self.generate_fewshot_dataset(val, 55 | num_shots=min( 56 | num_shots, 4)) 57 | data = {"train": train, "val": val} 58 | print(f"Saving preprocessed few-shot data to {preprocessed}") 59 | with open(preprocessed, "wb") as file: 60 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 61 | 62 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 63 | train, val, test = OxfordPets.subsample_classes(train, 64 | val, 65 | test, 66 | subsample=subsample) 67 | 68 | super().__init__(train_x=train, val=val, test=test) 69 | 70 | def read_data(self): 71 | tracker = defaultdict(list) 72 | label_file = loadmat(self.label_file)["labels"][0] 73 | for i, label in enumerate(label_file): 74 | imname = f"image_{str(i + 1).zfill(5)}.jpg" 75 | impath = os.path.join(self.image_dir, imname) 76 | label = int(label) 77 | tracker[label].append(impath) 78 | 79 | print("Splitting data into 50% train, 20% val, and 30% test") 80 | 81 | def _collate(ims, y, c): 82 | items = [] 83 | for im in ims: 84 | item = Datum(impath=im, label=y - 1, 85 | classname=c) # convert to 0-based label 86 | items.append(item) 87 | return items 88 | 89 | lab2cname = read_json(self.lab2cname_file) 90 | train, val, test = [], [], [] 91 | for label, impaths in tracker.items(): 92 | random.shuffle(impaths) 93 | n_total = len(impaths) 94 | n_train = round(n_total * 0.5) 95 | n_val = round(n_total * 0.2) 96 | n_test = n_total - n_train - n_val 97 | assert n_train > 0 and n_val > 0 and n_test > 0 98 | cname = lab2cname[str(label)] 99 | train.extend(_collate(impaths[:n_train], label, cname)) 100 | val.extend(_collate(impaths[n_train:n_train + n_val], label, 101 | cname)) 102 | test.extend(_collate(impaths[n_train + n_val:], label, cname)) 103 | 104 | return train, val, test 105 | -------------------------------------------------------------------------------- /datasets/oxford_pets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import math 4 | import random 5 | from collections import defaultdict 6 | 7 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 8 | from dassl.utils import read_json, write_json, mkdir_if_missing 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class OxfordPets(DatasetBase): 13 | 14 | dataset_dir = "oxford_pets" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "images") 20 | self.anno_dir = os.path.join(self.dataset_dir, "annotations") 21 | self.split_path = os.path.join(self.dataset_dir, 22 | "split_zhou_OxfordPets.json") 23 | self.split_fewshot_dir = os.path.join(self.dataset_dir, 24 | "split_fewshot") 25 | mkdir_if_missing(self.split_fewshot_dir) 26 | 27 | if os.path.exists(self.split_path): 28 | train, val, test = self.read_split(self.split_path, self.image_dir) 29 | else: 30 | trainval = self.read_data(split_file="trainval.txt") 31 | test = self.read_data(split_file="test.txt") 32 | train, val = self.split_trainval(trainval) 33 | self.save_split(train, val, test, self.split_path, self.image_dir) 34 | 35 | num_shots = cfg.DATASET.NUM_SHOTS 36 | if num_shots >= 1: 37 | seed = cfg.SEED 38 | preprocessed = os.path.join(self.split_fewshot_dir, 39 | f"shot_{num_shots}-seed_{seed}.pkl") 40 | 41 | if os.path.exists(preprocessed): 42 | print( 43 | f"Loading preprocessed few-shot data from {preprocessed}") 44 | with open(preprocessed, "rb") as file: 45 | data = pickle.load(file) 46 | train, val = data["train"], data["val"] 47 | else: 48 | train = self.generate_fewshot_dataset(train, 49 | num_shots=num_shots) 50 | val = self.generate_fewshot_dataset(val, 51 | num_shots=min( 52 | num_shots, 4)) 53 | data = {"train": train, "val": val} 54 | print(f"Saving preprocessed few-shot data to {preprocessed}") 55 | with open(preprocessed, "wb") as file: 56 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 57 | 58 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 59 | train, val, test = self.subsample_classes(train, 60 | val, 61 | test, 62 | subsample=subsample) 63 | 64 | super().__init__(train_x=train, val=val, test=test) 65 | 66 | def read_data(self, split_file): 67 | filepath = os.path.join(self.anno_dir, split_file) 68 | items = [] 69 | 70 | with open(filepath, "r") as f: 71 | lines = f.readlines() 72 | for line in lines: 73 | line = line.strip() 74 | imname, label, species, _ = line.split(" ") 75 | breed = imname.split("_")[:-1] 76 | breed = "_".join(breed) 77 | breed = breed.lower() 78 | imname += ".jpg" 79 | impath = os.path.join(self.image_dir, imname) 80 | label = int(label) - 1 # convert to 0-based index 81 | item = Datum(impath=impath, label=label, classname=breed) 82 | items.append(item) 83 | 84 | return items 85 | 86 | @staticmethod 87 | def split_trainval(trainval, p_val=0.2): 88 | p_trn = 1 - p_val 89 | print(f"Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val") 90 | tracker = defaultdict(list) 91 | for idx, item in enumerate(trainval): 92 | label = item.label 93 | tracker[label].append(idx) 94 | 95 | train, val = [], [] 96 | for label, idxs in tracker.items(): 97 | n_val = round(len(idxs) * p_val) 98 | assert n_val > 0 99 | random.shuffle(idxs) 100 | for n, idx in enumerate(idxs): 101 | item = trainval[idx] 102 | if n < n_val: 103 | val.append(item) 104 | else: 105 | train.append(item) 106 | 107 | return train, val 108 | 109 | @staticmethod 110 | def save_split(train, val, test, filepath, path_prefix): 111 | def _extract(items): 112 | out = [] 113 | for item in items: 114 | impath = item.impath 115 | label = item.label 116 | classname = item.classname 117 | impath = impath.replace(path_prefix, "") 118 | if impath.startswith("/"): 119 | impath = impath[1:] 120 | out.append((impath, label, classname)) 121 | return out 122 | 123 | train = _extract(train) 124 | val = _extract(val) 125 | test = _extract(test) 126 | 127 | split = {"train": train, "val": val, "test": test} 128 | 129 | write_json(split, filepath) 130 | print(f"Saved split to {filepath}") 131 | 132 | @staticmethod 133 | def read_split(filepath, path_prefix): 134 | def _convert(items): 135 | out = [] 136 | for impath, label, classname in items: 137 | impath = os.path.join(path_prefix, impath) 138 | item = Datum(impath=impath, 139 | label=int(label), 140 | classname=classname) 141 | out.append(item) 142 | return out 143 | 144 | print(f"Reading split from {filepath}") 145 | split = read_json(filepath) 146 | train = _convert(split["train"]) 147 | val = _convert(split["val"]) 148 | test = _convert(split["test"]) 149 | 150 | return train, val, test 151 | 152 | @staticmethod 153 | def subsample_classes(*args, subsample="all"): 154 | """Divide classes into two groups. The first group 155 | represents base classes while the second group represents 156 | new classes. 157 | 158 | Args: 159 | args: a list of datasets, e.g. train, val and test. 160 | subsample (str): what classes to subsample. 161 | """ 162 | assert subsample in ["all", "base", "new"] 163 | 164 | if subsample == "all": 165 | return args 166 | 167 | dataset = args[0] 168 | labels = set() 169 | for item in dataset: 170 | labels.add(item.label) 171 | labels = list(labels) 172 | labels.sort() 173 | n = len(labels) 174 | # Divide classes into two halves 175 | m = math.ceil(n / 2) 176 | 177 | print(f"SUBSAMPLE {subsample.upper()} CLASSES!") 178 | if subsample == "base": 179 | selected = labels[:m] # take the first half 180 | else: 181 | selected = labels[m:] # take the second half 182 | relabeler = {y: y_new for y_new, y in enumerate(selected)} 183 | 184 | output = [] 185 | for dataset in args: 186 | dataset_new = [] 187 | for item in dataset: 188 | if item.label not in selected: 189 | continue 190 | item_new = Datum(impath=item.impath, 191 | label=relabeler[item.label], 192 | classname=item.classname) 193 | dataset_new.append(item_new) 194 | output.append(dataset_new) 195 | 196 | return output 197 | -------------------------------------------------------------------------------- /datasets/stanford_cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from scipy.io import loadmat 4 | 5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 6 | from dassl.utils import mkdir_if_missing 7 | 8 | from .oxford_pets import OxfordPets 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class StanfordCars(DatasetBase): 13 | 14 | dataset_dir = "stanford_cars" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.split_path = os.path.join(self.dataset_dir, 20 | "split_zhou_StanfordCars.json") 21 | self.split_fewshot_dir = os.path.join(self.dataset_dir, 22 | "split_fewshot") 23 | mkdir_if_missing(self.split_fewshot_dir) 24 | 25 | if os.path.exists(self.split_path): 26 | train, val, test = OxfordPets.read_split(self.split_path, 27 | self.dataset_dir) 28 | else: 29 | trainval_file = os.path.join(self.dataset_dir, "devkit", 30 | "cars_train_annos.mat") 31 | test_file = os.path.join(self.dataset_dir, 32 | "cars_test_annos_withlabels.mat") 33 | meta_file = os.path.join(self.dataset_dir, "devkit", 34 | "cars_meta.mat") 35 | trainval = self.read_data("cars_train", trainval_file, meta_file) 36 | test = self.read_data("cars_test", test_file, meta_file) 37 | train, val = OxfordPets.split_trainval(trainval) 38 | OxfordPets.save_split(train, val, test, self.split_path, 39 | self.dataset_dir) 40 | 41 | num_shots = cfg.DATASET.NUM_SHOTS 42 | if num_shots >= 1: 43 | seed = cfg.SEED 44 | preprocessed = os.path.join(self.split_fewshot_dir, 45 | f"shot_{num_shots}-seed_{seed}.pkl") 46 | 47 | if os.path.exists(preprocessed): 48 | print( 49 | f"Loading preprocessed few-shot data from {preprocessed}") 50 | with open(preprocessed, "rb") as file: 51 | data = pickle.load(file) 52 | train, val = data["train"], data["val"] 53 | else: 54 | train = self.generate_fewshot_dataset(train, 55 | num_shots=num_shots) 56 | val = self.generate_fewshot_dataset(val, 57 | num_shots=min( 58 | num_shots, 4)) 59 | data = {"train": train, "val": val} 60 | print(f"Saving preprocessed few-shot data to {preprocessed}") 61 | with open(preprocessed, "wb") as file: 62 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 63 | 64 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 65 | train, val, test = OxfordPets.subsample_classes(train, 66 | val, 67 | test, 68 | subsample=subsample) 69 | 70 | super().__init__(train_x=train, val=val, test=test) 71 | 72 | def read_data(self, image_dir, anno_file, meta_file): 73 | anno_file = loadmat(anno_file)["annotations"][0] 74 | meta_file = loadmat(meta_file)["class_names"][0] 75 | items = [] 76 | 77 | for i in range(len(anno_file)): 78 | imname = anno_file[i]["fname"][0] 79 | impath = os.path.join(self.dataset_dir, image_dir, imname) 80 | label = anno_file[i]["class"][0, 0] 81 | label = int(label) - 1 # convert to 0-based index 82 | classname = meta_file[label][0] 83 | names = classname.split(" ") 84 | year = names.pop(-1) 85 | names.insert(0, year) 86 | classname = " ".join(names) 87 | item = Datum(impath=impath, label=label, classname=classname) 88 | items.append(item) 89 | 90 | return items 91 | -------------------------------------------------------------------------------- /datasets/sun397.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class SUN397(DatasetBase): 12 | 13 | dataset_dir = "sun397" 14 | 15 | def __init__(self, cfg): 16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 17 | self.dataset_dir = os.path.join(root, self.dataset_dir) 18 | self.image_dir = os.path.join(self.dataset_dir, "SUN397") 19 | self.split_path = os.path.join(self.dataset_dir, 20 | "split_zhou_SUN397.json") 21 | self.split_fewshot_dir = os.path.join(self.dataset_dir, 22 | "split_fewshot") 23 | mkdir_if_missing(self.split_fewshot_dir) 24 | 25 | if os.path.exists(self.split_path): 26 | train, val, test = OxfordPets.read_split(self.split_path, 27 | self.image_dir) 28 | else: 29 | classnames = [] 30 | with open(os.path.join(self.dataset_dir, "ClassName.txt"), 31 | "r") as f: 32 | lines = f.readlines() 33 | for line in lines: 34 | line = line.strip()[1:] # remove / 35 | classnames.append(line) 36 | cname2lab = {c: i for i, c in enumerate(classnames)} 37 | trainval = self.read_data(cname2lab, "Training_01.txt") 38 | test = self.read_data(cname2lab, "Testing_01.txt") 39 | train, val = OxfordPets.split_trainval(trainval) 40 | OxfordPets.save_split(train, val, test, self.split_path, 41 | self.image_dir) 42 | 43 | num_shots = cfg.DATASET.NUM_SHOTS 44 | if num_shots >= 1: 45 | seed = cfg.SEED 46 | preprocessed = os.path.join(self.split_fewshot_dir, 47 | f"shot_{num_shots}-seed_{seed}.pkl") 48 | 49 | if os.path.exists(preprocessed): 50 | print( 51 | f"Loading preprocessed few-shot data from {preprocessed}") 52 | with open(preprocessed, "rb") as file: 53 | data = pickle.load(file) 54 | train, val = data["train"], data["val"] 55 | else: 56 | train = self.generate_fewshot_dataset(train, 57 | num_shots=num_shots) 58 | val = self.generate_fewshot_dataset(val, 59 | num_shots=min( 60 | num_shots, 4)) 61 | data = {"train": train, "val": val} 62 | print(f"Saving preprocessed few-shot data to {preprocessed}") 63 | with open(preprocessed, "wb") as file: 64 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 65 | 66 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 67 | train, val, test = OxfordPets.subsample_classes(train, 68 | val, 69 | test, 70 | subsample=subsample) 71 | 72 | super().__init__(train_x=train, val=val, test=test) 73 | 74 | def read_data(self, cname2lab, text_file): 75 | text_file = os.path.join(self.dataset_dir, text_file) 76 | items = [] 77 | 78 | with open(text_file, "r") as f: 79 | lines = f.readlines() 80 | for line in lines: 81 | imname = line.strip()[1:] # remove / 82 | classname = os.path.dirname(imname) 83 | label = cname2lab[classname] 84 | impath = os.path.join(self.image_dir, imname) 85 | 86 | names = classname.split("/")[1:] # remove 1st letter 87 | names = names[::-1] # put words like indoor/outdoor at first 88 | classname = " ".join(names) 89 | 90 | item = Datum(impath=impath, label=label, classname=classname) 91 | items.append(item) 92 | 93 | return items 94 | -------------------------------------------------------------------------------- /datasets/ucf101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import re 4 | 5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 6 | from dassl.utils import mkdir_if_missing 7 | 8 | from .oxford_pets import OxfordPets 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class UCF101(DatasetBase): 13 | 14 | dataset_dir = "ucf101" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "UCF-101-midframes") 20 | self.split_path = os.path.join(self.dataset_dir, 21 | "split_zhou_UCF101.json") 22 | self.split_fewshot_dir = os.path.join(self.dataset_dir, 23 | "split_fewshot") 24 | mkdir_if_missing(self.split_fewshot_dir) 25 | 26 | if os.path.exists(self.split_path): 27 | train, val, test = OxfordPets.read_split(self.split_path, 28 | self.image_dir) 29 | else: 30 | cname2lab = {} 31 | filepath = os.path.join(self.dataset_dir, 32 | "ucfTrainTestlist/classInd.txt") 33 | with open(filepath, "r") as f: 34 | lines = f.readlines() 35 | for line in lines: 36 | label, classname = line.strip().split(" ") 37 | label = int(label) - 1 # conver to 0-based index 38 | cname2lab[classname] = label 39 | 40 | trainval = self.read_data(cname2lab, 41 | "ucfTrainTestlist/trainlist01.txt") 42 | test = self.read_data(cname2lab, "ucfTrainTestlist/testlist01.txt") 43 | train, val = OxfordPets.split_trainval(trainval) 44 | OxfordPets.save_split(train, val, test, self.split_path, 45 | self.image_dir) 46 | 47 | num_shots = cfg.DATASET.NUM_SHOTS 48 | if num_shots >= 1: 49 | seed = cfg.SEED 50 | preprocessed = os.path.join(self.split_fewshot_dir, 51 | f"shot_{num_shots}-seed_{seed}.pkl") 52 | 53 | if os.path.exists(preprocessed): 54 | print( 55 | f"Loading preprocessed few-shot data from {preprocessed}") 56 | with open(preprocessed, "rb") as file: 57 | data = pickle.load(file) 58 | train, val = data["train"], data["val"] 59 | else: 60 | train = self.generate_fewshot_dataset(train, 61 | num_shots=num_shots) 62 | val = self.generate_fewshot_dataset(val, 63 | num_shots=min( 64 | num_shots, 4)) 65 | data = {"train": train, "val": val} 66 | print(f"Saving preprocessed few-shot data to {preprocessed}") 67 | with open(preprocessed, "wb") as file: 68 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 69 | 70 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 71 | train, val, test = OxfordPets.subsample_classes(train, 72 | val, 73 | test, 74 | subsample=subsample) 75 | 76 | super().__init__(train_x=train, val=val, test=test) 77 | 78 | def read_data(self, cname2lab, text_file): 79 | text_file = os.path.join(self.dataset_dir, text_file) 80 | items = [] 81 | 82 | with open(text_file, "r") as f: 83 | lines = f.readlines() 84 | for line in lines: 85 | line = line.strip().split(" ")[0] # trainlist: filename, label 86 | action, filename = line.split("/") 87 | label = cname2lab[action] 88 | 89 | elements = re.findall("[A-Z][^A-Z]*", action) 90 | renamed_action = "_".join(elements) 91 | 92 | filename = filename.replace(".avi", ".jpg") 93 | impath = os.path.join(self.image_dir, renamed_action, filename) 94 | 95 | item = Datum(impath=impath, 96 | label=label, 97 | classname=renamed_action) 98 | items.append(item) 99 | 100 | return items 101 | -------------------------------------------------------------------------------- /eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for DATASET in caltech101 oxford_pets stanford_cars oxford_flowers food101 fgvc_aircraft dtd eurosat ucf101 4 | do 5 | python parse_test_res.py output_1108_4/base2new/train_base/${DATASET}/shots_16_1.0/TCP/vit_b16_ep100_ctxv1/ --test-log 6 | python parse_test_res.py output_1108_4_eval/base2new/test_new/${DATASET}/shots_16_1.0/TCP/vit_b16_ep100_ctxv1/ --test-log 7 | done 8 | -------------------------------------------------------------------------------- /interpret_prompt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import torch 5 | 6 | from clip.simple_tokenizer import SimpleTokenizer 7 | from clip import clip 8 | 9 | 10 | def load_clip_to_cpu(backbone_name="RN50"): 11 | url = clip._MODELS[backbone_name] 12 | model_path = clip._download(url) 13 | 14 | try: 15 | # loading JIT archive 16 | model = torch.jit.load(model_path, map_location="cpu").eval() 17 | state_dict = None 18 | 19 | except RuntimeError: 20 | state_dict = torch.load(model_path, map_location="cpu") 21 | 22 | model = clip.build_model(state_dict or model.state_dict()) 23 | 24 | return model 25 | 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument("fpath", type=str, help="Path to the learned prompt") 29 | parser.add_argument("topk", type=int, help="Select top-k similar words") 30 | args = parser.parse_args() 31 | 32 | fpath = args.fpath 33 | topk = args.topk 34 | 35 | assert os.path.exists(fpath) 36 | 37 | print(f"Return the top-{topk} matched words") 38 | 39 | tokenizer = SimpleTokenizer() 40 | clip_model = load_clip_to_cpu() 41 | token_embedding = clip_model.token_embedding.weight 42 | print(f"Size of token embedding: {token_embedding.shape}") 43 | 44 | prompt_learner = torch.load(fpath, map_location="cpu")["state_dict"] 45 | ctx = prompt_learner["ctx"] 46 | ctx = ctx.float() 47 | print(f"Size of context: {ctx.shape}") 48 | 49 | if ctx.dim() == 2: 50 | # Generic context 51 | distance = torch.cdist(ctx, token_embedding) 52 | print(f"Size of distance matrix: {distance.shape}") 53 | sorted_idxs = torch.argsort(distance, dim=1) 54 | sorted_idxs = sorted_idxs[:, :topk] 55 | 56 | for m, idxs in enumerate(sorted_idxs): 57 | words = [tokenizer.decoder[idx.item()] for idx in idxs] 58 | dist = [f"{distance[m, idx].item():.4f}" for idx in idxs] 59 | print(f"{m+1}: {words} {dist}") 60 | 61 | elif ctx.dim() == 3: 62 | # Class-specific context 63 | raise NotImplementedError 64 | -------------------------------------------------------------------------------- /lpclip/README.md: -------------------------------------------------------------------------------- 1 | # Linear Probe CLIP 2 | 3 | To run linear probe baselines, make sure that your current working directory is `lpclip/`. 4 | 5 | Step 1: Extract Features using the CLIP Image Encoder 6 | ```bash 7 | sh feat_extractor.sh 8 | ``` 9 | 10 | Step 2: Train few-shot linear probe 11 | ```bash 12 | sh linear_probe.sh 13 | ``` 14 | 15 | We follow the instructions stated in the Appendix A3 (pp.38) of [the original CLIP paper](https://arxiv.org/pdf/2103.00020.pdf), with a careful hyperparameter sweep. 16 | 17 | Note: please pull the latest Dassl (version >= `606a2c6`). 18 | -------------------------------------------------------------------------------- /lpclip/feat_extractor.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | import numpy as np 3 | import torch 4 | import sys 5 | 6 | sys.path.append(os.path.abspath("..")) 7 | 8 | from datasets.oxford_pets import OxfordPets 9 | from datasets.oxford_flowers import OxfordFlowers 10 | from datasets.fgvc_aircraft import FGVCAircraft 11 | from datasets.dtd import DescribableTextures 12 | from datasets.eurosat import EuroSAT 13 | from datasets.stanford_cars import StanfordCars 14 | from datasets.food101 import Food101 15 | from datasets.sun397 import SUN397 16 | from datasets.caltech101 import Caltech101 17 | from datasets.ucf101 import UCF101 18 | from datasets.imagenet import ImageNet 19 | from datasets.imagenetv2 import ImageNetV2 20 | from datasets.imagenet_sketch import ImageNetSketch 21 | from datasets.imagenet_a import ImageNetA 22 | from datasets.imagenet_r import ImageNetR 23 | 24 | from dassl.utils import setup_logger, set_random_seed, collect_env_info 25 | from dassl.config import get_cfg_default 26 | from dassl.data.transforms import build_transform 27 | from dassl.data import DatasetWrapper 28 | 29 | import clip 30 | 31 | # import pdb; pdb.set_trace() 32 | 33 | 34 | def print_args(args, cfg): 35 | print("***************") 36 | print("** Arguments **") 37 | print("***************") 38 | optkeys = list(args.__dict__.keys()) 39 | optkeys.sort() 40 | for key in optkeys: 41 | print("{}: {}".format(key, args.__dict__[key])) 42 | print("************") 43 | print("** Config **") 44 | print("************") 45 | print(cfg) 46 | 47 | 48 | def reset_cfg(cfg, args): 49 | if args.root: 50 | cfg.DATASET.ROOT = args.root 51 | 52 | if args.output_dir: 53 | cfg.OUTPUT_DIR = args.output_dir 54 | 55 | if args.trainer: 56 | cfg.TRAINER.NAME = args.trainer 57 | 58 | if args.backbone: 59 | cfg.MODEL.BACKBONE.NAME = args.backbone 60 | 61 | if args.head: 62 | cfg.MODEL.HEAD.NAME = args.head 63 | 64 | 65 | def extend_cfg(cfg): 66 | """ 67 | Add new config variables. 68 | 69 | E.g. 70 | from yacs.config import CfgNode as CN 71 | cfg.TRAINER.MY_MODEL = CN() 72 | cfg.TRAINER.MY_MODEL.PARAM_A = 1. 73 | cfg.TRAINER.MY_MODEL.PARAM_B = 0.5 74 | cfg.TRAINER.MY_MODEL.PARAM_C = False 75 | """ 76 | from yacs.config import CfgNode as CN 77 | 78 | cfg.TRAINER.OURS = CN() 79 | cfg.TRAINER.OURS.N_CTX = 10 # number of context vectors 80 | cfg.TRAINER.OURS.CSC = False # class-specific context 81 | cfg.TRAINER.OURS.CTX_INIT = "" # initialize context vectors with given words 82 | cfg.TRAINER.OURS.WEIGHT_U = 0.1 # weight for the unsupervised loss 83 | cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new 84 | 85 | 86 | def setup_cfg(args): 87 | cfg = get_cfg_default() 88 | extend_cfg(cfg) 89 | 90 | # 1. From the dataset config file 91 | if args.dataset_config_file: 92 | cfg.merge_from_file(args.dataset_config_file) 93 | 94 | # 2. From the method config file 95 | if args.config_file: 96 | cfg.merge_from_file(args.config_file) 97 | 98 | # 3. From input arguments 99 | reset_cfg(cfg, args) 100 | 101 | cfg.freeze() 102 | 103 | return cfg 104 | 105 | 106 | def main(args): 107 | cfg = setup_cfg(args) 108 | if cfg.SEED >= 0: 109 | print("Setting fixed seed: {}".format(cfg.SEED)) 110 | set_random_seed(cfg.SEED) 111 | setup_logger(cfg.OUTPUT_DIR) 112 | 113 | if torch.cuda.is_available() and cfg.USE_CUDA: 114 | torch.backends.cudnn.benchmark = True 115 | 116 | print_args(args, cfg) 117 | print("Collecting env info ...") 118 | print("** System info **\n{}\n".format(collect_env_info())) 119 | 120 | ###################################### 121 | # Setup DataLoader 122 | ###################################### 123 | dataset = eval(cfg.DATASET.NAME)(cfg) 124 | 125 | if args.split == "train": 126 | dataset_input = dataset.train_x 127 | elif args.split == "val": 128 | dataset_input = dataset.val 129 | else: 130 | dataset_input = dataset.test 131 | 132 | tfm_train = build_transform(cfg, is_train=False) 133 | data_loader = torch.utils.data.DataLoader( 134 | DatasetWrapper(cfg, dataset_input, transform=tfm_train, 135 | is_train=False), 136 | batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE, 137 | sampler=None, 138 | shuffle=False, 139 | num_workers=cfg.DATALOADER.NUM_WORKERS, 140 | drop_last=False, 141 | pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA), 142 | ) 143 | 144 | ######################################## 145 | # Setup Network 146 | ######################################## 147 | clip_model, _ = clip.load("RN50", "cuda", jit=False) 148 | clip_model.eval() 149 | ################################################################################################################### 150 | # Start Feature Extractor 151 | feature_list = [] 152 | label_list = [] 153 | train_dataiter = iter(data_loader) 154 | for train_step in range(1, len(train_dataiter) + 1): 155 | batch = next(train_dataiter) 156 | data = batch["img"].cuda() 157 | feature = clip_model.visual(data) 158 | feature = feature.cpu() 159 | for idx in range(len(data)): 160 | feature_list.append(feature[idx].tolist()) 161 | label_list.extend(batch["label"].tolist()) 162 | save_dir = os.path.join(cfg.OUTPUT_DIR, cfg.DATASET.NAME) 163 | os.makedirs(save_dir, exist_ok=True) 164 | save_filename = f"{args.split}" 165 | np.savez( 166 | os.path.join(save_dir, save_filename), 167 | feature_list=feature_list, 168 | label_list=label_list, 169 | ) 170 | 171 | 172 | if __name__ == "__main__": 173 | parser = argparse.ArgumentParser() 174 | parser.add_argument("--root", type=str, default="", help="path to dataset") 175 | parser.add_argument("--output-dir", 176 | type=str, 177 | default="", 178 | help="output directory") 179 | parser.add_argument("--config-file", 180 | type=str, 181 | default="", 182 | help="path to config file") 183 | parser.add_argument( 184 | "--dataset-config-file", 185 | type=str, 186 | default="", 187 | help="path to config file for dataset setup", 188 | ) 189 | parser.add_argument("--num-shot", 190 | type=int, 191 | default=1, 192 | help="number of shots") 193 | parser.add_argument("--split", 194 | type=str, 195 | choices=["train", "val", "test"], 196 | help="which split") 197 | parser.add_argument("--trainer", 198 | type=str, 199 | default="", 200 | help="name of trainer") 201 | parser.add_argument("--backbone", 202 | type=str, 203 | default="", 204 | help="name of CNN backbone") 205 | parser.add_argument("--head", type=str, default="", help="name of head") 206 | parser.add_argument("--seed", 207 | type=int, 208 | default=-1, 209 | help="only positive value enables a fixed seed") 210 | parser.add_argument("--eval-only", 211 | action="store_true", 212 | help="evaluation only") 213 | args = parser.parse_args() 214 | main(args) 215 | -------------------------------------------------------------------------------- /lpclip/feat_extractor.sh: -------------------------------------------------------------------------------- 1 | # sh feat_extractor.sh 2 | DATA=/data1/CoOpData 3 | OUTPUT='/data1/CoOpData/clip_feat/' 4 | SEED=1 5 | 6 | GPULIST=(0 1 2 3) 7 | GPUIDX=0 8 | 9 | # oxford_pets oxford_flowers fgvc_aircraft dtd eurosat stanford_cars food101 sun397 caltech101 ucf101 imagenet 10 | # imagenet oxford_pets oxford_flowers stanford_cars food101 caltech101 11 | for DATASET in imagenetv2 imagenet_sketch imagenet_a imagenet_r 12 | do 13 | for SPLIT in train val test 14 | do 15 | while true 16 | do 17 | sleep 10 18 | let STATIDX=GPULIST[GPUIDX]+2 19 | stat=$(gpustat | awk '{print $11}' | sed -n ${STATIDX}'p') 20 | if [ "$stat" -lt 20 ] 21 | then 22 | break 23 | fi 24 | let GPUIDX=(GPUIDX+1)%${#GPULIST[@]} 25 | echo $GPUIDX'N' 26 | done 27 | CUDA_VISIBLE_DEVICES=${GPULIST[${GPUIDX}]} python feat_extractor.py \ 28 | --split ${SPLIT} \ 29 | --root ${DATA} \ 30 | --seed ${SEED} \ 31 | --dataset-config-file ../configs/datasets/${DATASET}.yaml \ 32 | --config-file ../configs/trainers/CoOp/rn50_val.yaml \ 33 | --output-dir ${OUTPUT} \ 34 | --eval-only & 35 | sleep 10 36 | done 37 | done 38 | -------------------------------------------------------------------------------- /lpclip/linear_probe.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from sklearn.linear_model import LogisticRegression 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--dataset", type=str, default="", help="path to dataset") 8 | parser.add_argument("--num_step", type=int, default=8, help="number of steps") 9 | parser.add_argument("--num_run", type=int, default=10, help="number of runs") 10 | parser.add_argument("--feature_dir", 11 | type=str, 12 | default="clip_feat", 13 | help="feature dir path") 14 | args = parser.parse_args() 15 | 16 | dataset = args.dataset 17 | dataset_path = os.path.join(f"{args.feature_dir}", dataset) 18 | 19 | train_file = np.load(os.path.join(dataset_path, "train.npz")) 20 | train_feature, train_label = train_file["feature_list"], train_file[ 21 | "label_list"] 22 | val_file = np.load(os.path.join(dataset_path, "val.npz")) 23 | val_feature, val_label = val_file["feature_list"], val_file["label_list"] 24 | test_file = np.load(os.path.join(dataset_path, "test.npz")) 25 | test_feature, test_label = test_file["feature_list"], test_file["label_list"] 26 | 27 | os.makedirs("report", exist_ok=True) 28 | 29 | val_shot_list = {1: 1, 2: 2, 4: 4, 8: 4, 16: 4} 30 | 31 | # for num_shot in [1, 2, 4, 8, 16]: 32 | for num_shot in [4, 16]: 33 | test_acc_step_list = np.zeros([args.num_run, args.num_step]) 34 | for seed in range(1, args.num_run + 1): 35 | np.random.seed(seed) 36 | print( 37 | f"-- Seed: {seed} --------------------------------------------------------------" 38 | ) 39 | # Sampling 40 | all_label_list = np.unique(train_label) 41 | selected_idx_list = [] 42 | for label in all_label_list: 43 | label_collection = np.where(train_label == label)[0] 44 | selected_idx = np.random.choice(label_collection, 45 | size=num_shot, 46 | replace=False) 47 | selected_idx_list.extend(selected_idx) 48 | 49 | fewshot_train_feature = train_feature[selected_idx_list] 50 | fewshot_train_label = train_label[selected_idx_list] 51 | 52 | val_num_shot = val_shot_list[num_shot] 53 | val_selected_idx_list = [] 54 | for label in all_label_list: 55 | label_collection = np.where(val_label == label)[0] 56 | selected_idx = np.random.choice(label_collection, 57 | size=val_num_shot, 58 | replace=False) 59 | val_selected_idx_list.extend(selected_idx) 60 | 61 | fewshot_val_feature = val_feature[val_selected_idx_list] 62 | fewshot_val_label = val_label[val_selected_idx_list] 63 | 64 | # search initialization 65 | search_list = [1e6, 1e4, 1e2, 1, 1e-2, 1e-4, 1e-6] 66 | acc_list = [] 67 | for c_weight in search_list: 68 | clf = LogisticRegression(solver="lbfgs", 69 | max_iter=1000, 70 | penalty="l2", 71 | C=c_weight).fit(fewshot_train_feature, 72 | fewshot_train_label) 73 | pred = clf.predict(fewshot_val_feature) 74 | acc_val = sum(pred == fewshot_val_label) / len(fewshot_val_label) 75 | acc_list.append(acc_val) 76 | 77 | print(acc_list, flush=True) 78 | 79 | # binary search 80 | peak_idx = np.argmax(acc_list) 81 | c_peak = search_list[peak_idx] 82 | c_left, c_right = 1e-1 * c_peak, 1e1 * c_peak 83 | 84 | def binary_search(c_left, c_right, seed, step, test_acc_step_list): 85 | clf_left = LogisticRegression(solver="lbfgs", 86 | max_iter=1000, 87 | penalty="l2", 88 | C=c_left).fit( 89 | fewshot_train_feature, 90 | fewshot_train_label) 91 | pred_left = clf_left.predict(fewshot_val_feature) 92 | acc_left = sum( 93 | pred_left == fewshot_val_label) / len(fewshot_val_label) 94 | print("Val accuracy (Left): {:.2f}".format(100 * acc_left), 95 | flush=True) 96 | 97 | clf_right = LogisticRegression(solver="lbfgs", 98 | max_iter=1000, 99 | penalty="l2", 100 | C=c_right).fit( 101 | fewshot_train_feature, 102 | fewshot_train_label) 103 | pred_right = clf_right.predict(fewshot_val_feature) 104 | acc_right = sum( 105 | pred_right == fewshot_val_label) / len(fewshot_val_label) 106 | print("Val accuracy (Right): {:.2f}".format(100 * acc_right), 107 | flush=True) 108 | 109 | # find maximum and update ranges 110 | if acc_left < acc_right: 111 | c_final = c_right 112 | clf_final = clf_right 113 | # range for the next step 114 | c_left = 0.5 * (np.log10(c_right) + np.log10(c_left)) 115 | c_right = np.log10(c_right) 116 | else: 117 | c_final = c_left 118 | clf_final = clf_left 119 | # range for the next step 120 | c_right = 0.5 * (np.log10(c_right) + np.log10(c_left)) 121 | c_left = np.log10(c_left) 122 | 123 | pred = clf_final.predict(test_feature) 124 | test_acc = 100 * sum(pred == test_label) / len(pred) 125 | print("Test Accuracy: {:.2f}".format(test_acc), flush=True) 126 | test_acc_step_list[seed - 1, step] = test_acc 127 | 128 | saveline = "{}, seed {}, {} shot, weight {}, test_acc {:.2f}\n".format( 129 | dataset, seed, num_shot, c_final, test_acc) 130 | with open( 131 | "./report/{}_s{}r{}_details.txt".format( 132 | 'clip_feat', args.num_step, args.num_run), 133 | "a+", 134 | ) as writer: 135 | writer.write(saveline) 136 | return ( 137 | np.power(10, c_left), 138 | np.power(10, c_right), 139 | seed, 140 | step, 141 | test_acc_step_list, 142 | ) 143 | 144 | for step in range(args.num_step): 145 | print( 146 | f"{dataset}, {num_shot} Shot, Round {step}: {c_left}/{c_right}", 147 | flush=True, 148 | ) 149 | c_left, c_right, seed, step, test_acc_step_list = binary_search( 150 | c_left, c_right, seed, step, test_acc_step_list) 151 | # save results of last step 152 | test_acc_list = test_acc_step_list[:, -1] 153 | acc_mean = np.mean(test_acc_list) 154 | acc_std = np.std(test_acc_list) 155 | save_line = "{}, {} Shot, Test acc stat: {:.2f} ({:.2f})\n".format( 156 | dataset, num_shot, acc_mean, acc_std) 157 | print(save_line, flush=True) 158 | with open( 159 | "./report/{}_s{}r{}.txt".format('clip_feat', args.num_step, 160 | args.num_run), 161 | "a+", 162 | ) as writer: 163 | writer.write(save_line) 164 | -------------------------------------------------------------------------------- /lpclip/linear_probe.sh: -------------------------------------------------------------------------------- 1 | feature_dir=/data1/CoOpData/clip_feat/ 2 | # ImageNet OxfordPets OxfordFlowers StanfordCars Food101 Caltech101 3 | for DATASET in ImageNet 4 | do 5 | python linear_probe.py \ 6 | --dataset ${DATASET} \ 7 | --feature_dir ${feature_dir} \ 8 | --num_step 8 \ 9 | --num_run 3 10 | done 11 | -------------------------------------------------------------------------------- /lpclip/linear_probe_transfer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from sklearn.linear_model import LogisticRegression 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | # parser.add_argument("--train_dataset", 8 | # type=str, 9 | # default="", 10 | # help="path to train dataset") 11 | # parser.add_argument("--test_dataset", 12 | # type=str, 13 | # default="", 14 | # help="path to test dataset") 15 | parser.add_argument("--num_step", type=int, default=8, help="number of steps") 16 | parser.add_argument("--num_run", type=int, default=10, help="number of runs") 17 | parser.add_argument("--feature_dir", 18 | type=str, 19 | default="/data1/CoOpData/clip_feat/", 20 | help="feature dir path") 21 | args = parser.parse_args() 22 | 23 | train_dataset = 'ImageNet' 24 | train_dataset_path = os.path.join(f"{args.feature_dir}", train_dataset) 25 | test_datasets = ['ImageNetV2', 'ImageNetSketch', 'ImageNetR', 'ImageNetA'] 26 | test_dataset_paths = [ 27 | os.path.join(f"{args.feature_dir}", test_dataset) 28 | for test_dataset in test_datasets 29 | ] 30 | 31 | train_file = np.load(os.path.join(train_dataset_path, "train.npz")) 32 | train_feature, train_label = train_file["feature_list"], train_file[ 33 | "label_list"] 34 | val_file = np.load(os.path.join(train_dataset_path, "val.npz")) 35 | val_feature, val_label = val_file["feature_list"], val_file["label_list"] 36 | 37 | test_files = [ 38 | np.load(os.path.join(test_dataset_path, "test.npz")) 39 | for test_dataset_path in test_dataset_paths 40 | ] 41 | test_features, test_labels = [ 42 | test_file["feature_list"] for test_file in test_files 43 | ], [test_file["label_list"] for test_file in test_files] 44 | 45 | os.makedirs("report", exist_ok=True) 46 | 47 | val_shot_list = {1: 1, 2: 2, 4: 4, 8: 4, 16: 4} 48 | 49 | # for num_shot in [1, 2, 4, 8, 16]: 50 | for num_shot in [16]: 51 | test_acc_step_list = np.zeros( 52 | [len(test_datasets), args.num_run, args.num_step]) 53 | for seed in range(1, args.num_run + 1): 54 | np.random.seed(seed) 55 | print( 56 | f"-- Seed: {seed} --------------------------------------------------------------" 57 | ) 58 | # Sampling 59 | all_label_list = np.unique(train_label) 60 | selected_idx_list = [] 61 | for label in all_label_list: 62 | label_collection = np.where(train_label == label)[0] 63 | selected_idx = np.random.choice(label_collection, 64 | size=num_shot, 65 | replace=False) 66 | selected_idx_list.extend(selected_idx) 67 | 68 | fewshot_train_feature = train_feature[selected_idx_list] 69 | fewshot_train_label = train_label[selected_idx_list] 70 | 71 | val_num_shot = val_shot_list[num_shot] 72 | val_selected_idx_list = [] 73 | for label in all_label_list: 74 | label_collection = np.where(val_label == label)[0] 75 | selected_idx = np.random.choice(label_collection, 76 | size=val_num_shot, 77 | replace=False) 78 | val_selected_idx_list.extend(selected_idx) 79 | 80 | fewshot_val_feature = val_feature[val_selected_idx_list] 81 | fewshot_val_label = val_label[val_selected_idx_list] 82 | 83 | # search initialization 84 | search_list = [1e6, 1e4, 1e2, 1, 1e-2, 1e-4, 1e-6] 85 | acc_list = [] 86 | for c_weight in search_list: 87 | clf = LogisticRegression(solver="lbfgs", 88 | max_iter=1000, 89 | penalty="l2", 90 | C=c_weight).fit(fewshot_train_feature, 91 | fewshot_train_label) 92 | pred = clf.predict(fewshot_val_feature) 93 | acc_val = sum(pred == fewshot_val_label) / len(fewshot_val_label) 94 | acc_list.append(acc_val) 95 | 96 | print(acc_list, flush=True) 97 | 98 | # binary search 99 | peak_idx = np.argmax(acc_list) 100 | c_peak = search_list[peak_idx] 101 | c_left, c_right = 1e-1 * c_peak, 1e1 * c_peak 102 | 103 | def binary_search(c_left, c_right, seed, step, test_acc_step_list): 104 | clf_left = LogisticRegression(solver="lbfgs", 105 | max_iter=1000, 106 | penalty="l2", 107 | C=c_left).fit( 108 | fewshot_train_feature, 109 | fewshot_train_label) 110 | pred_left = clf_left.predict(fewshot_val_feature) 111 | acc_left = sum( 112 | pred_left == fewshot_val_label) / len(fewshot_val_label) 113 | print("Val accuracy (Left): {:.2f}".format(100 * acc_left), 114 | flush=True) 115 | 116 | clf_right = LogisticRegression(solver="lbfgs", 117 | max_iter=1000, 118 | penalty="l2", 119 | C=c_right).fit( 120 | fewshot_train_feature, 121 | fewshot_train_label) 122 | pred_right = clf_right.predict(fewshot_val_feature) 123 | acc_right = sum( 124 | pred_right == fewshot_val_label) / len(fewshot_val_label) 125 | print("Val accuracy (Right): {:.2f}".format(100 * acc_right), 126 | flush=True) 127 | 128 | # find maximum and update ranges 129 | if acc_left < acc_right: 130 | c_final = c_right 131 | clf_final = clf_right 132 | # range for the next step 133 | c_left = 0.5 * (np.log10(c_right) + np.log10(c_left)) 134 | c_right = np.log10(c_right) 135 | else: 136 | c_final = c_left 137 | clf_final = clf_left 138 | # range for the next step 139 | c_right = 0.5 * (np.log10(c_right) + np.log10(c_left)) 140 | c_left = np.log10(c_left) 141 | 142 | for i, (test_feature, test_label, test_dataset) in enumerate( 143 | zip(test_features, test_labels, test_datasets)): 144 | pred = clf_final.predict(test_feature) 145 | test_acc = 100 * sum(pred == test_label) / len(pred) 146 | print("Test Accuracy: {:.2f}".format(test_acc), flush=True) 147 | test_acc_step_list[i, seed - 1, step] = test_acc 148 | 149 | saveline = "{}, {}, seed {}, {} shot, weight {}, test_acc {:.2f}\n".format( 150 | train_dataset, test_dataset, seed, num_shot, c_final, 151 | test_acc) 152 | with open( 153 | "./report/{}_s{}r{}_details.txt".format( 154 | 'clip_feat', args.num_step, args.num_run), 155 | "a+", 156 | ) as writer: 157 | writer.write(saveline) 158 | return ( 159 | np.power(10, c_left), 160 | np.power(10, c_right), 161 | seed, 162 | step, 163 | test_acc_step_list, 164 | ) 165 | 166 | for step in range(args.num_step): 167 | print( 168 | f"{train_dataset}, {num_shot} Shot, Round {step}: {c_left}/{c_right}", 169 | flush=True, 170 | ) 171 | c_left, c_right, seed, step, test_acc_step_list = binary_search( 172 | c_left, c_right, seed, step, test_acc_step_list) 173 | # save results of last step 174 | test_acc_list = test_acc_step_list[:, :, -1] 175 | acc_mean = np.mean(test_acc_list, dim=-1) 176 | acc_std = np.std(test_acc_list, dim=-1) 177 | for i in range(len(test_datasets)): 178 | save_line = "{}, {}, {} Shot, Test acc stat: {:.2f} ({:.2f})\n".format( 179 | train_dataset, test_datasets[i], num_shot, acc_mean[i], acc_std[i]) 180 | print(save_line, flush=True) 181 | with open( 182 | "./report/{}_s{}r{}.txt".format('clip_feat', args.num_step, 183 | args.num_run), 184 | "a+", 185 | ) as writer: 186 | writer.write(save_line) 187 | -------------------------------------------------------------------------------- /parse_test_res.py: -------------------------------------------------------------------------------- 1 | """ 2 | Goal 3 | --- 4 | 1. Read test results from log.txt files 5 | 2. Compute mean and std across different folders (seeds) 6 | 7 | Usage 8 | --- 9 | Assume the output files are saved under output/my_experiment, 10 | which contains results of different seeds, e.g., 11 | 12 | my_experiment/ 13 | seed1/ 14 | log.txt 15 | seed2/ 16 | log.txt 17 | seed3/ 18 | log.txt 19 | 20 | Run the following command from the root directory: 21 | 22 | $ python tools/parse_test_res.py output/my_experiment 23 | 24 | Add --ci95 to the argument if you wanna get 95% confidence 25 | interval instead of standard deviation: 26 | 27 | $ python tools/parse_test_res.py output/my_experiment --ci95 28 | 29 | If my_experiment/ has the following structure, 30 | 31 | my_experiment/ 32 | exp-1/ 33 | seed1/ 34 | log.txt 35 | ... 36 | seed2/ 37 | log.txt 38 | ... 39 | seed3/ 40 | log.txt 41 | ... 42 | exp-2/ 43 | ... 44 | exp-3/ 45 | ... 46 | 47 | Run 48 | 49 | $ python tools/parse_test_res.py output/my_experiment --multi-exp 50 | """ 51 | import re 52 | import numpy as np 53 | import os.path as osp 54 | import argparse 55 | from collections import OrderedDict, defaultdict 56 | 57 | from dassl.utils import check_isfile, listdir_nohidden 58 | 59 | 60 | def compute_ci95(res): 61 | return 1.96 * np.std(res) / np.sqrt(len(res)) 62 | 63 | 64 | def parse_function(*metrics, directory="", args=None, end_signal=None): 65 | #print(f"Parsing files in {directory}") 66 | subdirs = listdir_nohidden(directory, sort=True) 67 | 68 | outputs = [] 69 | 70 | for subdir in subdirs: 71 | fpath = osp.join(directory, subdir, "log.txt") 72 | assert check_isfile(fpath) 73 | good_to_go = False 74 | output = OrderedDict() 75 | 76 | with open(fpath, "r") as f: 77 | lines = f.readlines() 78 | 79 | for line in lines: 80 | line = line.strip() 81 | 82 | if line == end_signal: 83 | good_to_go = True 84 | 85 | for metric in metrics: 86 | match = metric["regex"].search(line) 87 | if match and good_to_go: 88 | if "file" not in output: 89 | output["file"] = fpath 90 | num = float(match.group(1)) 91 | name = metric["name"] 92 | output[name] = num 93 | 94 | if output: 95 | outputs.append(output) 96 | 97 | assert len(outputs) > 0, f"Nothing found in {directory}" 98 | 99 | metrics_results = defaultdict(list) 100 | 101 | for output in outputs: 102 | msg = "" 103 | for key, value in output.items(): 104 | if isinstance(value, float): 105 | msg += f"{key}: {value:.2f}%. " 106 | else: 107 | msg += f"{key}: {value}. " 108 | if key != "file": 109 | metrics_results[key].append(value) 110 | #print(msg) 111 | 112 | output_results = OrderedDict() 113 | 114 | #print("===") 115 | #print(f"Summary of directory: {directory}") 116 | dir_sets = directory.split('/') 117 | #print(dir_sets) 118 | for key, values in metrics_results.items(): 119 | avg = np.mean(values) 120 | std = compute_ci95(values) if args.ci95 else np.std(values) 121 | print(f"* {dir_sets[2]} {dir_sets[3]} {key}: {avg:.2f}% +- {std:.2f}%") 122 | output_results[key] = avg 123 | #print("===") 124 | 125 | return output_results 126 | 127 | 128 | def main(args, end_signal): 129 | metric = { 130 | "name": args.keyword, 131 | "regex": re.compile(fr"\* {args.keyword}: ([\.\deE+-]+)%"), 132 | } 133 | 134 | if args.multi_exp: 135 | final_results = defaultdict(list) 136 | 137 | for directory in listdir_nohidden(args.directory, sort=True): 138 | directory = osp.join(args.directory, directory) 139 | results = parse_function(metric, 140 | directory=directory, 141 | args=args, 142 | end_signal=end_signal) 143 | 144 | for key, value in results.items(): 145 | final_results[key].append(value) 146 | 147 | print("Average performance") 148 | for key, values in final_results.items(): 149 | avg = np.mean(values) 150 | print(f"* {key}: {avg:.2f}%") 151 | 152 | else: 153 | parse_function(metric, 154 | directory=args.directory, 155 | args=args, 156 | end_signal=end_signal) 157 | 158 | 159 | if __name__ == "__main__": 160 | parser = argparse.ArgumentParser() 161 | parser.add_argument("directory", type=str, help="path to directory") 162 | parser.add_argument("--ci95", 163 | action="store_true", 164 | help=r"compute 95\% confidence interval") 165 | parser.add_argument("--test-log", 166 | action="store_true", 167 | help="parse test-only logs") 168 | parser.add_argument("--multi-exp", 169 | action="store_true", 170 | help="parse multiple experiments") 171 | parser.add_argument("--keyword", 172 | default="accuracy", 173 | type=str, 174 | help="which keyword to extract") 175 | args = parser.parse_args() 176 | 177 | end_signal = "Finished training" 178 | if args.test_log: 179 | end_signal = "=> result" 180 | 181 | main(args, end_signal) 182 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy 2 | regex 3 | tqdm 4 | -------------------------------------------------------------------------------- /scripts/base2new_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd .. 4 | 5 | # custom config 6 | DATA=XXXXX 7 | TRAINER=TCP 8 | WEIGHT=1.0 9 | 10 | CFG=vit_b16_ep100_ctxv1 11 | CTP=end # class token position (end or middle) 12 | NCTX=4 # number of context tokens 13 | SHOTS=16 # number of shots (1, 2, 4, 8, 16) 14 | CSC=False # class-specific context (False or True) 15 | FOLDER=output_1108 16 | 17 | for DATASET in caltech101 dtd eurosat fgvc_aircraft food101 oxford_flowers oxford_pets stanford_cars ucf101 18 | do 19 | for SEED in 1 2 3 20 | do 21 | DIR=${FOLDER}_${NCTX}/base2new/train_base/${DATASET}/shots_${SHOTS}_${WEIGHT}/${TRAINER}/${CFG}/seed${SEED} 22 | if [ -d "$DIR" ]; then 23 | echo "Results are available in ${DIR}. Skip this job" 24 | else 25 | echo "Run this job and save the output to ${DIR}" 26 | CUDA_VISIBLE_DEVICES=0 python train.py \ 27 | --root ${DATA} \ 28 | --seed ${SEED} \ 29 | --trainer ${TRAINER} \ 30 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 31 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 32 | --output-dir ${DIR} \ 33 | TRAINER.COOP.N_CTX ${NCTX} \ 34 | TRAINER.COOP.CSC ${CSC} \ 35 | TRAINER.COOP.W ${WEIGHT} \ 36 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ 37 | DATASET.NUM_SHOTS ${SHOTS} \ 38 | DATASET.SUBSAMPLE_CLASSES base 39 | fi 40 | done 41 | 42 | 43 | LOADEP=50 44 | SUB=new 45 | for SEED in 1 2 3 46 | do 47 | COMMON_DIR=${DATASET}/shots_${SHOTS}_${WEIGHT}/${TRAINER}/${CFG}/seed${SEED} 48 | MODEL_DIR=${FOLDER}_${NCTX}/base2new/train_base/${COMMON_DIR} 49 | DIR=${FOLDER}_${NCTX}_eval/base2new/test_${SUB}/${COMMON_DIR} 50 | 51 | if [ -d "$DIR" ]; then 52 | echo "Results are available in ${DIR}. Skip this job" 53 | else 54 | echo "Run this job and save the output to ${DIR}" 55 | CUDA_VISIBLE_DEVICES=0 python train.py \ 56 | --root ${DATA} \ 57 | --seed ${SEED} \ 58 | --trainer ${TRAINER} \ 59 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 60 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 61 | --output-dir ${DIR} \ 62 | --model-dir ${MODEL_DIR} \ 63 | --load-epoch ${LOADEP} \ 64 | --eval-only \ 65 | TRAINER.COOP.N_CTX ${NCTX} \ 66 | TRAINER.COOP.CSC ${CSC} \ 67 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ 68 | DATASET.NUM_SHOTS ${SHOTS} \ 69 | DATASET.SUBSAMPLE_CLASSES ${SUB} 70 | fi 71 | done 72 | done 73 | 74 | 75 | for DATASET in sun397 imagenet 76 | do 77 | for SEED in 1 2 3 78 | do 79 | DIR=${FOLDER}_${NCTX}/base2new/train_base/${DATASET}/shots_${SHOTS}_${WEIGHT}/${TRAINER}/${CFG}/seed${SEED} 80 | if [ -d "$DIR" ]; then 81 | echo "Results are available in ${DIR}. Skip this job" 82 | else 83 | echo "Run this job and save the output to ${DIR}" 84 | CUDA_VISIBLE_DEVICES=0 python train.py \ 85 | --root ${DATA} \ 86 | --seed ${SEED} \ 87 | --trainer ${TRAINER} \ 88 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 89 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 90 | --output-dir ${DIR} \ 91 | TRAINER.COOP.N_CTX ${NCTX} \ 92 | TRAINER.COOP.CSC ${CSC} \ 93 | TRAINER.COOP.W ${WEIGHT} \ 94 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ 95 | DATASET.NUM_SHOTS ${SHOTS} \ 96 | DATASET.SUBSAMPLE_CLASSES base 97 | fi 98 | done 99 | 100 | 101 | LOADEP=25 102 | SUB=new 103 | for SEED in 1 2 3 104 | do 105 | COMMON_DIR=${DATASET}/shots_${SHOTS}_${WEIGHT}/${TRAINER}/${CFG}/seed${SEED} 106 | MODEL_DIR=${FOLDER}_${NCTX}/base2new/train_base/${COMMON_DIR} 107 | DIR=${FOLDER}_${NCTX}/base2new/test_${SUB}/${COMMON_DIR} 108 | 109 | if [ -d "$DIR" ]; then 110 | echo "Results are available in ${DIR}. Skip this job" 111 | else 112 | echo "Run this job and save the output to ${DIR}" 113 | CUDA_VISIBLE_DEVICES=0 python train.py \ 114 | --root ${DATA} \ 115 | --seed ${SEED} \ 116 | --trainer ${TRAINER} \ 117 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 118 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 119 | --output-dir ${DIR} \ 120 | --model-dir ${MODEL_DIR} \ 121 | --load-epoch ${LOADEP} \ 122 | --eval-only \ 123 | TRAINER.COOP.N_CTX ${NCTX} \ 124 | TRAINER.COOP.CSC ${CSC} \ 125 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ 126 | DATASET.NUM_SHOTS ${SHOTS} \ 127 | DATASET.SUBSAMPLE_CLASSES ${SUB} 128 | fi 129 | done 130 | done 131 | 132 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import time 4 | import os 5 | 6 | from dassl.utils import setup_logger, set_random_seed, collect_env_info 7 | from dassl.config import get_cfg_default 8 | from dassl.engine import build_trainer 9 | 10 | # custom 11 | import datasets.oxford_pets 12 | import datasets.oxford_flowers 13 | import datasets.fgvc_aircraft 14 | import datasets.dtd 15 | import datasets.eurosat 16 | import datasets.stanford_cars 17 | import datasets.food101 18 | import datasets.sun397 19 | import datasets.caltech101 20 | import datasets.ucf101 21 | import datasets.imagenet 22 | 23 | import datasets.imagenet_sketch 24 | import datasets.imagenetv2 25 | import datasets.imagenet_a 26 | import datasets.imagenet_r 27 | 28 | import trainers.coop 29 | import trainers.cocoop 30 | import trainers.zsclip 31 | import trainers.prograd 32 | import trainers.tcp 33 | def print_args(args, cfg): 34 | print("***************") 35 | print("** Arguments **") 36 | print("***************") 37 | optkeys = list(args.__dict__.keys()) 38 | optkeys.sort() 39 | for key in optkeys: 40 | print("{}: {}".format(key, args.__dict__[key])) 41 | print("************") 42 | print("** Config **") 43 | print("************") 44 | print(cfg) 45 | 46 | 47 | def reset_cfg(cfg, args): 48 | if args.root: 49 | cfg.DATASET.ROOT = args.root 50 | 51 | if args.output_dir: 52 | cfg.OUTPUT_DIR = args.output_dir 53 | 54 | if args.resume: 55 | cfg.RESUME = args.resume 56 | 57 | if args.seed: 58 | cfg.SEED = args.seed 59 | 60 | if args.source_domains: 61 | cfg.DATASET.SOURCE_DOMAINS = args.source_domains 62 | 63 | if args.target_domains: 64 | cfg.DATASET.TARGET_DOMAINS = args.target_domains 65 | 66 | if args.transforms: 67 | cfg.INPUT.TRANSFORMS = args.transforms 68 | 69 | if args.trainer: 70 | cfg.TRAINER.NAME = args.trainer 71 | 72 | if args.backbone: 73 | cfg.MODEL.BACKBONE.NAME = args.backbone 74 | 75 | if args.head: 76 | cfg.MODEL.HEAD.NAME = args.head 77 | 78 | 79 | def extend_cfg(cfg): 80 | """ 81 | Add new config variables. 82 | 83 | E.g. 84 | from yacs.config import CfgNode as CN 85 | cfg.TRAINER.MY_MODEL = CN() 86 | cfg.TRAINER.MY_MODEL.PARAM_A = 1. 87 | cfg.TRAINER.MY_MODEL.PARAM_B = 0.5 88 | cfg.TRAINER.MY_MODEL.PARAM_C = False 89 | """ 90 | from yacs.config import CfgNode as CN 91 | 92 | cfg.TRAINER.COOP = CN() 93 | cfg.TRAINER.COOP.ALPHA = 1.0 94 | cfg.TRAINER.COOP.N_CTX = 16 # number of context vectors 95 | cfg.TRAINER.COOP.CSC = False # class-specific context 96 | cfg.TRAINER.COOP.CTX_INIT = False # initialization words 97 | cfg.TRAINER.COOP.W = 1.0 98 | cfg.TRAINER.COOP.PREC = "fp16" # fp16, fp32, amp 99 | cfg.TRAINER.COOP.CLASS_TOKEN_POSITION = "end" # 'middle' or 'end' or 'front' 100 | 101 | cfg.TRAINER.COCOOP = CN() 102 | cfg.TRAINER.COCOOP.N_CTX = 16 # number of context vectors 103 | cfg.TRAINER.COCOOP.CTX_INIT = False # initialization words 104 | cfg.TRAINER.COCOOP.PREC = "fp16" # fp16, fp32, amp 105 | 106 | cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new 107 | """ 108 | Add new config 109 | """ 110 | cfg.LOSS = CN() 111 | cfg.LOSS.GM = False 112 | cfg.LOSS.NAME = "" 113 | cfg.LOSS.ALPHA = 0. 114 | cfg.LOSS.T = 1. 115 | cfg.LOSS.LAMBDA = 1. 116 | 117 | 118 | def setup_cfg(args): 119 | cfg = get_cfg_default() 120 | extend_cfg(cfg) 121 | 122 | # 1. From the dataset config file 123 | if args.dataset_config_file: 124 | cfg.merge_from_file(args.dataset_config_file) 125 | 126 | # 2. From the method config file 127 | if args.config_file: 128 | cfg.merge_from_file(args.config_file) 129 | 130 | # 3. From input arguments 131 | reset_cfg(cfg, args) 132 | 133 | # 4. From optional input arguments 134 | cfg.merge_from_list(args.opts) 135 | 136 | if cfg.DATASET.NAME in ["ImageNet",'SUN397']: 137 | cfg.OPTIM.MAX_EPOCH=25 138 | cfg.freeze() 139 | 140 | return cfg 141 | 142 | 143 | def main(args): 144 | cfg = setup_cfg(args) 145 | if cfg.SEED >= 0: 146 | print("Setting fixed seed: {}".format(cfg.SEED)) 147 | set_random_seed(cfg.SEED) 148 | setup_logger(cfg.OUTPUT_DIR) 149 | 150 | if torch.cuda.is_available() and cfg.USE_CUDA: 151 | torch.backends.cudnn.benchmark = True 152 | 153 | print_args(args, cfg) 154 | print("Collecting env info ...") 155 | print("** System info **\n{}\n".format(collect_env_info())) 156 | 157 | trainer = build_trainer(cfg) 158 | 159 | if args.eval_only: 160 | trainer.load_model(args.model_dir, epoch=args.load_epoch) 161 | trainer.test() 162 | return 163 | 164 | if not args.no_train: 165 | trainer.train() 166 | 167 | 168 | if __name__ == "__main__": 169 | parser = argparse.ArgumentParser() 170 | parser.add_argument("--root", type=str, default="", help="path to dataset") 171 | parser.add_argument("--output-dir", 172 | type=str, 173 | default="", 174 | help="output directory") 175 | parser.add_argument( 176 | "--resume", 177 | type=str, 178 | default="", 179 | help="checkpoint directory (from which the training resumes)", 180 | ) 181 | parser.add_argument("--seed", 182 | type=int, 183 | default=-1, 184 | help="only positive value enables a fixed seed") 185 | parser.add_argument("--source-domains", 186 | type=str, 187 | nargs="+", 188 | help="source domains for DA/DG") 189 | parser.add_argument("--target-domains", 190 | type=str, 191 | nargs="+", 192 | help="target domains for DA/DG") 193 | parser.add_argument("--transforms", 194 | type=str, 195 | nargs="+", 196 | help="data augmentation methods") 197 | parser.add_argument("--config-file", 198 | type=str, 199 | default="", 200 | help="path to config file") 201 | parser.add_argument( 202 | "--dataset-config-file", 203 | type=str, 204 | default="", 205 | help="path to config file for dataset setup", 206 | ) 207 | parser.add_argument("--trainer", 208 | type=str, 209 | default="", 210 | help="name of trainer") 211 | parser.add_argument("--backbone", 212 | type=str, 213 | default="", 214 | help="name of CNN backbone") 215 | parser.add_argument("--head", type=str, default="", help="name of head") 216 | parser.add_argument("--eval-only", 217 | action="store_true", 218 | help="evaluation only") 219 | parser.add_argument( 220 | "--model-dir", 221 | type=str, 222 | default="", 223 | help="load model from this directory for eval-only mode", 224 | ) 225 | parser.add_argument("--load-epoch", 226 | type=int, 227 | help="load model weights at this epoch for evaluation") 228 | parser.add_argument("--no-train", 229 | action="store_true", 230 | help="do not call trainer.train()") 231 | parser.add_argument( 232 | "opts", 233 | default=None, 234 | nargs=argparse.REMAINDER, 235 | help="modify config options using the command-line", 236 | ) 237 | args = parser.parse_args() 238 | main(args) 239 | -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__init__.py -------------------------------------------------------------------------------- /trainers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/cocoop.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/cocoop.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/cocoop.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/cocoop.cpython-38.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/coop.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/coop.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/coop.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/coop.cpython-38.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/imagenet_templates.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/imagenet_templates.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/imagenet_templates.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/imagenet_templates.cpython-38.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/kgcoop.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/kgcoop.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/kgcoop.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/kgcoop.cpython-38.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/prograd.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/prograd.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/prograd.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/prograd.cpython-38.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/tcp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/tcp.cpython-38.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/zsclip.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/zsclip.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/zsclip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/zsclip.cpython-38.pyc -------------------------------------------------------------------------------- /trainers/clip_text/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /trainers/clip_text/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/clip_text/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/clip_text/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/clip_text/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /trainers/clip_text/__pycache__/clip.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/clip_text/__pycache__/clip.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/clip_text/__pycache__/clip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/clip_text/__pycache__/clip.cpython-38.pyc -------------------------------------------------------------------------------- /trainers/clip_text/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/clip_text/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/clip_text/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/clip_text/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /trainers/clip_text/__pycache__/simple_tokenizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/clip_text/__pycache__/simple_tokenizer.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/clip_text/__pycache__/simple_tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/clip_text/__pycache__/simple_tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /trainers/clip_text/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/clip_text/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /trainers/clip_text/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | try: 16 | from torchvision.transforms import InterpolationMode 17 | BICUBIC = InterpolationMode.BICUBIC 18 | except ImportError: 19 | BICUBIC = Image.BICUBIC 20 | 21 | if torch.__version__.split(".") < ["1", "7", "1"]: 22 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 23 | 24 | __all__ = ["available_models", "load", "tokenize"] 25 | _tokenizer = _Tokenizer() 26 | 27 | _MODELS = { 28 | "RN50": 29 | "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 30 | "RN101": 31 | "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 32 | "RN50x4": 33 | "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": 35 | "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 36 | "ViT-B/32": 37 | "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 38 | "ViT-B/16": 39 | "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 40 | } 41 | 42 | 43 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 44 | os.makedirs(root, exist_ok=True) 45 | filename = os.path.basename(url) 46 | 47 | expected_sha256 = url.split("/")[-2] 48 | download_target = os.path.join(root, filename) 49 | 50 | if os.path.exists(download_target) and not os.path.isfile(download_target): 51 | raise RuntimeError( 52 | f"{download_target} exists and is not a regular file") 53 | 54 | if os.path.isfile(download_target): 55 | if hashlib.sha256(open(download_target, 56 | "rb").read()).hexdigest() == expected_sha256: 57 | return download_target 58 | else: 59 | warnings.warn( 60 | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" 61 | ) 62 | 63 | with urllib.request.urlopen(url) as source, open(download_target, 64 | "wb") as output: 65 | with tqdm(total=int(source.info().get("Content-Length")), 66 | ncols=80, 67 | unit='iB', 68 | unit_scale=True) as loop: 69 | while True: 70 | buffer = source.read(8192) 71 | if not buffer: 72 | break 73 | 74 | output.write(buffer) 75 | loop.update(len(buffer)) 76 | 77 | if hashlib.sha256(open(download_target, 78 | "rb").read()).hexdigest() != expected_sha256: 79 | raise RuntimeError( 80 | f"Model has been downloaded but the SHA256 checksum does not not match" 81 | ) 82 | 83 | return download_target 84 | 85 | 86 | def _transform(n_px): 87 | return Compose([ 88 | Resize(n_px, interpolation=BICUBIC), 89 | CenterCrop(n_px), 90 | lambda image: image.convert("RGB"), 91 | ToTensor(), 92 | Normalize((0.48145466, 0.4578275, 0.40821073), 93 | (0.26862954, 0.26130258, 0.27577711)), 94 | ]) 95 | 96 | 97 | def available_models() -> List[str]: 98 | """Returns the names of available CLIP models""" 99 | return list(_MODELS.keys()) 100 | 101 | 102 | def load(name: str, 103 | device: Union[str, torch.device] = "cuda" 104 | if torch.cuda.is_available() else "cpu", 105 | jit=False): 106 | """Load a CLIP model 107 | 108 | Parameters 109 | ---------- 110 | name : str 111 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 112 | 113 | device : Union[str, torch.device] 114 | The device to put the loaded model 115 | 116 | jit : bool 117 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 118 | 119 | Returns 120 | ------- 121 | model : torch.nn.Module 122 | The CLIP model 123 | 124 | preprocess : Callable[[PIL.Image], torch.Tensor] 125 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 126 | """ 127 | if name in _MODELS: 128 | model_path = _download(_MODELS[name]) 129 | elif os.path.isfile(name): 130 | model_path = name 131 | else: 132 | raise RuntimeError( 133 | f"Model {name} not found; available models = {available_models()}") 134 | 135 | try: 136 | # loading JIT archive 137 | model = torch.jit.load(model_path, 138 | map_location=device if jit else "cpu").eval() 139 | state_dict = None 140 | except RuntimeError: 141 | # loading saved state dict 142 | if jit: 143 | warnings.warn( 144 | f"File {model_path} is not a JIT archive. Loading as a state dict instead" 145 | ) 146 | jit = False 147 | state_dict = torch.load(model_path, map_location="cpu") 148 | 149 | if not jit: 150 | model = build_model(state_dict or model.state_dict()).to(device) 151 | if str(device) == "cpu": 152 | model.float() 153 | return model, _transform(model.visual.input_resolution) 154 | 155 | # patch the device names 156 | device_holder = torch.jit.trace( 157 | lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 158 | device_node = [ 159 | n for n in device_holder.graph.findAllNodes("prim::Constant") 160 | if "Device" in repr(n) 161 | ][-1] 162 | 163 | def patch_device(module): 164 | try: 165 | graphs = [module.graph] if hasattr(module, "graph") else [] 166 | except RuntimeError: 167 | graphs = [] 168 | 169 | if hasattr(module, "forward1"): 170 | graphs.append(module.forward1.graph) 171 | 172 | for graph in graphs: 173 | for node in graph.findAllNodes("prim::Constant"): 174 | if "value" in node.attributeNames() and str( 175 | node["value"]).startswith("cuda"): 176 | node.copyAttributes(device_node) 177 | 178 | model.apply(patch_device) 179 | patch_device(model.encode_image) 180 | patch_device(model.encode_text) 181 | 182 | # patch dtype to float32 on CPU 183 | if str(device) == "cpu": 184 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), 185 | example_inputs=[]) 186 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 187 | float_node = float_input.node() 188 | 189 | def patch_float(module): 190 | try: 191 | graphs = [module.graph] if hasattr(module, "graph") else [] 192 | except RuntimeError: 193 | graphs = [] 194 | 195 | if hasattr(module, "forward1"): 196 | graphs.append(module.forward1.graph) 197 | 198 | for graph in graphs: 199 | for node in graph.findAllNodes("aten::to"): 200 | inputs = list(node.inputs()) 201 | for i in [ 202 | 1, 2 203 | ]: # dtype can be the second or third argument to aten::to() 204 | if inputs[i].node()["value"] == 5: 205 | inputs[i].node().copyAttributes(float_node) 206 | 207 | model.apply(patch_float) 208 | patch_float(model.encode_image) 209 | patch_float(model.encode_text) 210 | 211 | model.float() 212 | 213 | return model, _transform(model.input_resolution.item()) 214 | 215 | 216 | def tokenize(texts: Union[str, List[str]], 217 | context_length: int = 77, 218 | truncate: bool = False) -> torch.LongTensor: 219 | """ 220 | Returns the tokenized representation of given input string(s) 221 | 222 | Parameters 223 | ---------- 224 | texts : Union[str, List[str]] 225 | An input string or a list of input strings to tokenize 226 | 227 | context_length : int 228 | The context length to use; all CLIP models use 77 as the context length 229 | 230 | truncate: bool 231 | Whether to truncate the text in case its encoding is longer than the context length 232 | 233 | Returns 234 | ------- 235 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 236 | """ 237 | if isinstance(texts, str): 238 | texts = [texts] 239 | 240 | sot_token = _tokenizer.encoder["<|startoftext|>"] 241 | eot_token = _tokenizer.encoder["<|endoftext|>"] 242 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] 243 | for text in texts] 244 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 245 | 246 | for i, tokens in enumerate(all_tokens): 247 | if len(tokens) > context_length: 248 | if truncate: 249 | tokens = tokens[:context_length] 250 | tokens[-1] = eot_token 251 | else: 252 | raise RuntimeError( 253 | f"Input {texts[i]} is too long for context length {context_length}" 254 | ) 255 | result[i, :len(tokens)] = torch.tensor(tokens) 256 | 257 | return result 258 | -------------------------------------------------------------------------------- /trainers/clip_text/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), 13 | "bpe_simple_vocab_16e6.txt.gz") 14 | 15 | 16 | @lru_cache() 17 | def bytes_to_unicode(): 18 | """ 19 | Returns list of utf-8 byte and a corresponding list of unicode strings. 20 | The reversible bpe codes work on unicode strings. 21 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 22 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 23 | This is a signficant percentage of your normal, say, 32K bpe vocab. 24 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 25 | And avoids mapping to whitespace/control characters the bpe code barfs on. 26 | """ 27 | bs = list(range(ord("!"), 28 | ord("~") + 1)) + list(range( 29 | ord("¡"), 30 | ord("¬") + 1)) + list(range(ord("®"), 31 | ord("ÿ") + 1)) 32 | cs = bs[:] 33 | n = 0 34 | for b in range(2**8): 35 | if b not in bs: 36 | bs.append(b) 37 | cs.append(2**8 + n) 38 | n += 1 39 | cs = [chr(n) for n in cs] 40 | return dict(zip(bs, cs)) 41 | 42 | 43 | def get_pairs(word): 44 | """Return set of symbol pairs in a word. 45 | Word is represented as tuple of symbols (symbols being variable-length strings). 46 | """ 47 | pairs = set() 48 | prev_char = word[0] 49 | for char in word[1:]: 50 | pairs.add((prev_char, char)) 51 | prev_char = char 52 | return pairs 53 | 54 | 55 | def basic_clean(text): 56 | text = ftfy.fix_text(text) 57 | text = html.unescape(html.unescape(text)) 58 | return text.strip() 59 | 60 | 61 | def whitespace_clean(text): 62 | text = re.sub(r'\s+', ' ', text) 63 | text = text.strip() 64 | return text 65 | 66 | 67 | class SimpleTokenizer(object): 68 | def __init__(self, bpe_path: str = default_bpe()): 69 | self.byte_encoder = bytes_to_unicode() 70 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 71 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 72 | merges = merges[1:49152 - 256 - 2 + 1] 73 | merges = [tuple(merge.split()) for merge in merges] 74 | vocab = list(bytes_to_unicode().values()) 75 | vocab = vocab + [v + '' for v in vocab] 76 | for merge in merges: 77 | vocab.append(''.join(merge)) 78 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 79 | self.encoder = dict(zip(vocab, range(len(vocab)))) 80 | self.decoder = {v: k for k, v in self.encoder.items()} 81 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 82 | self.cache = { 83 | '<|startoftext|>': '<|startoftext|>', 84 | '<|endoftext|>': '<|endoftext|>' 85 | } 86 | self.pat = re.compile( 87 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", 88 | re.IGNORECASE) 89 | 90 | def bpe(self, token): 91 | if token in self.cache: 92 | return self.cache[token] 93 | word = tuple(token[:-1]) + (token[-1] + '', ) 94 | pairs = get_pairs(word) 95 | 96 | if not pairs: 97 | return token + '' 98 | 99 | while True: 100 | bigram = min( 101 | pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 102 | if bigram not in self.bpe_ranks: 103 | break 104 | first, second = bigram 105 | new_word = [] 106 | i = 0 107 | while i < len(word): 108 | try: 109 | j = word.index(first, i) 110 | new_word.extend(word[i:j]) 111 | i = j 112 | except: 113 | new_word.extend(word[i:]) 114 | break 115 | 116 | if word[i] == first and i < len(word) - 1 and word[ 117 | i + 1] == second: 118 | new_word.append(first + second) 119 | i += 2 120 | else: 121 | new_word.append(word[i]) 122 | i += 1 123 | new_word = tuple(new_word) 124 | word = new_word 125 | if len(word) == 1: 126 | break 127 | else: 128 | pairs = get_pairs(word) 129 | word = ' '.join(word) 130 | self.cache[token] = word 131 | return word 132 | 133 | def encode(self, text): 134 | bpe_tokens = [] 135 | text = whitespace_clean(basic_clean(text)).lower() 136 | for token in re.findall(self.pat, text): 137 | token = ''.join(self.byte_encoder[b] 138 | for b in token.encode('utf-8')) 139 | bpe_tokens.extend(self.encoder[bpe_token] 140 | for bpe_token in self.bpe(token).split(' ')) 141 | return bpe_tokens 142 | 143 | def decode(self, tokens): 144 | text = ''.join([self.decoder[token] for token in tokens]) 145 | text = bytearray([self.byte_decoder[c] for c in text 146 | ]).decode('utf-8', 147 | errors="replace").replace('', ' ') 148 | return text 149 | -------------------------------------------------------------------------------- /trainers/coop.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from torch.cuda.amp import GradScaler, autocast 7 | 8 | from dassl.engine import TRAINER_REGISTRY, TrainerX 9 | from dassl.metrics import compute_accuracy 10 | from dassl.utils import load_pretrained_weights, load_checkpoint 11 | from dassl.optim import build_optimizer, build_lr_scheduler 12 | 13 | from clip import clip 14 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | _tokenizer = _Tokenizer() 17 | 18 | 19 | def load_clip_to_cpu(cfg): 20 | backbone_name = cfg.MODEL.BACKBONE.NAME 21 | url = clip._MODELS[backbone_name] 22 | model_path = clip._download(url) 23 | 24 | try: 25 | # loading JIT archive 26 | model = torch.jit.load(model_path, map_location="cpu").eval() 27 | state_dict = None 28 | 29 | except RuntimeError: 30 | state_dict = torch.load(model_path, map_location="cpu") 31 | 32 | model = clip.build_model(state_dict or model.state_dict()) 33 | 34 | return model 35 | 36 | 37 | class TextEncoder(nn.Module): 38 | def __init__(self, clip_model): 39 | super().__init__() 40 | self.transformer = clip_model.transformer 41 | self.positional_embedding = clip_model.positional_embedding 42 | self.ln_final = clip_model.ln_final 43 | self.text_projection = clip_model.text_projection 44 | self.dtype = clip_model.dtype 45 | 46 | def forward(self, prompts, tokenized_prompts): 47 | x = prompts + self.positional_embedding.type(self.dtype) 48 | x = x.permute(1, 0, 2) # NLD -> LND 49 | x = self.transformer(x) 50 | x = x.permute(1, 0, 2) # LND -> NLD 51 | x = self.ln_final(x).type(self.dtype) 52 | 53 | # x.shape = [batch_size, n_ctx, transformer.width] 54 | # take features from the eot embedding (eot_token is the highest number in each sequence) 55 | x = x[torch.arange(x.shape[0]), 56 | tokenized_prompts.argmax(dim=-1)] @ self.text_projection 57 | 58 | return x 59 | 60 | 61 | class PromptLearner(nn.Module): 62 | def __init__(self, cfg, classnames, clip_model): 63 | super().__init__() 64 | n_cls = len(classnames) 65 | n_ctx = cfg.TRAINER.COOP.N_CTX 66 | ctx_init = cfg.TRAINER.COOP.CTX_INIT 67 | dtype = clip_model.dtype 68 | ctx_dim = clip_model.ln_final.weight.shape[0] 69 | clip_imsize = clip_model.visual.input_resolution 70 | cfg_imsize = cfg.INPUT.SIZE[0] 71 | assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})" 72 | 73 | # if ctx_init: 74 | # # use given words to initialize context vectors 75 | # ctx_init = ctx_init.replace("_", " ") 76 | # n_ctx = len(ctx_init.split(" ")) 77 | # prompt = clip.tokenize(ctx_init) 78 | # with torch.no_grad(): 79 | # embedding = clip_model.token_embedding(prompt).type(dtype) 80 | # ctx_vectors = embedding[0, 1:1 + n_ctx, :] 81 | # prompt_prefix = ctx_init 82 | if ctx_init: 83 | ctx_init = CUSTOM_TEMPLATES[cfg.DATASET.NAME] 84 | ctx_init = ctx_init.replace(" {}.", "") 85 | ctx_init = ctx_init.replace("_", " ") 86 | prompt_n_ctx = len(ctx_init.split(" ")) 87 | 88 | assert n_ctx >= prompt_n_ctx, f"#tokens ({n_ctx}) should larger equal than #initial prompt tokens ({prompt_n_ctx}, {ctx_init})" 89 | 90 | prompt = clip.tokenize(ctx_init) 91 | with torch.no_grad(): 92 | embedding = clip_model.token_embedding(prompt).type(dtype) 93 | 94 | ctx_vectors = torch.zeros(n_ctx, ctx_dim, dtype=dtype) 95 | 96 | ctx_vectors[n_ctx - prompt_n_ctx:, :] = embedding[0, 1:1 + 97 | prompt_n_ctx, :] 98 | prompt_prefix = " ".join(["X"] * (n_ctx - prompt_n_ctx)) 99 | prompt_prefix = f"{prompt_prefix} {ctx_init}" 100 | else: 101 | # random initialization 102 | if cfg.TRAINER.COOP.CSC: 103 | print("Initializing class-specific contexts") 104 | ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype) 105 | else: 106 | print("Initializing a generic context") 107 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) 108 | nn.init.normal_(ctx_vectors, std=0.02) 109 | prompt_prefix = " ".join(["X"] * n_ctx) 110 | 111 | print(f'Initial context: "{prompt_prefix}"') 112 | print(f"Number of context words (tokens): {n_ctx}") 113 | 114 | self.ctx = nn.Parameter(ctx_vectors) # to be optimized 115 | 116 | classnames = [name.replace("_", " ") for name in classnames] 117 | name_lens = [len(_tokenizer.encode(name)) for name in classnames] 118 | prompts = [prompt_prefix + " " + name + "." for name in classnames] 119 | 120 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) 121 | with torch.no_grad(): 122 | embedding = clip_model.token_embedding(tokenized_prompts).type( 123 | dtype) 124 | 125 | # These token vectors will be saved when in save_model(), 126 | # but they should be ignored in load_model() as we want to use 127 | # those computed using the current class names 128 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS 129 | self.register_buffer("token_suffix", 130 | embedding[:, 1 + n_ctx:, :]) # CLS, EOS 131 | 132 | self.n_cls = n_cls 133 | self.n_ctx = n_ctx 134 | self.tokenized_prompts = tokenized_prompts # torch.Tensor 135 | self.name_lens = name_lens 136 | self.class_token_position = cfg.TRAINER.COOP.CLASS_TOKEN_POSITION 137 | 138 | def forward(self): 139 | ctx = self.ctx 140 | if ctx.dim() == 2: 141 | ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) 142 | 143 | prefix = self.token_prefix 144 | suffix = self.token_suffix 145 | 146 | if self.class_token_position == "end": 147 | prompts = torch.cat( 148 | [ 149 | prefix, # (n_cls, 1, dim) 150 | ctx, # (n_cls, n_ctx, dim) 151 | suffix, # (n_cls, *, dim) 152 | ], 153 | dim=1, 154 | ) 155 | 156 | elif self.class_token_position == "middle": 157 | half_n_ctx = self.n_ctx // 2 158 | prompts = [] 159 | for i in range(self.n_cls): 160 | name_len = self.name_lens[i] 161 | prefix_i = prefix[i:i + 1, :, :] 162 | class_i = suffix[i:i + 1, :name_len, :] 163 | suffix_i = suffix[i:i + 1, name_len:, :] 164 | ctx_i_half1 = ctx[i:i + 1, :half_n_ctx, :] 165 | ctx_i_half2 = ctx[i:i + 1, half_n_ctx:, :] 166 | prompt = torch.cat( 167 | [ 168 | prefix_i, # (1, 1, dim) 169 | ctx_i_half1, # (1, n_ctx//2, dim) 170 | class_i, # (1, name_len, dim) 171 | ctx_i_half2, # (1, n_ctx//2, dim) 172 | suffix_i, # (1, *, dim) 173 | ], 174 | dim=1, 175 | ) 176 | prompts.append(prompt) 177 | prompts = torch.cat(prompts, dim=0) 178 | 179 | elif self.class_token_position == "front": 180 | prompts = [] 181 | for i in range(self.n_cls): 182 | name_len = self.name_lens[i] 183 | prefix_i = prefix[i:i + 1, :, :] 184 | class_i = suffix[i:i + 1, :name_len, :] 185 | suffix_i = suffix[i:i + 1, name_len:, :] 186 | ctx_i = ctx[i:i + 1, :, :] 187 | prompt = torch.cat( 188 | [ 189 | prefix_i, # (1, 1, dim) 190 | class_i, # (1, name_len, dim) 191 | ctx_i, # (1, n_ctx, dim) 192 | suffix_i, # (1, *, dim) 193 | ], 194 | dim=1, 195 | ) 196 | prompts.append(prompt) 197 | prompts = torch.cat(prompts, dim=0) 198 | 199 | else: 200 | raise ValueError 201 | 202 | return prompts 203 | 204 | 205 | CUSTOM_TEMPLATES = { 206 | # "OxfordPets": "a photo of a {}, a type of pet.", 207 | "OxfordPets": "a type of pet, a photo of a {}.", 208 | # "OxfordFlowers": "a photo of a {}, a type of flower.", 209 | "OxfordFlowers": "a type of flower, a photo of a {}.", 210 | "FGVCAircraft": "a type of aircraft, a photo of a {}.", 211 | "DescribableTextures": "a texture of {}.", 212 | "EuroSAT": "a centered satellite photo of {}.", 213 | "StanfordCars": "a photo of a {}.", 214 | # "Food101": "a photo of {}, a type of food.", 215 | "Food101": "a type of food, a photo of {}.", 216 | "SUN397": "a photo of a {}.", 217 | "Caltech101": "a photo of a {}.", 218 | "UCF101": "a photo of a person doing {}.", 219 | "ImageNet": "a photo of a {}.", 220 | "ImageNetSketch": "a photo of a {}.", 221 | "ImageNetV2": "a photo of a {}.", 222 | "ImageNetA": "a photo of a {}.", 223 | "ImageNetR": "a photo of a {}.", 224 | } 225 | 226 | 227 | class CustomCLIP(nn.Module): 228 | def __init__(self, cfg, classnames, clip_model): 229 | super().__init__() 230 | self.prompt_learner = PromptLearner(cfg, classnames, clip_model) 231 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts 232 | self.image_encoder = clip_model.visual 233 | self.text_encoder = TextEncoder(clip_model) 234 | self.logit_scale = clip_model.logit_scale 235 | self.dtype = clip_model.dtype 236 | 237 | def forward(self, image): 238 | image_features = self.image_encoder(image.type(self.dtype)) 239 | 240 | prompts = self.prompt_learner() 241 | tokenized_prompts = self.tokenized_prompts 242 | text_features = self.text_encoder(prompts, tokenized_prompts) 243 | 244 | image_features = image_features / image_features.norm(dim=-1, 245 | keepdim=True) 246 | text_features = text_features / text_features.norm(dim=-1, 247 | keepdim=True) 248 | 249 | logit_scale = self.logit_scale.exp() 250 | logits = logit_scale * image_features @ text_features.t() 251 | 252 | return logits 253 | 254 | 255 | @TRAINER_REGISTRY.register() 256 | class CoOp(TrainerX): 257 | """Context Optimization (CoOp). 258 | 259 | Learning to Prompt for Vision-Language Models 260 | https://arxiv.org/abs/2109.01134 261 | """ 262 | def check_cfg(self, cfg): 263 | assert cfg.TRAINER.COOP.PREC in ["fp16", "fp32", "amp"] 264 | 265 | def build_model(self): 266 | cfg = self.cfg 267 | classnames = self.dm.dataset.classnames 268 | 269 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 270 | clip_model = load_clip_to_cpu(cfg) 271 | 272 | if cfg.TRAINER.COOP.PREC == "fp32" or cfg.TRAINER.COOP.PREC == "amp": 273 | # CLIP's default precision is fp16 274 | clip_model.float() 275 | 276 | print("Building custom CLIP") 277 | self.model = CustomCLIP(cfg, classnames, clip_model) 278 | 279 | print("Turning off gradients in both the image and the text encoder") 280 | for name, param in self.model.named_parameters(): 281 | if "prompt_learner" not in name: 282 | param.requires_grad_(False) 283 | 284 | if cfg.MODEL.INIT_WEIGHTS: 285 | load_pretrained_weights(self.model.prompt_learner, 286 | cfg.MODEL.INIT_WEIGHTS) 287 | 288 | self.model.to(self.device) 289 | # NOTE: only give prompt_learner to the optimizer 290 | self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM) 291 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) 292 | self.register_model("prompt_learner", self.model.prompt_learner, 293 | self.optim, self.sched) 294 | 295 | self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == "amp" else None 296 | 297 | # Note that multi-gpu training could be slow because CLIP's size is 298 | # big, which slows down the copy operation in DataParallel 299 | device_count = torch.cuda.device_count() 300 | if device_count > 1: 301 | print( 302 | f"Multiple GPUs detected (n_gpus={device_count}), use all of them!" 303 | ) 304 | self.model = nn.DataParallel(self.model) 305 | 306 | def forward_backward(self, batch): 307 | image, label = self.parse_batch_train(batch) 308 | 309 | prec = self.cfg.TRAINER.COOP.PREC 310 | if prec == "amp": 311 | with autocast(): 312 | output = self.model(image) 313 | loss = F.cross_entropy(output, label) 314 | self.optim.zero_grad() 315 | self.scaler.scale(loss).backward() 316 | self.scaler.step(self.optim) 317 | self.scaler.update() 318 | else: 319 | output = self.model(image) 320 | loss = F.cross_entropy(output, label) 321 | self.model_backward_and_update(loss) 322 | 323 | loss_summary = { 324 | "loss": loss.item(), 325 | "acc": compute_accuracy(output, label)[0].item(), 326 | } 327 | 328 | if (self.batch_idx + 1) == self.num_batches: 329 | self.update_lr() 330 | 331 | return loss_summary 332 | 333 | def parse_batch_train(self, batch): 334 | input = batch["img"] 335 | label = batch["label"] 336 | input = input.to(self.device) 337 | label = label.to(self.device) 338 | return input, label 339 | 340 | def load_model(self, directory, epoch=None): 341 | if not directory: 342 | print( 343 | "Note that load_model() is skipped as no pretrained model is given" 344 | ) 345 | return 346 | 347 | names = self.get_model_names() 348 | 349 | # By default, the best model is loaded 350 | model_file = "model-best.pth.tar" 351 | 352 | if epoch is not None: 353 | model_file = "model.pth.tar-" + str(epoch) 354 | 355 | for name in names: 356 | model_path = osp.join(directory, name, model_file) 357 | 358 | if not osp.exists(model_path): 359 | raise FileNotFoundError( 360 | 'Model not found at "{}"'.format(model_path)) 361 | 362 | checkpoint = load_checkpoint(model_path) 363 | state_dict = checkpoint["state_dict"] 364 | epoch = checkpoint["epoch"] 365 | 366 | # Ignore fixed token vectors 367 | if "token_prefix" in state_dict: 368 | del state_dict["token_prefix"] 369 | 370 | if "token_suffix" in state_dict: 371 | del state_dict["token_suffix"] 372 | 373 | print("Loading weights to {} " 374 | 'from "{}" (epoch = {})'.format(name, model_path, epoch)) 375 | # set strict=False 376 | self._models[name].load_state_dict(state_dict, strict=False) 377 | -------------------------------------------------------------------------------- /trainers/imagenet_templates.py: -------------------------------------------------------------------------------- 1 | # source: https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb 2 | 3 | IMAGENET_TEMPLATES = [ 4 | "a bad photo of a {}.", 5 | "a photo of many {}.", 6 | "a sculpture of a {}.", 7 | "a photo of the hard to see {}.", 8 | "a low resolution photo of the {}.", 9 | "a rendering of a {}.", 10 | "graffiti of a {}.", 11 | "a bad photo of the {}.", 12 | "a cropped photo of the {}.", 13 | "a tattoo of a {}.", 14 | "the embroidered {}.", 15 | "a photo of a hard to see {}.", 16 | "a bright photo of a {}.", 17 | "a photo of a clean {}.", 18 | "a photo of a dirty {}.", 19 | "a dark photo of the {}.", 20 | "a drawing of a {}.", 21 | "a photo of my {}.", 22 | "the plastic {}.", 23 | "a photo of the cool {}.", 24 | "a close-up photo of a {}.", 25 | "a black and white photo of the {}.", 26 | "a painting of the {}.", 27 | "a painting of a {}.", 28 | "a pixelated photo of the {}.", 29 | "a sculpture of the {}.", 30 | "a bright photo of the {}.", 31 | "a cropped photo of a {}.", 32 | "a plastic {}.", 33 | "a photo of the dirty {}.", 34 | "a jpeg corrupted photo of a {}.", 35 | "a blurry photo of the {}.", 36 | "a photo of the {}.", 37 | "a good photo of the {}.", 38 | "a rendering of the {}.", 39 | "a {} in a video game.", 40 | "a photo of one {}.", 41 | "a doodle of a {}.", 42 | "a close-up photo of the {}.", 43 | "a photo of a {}.", 44 | "the origami {}.", 45 | "the {} in a video game.", 46 | "a sketch of a {}.", 47 | "a doodle of the {}.", 48 | "a origami {}.", 49 | "a low resolution photo of a {}.", 50 | "the toy {}.", 51 | "a rendition of the {}.", 52 | "a photo of the clean {}.", 53 | "a photo of a large {}.", 54 | "a rendition of a {}.", 55 | "a photo of a nice {}.", 56 | "a photo of a weird {}.", 57 | "a blurry photo of a {}.", 58 | "a cartoon {}.", 59 | "art of a {}.", 60 | "a sketch of the {}.", 61 | "a embroidered {}.", 62 | "a pixelated photo of a {}.", 63 | "itap of the {}.", 64 | "a jpeg corrupted photo of the {}.", 65 | "a good photo of a {}.", 66 | "a plushie {}.", 67 | "a photo of the nice {}.", 68 | "a photo of the small {}.", 69 | "a photo of the weird {}.", 70 | "the cartoon {}.", 71 | "art of the {}.", 72 | "a drawing of the {}.", 73 | "a photo of the large {}.", 74 | "a black and white photo of a {}.", 75 | "the plushie {}.", 76 | "a dark photo of a {}.", 77 | "itap of a {}.", 78 | "graffiti of the {}.", 79 | "a toy {}.", 80 | "itap of my {}.", 81 | "a photo of a cool {}.", 82 | "a photo of a small {}.", 83 | "a tattoo of the {}.", 84 | ] 85 | 86 | IMAGENET_TEMPLATES_SELECT = [ 87 | "itap of a {}.", 88 | "a bad photo of the {}.", 89 | "a origami {}.", 90 | "a photo of the large {}.", 91 | "a {} in a video game.", 92 | "art of the {}.", 93 | "a photo of the small {}.", 94 | ] 95 | -------------------------------------------------------------------------------- /trainers/zsclip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from dassl.engine import TRAINER_REGISTRY, TrainerX 5 | from dassl.optim import build_optimizer, build_lr_scheduler 6 | 7 | from clip import clip 8 | from clip.model import convert_weights 9 | 10 | from .coop import load_clip_to_cpu 11 | from .imagenet_templates import IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT 12 | 13 | CUSTOM_TEMPLATES = { 14 | # "OxfordPets": "a photo of a {}, a type of pet.", 15 | "OxfordPets": "a type of pet, a photo of a {}.", 16 | # "OxfordFlowers": "a photo of a {}, a type of flower.", 17 | "OxfordFlowers": "a type of flower, a photo of a {}.", 18 | "FGVCAircraft": "a photo of a {}, a type of aircraft.", 19 | "DescribableTextures": "{} texture.", 20 | "EuroSAT": "a centered satellite photo of {}.", 21 | "StanfordCars": "a photo of a {}.", 22 | # "Food101": "a photo of {}, a type of food.", 23 | "Food101": "a type of food, a photo of {}.", 24 | "SUN397": "a photo of a {}.", 25 | "Caltech101": "a photo of a {}.", 26 | "UCF101": "a photo of a person doing {}.", 27 | "ImageNet": "a photo of a {}.", 28 | "ImageNetSketch": "a photo of a {}.", 29 | "ImageNetV2": "a photo of a {}.", 30 | "ImageNetA": "a photo of a {}.", 31 | "ImageNetR": "a photo of a {}.", 32 | } 33 | 34 | 35 | @TRAINER_REGISTRY.register() 36 | class ZeroshotCLIP(TrainerX): 37 | def build_model(self): 38 | cfg = self.cfg 39 | classnames = self.dm.dataset.classnames 40 | 41 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 42 | clip_model = load_clip_to_cpu(cfg) 43 | clip_model.to(self.device) 44 | 45 | temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME] 46 | prompts = [temp.format(c.replace("_", " ")) for c in classnames] 47 | print(f"Prompts: {prompts}") 48 | prompts = torch.cat([clip.tokenize(p) for p in prompts]) 49 | prompts = prompts.to(self.device) 50 | 51 | with torch.no_grad(): 52 | text_features = clip_model.encode_text(prompts) 53 | text_features = text_features / text_features.norm(dim=-1, 54 | keepdim=True) 55 | 56 | self.text_features = text_features 57 | self.clip_model = clip_model 58 | 59 | def model_inference(self, image): 60 | image_features = self.clip_model.encode_image(image) 61 | image_features = image_features / image_features.norm(dim=-1, 62 | keepdim=True) 63 | logit_scale = self.clip_model.logit_scale.exp() 64 | logits = logit_scale * image_features @ self.text_features.t() 65 | return logits 66 | 67 | 68 | @TRAINER_REGISTRY.register() 69 | class ZeroshotCLIP2(ZeroshotCLIP): 70 | """Prompt ensembling.""" 71 | 72 | # templates = IMAGENET_TEMPLATES 73 | templates = IMAGENET_TEMPLATES_SELECT 74 | 75 | def build_model(self): 76 | cfg = self.cfg 77 | classnames = self.dm.dataset.classnames 78 | 79 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 80 | clip_model = load_clip_to_cpu(cfg) 81 | clip_model.to(self.device) 82 | 83 | for params in clip_model.parameters(): 84 | params.requires_grad_(False) 85 | 86 | # add custom-made prompt 87 | if cfg.DATASET.NAME != "ImageNet": 88 | self.templates += [CUSTOM_TEMPLATES[cfg.DATASET.NAME]] 89 | 90 | num_temp = len(self.templates) 91 | print(f"Prompt ensembling (n={num_temp})") 92 | 93 | mean_text_features = 0 94 | for i, temp in enumerate(self.templates): 95 | prompts = [temp.format(c.replace("_", " ")) for c in classnames] 96 | prompts = torch.cat([clip.tokenize(p) 97 | for p in prompts]).to(self.device) 98 | text_features = clip_model.encode_text(prompts) 99 | text_features = text_features / text_features.norm(dim=-1, 100 | keepdim=True) 101 | mean_text_features = mean_text_features + text_features 102 | mean_text_features = mean_text_features / num_temp 103 | mean_text_features = mean_text_features / mean_text_features.norm( 104 | dim=-1, keepdim=True) 105 | 106 | self.text_features = mean_text_features 107 | self.clip_model = clip_model 108 | --------------------------------------------------------------------------------