├── DATASETS.md
├── LICENSE
├── README.md
├── clip
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── clip.cpython-37.pyc
│ ├── clip.cpython-38.pyc
│ ├── model.cpython-37.pyc
│ ├── model.cpython-38.pyc
│ ├── simple_tokenizer.cpython-37.pyc
│ └── simple_tokenizer.cpython-38.pyc
├── bpe_simple_vocab_16e6.txt.gz
├── clip.py
├── model.py
└── simple_tokenizer.py
├── configs
├── datasets
│ ├── caltech101.yaml
│ ├── dtd.yaml
│ ├── eurosat.yaml
│ ├── fgvc_aircraft.yaml
│ ├── food101.yaml
│ ├── imagenet.yaml
│ ├── imagenet_a.yaml
│ ├── imagenet_r.yaml
│ ├── imagenet_sketch.yaml
│ ├── imagenetv2.yaml
│ ├── oxford_flowers.yaml
│ ├── oxford_pets.yaml
│ ├── stanford_cars.yaml
│ ├── sun397.yaml
│ └── ucf101.yaml
└── trainers
│ ├── CoCoOp
│ ├── rn50_c4_ep10_batch1_ctxv1.yaml
│ ├── rn50_ep100_init.yaml
│ ├── rn50_ep50.yaml
│ ├── vit_b16_c16_ep10_batch1.yaml
│ ├── vit_b16_c4_ep10_batch1.yaml
│ ├── vit_b16_c4_ep10_batch1_ctxv1.yaml
│ └── vit_b16_c8_ep10_batch1.yaml
│ ├── CoOp
│ ├── rn50.yaml
│ ├── rn50_ep100.yaml
│ ├── rn50_ep50.yaml
│ └── rn50_val.yaml
│ ├── KgCoOp
│ ├── .ipynb_checkpoints
│ │ └── vit_b16_ep100_ctxv1-checkpoint.yaml
│ ├── rn50_ep100.yaml
│ ├── rn50_ep100_b16.yaml
│ ├── vit_b16_ep100_ctxv1.yaml
│ ├── vit_b16_ep100_ctxv1_b128.yaml
│ ├── vit_b16_ep100_ctxv1_b16.yaml
│ └── vit_b16_ep100_ctxv1_b8.yaml
│ ├── ProGrad
│ ├── rn50.yaml
│ ├── rn50_ep100.yaml
│ └── rn50_ep50.yaml
│ └── TCP
│ └── vit_b16_ep100_ctxv1.yaml
├── datasets
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── caltech101.cpython-37.pyc
│ ├── caltech101.cpython-38.pyc
│ ├── dtd.cpython-37.pyc
│ ├── dtd.cpython-38.pyc
│ ├── eurosat.cpython-37.pyc
│ ├── eurosat.cpython-38.pyc
│ ├── fgvc_aircraft.cpython-37.pyc
│ ├── fgvc_aircraft.cpython-38.pyc
│ ├── food101.cpython-37.pyc
│ ├── food101.cpython-38.pyc
│ ├── imagenet.cpython-37.pyc
│ ├── imagenet.cpython-38.pyc
│ ├── imagenet_a.cpython-37.pyc
│ ├── imagenet_a.cpython-38.pyc
│ ├── imagenet_r.cpython-37.pyc
│ ├── imagenet_r.cpython-38.pyc
│ ├── imagenet_sketch.cpython-37.pyc
│ ├── imagenet_sketch.cpython-38.pyc
│ ├── imagenetv2.cpython-37.pyc
│ ├── imagenetv2.cpython-38.pyc
│ ├── oxford_flowers.cpython-37.pyc
│ ├── oxford_flowers.cpython-38.pyc
│ ├── oxford_pets.cpython-37.pyc
│ ├── oxford_pets.cpython-38.pyc
│ ├── stanford_cars.cpython-37.pyc
│ ├── stanford_cars.cpython-38.pyc
│ ├── sun397.cpython-37.pyc
│ ├── sun397.cpython-38.pyc
│ ├── ucf101.cpython-37.pyc
│ └── ucf101.cpython-38.pyc
├── caltech101.py
├── dtd.py
├── eurosat.py
├── fgvc_aircraft.py
├── food101.py
├── imagenet.py
├── imagenet_a.py
├── imagenet_r.py
├── imagenet_sketch.py
├── imagenetv2.py
├── oxford_flowers.py
├── oxford_pets.py
├── stanford_cars.py
├── sun397.py
└── ucf101.py
├── eval.sh
├── interpret_prompt.py
├── lpclip
├── README.md
├── feat_extractor.py
├── feat_extractor.sh
├── linear_probe.py
├── linear_probe.sh
└── linear_probe_transfer.py
├── parse_test_res.py
├── requirements.txt
├── scripts
└── base2new_train.sh
├── train.py
└── trainers
├── .ipynb_checkpoints
├── kgcoop-checkpoint.py
└── kgcoop_bk-checkpoint.py
├── __init__.py
├── __pycache__
├── __init__.cpython-37.pyc
├── __init__.cpython-38.pyc
├── cocoop.cpython-37.pyc
├── cocoop.cpython-38.pyc
├── coop.cpython-37.pyc
├── coop.cpython-38.pyc
├── imagenet_templates.cpython-37.pyc
├── imagenet_templates.cpython-38.pyc
├── kgcoop.cpython-37.pyc
├── kgcoop.cpython-38.pyc
├── prograd.cpython-37.pyc
├── prograd.cpython-38.pyc
├── tcp.cpython-38.pyc
├── zsclip.cpython-37.pyc
└── zsclip.cpython-38.pyc
├── clip_text
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── clip.cpython-37.pyc
│ ├── clip.cpython-38.pyc
│ ├── model.cpython-37.pyc
│ ├── model.cpython-38.pyc
│ ├── simple_tokenizer.cpython-37.pyc
│ └── simple_tokenizer.cpython-38.pyc
├── bpe_simple_vocab_16e6.txt.gz
├── clip.py
├── model.py
└── simple_tokenizer.py
├── cocoop.py
├── coop.py
├── imagenet_templates.py
├── prograd.py
├── tcp.py
└── zsclip.py
/DATASETS.md:
--------------------------------------------------------------------------------
1 | # How to install datasets
2 |
3 | We suggest putting all datasets under the same folder (say `$DATA`) to ease management and following the instructions below to organize datasets to avoid modifying the source code. The file structure looks like
4 |
5 | ```
6 | $DATA/
7 | |–– imagenet/
8 | |–– caltech-101/
9 | |–– oxford_pets/
10 | |–– stanford_cars/
11 | ```
12 |
13 | If you have some datasets already installed somewhere else, you can create symbolic links in `$DATA/dataset_name` that point to the original data to avoid duplicate download.
14 |
15 | Datasets list:
16 | - [ImageNet](#imagenet)
17 | - [Caltech101](#caltech101)
18 | - [OxfordPets](#oxfordpets)
19 | - [StanfordCars](#stanfordcars)
20 | - [Flowers102](#flowers102)
21 | - [Food101](#food101)
22 | - [FGVCAircraft](#fgvcaircraft)
23 | - [SUN397](#sun397)
24 | - [DTD](#dtd)
25 | - [EuroSAT](#eurosat)
26 | - [UCF101](#ucf101)
27 | - [ImageNetV2](#imagenetv2)
28 | - [ImageNet-Sketch](#imagenet-sketch)
29 | - [ImageNet-A](#imagenet-a)
30 | - [ImageNet-R](#imagenet-r)
31 |
32 | The instructions to prepare each dataset are detailed below. To ensure reproducibility and fair comparison for future work, we provide fixed train/val/test splits for all datasets except ImageNet where the validation set is used as test set. The fixed splits are either from the original datasets (if available) or created by us.
33 |
34 | ### ImageNet
35 | - Create a folder named `imagenet/` under `$DATA`.
36 | - Create `images/` under `imagenet/`.
37 | - Download the dataset from the [official website](https://image-net.org/index.php) and extract the training and validation sets to `$DATA/imagenet/images`. The directory structure should look like
38 | ```
39 | imagenet/
40 | |–– images/
41 | | |–– train/ # contains 1,000 folders like n01440764, n01443537, etc.
42 | | |–– val/
43 | ```
44 | - If you had downloaded the ImageNet dataset before, you can create symbolic links to map the training and validation sets to `$DATA/imagenet/images`.
45 | - Download the `classnames.txt` to `$DATA/imagenet/` from this [link](https://drive.google.com/file/d/1-61f_ol79pViBFDG_IDlUQSwoLcn2XXF/view?usp=sharing). The class names are copied from [CLIP](https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb).
46 |
47 | ### Caltech101
48 | - Create a folder named `caltech-101/` under `$DATA`.
49 | - Download `101_ObjectCategories.tar.gz` from http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz and extract the file under `$DATA/caltech-101`.
50 | - Download `split_zhou_Caltech101.json` from this [link](https://drive.google.com/file/d/1hyarUivQE36mY6jSomru6Fjd-JzwcCzN/view?usp=sharing) and put it under `$DATA/caltech-101`.
51 |
52 | The directory structure should look like
53 | ```
54 | caltech-101/
55 | |–– 101_ObjectCategories/
56 | |–– split_zhou_Caltech101.json
57 | ```
58 |
59 | ### OxfordPets
60 | - Create a folder named `oxford_pets/` under `$DATA`.
61 | - Download the images from https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz.
62 | - Download the annotations from https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz.
63 | - Download `split_zhou_OxfordPets.json` from this [link](https://drive.google.com/file/d/1501r8Ber4nNKvmlFVQZ8SeUHTcdTTEqs/view?usp=sharing).
64 |
65 | The directory structure should look like
66 | ```
67 | oxford_pets/
68 | |–– images/
69 | |–– annotations/
70 | |–– split_zhou_OxfordPets.json
71 | ```
72 |
73 | ### StanfordCars
74 | - Create a folder named `stanford_cars/` under `$DATA`.
75 | - Download the train images http://ai.stanford.edu/~jkrause/car196/cars_train.tgz.
76 | - Download the test images http://ai.stanford.edu/~jkrause/car196/cars_test.tgz.
77 | - Download the train labels https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz.
78 | - Download the test labels http://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat.
79 | - Download `split_zhou_StanfordCars.json` from this [link](https://drive.google.com/file/d/1ObCFbaAgVu0I-k_Au-gIUcefirdAuizT/view?usp=sharing).
80 |
81 | The directory structure should look like
82 | ```
83 | stanford_cars/
84 | |–– cars_test\
85 | |–– cars_test_annos_withlabels.mat
86 | |–– cars_train\
87 | |–– devkit\
88 | |–– split_zhou_StanfordCars.json
89 | ```
90 |
91 | ### Flowers102
92 | - Create a folder named `oxford_flowers/` under `$DATA`.
93 | - Download the images and labels from https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz and https://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat respectively.
94 | - Download `cat_to_name.json` from [here](https://drive.google.com/file/d/1AkcxCXeK_RCGCEC_GvmWxjcjaNhu-at0/view?usp=sharing).
95 | - Download `split_zhou_OxfordFlowers.json` from [here](https://drive.google.com/file/d/1Pp0sRXzZFZq15zVOzKjKBu4A9i01nozT/view?usp=sharing).
96 |
97 | The directory structure should look like
98 | ```
99 | oxford_flowers/
100 | |–– cat_to_name.json
101 | |–– imagelabels.mat
102 | |–– jpg/
103 | |–– split_zhou_OxfordFlowers.json
104 | ```
105 |
106 | ### Food101
107 | - Download the dataset from https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/ and extract the file `food-101.tar.gz` under `$DATA`, resulting in a folder named `$DATA/food-101/`.
108 | - Download `split_zhou_Food101.json` from [here](https://drive.google.com/file/d/1QK0tGi096I0Ba6kggatX1ee6dJFIcEJl/view?usp=sharing).
109 |
110 | The directory structure should look like
111 | ```
112 | food-101/
113 | |–– images/
114 | |–– license_agreement.txt
115 | |–– meta/
116 | |–– README.txt
117 | |–– split_zhou_Food101.json
118 | ```
119 |
120 | ### FGVCAircraft
121 | - Download the data from https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz.
122 | - Extract `fgvc-aircraft-2013b.tar.gz` and keep only `data/`.
123 | - Move `data/` to `$DATA` and rename the folder to `fgvc_aircraft/`.
124 |
125 | The directory structure should look like
126 | ```
127 | fgvc_aircraft/
128 | |–– images/
129 | |–– ... # a bunch of .txt files
130 | ```
131 |
132 | ### SUN397
133 | - Create a folder named `sun397/` under `$DATA`.
134 | - Download the images http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz.
135 | - Download the partitions https://vision.princeton.edu/projects/2010/SUN/download/Partitions.zip.
136 | - Extract these files under `$DATA/sun397/`.
137 | - Download `split_zhou_SUN397.json` from this [link](https://drive.google.com/file/d/1y2RD81BYuiyvebdN-JymPfyWYcd8_MUq/view?usp=sharing).
138 |
139 | The directory structure should look like
140 | ```
141 | sun397/
142 | |–– SUN397/
143 | |–– split_zhou_SUN397.json
144 | |–– ... # a bunch of .txt files
145 | ```
146 |
147 | ### DTD
148 | - Download the dataset from https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz and extract it to `$DATA`. This should lead to `$DATA/dtd/`.
149 | - Download `split_zhou_DescribableTextures.json` from this [link](https://drive.google.com/file/d/1u3_QfB467jqHgNXC00UIzbLZRQCg2S7x/view?usp=sharing).
150 |
151 | The directory structure should look like
152 | ```
153 | dtd/
154 | |–– images/
155 | |–– imdb/
156 | |–– labels/
157 | |–– split_zhou_DescribableTextures.json
158 | ```
159 |
160 | ### EuroSAT
161 | - Create a folder named `eurosat/` under `$DATA`.
162 | - Download the dataset from http://madm.dfki.de/files/sentinel/EuroSAT.zip and extract it to `$DATA/eurosat/`.
163 | - Download `split_zhou_EuroSAT.json` from [here](https://drive.google.com/file/d/1Ip7yaCWFi0eaOFUGga0lUdVi_DDQth1o/view?usp=sharing).
164 |
165 | The directory structure should look like
166 | ```
167 | eurosat/
168 | |–– 2750/
169 | |–– split_zhou_EuroSAT.json
170 | ```
171 |
172 | ### UCF101
173 | - Create a folder named `ucf101/` under `$DATA`.
174 | - Download the zip file `UCF-101-midframes.zip` from [here](https://drive.google.com/file/d/10Jqome3vtUA2keJkNanAiFpgbyC9Hc2O/view?usp=sharing) and extract it to `$DATA/ucf101/`. This zip file contains the extracted middle video frames.
175 | - Download `split_zhou_UCF101.json` from this [link](https://drive.google.com/file/d/1I0S0q91hJfsV9Gf4xDIjgDq4AqBNJb1y/view?usp=sharing).
176 |
177 | The directory structure should look like
178 | ```
179 | ucf101/
180 | |–– UCF-101-midframes/
181 | |–– split_zhou_UCF101.json
182 | ```
183 |
184 | ### ImageNetV2
185 | - Create a folder named `imagenetv2/` under `$DATA`.
186 | - Go to this github repo https://github.com/modestyachts/ImageNetV2.
187 | - Download the matched-frequency dataset from https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-matched-frequency.tar.gz and extract it to `$DATA/imagenetv2/`.
188 | - Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenetv2/`.
189 |
190 | The directory structure should look like
191 | ```
192 | imagenetv2/
193 | |–– imagenetv2-matched-frequency-format-val/
194 | |–– classnames.txt
195 | ```
196 |
197 | ### ImageNet-Sketch
198 | - Download the dataset from https://github.com/HaohanWang/ImageNet-Sketch.
199 | - Extract the dataset to `$DATA/imagenet-sketch`.
200 | - Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenet-sketch/`.
201 |
202 | The directory structure should look like
203 | ```
204 | imagenet-sketch/
205 | |–– images/ # contains 1,000 folders whose names have the format of n*
206 | |–– classnames.txt
207 | ```
208 |
209 | ### ImageNet-A
210 | - Create a folder named `imagenet-adversarial/` under `$DATA`.
211 | - Download the dataset from https://github.com/hendrycks/natural-adv-examples and extract it to `$DATA/imagenet-adversarial/`.
212 | - Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenet-adversarial/`.
213 |
214 | The directory structure should look like
215 | ```
216 | imagenet-adversarial/
217 | |–– imagenet-a/ # contains 200 folders whose names have the format of n*
218 | |–– classnames.txt
219 | ```
220 |
221 | ### ImageNet-R
222 | - Create a folder named `imagenet-rendition/` under `$DATA`.
223 | - Download the dataset from https://github.com/hendrycks/imagenet-r and extract it to `$DATA/imagenet-rendition/`.
224 | - Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenet-rendition/`.
225 |
226 | The directory structure should look like
227 | ```
228 | imagenet-rendition/
229 | |–– imagenet-r/ # contains 200 folders whose names have the format of n*
230 | |–– classnames.txt
231 | ```
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Kaiyang Zhou
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## TCP: Textual-based Class-aware Prompt tuning for Visual-Language Model[CVPR24]
2 |
3 | > [**TCP: Textual-based Class-aware Prompt tuning for Visual-Language Model**](https://arxiv.org/abs/2311.18231)
4 | > Hantao Yao, Rui Zhang, Changsheng Xu
5 |
6 | ## How to Install
7 | This code is built on top of the toolbox [Dassl](https://github.com/KaiyangZhou/Dassl.pytorch). You can prepare the environment as follows:
8 |
9 | ```
10 | # Create a conda environment
11 | conda create -n dassl python=3.7
12 |
13 | # Activate the environment
14 | conda activate dassl
15 |
16 | # Install dependencies
17 | pip install -r requirements.txt
18 |
19 | # Install torch (version >= 1.7.1) and torchvision
20 | # Please make sure you have installed the gpu version due to the speed.
21 | # For example:
22 | conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
23 |
24 | # Install this library (no need to re-build if the source code is modified)
25 | python setup.py develop
26 | ```
27 |
28 | After that, run `pip install -r requirements.txt` under `Textual-based_Class-aware_prompt_tuning/` to install a few more packages required by [CLIP](https://github.com/openai/CLIP) (this should be done when `dassl` is activated). Then, you are ready to go.
29 |
30 | Follow [DATASETS.md](DATASETS.md) to install the datasets.
31 |
32 | ## [Importantly]Adjust `EPS` in Adam optimzier
33 | Since using the standard AdaW on the fp16 data will produce NaN loss, we thus set the EPS in AdaW as 1e-3. The discussion can also be see https://discuss.pytorch.org/t/adam-half-precision-nans/1765.
34 |
35 | Line 80: ./Dassl.pytorch/dassl/optim/optimizer.py
36 |
37 | ```
38 | if optim == "adam":
39 | optimizer = torch.optim.Adam(
40 | param_groups,
41 | lr=lr,
42 | weight_decay=weight_decay,
43 | betas=(adam_beta1, adam_beta2),
44 | eps=1e-3,
45 | )
46 | ```
47 |
48 |
49 |
50 |
51 | ## Generalization From Base to New Classes
52 |
53 | You will need `base2new_train_main.sh`. The scripts with the prefix `base2new_train` train a model on base classes while the ones with the prefix `base2new_test` evaluate the trained model on new classes. Both kinds of scripts have only one input argument, i.e., `DATASET`. `DATASET` takes as input a dataset name, like `imagenet` or `caltech101`. The valid names are the files' names in `CoOp/configs/datasets/`.
54 |
55 | Below we provide an example on how to evaluate the model on ImageNet.
56 |
57 | ```bash
58 | bash base2new_train.sh
59 | ```
60 |
61 | When the evaluation is done, you can use `parse_test_res.py` to automatically calculate the average results. For instance, after you finish the evaluation using the aforementioned commands, you would get
62 |
63 |
64 | Then, to get the average performance on the base classes, run
65 |
66 | ```bash
67 | python parse_test_res.py output/base2new/train_base/stanford_cars/shots_16/CoCoOp/rn50_ep100
68 | ```
69 |
70 | To get the average performance on the new classes, run
71 |
72 | ```bash
73 | python parse_test_res.py output/base2new/test_new/stanford_cars/shots_16/CoCoOp/rn50_ep100 --test-log
74 | ```
75 |
76 | ## Citation
77 | If you use our work, please consider citing:
78 | ```bibtex
79 | @inproceedings{TCP24,
80 | title={TCP: Textual-based Class-aware Prompt tuning for Visual-Language Model},
81 | author={Hantao Yao, Rui Zhang, Changsheng Xu},
82 | booktitle={The IEEE/CVF Conference on Computer Vision and Pattern Recognition},
83 | year={2024}
84 | }
85 | ```
86 |
87 |
--------------------------------------------------------------------------------
/clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import *
2 |
--------------------------------------------------------------------------------
/clip/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/clip/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/clip/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/clip/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/clip/__pycache__/clip.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/clip/__pycache__/clip.cpython-37.pyc
--------------------------------------------------------------------------------
/clip/__pycache__/clip.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/clip/__pycache__/clip.cpython-38.pyc
--------------------------------------------------------------------------------
/clip/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/clip/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/clip/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/clip/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/clip/__pycache__/simple_tokenizer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/clip/__pycache__/simple_tokenizer.cpython-37.pyc
--------------------------------------------------------------------------------
/clip/__pycache__/simple_tokenizer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/clip/__pycache__/simple_tokenizer.cpython-38.pyc
--------------------------------------------------------------------------------
/clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/clip/clip.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 | from typing import Any, Union, List
6 | from pkg_resources import packaging
7 |
8 | import torch
9 | from PIL import Image
10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11 | from tqdm import tqdm
12 |
13 | from .model import build_model
14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15 |
16 | try:
17 | from torchvision.transforms import InterpolationMode
18 | BICUBIC = InterpolationMode.BICUBIC
19 | except ImportError:
20 | BICUBIC = Image.BICUBIC
21 |
22 |
23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25 |
26 |
27 | __all__ = ["available_models", "load", "tokenize"]
28 | _tokenizer = _Tokenizer()
29 |
30 | _MODELS = {
31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
39 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
40 | }
41 |
42 |
43 | #def _download(url: str, root: str):
44 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
45 | os.makedirs(root, exist_ok=True)
46 | filename = os.path.basename(url)
47 |
48 | expected_sha256 = url.split("/")[-2]
49 | download_target = os.path.join(root, filename)
50 |
51 | if os.path.exists(download_target) and not os.path.isfile(download_target):
52 | raise RuntimeError(f"{download_target} exists and is not a regular file")
53 |
54 | if os.path.isfile(download_target):
55 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
56 | return download_target
57 | else:
58 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
59 |
60 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
61 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
62 | while True:
63 | buffer = source.read(8192)
64 | if not buffer:
65 | break
66 |
67 | output.write(buffer)
68 | loop.update(len(buffer))
69 |
70 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
71 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
72 |
73 | return download_target
74 |
75 |
76 | def _convert_image_to_rgb(image):
77 | return image.convert("RGB")
78 |
79 |
80 | def _transform(n_px):
81 | return Compose([
82 | Resize(n_px, interpolation=BICUBIC),
83 | CenterCrop(n_px),
84 | _convert_image_to_rgb,
85 | ToTensor(),
86 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
87 | ])
88 |
89 |
90 | def available_models() -> List[str]:
91 | """Returns the names of available CLIP models"""
92 | return list(_MODELS.keys())
93 |
94 |
95 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
96 | """Load a CLIP model
97 |
98 | Parameters
99 | ----------
100 | name : str
101 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
102 |
103 | device : Union[str, torch.device]
104 | The device to put the loaded model
105 |
106 | jit : bool
107 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
108 |
109 | download_root: str
110 | path to download the model files; by default, it uses "~/.cache/clip"
111 |
112 | Returns
113 | -------
114 | model : torch.nn.Module
115 | The CLIP model
116 |
117 | preprocess : Callable[[PIL.Image], torch.Tensor]
118 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
119 | """
120 | if name in _MODELS:
121 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
122 | elif os.path.isfile(name):
123 | model_path = name
124 | else:
125 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
126 |
127 | with open(model_path, 'rb') as opened_file:
128 | try:
129 | # loading JIT archive
130 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
131 | state_dict = None
132 | except RuntimeError:
133 | # loading saved state dict
134 | if jit:
135 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
136 | jit = False
137 | state_dict = torch.load(opened_file, map_location="cpu")
138 |
139 | if not jit:
140 | model = build_model(state_dict or model.state_dict()).to(device)
141 | if str(device) == "cpu":
142 | model.float()
143 | return model, _transform(model.visual.input_resolution)
144 |
145 | # patch the device names
146 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
147 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
148 |
149 | def _node_get(node: torch._C.Node, key: str):
150 | """Gets attributes of a node which is polymorphic over return type.
151 |
152 | From https://github.com/pytorch/pytorch/pull/82628
153 | """
154 | sel = node.kindOf(key)
155 | return getattr(node, sel)(key)
156 |
157 | def patch_device(module):
158 | try:
159 | graphs = [module.graph] if hasattr(module, "graph") else []
160 | except RuntimeError:
161 | graphs = []
162 |
163 | if hasattr(module, "forward1"):
164 | graphs.append(module.forward1.graph)
165 |
166 | for graph in graphs:
167 | for node in graph.findAllNodes("prim::Constant"):
168 | if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
169 | node.copyAttributes(device_node)
170 |
171 | model.apply(patch_device)
172 | patch_device(model.encode_image)
173 | patch_device(model.encode_text)
174 |
175 | # patch dtype to float32 on CPU
176 | if str(device) == "cpu":
177 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
178 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
179 | float_node = float_input.node()
180 |
181 | def patch_float(module):
182 | try:
183 | graphs = [module.graph] if hasattr(module, "graph") else []
184 | except RuntimeError:
185 | graphs = []
186 |
187 | if hasattr(module, "forward1"):
188 | graphs.append(module.forward1.graph)
189 |
190 | for graph in graphs:
191 | for node in graph.findAllNodes("aten::to"):
192 | inputs = list(node.inputs())
193 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
194 | if _node_get(inputs[i].node(), "value") == 5:
195 | inputs[i].node().copyAttributes(float_node)
196 |
197 | model.apply(patch_float)
198 | patch_float(model.encode_image)
199 | patch_float(model.encode_text)
200 |
201 | model.float()
202 |
203 | return model, _transform(model.input_resolution.item())
204 |
205 |
206 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
207 | """
208 | Returns the tokenized representation of given input string(s)
209 |
210 | Parameters
211 | ----------
212 | texts : Union[str, List[str]]
213 | An input string or a list of input strings to tokenize
214 |
215 | context_length : int
216 | The context length to use; all CLIP models use 77 as the context length
217 |
218 | truncate: bool
219 | Whether to truncate the text in case its encoding is longer than the context length
220 |
221 | Returns
222 | -------
223 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
224 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
225 | """
226 | if isinstance(texts, str):
227 | texts = [texts]
228 |
229 | sot_token = _tokenizer.encoder["<|startoftext|>"]
230 | eot_token = _tokenizer.encoder["<|endoftext|>"]
231 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
232 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
233 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
234 | else:
235 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
236 |
237 | for i, tokens in enumerate(all_tokens):
238 | if len(tokens) > context_length:
239 | if truncate:
240 | tokens = tokens[:context_length]
241 | tokens[-1] = eot_token
242 | else:
243 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
244 | result[i, :len(tokens)] = torch.tensor(tokens)
245 |
246 | return result
247 |
--------------------------------------------------------------------------------
/clip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
--------------------------------------------------------------------------------
/configs/datasets/caltech101.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "Caltech101"
3 |
--------------------------------------------------------------------------------
/configs/datasets/dtd.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "DescribableTextures"
3 |
--------------------------------------------------------------------------------
/configs/datasets/eurosat.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "EuroSAT"
3 |
--------------------------------------------------------------------------------
/configs/datasets/fgvc_aircraft.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "FGVCAircraft"
3 |
--------------------------------------------------------------------------------
/configs/datasets/food101.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "Food101"
3 |
--------------------------------------------------------------------------------
/configs/datasets/imagenet.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "ImageNet"
3 |
--------------------------------------------------------------------------------
/configs/datasets/imagenet_a.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "ImageNetA"
3 |
--------------------------------------------------------------------------------
/configs/datasets/imagenet_r.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "ImageNetR"
3 |
--------------------------------------------------------------------------------
/configs/datasets/imagenet_sketch.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "ImageNetSketch"
3 |
--------------------------------------------------------------------------------
/configs/datasets/imagenetv2.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "ImageNetV2"
3 |
--------------------------------------------------------------------------------
/configs/datasets/oxford_flowers.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "OxfordFlowers"
--------------------------------------------------------------------------------
/configs/datasets/oxford_pets.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "OxfordPets"
--------------------------------------------------------------------------------
/configs/datasets/stanford_cars.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "StanfordCars"
3 |
--------------------------------------------------------------------------------
/configs/datasets/sun397.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "SUN397"
3 |
--------------------------------------------------------------------------------
/configs/datasets/ucf101.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "UCF101"
3 |
--------------------------------------------------------------------------------
/configs/trainers/CoCoOp/rn50_c4_ep10_batch1_ctxv1.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 1
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 10
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 20
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "RN50"
30 |
31 | TRAINER:
32 | COCOOP:
33 | N_CTX: 4
34 | CTX_INIT: True
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/CoCoOp/rn50_ep100_init.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 1
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 100
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 20
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "RN50"
30 |
31 | TRAINER:
32 | COCOOP:
33 | N_CTX: 16
34 | CTX_INIT: True
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/CoCoOp/rn50_ep50.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 1
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 50
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 20
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "RN50"
30 |
31 | TRAINER:
32 | COCOOP:
33 | N_CTX: 16
34 | CTX_INIT: True
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/CoCoOp/vit_b16_c16_ep10_batch1.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 1
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 10
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 20
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COCOOP:
33 | N_CTX: 16
34 | CTX_INIT: ""
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/CoCoOp/vit_b16_c4_ep10_batch1.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 1
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 10
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 20
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COCOOP:
33 | N_CTX: 4
34 | CTX_INIT: ""
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/CoCoOp/vit_b16_c4_ep10_batch1_ctxv1.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 16
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 10
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 20
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COCOOP:
33 | N_CTX: 4
34 | #CTX_INIT: "a photo of a"
35 | PREC: "fp16"
36 |
--------------------------------------------------------------------------------
/configs/trainers/CoCoOp/vit_b16_c8_ep10_batch1.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 1
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 10
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 20
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COCOOP:
33 | N_CTX: 8
34 | CTX_INIT: ""
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/CoOp/rn50.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 200
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "RN50"
--------------------------------------------------------------------------------
/configs/trainers/CoOp/rn50_ep100.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 100
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "RN50"
30 |
--------------------------------------------------------------------------------
/configs/trainers/CoOp/rn50_ep50.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 50
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "RN50"
--------------------------------------------------------------------------------
/configs/trainers/CoOp/rn50_val.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 32
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | MODEL:
16 | BACKBONE:
17 | NAME: "RN50"
--------------------------------------------------------------------------------
/configs/trainers/KgCoOp/.ipynb_checkpoints/vit_b16_ep100_ctxv1-checkpoint.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "adam"
17 | LR: 0.002
18 | MAX_EPOCH: 200
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COOP:
33 | CTX_INIT: False
34 |
--------------------------------------------------------------------------------
/configs/trainers/KgCoOp/rn50_ep100.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 256
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "adam" #"sgd"
17 | LR: 0.002 #0.002
18 | MAX_EPOCH: 100
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "RN50"
30 |
31 | TRAINER:
32 | COOP:
33 | CTX_INIT: False
34 |
--------------------------------------------------------------------------------
/configs/trainers/KgCoOp/rn50_ep100_b16.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 16
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "adam" #"sgd"
17 | LR: 0.002 #0.002
18 | MAX_EPOCH: 100
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "RN50"
30 |
31 | TRAINER:
32 | COOP:
33 | CTX_INIT: False
34 |
--------------------------------------------------------------------------------
/configs/trainers/KgCoOp/vit_b16_ep100_ctxv1.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 500
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "adam"
17 | LR: 0.002
18 | MAX_EPOCH: 50
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 | #EPS: 1e-3
24 |
25 | TRAIN:
26 | PRINT_FREQ: 5
27 |
28 | MODEL:
29 | BACKBONE:
30 | NAME: "ViT-B/16"
31 |
32 | TRAINER:
33 | COOP:
34 | CTX_INIT: False
35 |
--------------------------------------------------------------------------------
/configs/trainers/KgCoOp/vit_b16_ep100_ctxv1_b128.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 128
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "adam"
17 | LR: 0.002
18 | MAX_EPOCH: 100
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COOP:
33 | CTX_INIT: True
34 |
--------------------------------------------------------------------------------
/configs/trainers/KgCoOp/vit_b16_ep100_ctxv1_b16.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 16
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "adam"
17 | LR: 0.002
18 | MAX_EPOCH: 100
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COOP:
33 | CTX_INIT: False
34 |
--------------------------------------------------------------------------------
/configs/trainers/KgCoOp/vit_b16_ep100_ctxv1_b8.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 8
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "adam"
17 | LR: 0.002
18 | MAX_EPOCH: 100
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 10
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COOP:
33 | CTX_INIT: False
34 |
--------------------------------------------------------------------------------
/configs/trainers/ProGrad/rn50.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 200
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | LOSS:
25 | NAME: "prograd"
26 | T: 1.0
27 |
28 | TRAIN:
29 | PRINT_FREQ: 5
30 |
31 | MODEL:
32 | BACKBONE:
33 | NAME: "RN50"
34 |
35 | TRAINER:
36 | COOP:
37 | CTX_INIT: True
38 |
--------------------------------------------------------------------------------
/configs/trainers/ProGrad/rn50_ep100.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 100
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | LOSS:
25 | NAME: "prograd"
26 | T: 1.0
27 |
28 | TRAIN:
29 | PRINT_FREQ: 5
30 |
31 | MODEL:
32 | BACKBONE:
33 | NAME: "RN50"
34 |
35 | TRAINER:
36 | COOP:
37 | CTX_INIT: True
38 |
--------------------------------------------------------------------------------
/configs/trainers/ProGrad/rn50_ep50.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 50
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | LOSS:
25 | NAME: "prograd"
26 | T: 1.0
27 |
28 | TRAIN:
29 | PRINT_FREQ: 5
30 |
31 | MODEL:
32 | BACKBONE:
33 | NAME: "RN50"
34 |
35 | TRAINER:
36 | COOP:
37 | CTX_INIT: True
--------------------------------------------------------------------------------
/configs/trainers/TCP/vit_b16_ep100_ctxv1.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 500
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "adam"
17 | LR: 0.002
18 | MAX_EPOCH: 50
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COOP:
33 | CTX_INIT: False
34 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__init__.py
--------------------------------------------------------------------------------
/datasets/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/caltech101.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/caltech101.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/caltech101.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/caltech101.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/dtd.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/dtd.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/dtd.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/dtd.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/eurosat.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/eurosat.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/eurosat.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/eurosat.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/fgvc_aircraft.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/fgvc_aircraft.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/fgvc_aircraft.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/fgvc_aircraft.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/food101.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/food101.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/food101.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/food101.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/imagenet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/imagenet.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/imagenet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/imagenet.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/imagenet_a.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/imagenet_a.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/imagenet_a.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/imagenet_a.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/imagenet_r.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/imagenet_r.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/imagenet_r.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/imagenet_r.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/imagenet_sketch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/imagenet_sketch.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/imagenet_sketch.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/imagenet_sketch.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/imagenetv2.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/imagenetv2.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/imagenetv2.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/imagenetv2.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/oxford_flowers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/oxford_flowers.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/oxford_flowers.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/oxford_flowers.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/oxford_pets.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/oxford_pets.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/oxford_pets.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/oxford_pets.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/stanford_cars.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/stanford_cars.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/stanford_cars.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/stanford_cars.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/sun397.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/sun397.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/sun397.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/sun397.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/ucf101.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/ucf101.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/ucf101.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/datasets/__pycache__/ucf101.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/caltech101.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
5 | from dassl.utils import mkdir_if_missing
6 |
7 | from .oxford_pets import OxfordPets
8 | from .dtd import DescribableTextures as DTD
9 |
10 | IGNORED = ["BACKGROUND_Google", "Faces_easy"]
11 | NEW_CNAMES = {
12 | "airplanes": "airplane",
13 | "Faces": "face",
14 | "Leopards": "leopard",
15 | "Motorbikes": "motorbike",
16 | }
17 |
18 |
19 | @DATASET_REGISTRY.register()
20 | class Caltech101(DatasetBase):
21 |
22 | dataset_dir = "caltech-101"
23 |
24 | def __init__(self, cfg):
25 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
26 | self.dataset_dir = os.path.join(root, self.dataset_dir)
27 | self.image_dir = os.path.join(self.dataset_dir, "101_ObjectCategories")
28 | self.split_path = os.path.join(self.dataset_dir,
29 | "split_zhou_Caltech101.json")
30 | self.split_fewshot_dir = os.path.join(self.dataset_dir,
31 | "split_fewshot")
32 | mkdir_if_missing(self.split_fewshot_dir)
33 |
34 | if os.path.exists(self.split_path):
35 | train, val, test = OxfordPets.read_split(self.split_path,
36 | self.image_dir)
37 | else:
38 | train, val, test = DTD.read_and_split_data(self.image_dir,
39 | ignored=IGNORED,
40 | new_cnames=NEW_CNAMES)
41 | OxfordPets.save_split(train, val, test, self.split_path,
42 | self.image_dir)
43 |
44 | num_shots = cfg.DATASET.NUM_SHOTS
45 | if num_shots >= 1:
46 | seed = cfg.SEED
47 | preprocessed = os.path.join(self.split_fewshot_dir,
48 | f"shot_{num_shots}-seed_{seed}.pkl")
49 |
50 | if os.path.exists(preprocessed):
51 | print(
52 | f"Loading preprocessed few-shot data from {preprocessed}")
53 | with open(preprocessed, "rb") as file:
54 | data = pickle.load(file)
55 | train, val = data["train"], data["val"]
56 | else:
57 | train = self.generate_fewshot_dataset(train,
58 | num_shots=num_shots)
59 | val = self.generate_fewshot_dataset(val,
60 | num_shots=min(
61 | num_shots, 4))
62 | data = {"train": train, "val": val}
63 | print(f"Saving preprocessed few-shot data to {preprocessed}")
64 | with open(preprocessed, "wb") as file:
65 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
66 |
67 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES
68 | train, val, test = OxfordPets.subsample_classes(train,
69 | val,
70 | test,
71 | subsample=subsample)
72 |
73 | super().__init__(train_x=train, val=val, test=test)
74 |
--------------------------------------------------------------------------------
/datasets/dtd.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import random
4 |
5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
6 | from dassl.utils import listdir_nohidden, mkdir_if_missing
7 |
8 | from .oxford_pets import OxfordPets
9 |
10 |
11 | @DATASET_REGISTRY.register()
12 | class DescribableTextures(DatasetBase):
13 |
14 | dataset_dir = "dtd"
15 |
16 | def __init__(self, cfg):
17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
18 | self.dataset_dir = os.path.join(root, self.dataset_dir)
19 | self.image_dir = os.path.join(self.dataset_dir, "images")
20 | self.split_path = os.path.join(self.dataset_dir,
21 | "split_zhou_DescribableTextures.json")
22 | self.split_fewshot_dir = os.path.join(self.dataset_dir,
23 | "split_fewshot")
24 | mkdir_if_missing(self.split_fewshot_dir)
25 |
26 | if os.path.exists(self.split_path):
27 | train, val, test = OxfordPets.read_split(self.split_path,
28 | self.image_dir)
29 | else:
30 | train, val, test = self.read_and_split_data(self.image_dir)
31 | OxfordPets.save_split(train, val, test, self.split_path,
32 | self.image_dir)
33 |
34 | num_shots = cfg.DATASET.NUM_SHOTS
35 | if num_shots >= 1:
36 | seed = cfg.SEED
37 | preprocessed = os.path.join(self.split_fewshot_dir,
38 | f"shot_{num_shots}-seed_{seed}.pkl")
39 |
40 | if os.path.exists(preprocessed):
41 | print(
42 | f"Loading preprocessed few-shot data from {preprocessed}")
43 | with open(preprocessed, "rb") as file:
44 | data = pickle.load(file)
45 | train, val = data["train"], data["val"]
46 | else:
47 | train = self.generate_fewshot_dataset(train,
48 | num_shots=num_shots)
49 | val = self.generate_fewshot_dataset(val,
50 | num_shots=min(
51 | num_shots, 4))
52 | data = {"train": train, "val": val}
53 | print(f"Saving preprocessed few-shot data to {preprocessed}")
54 | with open(preprocessed, "wb") as file:
55 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
56 |
57 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES
58 | train, val, test = OxfordPets.subsample_classes(train,
59 | val,
60 | test,
61 | subsample=subsample)
62 |
63 | super().__init__(train_x=train, val=val, test=test)
64 |
65 | @staticmethod
66 | def read_and_split_data(image_dir,
67 | p_trn=0.5,
68 | p_val=0.2,
69 | ignored=[],
70 | new_cnames=None):
71 | # The data are supposed to be organized into the following structure
72 | # =============
73 | # images/
74 | # dog/
75 | # cat/
76 | # horse/
77 | # =============
78 | categories = listdir_nohidden(image_dir)
79 | categories = [c for c in categories if c not in ignored]
80 | categories.sort()
81 |
82 | p_tst = 1 - p_trn - p_val
83 | print(
84 | f"Splitting into {p_trn:.0%} train, {p_val:.0%} val, and {p_tst:.0%} test"
85 | )
86 |
87 | def _collate(ims, y, c):
88 | items = []
89 | for im in ims:
90 | item = Datum(impath=im, label=y,
91 | classname=c) # is already 0-based
92 | items.append(item)
93 | return items
94 |
95 | train, val, test = [], [], []
96 | for label, category in enumerate(categories):
97 | category_dir = os.path.join(image_dir, category)
98 | images = listdir_nohidden(category_dir)
99 | images = [os.path.join(category_dir, im) for im in images]
100 | random.shuffle(images)
101 | n_total = len(images)
102 | n_train = round(n_total * p_trn)
103 | n_val = round(n_total * p_val)
104 | n_test = n_total - n_train - n_val
105 | assert n_train > 0 and n_val > 0 and n_test > 0
106 |
107 | if new_cnames is not None and category in new_cnames:
108 | category = new_cnames[category]
109 |
110 | train.extend(_collate(images[:n_train], label, category))
111 | val.extend(
112 | _collate(images[n_train:n_train + n_val], label, category))
113 | test.extend(_collate(images[n_train + n_val:], label, category))
114 |
115 | return train, val, test
116 |
--------------------------------------------------------------------------------
/datasets/eurosat.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
5 | from dassl.utils import mkdir_if_missing
6 |
7 | from .oxford_pets import OxfordPets
8 | from .dtd import DescribableTextures as DTD
9 |
10 | NEW_CNAMES = {
11 | "AnnualCrop": "Annual Crop Land",
12 | "Forest": "Forest",
13 | "HerbaceousVegetation": "Herbaceous Vegetation Land",
14 | "Highway": "Highway or Road",
15 | "Industrial": "Industrial Buildings",
16 | "Pasture": "Pasture Land",
17 | "PermanentCrop": "Permanent Crop Land",
18 | "Residential": "Residential Buildings",
19 | "River": "River",
20 | "SeaLake": "Sea or Lake",
21 | }
22 |
23 |
24 | @DATASET_REGISTRY.register()
25 | class EuroSAT(DatasetBase):
26 |
27 | dataset_dir = "eurosat"
28 |
29 | def __init__(self, cfg):
30 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
31 | self.dataset_dir = os.path.join(root, self.dataset_dir)
32 | self.image_dir = os.path.join(self.dataset_dir, "2750")
33 | self.split_path = os.path.join(self.dataset_dir,
34 | "split_zhou_EuroSAT.json")
35 | self.split_fewshot_dir = os.path.join(self.dataset_dir,
36 | "split_fewshot")
37 | mkdir_if_missing(self.split_fewshot_dir)
38 |
39 | if os.path.exists(self.split_path):
40 | train, val, test = OxfordPets.read_split(self.split_path,
41 | self.image_dir)
42 | else:
43 | train, val, test = DTD.read_and_split_data(self.image_dir,
44 | new_cnames=NEW_CNAMES)
45 | OxfordPets.save_split(train, val, test, self.split_path,
46 | self.image_dir)
47 |
48 | num_shots = cfg.DATASET.NUM_SHOTS
49 | if num_shots >= 1:
50 | seed = cfg.SEED
51 | preprocessed = os.path.join(self.split_fewshot_dir,
52 | f"shot_{num_shots}-seed_{seed}.pkl")
53 |
54 | if os.path.exists(preprocessed):
55 | print(
56 | f"Loading preprocessed few-shot data from {preprocessed}")
57 | with open(preprocessed, "rb") as file:
58 | data = pickle.load(file)
59 | train, val = data["train"], data["val"]
60 | else:
61 | train = self.generate_fewshot_dataset(train,
62 | num_shots=num_shots)
63 | val = self.generate_fewshot_dataset(val,
64 | num_shots=min(
65 | num_shots, 4))
66 | data = {"train": train, "val": val}
67 | print(f"Saving preprocessed few-shot data to {preprocessed}")
68 | with open(preprocessed, "wb") as file:
69 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
70 |
71 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES
72 | train, val, test = OxfordPets.subsample_classes(train,
73 | val,
74 | test,
75 | subsample=subsample)
76 |
77 | super().__init__(train_x=train, val=val, test=test)
78 |
79 | def update_classname(self, dataset_old):
80 | dataset_new = []
81 | for item_old in dataset_old:
82 | cname_old = item_old.classname
83 | cname_new = NEW_CLASSNAMES[cname_old]
84 | item_new = Datum(impath=item_old.impath,
85 | label=item_old.label,
86 | classname=cname_new)
87 | dataset_new.append(item_new)
88 | return dataset_new
89 |
--------------------------------------------------------------------------------
/datasets/fgvc_aircraft.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
5 | from dassl.utils import mkdir_if_missing
6 |
7 | from .oxford_pets import OxfordPets
8 |
9 |
10 | @DATASET_REGISTRY.register()
11 | class FGVCAircraft(DatasetBase):
12 |
13 | dataset_dir = "fgvc_aircraft"
14 |
15 | def __init__(self, cfg):
16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
17 | self.dataset_dir = os.path.join(root, self.dataset_dir)
18 | self.image_dir = os.path.join(self.dataset_dir, "images")
19 | self.split_fewshot_dir = os.path.join(self.dataset_dir,
20 | "split_fewshot")
21 | mkdir_if_missing(self.split_fewshot_dir)
22 |
23 | classnames = []
24 | with open(os.path.join(self.dataset_dir, "variants.txt"), "r") as f:
25 | lines = f.readlines()
26 | for line in lines:
27 | classnames.append(line.strip())
28 | cname2lab = {c: i for i, c in enumerate(classnames)}
29 |
30 | train = self.read_data(cname2lab, "images_variant_train.txt")
31 | val = self.read_data(cname2lab, "images_variant_val.txt")
32 | test = self.read_data(cname2lab, "images_variant_test.txt")
33 |
34 | num_shots = cfg.DATASET.NUM_SHOTS
35 | if num_shots >= 1:
36 | seed = cfg.SEED
37 | preprocessed = os.path.join(self.split_fewshot_dir,
38 | f"shot_{num_shots}-seed_{seed}.pkl")
39 |
40 | if os.path.exists(preprocessed):
41 | print(
42 | f"Loading preprocessed few-shot data from {preprocessed}")
43 | with open(preprocessed, "rb") as file:
44 | data = pickle.load(file)
45 | train, val = data["train"], data["val"]
46 | else:
47 | train = self.generate_fewshot_dataset(train,
48 | num_shots=num_shots)
49 | val = self.generate_fewshot_dataset(val,
50 | num_shots=min(
51 | num_shots, 4))
52 | data = {"train": train, "val": val}
53 | print(f"Saving preprocessed few-shot data to {preprocessed}")
54 | with open(preprocessed, "wb") as file:
55 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
56 |
57 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES
58 | train, val, test = OxfordPets.subsample_classes(train,
59 | val,
60 | test,
61 | subsample=subsample)
62 |
63 | super().__init__(train_x=train, val=val, test=test)
64 |
65 | def read_data(self, cname2lab, split_file):
66 | filepath = os.path.join(self.dataset_dir, split_file)
67 | items = []
68 |
69 | with open(filepath, "r") as f:
70 | lines = f.readlines()
71 | for line in lines:
72 | line = line.strip().split(" ")
73 | imname = line[0] + ".jpg"
74 | classname = " ".join(line[1:])
75 | impath = os.path.join(self.image_dir, imname)
76 | label = cname2lab[classname]
77 | item = Datum(impath=impath, label=label, classname=classname)
78 | items.append(item)
79 |
80 | return items
81 |
--------------------------------------------------------------------------------
/datasets/food101.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
5 | from dassl.utils import mkdir_if_missing
6 |
7 | from .oxford_pets import OxfordPets
8 | from .dtd import DescribableTextures as DTD
9 |
10 |
11 | @DATASET_REGISTRY.register()
12 | class Food101(DatasetBase):
13 |
14 | dataset_dir = "food-101"
15 |
16 | def __init__(self, cfg):
17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
18 | self.dataset_dir = os.path.join(root, self.dataset_dir)
19 | self.image_dir = os.path.join(self.dataset_dir, "images")
20 | self.split_path = os.path.join(self.dataset_dir,
21 | "split_zhou_Food101.json")
22 | self.split_fewshot_dir = os.path.join(self.dataset_dir,
23 | "split_fewshot")
24 | mkdir_if_missing(self.split_fewshot_dir)
25 |
26 | if os.path.exists(self.split_path):
27 | train, val, test = OxfordPets.read_split(self.split_path,
28 | self.image_dir)
29 | else:
30 | train, val, test = DTD.read_and_split_data(self.image_dir)
31 | OxfordPets.save_split(train, val, test, self.split_path,
32 | self.image_dir)
33 |
34 | num_shots = cfg.DATASET.NUM_SHOTS
35 | if num_shots >= 1:
36 | seed = cfg.SEED
37 | preprocessed = os.path.join(self.split_fewshot_dir,
38 | f"shot_{num_shots}-seed_{seed}.pkl")
39 |
40 | if os.path.exists(preprocessed):
41 | print(
42 | f"Loading preprocessed few-shot data from {preprocessed}")
43 | with open(preprocessed, "rb") as file:
44 | data = pickle.load(file)
45 | train, val = data["train"], data["val"]
46 | else:
47 | train = self.generate_fewshot_dataset(train,
48 | num_shots=num_shots)
49 | val = self.generate_fewshot_dataset(val,
50 | num_shots=min(
51 | num_shots, 4))
52 | data = {"train": train, "val": val}
53 | print(f"Saving preprocessed few-shot data to {preprocessed}")
54 | with open(preprocessed, "wb") as file:
55 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
56 |
57 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES
58 | train, val, test = OxfordPets.subsample_classes(train,
59 | val,
60 | test,
61 | subsample=subsample)
62 |
63 | super().__init__(train_x=train, val=val, test=test)
64 |
--------------------------------------------------------------------------------
/datasets/imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from collections import OrderedDict
4 |
5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
6 | from dassl.utils import listdir_nohidden, mkdir_if_missing
7 |
8 | from .oxford_pets import OxfordPets
9 |
10 |
11 | @DATASET_REGISTRY.register()
12 | class ImageNet(DatasetBase):
13 |
14 | dataset_dir = "imagenet"
15 |
16 | def __init__(self, cfg):
17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
18 | self.dataset_dir = os.path.join(root, self.dataset_dir)
19 | self.image_dir = os.path.join(self.dataset_dir, "images")
20 | self.preprocessed = os.path.join(self.dataset_dir, "preprocessed.pkl")
21 | self.split_fewshot_dir = os.path.join(self.dataset_dir,
22 | "split_fewshot")
23 | mkdir_if_missing(self.split_fewshot_dir)
24 |
25 | if os.path.exists(self.preprocessed):
26 | with open(self.preprocessed, "rb") as f:
27 | preprocessed = pickle.load(f)
28 | train = preprocessed["train"]
29 | test = preprocessed["test"]
30 | else:
31 | text_file = os.path.join(self.dataset_dir, "classnames.txt")
32 | classnames = self.read_classnames(text_file)
33 | train = self.read_data(classnames, "train")
34 | # Follow standard practice to perform evaluation on the val set
35 | # Also used as the val set (so evaluate the last-step model)
36 | test = self.read_data(classnames, "val")
37 |
38 | preprocessed = {"train": train, "test": test}
39 | with open(self.preprocessed, "wb") as f:
40 | pickle.dump(preprocessed, f, protocol=pickle.HIGHEST_PROTOCOL)
41 |
42 | num_shots = cfg.DATASET.NUM_SHOTS
43 | if num_shots >= 1:
44 | seed = cfg.SEED
45 | preprocessed = os.path.join(self.split_fewshot_dir,
46 | f"shot_{num_shots}-seed_{seed}.pkl")
47 |
48 | if os.path.exists(preprocessed):
49 | print(
50 | f"Loading preprocessed few-shot data from {preprocessed}")
51 | with open(preprocessed, "rb") as file:
52 | data = pickle.load(file)
53 | train = data["train"]
54 | else:
55 | train = self.generate_fewshot_dataset(train,
56 | num_shots=num_shots)
57 | data = {"train": train}
58 | print(f"Saving preprocessed few-shot data to {preprocessed}")
59 | with open(preprocessed, "wb") as file:
60 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
61 |
62 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES
63 | train, test = OxfordPets.subsample_classes(train,
64 | test,
65 | subsample=subsample)
66 |
67 | super().__init__(train_x=train, val=test, test=test)
68 |
69 | @staticmethod
70 | def read_classnames(text_file):
71 | """Return a dictionary containing
72 | key-value pairs of : .
73 | """
74 | classnames = OrderedDict()
75 | with open(text_file, "r") as f:
76 | lines = f.readlines()
77 | for line in lines:
78 | line = line.strip().split(" ")
79 | folder = line[0]
80 | classname = " ".join(line[1:])
81 | classnames[folder] = classname
82 | return classnames
83 |
84 | def read_data(self, classnames, split_dir):
85 | split_dir = os.path.join(self.image_dir, split_dir)
86 | folders = sorted(f.name for f in os.scandir(split_dir) if f.is_dir())
87 | items = []
88 |
89 | for label, folder in enumerate(folders):
90 | imnames = listdir_nohidden(os.path.join(split_dir, folder))
91 | classname = classnames[folder]
92 | for imname in imnames:
93 | impath = os.path.join(split_dir, folder, imname)
94 | item = Datum(impath=impath, label=label, classname=classname)
95 | items.append(item)
96 |
97 | return items
98 |
--------------------------------------------------------------------------------
/datasets/imagenet_a.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
4 | from dassl.utils import listdir_nohidden
5 |
6 | from .imagenet import ImageNet
7 |
8 | TO_BE_IGNORED = ["README.txt"]
9 |
10 |
11 | @DATASET_REGISTRY.register()
12 | class ImageNetA(DatasetBase):
13 | """ImageNet-A(dversarial).
14 |
15 | This dataset is used for testing only.
16 | """
17 |
18 | dataset_dir = "imagenet-adversarial"
19 |
20 | def __init__(self, cfg):
21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
22 | self.dataset_dir = os.path.join(root, self.dataset_dir)
23 | self.image_dir = os.path.join(self.dataset_dir, "imagenet-a")
24 |
25 | text_file = os.path.join(self.dataset_dir, "classnames.txt")
26 | classnames = ImageNet.read_classnames(text_file)
27 |
28 | data = self.read_data(classnames)
29 |
30 | super().__init__(train_x=data, test=data)
31 |
32 | def read_data(self, classnames):
33 | image_dir = self.image_dir
34 | folders = listdir_nohidden(image_dir, sort=True)
35 | folders = [f for f in folders if f not in TO_BE_IGNORED]
36 | items = []
37 |
38 | for label, folder in enumerate(folders):
39 | imnames = listdir_nohidden(os.path.join(image_dir, folder))
40 | classname = classnames[folder]
41 | for imname in imnames:
42 | impath = os.path.join(image_dir, folder, imname)
43 | item = Datum(impath=impath, label=label, classname=classname)
44 | items.append(item)
45 |
46 | return items
47 |
--------------------------------------------------------------------------------
/datasets/imagenet_r.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
4 | from dassl.utils import listdir_nohidden
5 |
6 | from .imagenet import ImageNet
7 |
8 | TO_BE_IGNORED = ["README.txt"]
9 |
10 |
11 | @DATASET_REGISTRY.register()
12 | class ImageNetR(DatasetBase):
13 | """ImageNet-R(endition).
14 |
15 | This dataset is used for testing only.
16 | """
17 |
18 | dataset_dir = "imagenet-rendition"
19 |
20 | def __init__(self, cfg):
21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
22 | self.dataset_dir = os.path.join(root, self.dataset_dir)
23 | self.image_dir = os.path.join(self.dataset_dir, "imagenet-r")
24 |
25 | text_file = os.path.join(self.dataset_dir, "classnames.txt")
26 | classnames = ImageNet.read_classnames(text_file)
27 |
28 | data = self.read_data(classnames)
29 |
30 | super().__init__(train_x=data, test=data)
31 |
32 | def read_data(self, classnames):
33 | image_dir = self.image_dir
34 | folders = listdir_nohidden(image_dir, sort=True)
35 | folders = [f for f in folders if f not in TO_BE_IGNORED]
36 | items = []
37 |
38 | for label, folder in enumerate(folders):
39 | imnames = listdir_nohidden(os.path.join(image_dir, folder))
40 | classname = classnames[folder]
41 | for imname in imnames:
42 | impath = os.path.join(image_dir, folder, imname)
43 | item = Datum(impath=impath, label=label, classname=classname)
44 | items.append(item)
45 |
46 | return items
47 |
--------------------------------------------------------------------------------
/datasets/imagenet_sketch.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
4 | from dassl.utils import listdir_nohidden
5 |
6 | from .imagenet import ImageNet
7 |
8 |
9 | @DATASET_REGISTRY.register()
10 | class ImageNetSketch(DatasetBase):
11 | """ImageNet-Sketch.
12 |
13 | This dataset is used for testing only.
14 | """
15 |
16 | dataset_dir = "imagenet-sketch"
17 |
18 | def __init__(self, cfg):
19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
20 | self.dataset_dir = os.path.join(root, self.dataset_dir)
21 | self.image_dir = os.path.join(self.dataset_dir, "images")
22 |
23 | text_file = os.path.join(self.dataset_dir, "classnames.txt")
24 | classnames = ImageNet.read_classnames(text_file)
25 |
26 | data = self.read_data(classnames)
27 |
28 | super().__init__(train_x=data, test=data)
29 |
30 | def read_data(self, classnames):
31 | image_dir = self.image_dir
32 | folders = listdir_nohidden(image_dir, sort=True)
33 | items = []
34 |
35 | for label, folder in enumerate(folders):
36 | imnames = listdir_nohidden(os.path.join(image_dir, folder))
37 | classname = classnames[folder]
38 | for imname in imnames:
39 | impath = os.path.join(image_dir, folder, imname)
40 | item = Datum(impath=impath, label=label, classname=classname)
41 | items.append(item)
42 |
43 | return items
44 |
--------------------------------------------------------------------------------
/datasets/imagenetv2.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
4 | from dassl.utils import listdir_nohidden
5 |
6 | from .imagenet import ImageNet
7 |
8 |
9 | @DATASET_REGISTRY.register()
10 | class ImageNetV2(DatasetBase):
11 | """ImageNetV2.
12 |
13 | This dataset is used for testing only.
14 | """
15 |
16 | dataset_dir = "imagenetv2"
17 |
18 | def __init__(self, cfg):
19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
20 | self.dataset_dir = os.path.join(root, self.dataset_dir)
21 | image_dir = "imagenetv2-matched-frequency-format-val"
22 | self.image_dir = os.path.join(self.dataset_dir, image_dir)
23 |
24 | text_file = os.path.join(self.dataset_dir, "classnames.txt")
25 | classnames = ImageNet.read_classnames(text_file)
26 |
27 | data = self.read_data(classnames)
28 |
29 | super().__init__(train_x=data, test=data)
30 |
31 | def read_data(self, classnames):
32 | image_dir = self.image_dir
33 | folders = list(classnames.keys())
34 | items = []
35 |
36 | for label in range(1000):
37 | class_dir = os.path.join(image_dir, str(label))
38 | imnames = listdir_nohidden(class_dir)
39 | folder = folders[label]
40 | classname = classnames[folder]
41 | for imname in imnames:
42 | impath = os.path.join(class_dir, imname)
43 | item = Datum(impath=impath, label=label, classname=classname)
44 | items.append(item)
45 |
46 | return items
47 |
--------------------------------------------------------------------------------
/datasets/oxford_flowers.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import random
4 | from scipy.io import loadmat
5 | from collections import defaultdict
6 |
7 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
8 | from dassl.utils import read_json, mkdir_if_missing
9 |
10 | from .oxford_pets import OxfordPets
11 |
12 |
13 | @DATASET_REGISTRY.register()
14 | class OxfordFlowers(DatasetBase):
15 |
16 | dataset_dir = "oxford_flowers"
17 |
18 | def __init__(self, cfg):
19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
20 | self.dataset_dir = os.path.join(root, self.dataset_dir)
21 | self.image_dir = os.path.join(self.dataset_dir, "jpg")
22 | self.label_file = os.path.join(self.dataset_dir, "imagelabels.mat")
23 | self.lab2cname_file = os.path.join(self.dataset_dir,
24 | "cat_to_name.json")
25 | self.split_path = os.path.join(self.dataset_dir,
26 | "split_zhou_OxfordFlowers.json")
27 | self.split_fewshot_dir = os.path.join(self.dataset_dir,
28 | "split_fewshot")
29 | mkdir_if_missing(self.split_fewshot_dir)
30 |
31 | if os.path.exists(self.split_path):
32 | train, val, test = OxfordPets.read_split(self.split_path,
33 | self.image_dir)
34 | else:
35 | train, val, test = self.read_data()
36 | OxfordPets.save_split(train, val, test, self.split_path,
37 | self.image_dir)
38 |
39 | num_shots = cfg.DATASET.NUM_SHOTS
40 | if num_shots >= 1:
41 | seed = cfg.SEED
42 | preprocessed = os.path.join(self.split_fewshot_dir,
43 | f"shot_{num_shots}-seed_{seed}.pkl")
44 |
45 | if os.path.exists(preprocessed):
46 | print(
47 | f"Loading preprocessed few-shot data from {preprocessed}")
48 | with open(preprocessed, "rb") as file:
49 | data = pickle.load(file)
50 | train, val = data["train"], data["val"]
51 | else:
52 | train = self.generate_fewshot_dataset(train,
53 | num_shots=num_shots)
54 | val = self.generate_fewshot_dataset(val,
55 | num_shots=min(
56 | num_shots, 4))
57 | data = {"train": train, "val": val}
58 | print(f"Saving preprocessed few-shot data to {preprocessed}")
59 | with open(preprocessed, "wb") as file:
60 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
61 |
62 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES
63 | train, val, test = OxfordPets.subsample_classes(train,
64 | val,
65 | test,
66 | subsample=subsample)
67 |
68 | super().__init__(train_x=train, val=val, test=test)
69 |
70 | def read_data(self):
71 | tracker = defaultdict(list)
72 | label_file = loadmat(self.label_file)["labels"][0]
73 | for i, label in enumerate(label_file):
74 | imname = f"image_{str(i + 1).zfill(5)}.jpg"
75 | impath = os.path.join(self.image_dir, imname)
76 | label = int(label)
77 | tracker[label].append(impath)
78 |
79 | print("Splitting data into 50% train, 20% val, and 30% test")
80 |
81 | def _collate(ims, y, c):
82 | items = []
83 | for im in ims:
84 | item = Datum(impath=im, label=y - 1,
85 | classname=c) # convert to 0-based label
86 | items.append(item)
87 | return items
88 |
89 | lab2cname = read_json(self.lab2cname_file)
90 | train, val, test = [], [], []
91 | for label, impaths in tracker.items():
92 | random.shuffle(impaths)
93 | n_total = len(impaths)
94 | n_train = round(n_total * 0.5)
95 | n_val = round(n_total * 0.2)
96 | n_test = n_total - n_train - n_val
97 | assert n_train > 0 and n_val > 0 and n_test > 0
98 | cname = lab2cname[str(label)]
99 | train.extend(_collate(impaths[:n_train], label, cname))
100 | val.extend(_collate(impaths[n_train:n_train + n_val], label,
101 | cname))
102 | test.extend(_collate(impaths[n_train + n_val:], label, cname))
103 |
104 | return train, val, test
105 |
--------------------------------------------------------------------------------
/datasets/oxford_pets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import math
4 | import random
5 | from collections import defaultdict
6 |
7 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
8 | from dassl.utils import read_json, write_json, mkdir_if_missing
9 |
10 |
11 | @DATASET_REGISTRY.register()
12 | class OxfordPets(DatasetBase):
13 |
14 | dataset_dir = "oxford_pets"
15 |
16 | def __init__(self, cfg):
17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
18 | self.dataset_dir = os.path.join(root, self.dataset_dir)
19 | self.image_dir = os.path.join(self.dataset_dir, "images")
20 | self.anno_dir = os.path.join(self.dataset_dir, "annotations")
21 | self.split_path = os.path.join(self.dataset_dir,
22 | "split_zhou_OxfordPets.json")
23 | self.split_fewshot_dir = os.path.join(self.dataset_dir,
24 | "split_fewshot")
25 | mkdir_if_missing(self.split_fewshot_dir)
26 |
27 | if os.path.exists(self.split_path):
28 | train, val, test = self.read_split(self.split_path, self.image_dir)
29 | else:
30 | trainval = self.read_data(split_file="trainval.txt")
31 | test = self.read_data(split_file="test.txt")
32 | train, val = self.split_trainval(trainval)
33 | self.save_split(train, val, test, self.split_path, self.image_dir)
34 |
35 | num_shots = cfg.DATASET.NUM_SHOTS
36 | if num_shots >= 1:
37 | seed = cfg.SEED
38 | preprocessed = os.path.join(self.split_fewshot_dir,
39 | f"shot_{num_shots}-seed_{seed}.pkl")
40 |
41 | if os.path.exists(preprocessed):
42 | print(
43 | f"Loading preprocessed few-shot data from {preprocessed}")
44 | with open(preprocessed, "rb") as file:
45 | data = pickle.load(file)
46 | train, val = data["train"], data["val"]
47 | else:
48 | train = self.generate_fewshot_dataset(train,
49 | num_shots=num_shots)
50 | val = self.generate_fewshot_dataset(val,
51 | num_shots=min(
52 | num_shots, 4))
53 | data = {"train": train, "val": val}
54 | print(f"Saving preprocessed few-shot data to {preprocessed}")
55 | with open(preprocessed, "wb") as file:
56 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
57 |
58 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES
59 | train, val, test = self.subsample_classes(train,
60 | val,
61 | test,
62 | subsample=subsample)
63 |
64 | super().__init__(train_x=train, val=val, test=test)
65 |
66 | def read_data(self, split_file):
67 | filepath = os.path.join(self.anno_dir, split_file)
68 | items = []
69 |
70 | with open(filepath, "r") as f:
71 | lines = f.readlines()
72 | for line in lines:
73 | line = line.strip()
74 | imname, label, species, _ = line.split(" ")
75 | breed = imname.split("_")[:-1]
76 | breed = "_".join(breed)
77 | breed = breed.lower()
78 | imname += ".jpg"
79 | impath = os.path.join(self.image_dir, imname)
80 | label = int(label) - 1 # convert to 0-based index
81 | item = Datum(impath=impath, label=label, classname=breed)
82 | items.append(item)
83 |
84 | return items
85 |
86 | @staticmethod
87 | def split_trainval(trainval, p_val=0.2):
88 | p_trn = 1 - p_val
89 | print(f"Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val")
90 | tracker = defaultdict(list)
91 | for idx, item in enumerate(trainval):
92 | label = item.label
93 | tracker[label].append(idx)
94 |
95 | train, val = [], []
96 | for label, idxs in tracker.items():
97 | n_val = round(len(idxs) * p_val)
98 | assert n_val > 0
99 | random.shuffle(idxs)
100 | for n, idx in enumerate(idxs):
101 | item = trainval[idx]
102 | if n < n_val:
103 | val.append(item)
104 | else:
105 | train.append(item)
106 |
107 | return train, val
108 |
109 | @staticmethod
110 | def save_split(train, val, test, filepath, path_prefix):
111 | def _extract(items):
112 | out = []
113 | for item in items:
114 | impath = item.impath
115 | label = item.label
116 | classname = item.classname
117 | impath = impath.replace(path_prefix, "")
118 | if impath.startswith("/"):
119 | impath = impath[1:]
120 | out.append((impath, label, classname))
121 | return out
122 |
123 | train = _extract(train)
124 | val = _extract(val)
125 | test = _extract(test)
126 |
127 | split = {"train": train, "val": val, "test": test}
128 |
129 | write_json(split, filepath)
130 | print(f"Saved split to {filepath}")
131 |
132 | @staticmethod
133 | def read_split(filepath, path_prefix):
134 | def _convert(items):
135 | out = []
136 | for impath, label, classname in items:
137 | impath = os.path.join(path_prefix, impath)
138 | item = Datum(impath=impath,
139 | label=int(label),
140 | classname=classname)
141 | out.append(item)
142 | return out
143 |
144 | print(f"Reading split from {filepath}")
145 | split = read_json(filepath)
146 | train = _convert(split["train"])
147 | val = _convert(split["val"])
148 | test = _convert(split["test"])
149 |
150 | return train, val, test
151 |
152 | @staticmethod
153 | def subsample_classes(*args, subsample="all"):
154 | """Divide classes into two groups. The first group
155 | represents base classes while the second group represents
156 | new classes.
157 |
158 | Args:
159 | args: a list of datasets, e.g. train, val and test.
160 | subsample (str): what classes to subsample.
161 | """
162 | assert subsample in ["all", "base", "new"]
163 |
164 | if subsample == "all":
165 | return args
166 |
167 | dataset = args[0]
168 | labels = set()
169 | for item in dataset:
170 | labels.add(item.label)
171 | labels = list(labels)
172 | labels.sort()
173 | n = len(labels)
174 | # Divide classes into two halves
175 | m = math.ceil(n / 2)
176 |
177 | print(f"SUBSAMPLE {subsample.upper()} CLASSES!")
178 | if subsample == "base":
179 | selected = labels[:m] # take the first half
180 | else:
181 | selected = labels[m:] # take the second half
182 | relabeler = {y: y_new for y_new, y in enumerate(selected)}
183 |
184 | output = []
185 | for dataset in args:
186 | dataset_new = []
187 | for item in dataset:
188 | if item.label not in selected:
189 | continue
190 | item_new = Datum(impath=item.impath,
191 | label=relabeler[item.label],
192 | classname=item.classname)
193 | dataset_new.append(item_new)
194 | output.append(dataset_new)
195 |
196 | return output
197 |
--------------------------------------------------------------------------------
/datasets/stanford_cars.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from scipy.io import loadmat
4 |
5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
6 | from dassl.utils import mkdir_if_missing
7 |
8 | from .oxford_pets import OxfordPets
9 |
10 |
11 | @DATASET_REGISTRY.register()
12 | class StanfordCars(DatasetBase):
13 |
14 | dataset_dir = "stanford_cars"
15 |
16 | def __init__(self, cfg):
17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
18 | self.dataset_dir = os.path.join(root, self.dataset_dir)
19 | self.split_path = os.path.join(self.dataset_dir,
20 | "split_zhou_StanfordCars.json")
21 | self.split_fewshot_dir = os.path.join(self.dataset_dir,
22 | "split_fewshot")
23 | mkdir_if_missing(self.split_fewshot_dir)
24 |
25 | if os.path.exists(self.split_path):
26 | train, val, test = OxfordPets.read_split(self.split_path,
27 | self.dataset_dir)
28 | else:
29 | trainval_file = os.path.join(self.dataset_dir, "devkit",
30 | "cars_train_annos.mat")
31 | test_file = os.path.join(self.dataset_dir,
32 | "cars_test_annos_withlabels.mat")
33 | meta_file = os.path.join(self.dataset_dir, "devkit",
34 | "cars_meta.mat")
35 | trainval = self.read_data("cars_train", trainval_file, meta_file)
36 | test = self.read_data("cars_test", test_file, meta_file)
37 | train, val = OxfordPets.split_trainval(trainval)
38 | OxfordPets.save_split(train, val, test, self.split_path,
39 | self.dataset_dir)
40 |
41 | num_shots = cfg.DATASET.NUM_SHOTS
42 | if num_shots >= 1:
43 | seed = cfg.SEED
44 | preprocessed = os.path.join(self.split_fewshot_dir,
45 | f"shot_{num_shots}-seed_{seed}.pkl")
46 |
47 | if os.path.exists(preprocessed):
48 | print(
49 | f"Loading preprocessed few-shot data from {preprocessed}")
50 | with open(preprocessed, "rb") as file:
51 | data = pickle.load(file)
52 | train, val = data["train"], data["val"]
53 | else:
54 | train = self.generate_fewshot_dataset(train,
55 | num_shots=num_shots)
56 | val = self.generate_fewshot_dataset(val,
57 | num_shots=min(
58 | num_shots, 4))
59 | data = {"train": train, "val": val}
60 | print(f"Saving preprocessed few-shot data to {preprocessed}")
61 | with open(preprocessed, "wb") as file:
62 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
63 |
64 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES
65 | train, val, test = OxfordPets.subsample_classes(train,
66 | val,
67 | test,
68 | subsample=subsample)
69 |
70 | super().__init__(train_x=train, val=val, test=test)
71 |
72 | def read_data(self, image_dir, anno_file, meta_file):
73 | anno_file = loadmat(anno_file)["annotations"][0]
74 | meta_file = loadmat(meta_file)["class_names"][0]
75 | items = []
76 |
77 | for i in range(len(anno_file)):
78 | imname = anno_file[i]["fname"][0]
79 | impath = os.path.join(self.dataset_dir, image_dir, imname)
80 | label = anno_file[i]["class"][0, 0]
81 | label = int(label) - 1 # convert to 0-based index
82 | classname = meta_file[label][0]
83 | names = classname.split(" ")
84 | year = names.pop(-1)
85 | names.insert(0, year)
86 | classname = " ".join(names)
87 | item = Datum(impath=impath, label=label, classname=classname)
88 | items.append(item)
89 |
90 | return items
91 |
--------------------------------------------------------------------------------
/datasets/sun397.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
5 | from dassl.utils import mkdir_if_missing
6 |
7 | from .oxford_pets import OxfordPets
8 |
9 |
10 | @DATASET_REGISTRY.register()
11 | class SUN397(DatasetBase):
12 |
13 | dataset_dir = "sun397"
14 |
15 | def __init__(self, cfg):
16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
17 | self.dataset_dir = os.path.join(root, self.dataset_dir)
18 | self.image_dir = os.path.join(self.dataset_dir, "SUN397")
19 | self.split_path = os.path.join(self.dataset_dir,
20 | "split_zhou_SUN397.json")
21 | self.split_fewshot_dir = os.path.join(self.dataset_dir,
22 | "split_fewshot")
23 | mkdir_if_missing(self.split_fewshot_dir)
24 |
25 | if os.path.exists(self.split_path):
26 | train, val, test = OxfordPets.read_split(self.split_path,
27 | self.image_dir)
28 | else:
29 | classnames = []
30 | with open(os.path.join(self.dataset_dir, "ClassName.txt"),
31 | "r") as f:
32 | lines = f.readlines()
33 | for line in lines:
34 | line = line.strip()[1:] # remove /
35 | classnames.append(line)
36 | cname2lab = {c: i for i, c in enumerate(classnames)}
37 | trainval = self.read_data(cname2lab, "Training_01.txt")
38 | test = self.read_data(cname2lab, "Testing_01.txt")
39 | train, val = OxfordPets.split_trainval(trainval)
40 | OxfordPets.save_split(train, val, test, self.split_path,
41 | self.image_dir)
42 |
43 | num_shots = cfg.DATASET.NUM_SHOTS
44 | if num_shots >= 1:
45 | seed = cfg.SEED
46 | preprocessed = os.path.join(self.split_fewshot_dir,
47 | f"shot_{num_shots}-seed_{seed}.pkl")
48 |
49 | if os.path.exists(preprocessed):
50 | print(
51 | f"Loading preprocessed few-shot data from {preprocessed}")
52 | with open(preprocessed, "rb") as file:
53 | data = pickle.load(file)
54 | train, val = data["train"], data["val"]
55 | else:
56 | train = self.generate_fewshot_dataset(train,
57 | num_shots=num_shots)
58 | val = self.generate_fewshot_dataset(val,
59 | num_shots=min(
60 | num_shots, 4))
61 | data = {"train": train, "val": val}
62 | print(f"Saving preprocessed few-shot data to {preprocessed}")
63 | with open(preprocessed, "wb") as file:
64 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
65 |
66 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES
67 | train, val, test = OxfordPets.subsample_classes(train,
68 | val,
69 | test,
70 | subsample=subsample)
71 |
72 | super().__init__(train_x=train, val=val, test=test)
73 |
74 | def read_data(self, cname2lab, text_file):
75 | text_file = os.path.join(self.dataset_dir, text_file)
76 | items = []
77 |
78 | with open(text_file, "r") as f:
79 | lines = f.readlines()
80 | for line in lines:
81 | imname = line.strip()[1:] # remove /
82 | classname = os.path.dirname(imname)
83 | label = cname2lab[classname]
84 | impath = os.path.join(self.image_dir, imname)
85 |
86 | names = classname.split("/")[1:] # remove 1st letter
87 | names = names[::-1] # put words like indoor/outdoor at first
88 | classname = " ".join(names)
89 |
90 | item = Datum(impath=impath, label=label, classname=classname)
91 | items.append(item)
92 |
93 | return items
94 |
--------------------------------------------------------------------------------
/datasets/ucf101.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import re
4 |
5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
6 | from dassl.utils import mkdir_if_missing
7 |
8 | from .oxford_pets import OxfordPets
9 |
10 |
11 | @DATASET_REGISTRY.register()
12 | class UCF101(DatasetBase):
13 |
14 | dataset_dir = "ucf101"
15 |
16 | def __init__(self, cfg):
17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
18 | self.dataset_dir = os.path.join(root, self.dataset_dir)
19 | self.image_dir = os.path.join(self.dataset_dir, "UCF-101-midframes")
20 | self.split_path = os.path.join(self.dataset_dir,
21 | "split_zhou_UCF101.json")
22 | self.split_fewshot_dir = os.path.join(self.dataset_dir,
23 | "split_fewshot")
24 | mkdir_if_missing(self.split_fewshot_dir)
25 |
26 | if os.path.exists(self.split_path):
27 | train, val, test = OxfordPets.read_split(self.split_path,
28 | self.image_dir)
29 | else:
30 | cname2lab = {}
31 | filepath = os.path.join(self.dataset_dir,
32 | "ucfTrainTestlist/classInd.txt")
33 | with open(filepath, "r") as f:
34 | lines = f.readlines()
35 | for line in lines:
36 | label, classname = line.strip().split(" ")
37 | label = int(label) - 1 # conver to 0-based index
38 | cname2lab[classname] = label
39 |
40 | trainval = self.read_data(cname2lab,
41 | "ucfTrainTestlist/trainlist01.txt")
42 | test = self.read_data(cname2lab, "ucfTrainTestlist/testlist01.txt")
43 | train, val = OxfordPets.split_trainval(trainval)
44 | OxfordPets.save_split(train, val, test, self.split_path,
45 | self.image_dir)
46 |
47 | num_shots = cfg.DATASET.NUM_SHOTS
48 | if num_shots >= 1:
49 | seed = cfg.SEED
50 | preprocessed = os.path.join(self.split_fewshot_dir,
51 | f"shot_{num_shots}-seed_{seed}.pkl")
52 |
53 | if os.path.exists(preprocessed):
54 | print(
55 | f"Loading preprocessed few-shot data from {preprocessed}")
56 | with open(preprocessed, "rb") as file:
57 | data = pickle.load(file)
58 | train, val = data["train"], data["val"]
59 | else:
60 | train = self.generate_fewshot_dataset(train,
61 | num_shots=num_shots)
62 | val = self.generate_fewshot_dataset(val,
63 | num_shots=min(
64 | num_shots, 4))
65 | data = {"train": train, "val": val}
66 | print(f"Saving preprocessed few-shot data to {preprocessed}")
67 | with open(preprocessed, "wb") as file:
68 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
69 |
70 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES
71 | train, val, test = OxfordPets.subsample_classes(train,
72 | val,
73 | test,
74 | subsample=subsample)
75 |
76 | super().__init__(train_x=train, val=val, test=test)
77 |
78 | def read_data(self, cname2lab, text_file):
79 | text_file = os.path.join(self.dataset_dir, text_file)
80 | items = []
81 |
82 | with open(text_file, "r") as f:
83 | lines = f.readlines()
84 | for line in lines:
85 | line = line.strip().split(" ")[0] # trainlist: filename, label
86 | action, filename = line.split("/")
87 | label = cname2lab[action]
88 |
89 | elements = re.findall("[A-Z][^A-Z]*", action)
90 | renamed_action = "_".join(elements)
91 |
92 | filename = filename.replace(".avi", ".jpg")
93 | impath = os.path.join(self.image_dir, renamed_action, filename)
94 |
95 | item = Datum(impath=impath,
96 | label=label,
97 | classname=renamed_action)
98 | items.append(item)
99 |
100 | return items
101 |
--------------------------------------------------------------------------------
/eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | for DATASET in caltech101 oxford_pets stanford_cars oxford_flowers food101 fgvc_aircraft dtd eurosat ucf101
4 | do
5 | python parse_test_res.py output_1108_4/base2new/train_base/${DATASET}/shots_16_1.0/TCP/vit_b16_ep100_ctxv1/ --test-log
6 | python parse_test_res.py output_1108_4_eval/base2new/test_new/${DATASET}/shots_16_1.0/TCP/vit_b16_ep100_ctxv1/ --test-log
7 | done
8 |
--------------------------------------------------------------------------------
/interpret_prompt.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | import torch
5 |
6 | from clip.simple_tokenizer import SimpleTokenizer
7 | from clip import clip
8 |
9 |
10 | def load_clip_to_cpu(backbone_name="RN50"):
11 | url = clip._MODELS[backbone_name]
12 | model_path = clip._download(url)
13 |
14 | try:
15 | # loading JIT archive
16 | model = torch.jit.load(model_path, map_location="cpu").eval()
17 | state_dict = None
18 |
19 | except RuntimeError:
20 | state_dict = torch.load(model_path, map_location="cpu")
21 |
22 | model = clip.build_model(state_dict or model.state_dict())
23 |
24 | return model
25 |
26 |
27 | parser = argparse.ArgumentParser()
28 | parser.add_argument("fpath", type=str, help="Path to the learned prompt")
29 | parser.add_argument("topk", type=int, help="Select top-k similar words")
30 | args = parser.parse_args()
31 |
32 | fpath = args.fpath
33 | topk = args.topk
34 |
35 | assert os.path.exists(fpath)
36 |
37 | print(f"Return the top-{topk} matched words")
38 |
39 | tokenizer = SimpleTokenizer()
40 | clip_model = load_clip_to_cpu()
41 | token_embedding = clip_model.token_embedding.weight
42 | print(f"Size of token embedding: {token_embedding.shape}")
43 |
44 | prompt_learner = torch.load(fpath, map_location="cpu")["state_dict"]
45 | ctx = prompt_learner["ctx"]
46 | ctx = ctx.float()
47 | print(f"Size of context: {ctx.shape}")
48 |
49 | if ctx.dim() == 2:
50 | # Generic context
51 | distance = torch.cdist(ctx, token_embedding)
52 | print(f"Size of distance matrix: {distance.shape}")
53 | sorted_idxs = torch.argsort(distance, dim=1)
54 | sorted_idxs = sorted_idxs[:, :topk]
55 |
56 | for m, idxs in enumerate(sorted_idxs):
57 | words = [tokenizer.decoder[idx.item()] for idx in idxs]
58 | dist = [f"{distance[m, idx].item():.4f}" for idx in idxs]
59 | print(f"{m+1}: {words} {dist}")
60 |
61 | elif ctx.dim() == 3:
62 | # Class-specific context
63 | raise NotImplementedError
64 |
--------------------------------------------------------------------------------
/lpclip/README.md:
--------------------------------------------------------------------------------
1 | # Linear Probe CLIP
2 |
3 | To run linear probe baselines, make sure that your current working directory is `lpclip/`.
4 |
5 | Step 1: Extract Features using the CLIP Image Encoder
6 | ```bash
7 | sh feat_extractor.sh
8 | ```
9 |
10 | Step 2: Train few-shot linear probe
11 | ```bash
12 | sh linear_probe.sh
13 | ```
14 |
15 | We follow the instructions stated in the Appendix A3 (pp.38) of [the original CLIP paper](https://arxiv.org/pdf/2103.00020.pdf), with a careful hyperparameter sweep.
16 |
17 | Note: please pull the latest Dassl (version >= `606a2c6`).
18 |
--------------------------------------------------------------------------------
/lpclip/feat_extractor.py:
--------------------------------------------------------------------------------
1 | import os, argparse
2 | import numpy as np
3 | import torch
4 | import sys
5 |
6 | sys.path.append(os.path.abspath(".."))
7 |
8 | from datasets.oxford_pets import OxfordPets
9 | from datasets.oxford_flowers import OxfordFlowers
10 | from datasets.fgvc_aircraft import FGVCAircraft
11 | from datasets.dtd import DescribableTextures
12 | from datasets.eurosat import EuroSAT
13 | from datasets.stanford_cars import StanfordCars
14 | from datasets.food101 import Food101
15 | from datasets.sun397 import SUN397
16 | from datasets.caltech101 import Caltech101
17 | from datasets.ucf101 import UCF101
18 | from datasets.imagenet import ImageNet
19 | from datasets.imagenetv2 import ImageNetV2
20 | from datasets.imagenet_sketch import ImageNetSketch
21 | from datasets.imagenet_a import ImageNetA
22 | from datasets.imagenet_r import ImageNetR
23 |
24 | from dassl.utils import setup_logger, set_random_seed, collect_env_info
25 | from dassl.config import get_cfg_default
26 | from dassl.data.transforms import build_transform
27 | from dassl.data import DatasetWrapper
28 |
29 | import clip
30 |
31 | # import pdb; pdb.set_trace()
32 |
33 |
34 | def print_args(args, cfg):
35 | print("***************")
36 | print("** Arguments **")
37 | print("***************")
38 | optkeys = list(args.__dict__.keys())
39 | optkeys.sort()
40 | for key in optkeys:
41 | print("{}: {}".format(key, args.__dict__[key]))
42 | print("************")
43 | print("** Config **")
44 | print("************")
45 | print(cfg)
46 |
47 |
48 | def reset_cfg(cfg, args):
49 | if args.root:
50 | cfg.DATASET.ROOT = args.root
51 |
52 | if args.output_dir:
53 | cfg.OUTPUT_DIR = args.output_dir
54 |
55 | if args.trainer:
56 | cfg.TRAINER.NAME = args.trainer
57 |
58 | if args.backbone:
59 | cfg.MODEL.BACKBONE.NAME = args.backbone
60 |
61 | if args.head:
62 | cfg.MODEL.HEAD.NAME = args.head
63 |
64 |
65 | def extend_cfg(cfg):
66 | """
67 | Add new config variables.
68 |
69 | E.g.
70 | from yacs.config import CfgNode as CN
71 | cfg.TRAINER.MY_MODEL = CN()
72 | cfg.TRAINER.MY_MODEL.PARAM_A = 1.
73 | cfg.TRAINER.MY_MODEL.PARAM_B = 0.5
74 | cfg.TRAINER.MY_MODEL.PARAM_C = False
75 | """
76 | from yacs.config import CfgNode as CN
77 |
78 | cfg.TRAINER.OURS = CN()
79 | cfg.TRAINER.OURS.N_CTX = 10 # number of context vectors
80 | cfg.TRAINER.OURS.CSC = False # class-specific context
81 | cfg.TRAINER.OURS.CTX_INIT = "" # initialize context vectors with given words
82 | cfg.TRAINER.OURS.WEIGHT_U = 0.1 # weight for the unsupervised loss
83 | cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
84 |
85 |
86 | def setup_cfg(args):
87 | cfg = get_cfg_default()
88 | extend_cfg(cfg)
89 |
90 | # 1. From the dataset config file
91 | if args.dataset_config_file:
92 | cfg.merge_from_file(args.dataset_config_file)
93 |
94 | # 2. From the method config file
95 | if args.config_file:
96 | cfg.merge_from_file(args.config_file)
97 |
98 | # 3. From input arguments
99 | reset_cfg(cfg, args)
100 |
101 | cfg.freeze()
102 |
103 | return cfg
104 |
105 |
106 | def main(args):
107 | cfg = setup_cfg(args)
108 | if cfg.SEED >= 0:
109 | print("Setting fixed seed: {}".format(cfg.SEED))
110 | set_random_seed(cfg.SEED)
111 | setup_logger(cfg.OUTPUT_DIR)
112 |
113 | if torch.cuda.is_available() and cfg.USE_CUDA:
114 | torch.backends.cudnn.benchmark = True
115 |
116 | print_args(args, cfg)
117 | print("Collecting env info ...")
118 | print("** System info **\n{}\n".format(collect_env_info()))
119 |
120 | ######################################
121 | # Setup DataLoader
122 | ######################################
123 | dataset = eval(cfg.DATASET.NAME)(cfg)
124 |
125 | if args.split == "train":
126 | dataset_input = dataset.train_x
127 | elif args.split == "val":
128 | dataset_input = dataset.val
129 | else:
130 | dataset_input = dataset.test
131 |
132 | tfm_train = build_transform(cfg, is_train=False)
133 | data_loader = torch.utils.data.DataLoader(
134 | DatasetWrapper(cfg, dataset_input, transform=tfm_train,
135 | is_train=False),
136 | batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE,
137 | sampler=None,
138 | shuffle=False,
139 | num_workers=cfg.DATALOADER.NUM_WORKERS,
140 | drop_last=False,
141 | pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA),
142 | )
143 |
144 | ########################################
145 | # Setup Network
146 | ########################################
147 | clip_model, _ = clip.load("RN50", "cuda", jit=False)
148 | clip_model.eval()
149 | ###################################################################################################################
150 | # Start Feature Extractor
151 | feature_list = []
152 | label_list = []
153 | train_dataiter = iter(data_loader)
154 | for train_step in range(1, len(train_dataiter) + 1):
155 | batch = next(train_dataiter)
156 | data = batch["img"].cuda()
157 | feature = clip_model.visual(data)
158 | feature = feature.cpu()
159 | for idx in range(len(data)):
160 | feature_list.append(feature[idx].tolist())
161 | label_list.extend(batch["label"].tolist())
162 | save_dir = os.path.join(cfg.OUTPUT_DIR, cfg.DATASET.NAME)
163 | os.makedirs(save_dir, exist_ok=True)
164 | save_filename = f"{args.split}"
165 | np.savez(
166 | os.path.join(save_dir, save_filename),
167 | feature_list=feature_list,
168 | label_list=label_list,
169 | )
170 |
171 |
172 | if __name__ == "__main__":
173 | parser = argparse.ArgumentParser()
174 | parser.add_argument("--root", type=str, default="", help="path to dataset")
175 | parser.add_argument("--output-dir",
176 | type=str,
177 | default="",
178 | help="output directory")
179 | parser.add_argument("--config-file",
180 | type=str,
181 | default="",
182 | help="path to config file")
183 | parser.add_argument(
184 | "--dataset-config-file",
185 | type=str,
186 | default="",
187 | help="path to config file for dataset setup",
188 | )
189 | parser.add_argument("--num-shot",
190 | type=int,
191 | default=1,
192 | help="number of shots")
193 | parser.add_argument("--split",
194 | type=str,
195 | choices=["train", "val", "test"],
196 | help="which split")
197 | parser.add_argument("--trainer",
198 | type=str,
199 | default="",
200 | help="name of trainer")
201 | parser.add_argument("--backbone",
202 | type=str,
203 | default="",
204 | help="name of CNN backbone")
205 | parser.add_argument("--head", type=str, default="", help="name of head")
206 | parser.add_argument("--seed",
207 | type=int,
208 | default=-1,
209 | help="only positive value enables a fixed seed")
210 | parser.add_argument("--eval-only",
211 | action="store_true",
212 | help="evaluation only")
213 | args = parser.parse_args()
214 | main(args)
215 |
--------------------------------------------------------------------------------
/lpclip/feat_extractor.sh:
--------------------------------------------------------------------------------
1 | # sh feat_extractor.sh
2 | DATA=/data1/CoOpData
3 | OUTPUT='/data1/CoOpData/clip_feat/'
4 | SEED=1
5 |
6 | GPULIST=(0 1 2 3)
7 | GPUIDX=0
8 |
9 | # oxford_pets oxford_flowers fgvc_aircraft dtd eurosat stanford_cars food101 sun397 caltech101 ucf101 imagenet
10 | # imagenet oxford_pets oxford_flowers stanford_cars food101 caltech101
11 | for DATASET in imagenetv2 imagenet_sketch imagenet_a imagenet_r
12 | do
13 | for SPLIT in train val test
14 | do
15 | while true
16 | do
17 | sleep 10
18 | let STATIDX=GPULIST[GPUIDX]+2
19 | stat=$(gpustat | awk '{print $11}' | sed -n ${STATIDX}'p')
20 | if [ "$stat" -lt 20 ]
21 | then
22 | break
23 | fi
24 | let GPUIDX=(GPUIDX+1)%${#GPULIST[@]}
25 | echo $GPUIDX'N'
26 | done
27 | CUDA_VISIBLE_DEVICES=${GPULIST[${GPUIDX}]} python feat_extractor.py \
28 | --split ${SPLIT} \
29 | --root ${DATA} \
30 | --seed ${SEED} \
31 | --dataset-config-file ../configs/datasets/${DATASET}.yaml \
32 | --config-file ../configs/trainers/CoOp/rn50_val.yaml \
33 | --output-dir ${OUTPUT} \
34 | --eval-only &
35 | sleep 10
36 | done
37 | done
38 |
--------------------------------------------------------------------------------
/lpclip/linear_probe.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from sklearn.linear_model import LogisticRegression
4 | import argparse
5 |
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument("--dataset", type=str, default="", help="path to dataset")
8 | parser.add_argument("--num_step", type=int, default=8, help="number of steps")
9 | parser.add_argument("--num_run", type=int, default=10, help="number of runs")
10 | parser.add_argument("--feature_dir",
11 | type=str,
12 | default="clip_feat",
13 | help="feature dir path")
14 | args = parser.parse_args()
15 |
16 | dataset = args.dataset
17 | dataset_path = os.path.join(f"{args.feature_dir}", dataset)
18 |
19 | train_file = np.load(os.path.join(dataset_path, "train.npz"))
20 | train_feature, train_label = train_file["feature_list"], train_file[
21 | "label_list"]
22 | val_file = np.load(os.path.join(dataset_path, "val.npz"))
23 | val_feature, val_label = val_file["feature_list"], val_file["label_list"]
24 | test_file = np.load(os.path.join(dataset_path, "test.npz"))
25 | test_feature, test_label = test_file["feature_list"], test_file["label_list"]
26 |
27 | os.makedirs("report", exist_ok=True)
28 |
29 | val_shot_list = {1: 1, 2: 2, 4: 4, 8: 4, 16: 4}
30 |
31 | # for num_shot in [1, 2, 4, 8, 16]:
32 | for num_shot in [4, 16]:
33 | test_acc_step_list = np.zeros([args.num_run, args.num_step])
34 | for seed in range(1, args.num_run + 1):
35 | np.random.seed(seed)
36 | print(
37 | f"-- Seed: {seed} --------------------------------------------------------------"
38 | )
39 | # Sampling
40 | all_label_list = np.unique(train_label)
41 | selected_idx_list = []
42 | for label in all_label_list:
43 | label_collection = np.where(train_label == label)[0]
44 | selected_idx = np.random.choice(label_collection,
45 | size=num_shot,
46 | replace=False)
47 | selected_idx_list.extend(selected_idx)
48 |
49 | fewshot_train_feature = train_feature[selected_idx_list]
50 | fewshot_train_label = train_label[selected_idx_list]
51 |
52 | val_num_shot = val_shot_list[num_shot]
53 | val_selected_idx_list = []
54 | for label in all_label_list:
55 | label_collection = np.where(val_label == label)[0]
56 | selected_idx = np.random.choice(label_collection,
57 | size=val_num_shot,
58 | replace=False)
59 | val_selected_idx_list.extend(selected_idx)
60 |
61 | fewshot_val_feature = val_feature[val_selected_idx_list]
62 | fewshot_val_label = val_label[val_selected_idx_list]
63 |
64 | # search initialization
65 | search_list = [1e6, 1e4, 1e2, 1, 1e-2, 1e-4, 1e-6]
66 | acc_list = []
67 | for c_weight in search_list:
68 | clf = LogisticRegression(solver="lbfgs",
69 | max_iter=1000,
70 | penalty="l2",
71 | C=c_weight).fit(fewshot_train_feature,
72 | fewshot_train_label)
73 | pred = clf.predict(fewshot_val_feature)
74 | acc_val = sum(pred == fewshot_val_label) / len(fewshot_val_label)
75 | acc_list.append(acc_val)
76 |
77 | print(acc_list, flush=True)
78 |
79 | # binary search
80 | peak_idx = np.argmax(acc_list)
81 | c_peak = search_list[peak_idx]
82 | c_left, c_right = 1e-1 * c_peak, 1e1 * c_peak
83 |
84 | def binary_search(c_left, c_right, seed, step, test_acc_step_list):
85 | clf_left = LogisticRegression(solver="lbfgs",
86 | max_iter=1000,
87 | penalty="l2",
88 | C=c_left).fit(
89 | fewshot_train_feature,
90 | fewshot_train_label)
91 | pred_left = clf_left.predict(fewshot_val_feature)
92 | acc_left = sum(
93 | pred_left == fewshot_val_label) / len(fewshot_val_label)
94 | print("Val accuracy (Left): {:.2f}".format(100 * acc_left),
95 | flush=True)
96 |
97 | clf_right = LogisticRegression(solver="lbfgs",
98 | max_iter=1000,
99 | penalty="l2",
100 | C=c_right).fit(
101 | fewshot_train_feature,
102 | fewshot_train_label)
103 | pred_right = clf_right.predict(fewshot_val_feature)
104 | acc_right = sum(
105 | pred_right == fewshot_val_label) / len(fewshot_val_label)
106 | print("Val accuracy (Right): {:.2f}".format(100 * acc_right),
107 | flush=True)
108 |
109 | # find maximum and update ranges
110 | if acc_left < acc_right:
111 | c_final = c_right
112 | clf_final = clf_right
113 | # range for the next step
114 | c_left = 0.5 * (np.log10(c_right) + np.log10(c_left))
115 | c_right = np.log10(c_right)
116 | else:
117 | c_final = c_left
118 | clf_final = clf_left
119 | # range for the next step
120 | c_right = 0.5 * (np.log10(c_right) + np.log10(c_left))
121 | c_left = np.log10(c_left)
122 |
123 | pred = clf_final.predict(test_feature)
124 | test_acc = 100 * sum(pred == test_label) / len(pred)
125 | print("Test Accuracy: {:.2f}".format(test_acc), flush=True)
126 | test_acc_step_list[seed - 1, step] = test_acc
127 |
128 | saveline = "{}, seed {}, {} shot, weight {}, test_acc {:.2f}\n".format(
129 | dataset, seed, num_shot, c_final, test_acc)
130 | with open(
131 | "./report/{}_s{}r{}_details.txt".format(
132 | 'clip_feat', args.num_step, args.num_run),
133 | "a+",
134 | ) as writer:
135 | writer.write(saveline)
136 | return (
137 | np.power(10, c_left),
138 | np.power(10, c_right),
139 | seed,
140 | step,
141 | test_acc_step_list,
142 | )
143 |
144 | for step in range(args.num_step):
145 | print(
146 | f"{dataset}, {num_shot} Shot, Round {step}: {c_left}/{c_right}",
147 | flush=True,
148 | )
149 | c_left, c_right, seed, step, test_acc_step_list = binary_search(
150 | c_left, c_right, seed, step, test_acc_step_list)
151 | # save results of last step
152 | test_acc_list = test_acc_step_list[:, -1]
153 | acc_mean = np.mean(test_acc_list)
154 | acc_std = np.std(test_acc_list)
155 | save_line = "{}, {} Shot, Test acc stat: {:.2f} ({:.2f})\n".format(
156 | dataset, num_shot, acc_mean, acc_std)
157 | print(save_line, flush=True)
158 | with open(
159 | "./report/{}_s{}r{}.txt".format('clip_feat', args.num_step,
160 | args.num_run),
161 | "a+",
162 | ) as writer:
163 | writer.write(save_line)
164 |
--------------------------------------------------------------------------------
/lpclip/linear_probe.sh:
--------------------------------------------------------------------------------
1 | feature_dir=/data1/CoOpData/clip_feat/
2 | # ImageNet OxfordPets OxfordFlowers StanfordCars Food101 Caltech101
3 | for DATASET in ImageNet
4 | do
5 | python linear_probe.py \
6 | --dataset ${DATASET} \
7 | --feature_dir ${feature_dir} \
8 | --num_step 8 \
9 | --num_run 3
10 | done
11 |
--------------------------------------------------------------------------------
/lpclip/linear_probe_transfer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from sklearn.linear_model import LogisticRegression
4 | import argparse
5 |
6 | parser = argparse.ArgumentParser()
7 | # parser.add_argument("--train_dataset",
8 | # type=str,
9 | # default="",
10 | # help="path to train dataset")
11 | # parser.add_argument("--test_dataset",
12 | # type=str,
13 | # default="",
14 | # help="path to test dataset")
15 | parser.add_argument("--num_step", type=int, default=8, help="number of steps")
16 | parser.add_argument("--num_run", type=int, default=10, help="number of runs")
17 | parser.add_argument("--feature_dir",
18 | type=str,
19 | default="/data1/CoOpData/clip_feat/",
20 | help="feature dir path")
21 | args = parser.parse_args()
22 |
23 | train_dataset = 'ImageNet'
24 | train_dataset_path = os.path.join(f"{args.feature_dir}", train_dataset)
25 | test_datasets = ['ImageNetV2', 'ImageNetSketch', 'ImageNetR', 'ImageNetA']
26 | test_dataset_paths = [
27 | os.path.join(f"{args.feature_dir}", test_dataset)
28 | for test_dataset in test_datasets
29 | ]
30 |
31 | train_file = np.load(os.path.join(train_dataset_path, "train.npz"))
32 | train_feature, train_label = train_file["feature_list"], train_file[
33 | "label_list"]
34 | val_file = np.load(os.path.join(train_dataset_path, "val.npz"))
35 | val_feature, val_label = val_file["feature_list"], val_file["label_list"]
36 |
37 | test_files = [
38 | np.load(os.path.join(test_dataset_path, "test.npz"))
39 | for test_dataset_path in test_dataset_paths
40 | ]
41 | test_features, test_labels = [
42 | test_file["feature_list"] for test_file in test_files
43 | ], [test_file["label_list"] for test_file in test_files]
44 |
45 | os.makedirs("report", exist_ok=True)
46 |
47 | val_shot_list = {1: 1, 2: 2, 4: 4, 8: 4, 16: 4}
48 |
49 | # for num_shot in [1, 2, 4, 8, 16]:
50 | for num_shot in [16]:
51 | test_acc_step_list = np.zeros(
52 | [len(test_datasets), args.num_run, args.num_step])
53 | for seed in range(1, args.num_run + 1):
54 | np.random.seed(seed)
55 | print(
56 | f"-- Seed: {seed} --------------------------------------------------------------"
57 | )
58 | # Sampling
59 | all_label_list = np.unique(train_label)
60 | selected_idx_list = []
61 | for label in all_label_list:
62 | label_collection = np.where(train_label == label)[0]
63 | selected_idx = np.random.choice(label_collection,
64 | size=num_shot,
65 | replace=False)
66 | selected_idx_list.extend(selected_idx)
67 |
68 | fewshot_train_feature = train_feature[selected_idx_list]
69 | fewshot_train_label = train_label[selected_idx_list]
70 |
71 | val_num_shot = val_shot_list[num_shot]
72 | val_selected_idx_list = []
73 | for label in all_label_list:
74 | label_collection = np.where(val_label == label)[0]
75 | selected_idx = np.random.choice(label_collection,
76 | size=val_num_shot,
77 | replace=False)
78 | val_selected_idx_list.extend(selected_idx)
79 |
80 | fewshot_val_feature = val_feature[val_selected_idx_list]
81 | fewshot_val_label = val_label[val_selected_idx_list]
82 |
83 | # search initialization
84 | search_list = [1e6, 1e4, 1e2, 1, 1e-2, 1e-4, 1e-6]
85 | acc_list = []
86 | for c_weight in search_list:
87 | clf = LogisticRegression(solver="lbfgs",
88 | max_iter=1000,
89 | penalty="l2",
90 | C=c_weight).fit(fewshot_train_feature,
91 | fewshot_train_label)
92 | pred = clf.predict(fewshot_val_feature)
93 | acc_val = sum(pred == fewshot_val_label) / len(fewshot_val_label)
94 | acc_list.append(acc_val)
95 |
96 | print(acc_list, flush=True)
97 |
98 | # binary search
99 | peak_idx = np.argmax(acc_list)
100 | c_peak = search_list[peak_idx]
101 | c_left, c_right = 1e-1 * c_peak, 1e1 * c_peak
102 |
103 | def binary_search(c_left, c_right, seed, step, test_acc_step_list):
104 | clf_left = LogisticRegression(solver="lbfgs",
105 | max_iter=1000,
106 | penalty="l2",
107 | C=c_left).fit(
108 | fewshot_train_feature,
109 | fewshot_train_label)
110 | pred_left = clf_left.predict(fewshot_val_feature)
111 | acc_left = sum(
112 | pred_left == fewshot_val_label) / len(fewshot_val_label)
113 | print("Val accuracy (Left): {:.2f}".format(100 * acc_left),
114 | flush=True)
115 |
116 | clf_right = LogisticRegression(solver="lbfgs",
117 | max_iter=1000,
118 | penalty="l2",
119 | C=c_right).fit(
120 | fewshot_train_feature,
121 | fewshot_train_label)
122 | pred_right = clf_right.predict(fewshot_val_feature)
123 | acc_right = sum(
124 | pred_right == fewshot_val_label) / len(fewshot_val_label)
125 | print("Val accuracy (Right): {:.2f}".format(100 * acc_right),
126 | flush=True)
127 |
128 | # find maximum and update ranges
129 | if acc_left < acc_right:
130 | c_final = c_right
131 | clf_final = clf_right
132 | # range for the next step
133 | c_left = 0.5 * (np.log10(c_right) + np.log10(c_left))
134 | c_right = np.log10(c_right)
135 | else:
136 | c_final = c_left
137 | clf_final = clf_left
138 | # range for the next step
139 | c_right = 0.5 * (np.log10(c_right) + np.log10(c_left))
140 | c_left = np.log10(c_left)
141 |
142 | for i, (test_feature, test_label, test_dataset) in enumerate(
143 | zip(test_features, test_labels, test_datasets)):
144 | pred = clf_final.predict(test_feature)
145 | test_acc = 100 * sum(pred == test_label) / len(pred)
146 | print("Test Accuracy: {:.2f}".format(test_acc), flush=True)
147 | test_acc_step_list[i, seed - 1, step] = test_acc
148 |
149 | saveline = "{}, {}, seed {}, {} shot, weight {}, test_acc {:.2f}\n".format(
150 | train_dataset, test_dataset, seed, num_shot, c_final,
151 | test_acc)
152 | with open(
153 | "./report/{}_s{}r{}_details.txt".format(
154 | 'clip_feat', args.num_step, args.num_run),
155 | "a+",
156 | ) as writer:
157 | writer.write(saveline)
158 | return (
159 | np.power(10, c_left),
160 | np.power(10, c_right),
161 | seed,
162 | step,
163 | test_acc_step_list,
164 | )
165 |
166 | for step in range(args.num_step):
167 | print(
168 | f"{train_dataset}, {num_shot} Shot, Round {step}: {c_left}/{c_right}",
169 | flush=True,
170 | )
171 | c_left, c_right, seed, step, test_acc_step_list = binary_search(
172 | c_left, c_right, seed, step, test_acc_step_list)
173 | # save results of last step
174 | test_acc_list = test_acc_step_list[:, :, -1]
175 | acc_mean = np.mean(test_acc_list, dim=-1)
176 | acc_std = np.std(test_acc_list, dim=-1)
177 | for i in range(len(test_datasets)):
178 | save_line = "{}, {}, {} Shot, Test acc stat: {:.2f} ({:.2f})\n".format(
179 | train_dataset, test_datasets[i], num_shot, acc_mean[i], acc_std[i])
180 | print(save_line, flush=True)
181 | with open(
182 | "./report/{}_s{}r{}.txt".format('clip_feat', args.num_step,
183 | args.num_run),
184 | "a+",
185 | ) as writer:
186 | writer.write(save_line)
187 |
--------------------------------------------------------------------------------
/parse_test_res.py:
--------------------------------------------------------------------------------
1 | """
2 | Goal
3 | ---
4 | 1. Read test results from log.txt files
5 | 2. Compute mean and std across different folders (seeds)
6 |
7 | Usage
8 | ---
9 | Assume the output files are saved under output/my_experiment,
10 | which contains results of different seeds, e.g.,
11 |
12 | my_experiment/
13 | seed1/
14 | log.txt
15 | seed2/
16 | log.txt
17 | seed3/
18 | log.txt
19 |
20 | Run the following command from the root directory:
21 |
22 | $ python tools/parse_test_res.py output/my_experiment
23 |
24 | Add --ci95 to the argument if you wanna get 95% confidence
25 | interval instead of standard deviation:
26 |
27 | $ python tools/parse_test_res.py output/my_experiment --ci95
28 |
29 | If my_experiment/ has the following structure,
30 |
31 | my_experiment/
32 | exp-1/
33 | seed1/
34 | log.txt
35 | ...
36 | seed2/
37 | log.txt
38 | ...
39 | seed3/
40 | log.txt
41 | ...
42 | exp-2/
43 | ...
44 | exp-3/
45 | ...
46 |
47 | Run
48 |
49 | $ python tools/parse_test_res.py output/my_experiment --multi-exp
50 | """
51 | import re
52 | import numpy as np
53 | import os.path as osp
54 | import argparse
55 | from collections import OrderedDict, defaultdict
56 |
57 | from dassl.utils import check_isfile, listdir_nohidden
58 |
59 |
60 | def compute_ci95(res):
61 | return 1.96 * np.std(res) / np.sqrt(len(res))
62 |
63 |
64 | def parse_function(*metrics, directory="", args=None, end_signal=None):
65 | #print(f"Parsing files in {directory}")
66 | subdirs = listdir_nohidden(directory, sort=True)
67 |
68 | outputs = []
69 |
70 | for subdir in subdirs:
71 | fpath = osp.join(directory, subdir, "log.txt")
72 | assert check_isfile(fpath)
73 | good_to_go = False
74 | output = OrderedDict()
75 |
76 | with open(fpath, "r") as f:
77 | lines = f.readlines()
78 |
79 | for line in lines:
80 | line = line.strip()
81 |
82 | if line == end_signal:
83 | good_to_go = True
84 |
85 | for metric in metrics:
86 | match = metric["regex"].search(line)
87 | if match and good_to_go:
88 | if "file" not in output:
89 | output["file"] = fpath
90 | num = float(match.group(1))
91 | name = metric["name"]
92 | output[name] = num
93 |
94 | if output:
95 | outputs.append(output)
96 |
97 | assert len(outputs) > 0, f"Nothing found in {directory}"
98 |
99 | metrics_results = defaultdict(list)
100 |
101 | for output in outputs:
102 | msg = ""
103 | for key, value in output.items():
104 | if isinstance(value, float):
105 | msg += f"{key}: {value:.2f}%. "
106 | else:
107 | msg += f"{key}: {value}. "
108 | if key != "file":
109 | metrics_results[key].append(value)
110 | #print(msg)
111 |
112 | output_results = OrderedDict()
113 |
114 | #print("===")
115 | #print(f"Summary of directory: {directory}")
116 | dir_sets = directory.split('/')
117 | #print(dir_sets)
118 | for key, values in metrics_results.items():
119 | avg = np.mean(values)
120 | std = compute_ci95(values) if args.ci95 else np.std(values)
121 | print(f"* {dir_sets[2]} {dir_sets[3]} {key}: {avg:.2f}% +- {std:.2f}%")
122 | output_results[key] = avg
123 | #print("===")
124 |
125 | return output_results
126 |
127 |
128 | def main(args, end_signal):
129 | metric = {
130 | "name": args.keyword,
131 | "regex": re.compile(fr"\* {args.keyword}: ([\.\deE+-]+)%"),
132 | }
133 |
134 | if args.multi_exp:
135 | final_results = defaultdict(list)
136 |
137 | for directory in listdir_nohidden(args.directory, sort=True):
138 | directory = osp.join(args.directory, directory)
139 | results = parse_function(metric,
140 | directory=directory,
141 | args=args,
142 | end_signal=end_signal)
143 |
144 | for key, value in results.items():
145 | final_results[key].append(value)
146 |
147 | print("Average performance")
148 | for key, values in final_results.items():
149 | avg = np.mean(values)
150 | print(f"* {key}: {avg:.2f}%")
151 |
152 | else:
153 | parse_function(metric,
154 | directory=args.directory,
155 | args=args,
156 | end_signal=end_signal)
157 |
158 |
159 | if __name__ == "__main__":
160 | parser = argparse.ArgumentParser()
161 | parser.add_argument("directory", type=str, help="path to directory")
162 | parser.add_argument("--ci95",
163 | action="store_true",
164 | help=r"compute 95\% confidence interval")
165 | parser.add_argument("--test-log",
166 | action="store_true",
167 | help="parse test-only logs")
168 | parser.add_argument("--multi-exp",
169 | action="store_true",
170 | help="parse multiple experiments")
171 | parser.add_argument("--keyword",
172 | default="accuracy",
173 | type=str,
174 | help="which keyword to extract")
175 | args = parser.parse_args()
176 |
177 | end_signal = "Finished training"
178 | if args.test_log:
179 | end_signal = "=> result"
180 |
181 | main(args, end_signal)
182 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ftfy
2 | regex
3 | tqdm
4 |
--------------------------------------------------------------------------------
/scripts/base2new_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ..
4 |
5 | # custom config
6 | DATA=XXXXX
7 | TRAINER=TCP
8 | WEIGHT=1.0
9 |
10 | CFG=vit_b16_ep100_ctxv1
11 | CTP=end # class token position (end or middle)
12 | NCTX=4 # number of context tokens
13 | SHOTS=16 # number of shots (1, 2, 4, 8, 16)
14 | CSC=False # class-specific context (False or True)
15 | FOLDER=output_1108
16 |
17 | for DATASET in caltech101 dtd eurosat fgvc_aircraft food101 oxford_flowers oxford_pets stanford_cars ucf101
18 | do
19 | for SEED in 1 2 3
20 | do
21 | DIR=${FOLDER}_${NCTX}/base2new/train_base/${DATASET}/shots_${SHOTS}_${WEIGHT}/${TRAINER}/${CFG}/seed${SEED}
22 | if [ -d "$DIR" ]; then
23 | echo "Results are available in ${DIR}. Skip this job"
24 | else
25 | echo "Run this job and save the output to ${DIR}"
26 | CUDA_VISIBLE_DEVICES=0 python train.py \
27 | --root ${DATA} \
28 | --seed ${SEED} \
29 | --trainer ${TRAINER} \
30 | --dataset-config-file configs/datasets/${DATASET}.yaml \
31 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
32 | --output-dir ${DIR} \
33 | TRAINER.COOP.N_CTX ${NCTX} \
34 | TRAINER.COOP.CSC ${CSC} \
35 | TRAINER.COOP.W ${WEIGHT} \
36 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \
37 | DATASET.NUM_SHOTS ${SHOTS} \
38 | DATASET.SUBSAMPLE_CLASSES base
39 | fi
40 | done
41 |
42 |
43 | LOADEP=50
44 | SUB=new
45 | for SEED in 1 2 3
46 | do
47 | COMMON_DIR=${DATASET}/shots_${SHOTS}_${WEIGHT}/${TRAINER}/${CFG}/seed${SEED}
48 | MODEL_DIR=${FOLDER}_${NCTX}/base2new/train_base/${COMMON_DIR}
49 | DIR=${FOLDER}_${NCTX}_eval/base2new/test_${SUB}/${COMMON_DIR}
50 |
51 | if [ -d "$DIR" ]; then
52 | echo "Results are available in ${DIR}. Skip this job"
53 | else
54 | echo "Run this job and save the output to ${DIR}"
55 | CUDA_VISIBLE_DEVICES=0 python train.py \
56 | --root ${DATA} \
57 | --seed ${SEED} \
58 | --trainer ${TRAINER} \
59 | --dataset-config-file configs/datasets/${DATASET}.yaml \
60 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
61 | --output-dir ${DIR} \
62 | --model-dir ${MODEL_DIR} \
63 | --load-epoch ${LOADEP} \
64 | --eval-only \
65 | TRAINER.COOP.N_CTX ${NCTX} \
66 | TRAINER.COOP.CSC ${CSC} \
67 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \
68 | DATASET.NUM_SHOTS ${SHOTS} \
69 | DATASET.SUBSAMPLE_CLASSES ${SUB}
70 | fi
71 | done
72 | done
73 |
74 |
75 | for DATASET in sun397 imagenet
76 | do
77 | for SEED in 1 2 3
78 | do
79 | DIR=${FOLDER}_${NCTX}/base2new/train_base/${DATASET}/shots_${SHOTS}_${WEIGHT}/${TRAINER}/${CFG}/seed${SEED}
80 | if [ -d "$DIR" ]; then
81 | echo "Results are available in ${DIR}. Skip this job"
82 | else
83 | echo "Run this job and save the output to ${DIR}"
84 | CUDA_VISIBLE_DEVICES=0 python train.py \
85 | --root ${DATA} \
86 | --seed ${SEED} \
87 | --trainer ${TRAINER} \
88 | --dataset-config-file configs/datasets/${DATASET}.yaml \
89 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
90 | --output-dir ${DIR} \
91 | TRAINER.COOP.N_CTX ${NCTX} \
92 | TRAINER.COOP.CSC ${CSC} \
93 | TRAINER.COOP.W ${WEIGHT} \
94 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \
95 | DATASET.NUM_SHOTS ${SHOTS} \
96 | DATASET.SUBSAMPLE_CLASSES base
97 | fi
98 | done
99 |
100 |
101 | LOADEP=25
102 | SUB=new
103 | for SEED in 1 2 3
104 | do
105 | COMMON_DIR=${DATASET}/shots_${SHOTS}_${WEIGHT}/${TRAINER}/${CFG}/seed${SEED}
106 | MODEL_DIR=${FOLDER}_${NCTX}/base2new/train_base/${COMMON_DIR}
107 | DIR=${FOLDER}_${NCTX}/base2new/test_${SUB}/${COMMON_DIR}
108 |
109 | if [ -d "$DIR" ]; then
110 | echo "Results are available in ${DIR}. Skip this job"
111 | else
112 | echo "Run this job and save the output to ${DIR}"
113 | CUDA_VISIBLE_DEVICES=0 python train.py \
114 | --root ${DATA} \
115 | --seed ${SEED} \
116 | --trainer ${TRAINER} \
117 | --dataset-config-file configs/datasets/${DATASET}.yaml \
118 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
119 | --output-dir ${DIR} \
120 | --model-dir ${MODEL_DIR} \
121 | --load-epoch ${LOADEP} \
122 | --eval-only \
123 | TRAINER.COOP.N_CTX ${NCTX} \
124 | TRAINER.COOP.CSC ${CSC} \
125 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \
126 | DATASET.NUM_SHOTS ${SHOTS} \
127 | DATASET.SUBSAMPLE_CLASSES ${SUB}
128 | fi
129 | done
130 | done
131 |
132 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import time
4 | import os
5 |
6 | from dassl.utils import setup_logger, set_random_seed, collect_env_info
7 | from dassl.config import get_cfg_default
8 | from dassl.engine import build_trainer
9 |
10 | # custom
11 | import datasets.oxford_pets
12 | import datasets.oxford_flowers
13 | import datasets.fgvc_aircraft
14 | import datasets.dtd
15 | import datasets.eurosat
16 | import datasets.stanford_cars
17 | import datasets.food101
18 | import datasets.sun397
19 | import datasets.caltech101
20 | import datasets.ucf101
21 | import datasets.imagenet
22 |
23 | import datasets.imagenet_sketch
24 | import datasets.imagenetv2
25 | import datasets.imagenet_a
26 | import datasets.imagenet_r
27 |
28 | import trainers.coop
29 | import trainers.cocoop
30 | import trainers.zsclip
31 | import trainers.prograd
32 | import trainers.tcp
33 | def print_args(args, cfg):
34 | print("***************")
35 | print("** Arguments **")
36 | print("***************")
37 | optkeys = list(args.__dict__.keys())
38 | optkeys.sort()
39 | for key in optkeys:
40 | print("{}: {}".format(key, args.__dict__[key]))
41 | print("************")
42 | print("** Config **")
43 | print("************")
44 | print(cfg)
45 |
46 |
47 | def reset_cfg(cfg, args):
48 | if args.root:
49 | cfg.DATASET.ROOT = args.root
50 |
51 | if args.output_dir:
52 | cfg.OUTPUT_DIR = args.output_dir
53 |
54 | if args.resume:
55 | cfg.RESUME = args.resume
56 |
57 | if args.seed:
58 | cfg.SEED = args.seed
59 |
60 | if args.source_domains:
61 | cfg.DATASET.SOURCE_DOMAINS = args.source_domains
62 |
63 | if args.target_domains:
64 | cfg.DATASET.TARGET_DOMAINS = args.target_domains
65 |
66 | if args.transforms:
67 | cfg.INPUT.TRANSFORMS = args.transforms
68 |
69 | if args.trainer:
70 | cfg.TRAINER.NAME = args.trainer
71 |
72 | if args.backbone:
73 | cfg.MODEL.BACKBONE.NAME = args.backbone
74 |
75 | if args.head:
76 | cfg.MODEL.HEAD.NAME = args.head
77 |
78 |
79 | def extend_cfg(cfg):
80 | """
81 | Add new config variables.
82 |
83 | E.g.
84 | from yacs.config import CfgNode as CN
85 | cfg.TRAINER.MY_MODEL = CN()
86 | cfg.TRAINER.MY_MODEL.PARAM_A = 1.
87 | cfg.TRAINER.MY_MODEL.PARAM_B = 0.5
88 | cfg.TRAINER.MY_MODEL.PARAM_C = False
89 | """
90 | from yacs.config import CfgNode as CN
91 |
92 | cfg.TRAINER.COOP = CN()
93 | cfg.TRAINER.COOP.ALPHA = 1.0
94 | cfg.TRAINER.COOP.N_CTX = 16 # number of context vectors
95 | cfg.TRAINER.COOP.CSC = False # class-specific context
96 | cfg.TRAINER.COOP.CTX_INIT = False # initialization words
97 | cfg.TRAINER.COOP.W = 1.0
98 | cfg.TRAINER.COOP.PREC = "fp16" # fp16, fp32, amp
99 | cfg.TRAINER.COOP.CLASS_TOKEN_POSITION = "end" # 'middle' or 'end' or 'front'
100 |
101 | cfg.TRAINER.COCOOP = CN()
102 | cfg.TRAINER.COCOOP.N_CTX = 16 # number of context vectors
103 | cfg.TRAINER.COCOOP.CTX_INIT = False # initialization words
104 | cfg.TRAINER.COCOOP.PREC = "fp16" # fp16, fp32, amp
105 |
106 | cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
107 | """
108 | Add new config
109 | """
110 | cfg.LOSS = CN()
111 | cfg.LOSS.GM = False
112 | cfg.LOSS.NAME = ""
113 | cfg.LOSS.ALPHA = 0.
114 | cfg.LOSS.T = 1.
115 | cfg.LOSS.LAMBDA = 1.
116 |
117 |
118 | def setup_cfg(args):
119 | cfg = get_cfg_default()
120 | extend_cfg(cfg)
121 |
122 | # 1. From the dataset config file
123 | if args.dataset_config_file:
124 | cfg.merge_from_file(args.dataset_config_file)
125 |
126 | # 2. From the method config file
127 | if args.config_file:
128 | cfg.merge_from_file(args.config_file)
129 |
130 | # 3. From input arguments
131 | reset_cfg(cfg, args)
132 |
133 | # 4. From optional input arguments
134 | cfg.merge_from_list(args.opts)
135 |
136 | if cfg.DATASET.NAME in ["ImageNet",'SUN397']:
137 | cfg.OPTIM.MAX_EPOCH=25
138 | cfg.freeze()
139 |
140 | return cfg
141 |
142 |
143 | def main(args):
144 | cfg = setup_cfg(args)
145 | if cfg.SEED >= 0:
146 | print("Setting fixed seed: {}".format(cfg.SEED))
147 | set_random_seed(cfg.SEED)
148 | setup_logger(cfg.OUTPUT_DIR)
149 |
150 | if torch.cuda.is_available() and cfg.USE_CUDA:
151 | torch.backends.cudnn.benchmark = True
152 |
153 | print_args(args, cfg)
154 | print("Collecting env info ...")
155 | print("** System info **\n{}\n".format(collect_env_info()))
156 |
157 | trainer = build_trainer(cfg)
158 |
159 | if args.eval_only:
160 | trainer.load_model(args.model_dir, epoch=args.load_epoch)
161 | trainer.test()
162 | return
163 |
164 | if not args.no_train:
165 | trainer.train()
166 |
167 |
168 | if __name__ == "__main__":
169 | parser = argparse.ArgumentParser()
170 | parser.add_argument("--root", type=str, default="", help="path to dataset")
171 | parser.add_argument("--output-dir",
172 | type=str,
173 | default="",
174 | help="output directory")
175 | parser.add_argument(
176 | "--resume",
177 | type=str,
178 | default="",
179 | help="checkpoint directory (from which the training resumes)",
180 | )
181 | parser.add_argument("--seed",
182 | type=int,
183 | default=-1,
184 | help="only positive value enables a fixed seed")
185 | parser.add_argument("--source-domains",
186 | type=str,
187 | nargs="+",
188 | help="source domains for DA/DG")
189 | parser.add_argument("--target-domains",
190 | type=str,
191 | nargs="+",
192 | help="target domains for DA/DG")
193 | parser.add_argument("--transforms",
194 | type=str,
195 | nargs="+",
196 | help="data augmentation methods")
197 | parser.add_argument("--config-file",
198 | type=str,
199 | default="",
200 | help="path to config file")
201 | parser.add_argument(
202 | "--dataset-config-file",
203 | type=str,
204 | default="",
205 | help="path to config file for dataset setup",
206 | )
207 | parser.add_argument("--trainer",
208 | type=str,
209 | default="",
210 | help="name of trainer")
211 | parser.add_argument("--backbone",
212 | type=str,
213 | default="",
214 | help="name of CNN backbone")
215 | parser.add_argument("--head", type=str, default="", help="name of head")
216 | parser.add_argument("--eval-only",
217 | action="store_true",
218 | help="evaluation only")
219 | parser.add_argument(
220 | "--model-dir",
221 | type=str,
222 | default="",
223 | help="load model from this directory for eval-only mode",
224 | )
225 | parser.add_argument("--load-epoch",
226 | type=int,
227 | help="load model weights at this epoch for evaluation")
228 | parser.add_argument("--no-train",
229 | action="store_true",
230 | help="do not call trainer.train()")
231 | parser.add_argument(
232 | "opts",
233 | default=None,
234 | nargs=argparse.REMAINDER,
235 | help="modify config options using the command-line",
236 | )
237 | args = parser.parse_args()
238 | main(args)
239 |
--------------------------------------------------------------------------------
/trainers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__init__.py
--------------------------------------------------------------------------------
/trainers/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/cocoop.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/cocoop.cpython-37.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/cocoop.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/cocoop.cpython-38.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/coop.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/coop.cpython-37.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/coop.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/coop.cpython-38.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/imagenet_templates.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/imagenet_templates.cpython-37.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/imagenet_templates.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/imagenet_templates.cpython-38.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/kgcoop.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/kgcoop.cpython-37.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/kgcoop.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/kgcoop.cpython-38.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/prograd.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/prograd.cpython-37.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/prograd.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/prograd.cpython-38.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/tcp.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/tcp.cpython-38.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/zsclip.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/zsclip.cpython-37.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/zsclip.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/__pycache__/zsclip.cpython-38.pyc
--------------------------------------------------------------------------------
/trainers/clip_text/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import *
2 |
--------------------------------------------------------------------------------
/trainers/clip_text/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/clip_text/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/trainers/clip_text/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/clip_text/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/trainers/clip_text/__pycache__/clip.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/clip_text/__pycache__/clip.cpython-37.pyc
--------------------------------------------------------------------------------
/trainers/clip_text/__pycache__/clip.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/clip_text/__pycache__/clip.cpython-38.pyc
--------------------------------------------------------------------------------
/trainers/clip_text/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/clip_text/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/trainers/clip_text/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/clip_text/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/trainers/clip_text/__pycache__/simple_tokenizer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/clip_text/__pycache__/simple_tokenizer.cpython-37.pyc
--------------------------------------------------------------------------------
/trainers/clip_text/__pycache__/simple_tokenizer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/clip_text/__pycache__/simple_tokenizer.cpython-38.pyc
--------------------------------------------------------------------------------
/trainers/clip_text/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htyao89/Textual-based_Class-aware_prompt_tuning/bd97bec6bd4901baaf318e9e8f073f618d12066e/trainers/clip_text/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/trainers/clip_text/clip.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 | from typing import Union, List
6 |
7 | import torch
8 | from PIL import Image
9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
10 | from tqdm import tqdm
11 |
12 | from .model import build_model
13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
14 |
15 | try:
16 | from torchvision.transforms import InterpolationMode
17 | BICUBIC = InterpolationMode.BICUBIC
18 | except ImportError:
19 | BICUBIC = Image.BICUBIC
20 |
21 | if torch.__version__.split(".") < ["1", "7", "1"]:
22 | warnings.warn("PyTorch version 1.7.1 or higher is recommended")
23 |
24 | __all__ = ["available_models", "load", "tokenize"]
25 | _tokenizer = _Tokenizer()
26 |
27 | _MODELS = {
28 | "RN50":
29 | "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
30 | "RN101":
31 | "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
32 | "RN50x4":
33 | "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34 | "RN50x16":
35 | "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
36 | "ViT-B/32":
37 | "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
38 | "ViT-B/16":
39 | "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
40 | }
41 |
42 |
43 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
44 | os.makedirs(root, exist_ok=True)
45 | filename = os.path.basename(url)
46 |
47 | expected_sha256 = url.split("/")[-2]
48 | download_target = os.path.join(root, filename)
49 |
50 | if os.path.exists(download_target) and not os.path.isfile(download_target):
51 | raise RuntimeError(
52 | f"{download_target} exists and is not a regular file")
53 |
54 | if os.path.isfile(download_target):
55 | if hashlib.sha256(open(download_target,
56 | "rb").read()).hexdigest() == expected_sha256:
57 | return download_target
58 | else:
59 | warnings.warn(
60 | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
61 | )
62 |
63 | with urllib.request.urlopen(url) as source, open(download_target,
64 | "wb") as output:
65 | with tqdm(total=int(source.info().get("Content-Length")),
66 | ncols=80,
67 | unit='iB',
68 | unit_scale=True) as loop:
69 | while True:
70 | buffer = source.read(8192)
71 | if not buffer:
72 | break
73 |
74 | output.write(buffer)
75 | loop.update(len(buffer))
76 |
77 | if hashlib.sha256(open(download_target,
78 | "rb").read()).hexdigest() != expected_sha256:
79 | raise RuntimeError(
80 | f"Model has been downloaded but the SHA256 checksum does not not match"
81 | )
82 |
83 | return download_target
84 |
85 |
86 | def _transform(n_px):
87 | return Compose([
88 | Resize(n_px, interpolation=BICUBIC),
89 | CenterCrop(n_px),
90 | lambda image: image.convert("RGB"),
91 | ToTensor(),
92 | Normalize((0.48145466, 0.4578275, 0.40821073),
93 | (0.26862954, 0.26130258, 0.27577711)),
94 | ])
95 |
96 |
97 | def available_models() -> List[str]:
98 | """Returns the names of available CLIP models"""
99 | return list(_MODELS.keys())
100 |
101 |
102 | def load(name: str,
103 | device: Union[str, torch.device] = "cuda"
104 | if torch.cuda.is_available() else "cpu",
105 | jit=False):
106 | """Load a CLIP model
107 |
108 | Parameters
109 | ----------
110 | name : str
111 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
112 |
113 | device : Union[str, torch.device]
114 | The device to put the loaded model
115 |
116 | jit : bool
117 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
118 |
119 | Returns
120 | -------
121 | model : torch.nn.Module
122 | The CLIP model
123 |
124 | preprocess : Callable[[PIL.Image], torch.Tensor]
125 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
126 | """
127 | if name in _MODELS:
128 | model_path = _download(_MODELS[name])
129 | elif os.path.isfile(name):
130 | model_path = name
131 | else:
132 | raise RuntimeError(
133 | f"Model {name} not found; available models = {available_models()}")
134 |
135 | try:
136 | # loading JIT archive
137 | model = torch.jit.load(model_path,
138 | map_location=device if jit else "cpu").eval()
139 | state_dict = None
140 | except RuntimeError:
141 | # loading saved state dict
142 | if jit:
143 | warnings.warn(
144 | f"File {model_path} is not a JIT archive. Loading as a state dict instead"
145 | )
146 | jit = False
147 | state_dict = torch.load(model_path, map_location="cpu")
148 |
149 | if not jit:
150 | model = build_model(state_dict or model.state_dict()).to(device)
151 | if str(device) == "cpu":
152 | model.float()
153 | return model, _transform(model.visual.input_resolution)
154 |
155 | # patch the device names
156 | device_holder = torch.jit.trace(
157 | lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
158 | device_node = [
159 | n for n in device_holder.graph.findAllNodes("prim::Constant")
160 | if "Device" in repr(n)
161 | ][-1]
162 |
163 | def patch_device(module):
164 | try:
165 | graphs = [module.graph] if hasattr(module, "graph") else []
166 | except RuntimeError:
167 | graphs = []
168 |
169 | if hasattr(module, "forward1"):
170 | graphs.append(module.forward1.graph)
171 |
172 | for graph in graphs:
173 | for node in graph.findAllNodes("prim::Constant"):
174 | if "value" in node.attributeNames() and str(
175 | node["value"]).startswith("cuda"):
176 | node.copyAttributes(device_node)
177 |
178 | model.apply(patch_device)
179 | patch_device(model.encode_image)
180 | patch_device(model.encode_text)
181 |
182 | # patch dtype to float32 on CPU
183 | if str(device) == "cpu":
184 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(),
185 | example_inputs=[])
186 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
187 | float_node = float_input.node()
188 |
189 | def patch_float(module):
190 | try:
191 | graphs = [module.graph] if hasattr(module, "graph") else []
192 | except RuntimeError:
193 | graphs = []
194 |
195 | if hasattr(module, "forward1"):
196 | graphs.append(module.forward1.graph)
197 |
198 | for graph in graphs:
199 | for node in graph.findAllNodes("aten::to"):
200 | inputs = list(node.inputs())
201 | for i in [
202 | 1, 2
203 | ]: # dtype can be the second or third argument to aten::to()
204 | if inputs[i].node()["value"] == 5:
205 | inputs[i].node().copyAttributes(float_node)
206 |
207 | model.apply(patch_float)
208 | patch_float(model.encode_image)
209 | patch_float(model.encode_text)
210 |
211 | model.float()
212 |
213 | return model, _transform(model.input_resolution.item())
214 |
215 |
216 | def tokenize(texts: Union[str, List[str]],
217 | context_length: int = 77,
218 | truncate: bool = False) -> torch.LongTensor:
219 | """
220 | Returns the tokenized representation of given input string(s)
221 |
222 | Parameters
223 | ----------
224 | texts : Union[str, List[str]]
225 | An input string or a list of input strings to tokenize
226 |
227 | context_length : int
228 | The context length to use; all CLIP models use 77 as the context length
229 |
230 | truncate: bool
231 | Whether to truncate the text in case its encoding is longer than the context length
232 |
233 | Returns
234 | -------
235 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
236 | """
237 | if isinstance(texts, str):
238 | texts = [texts]
239 |
240 | sot_token = _tokenizer.encoder["<|startoftext|>"]
241 | eot_token = _tokenizer.encoder["<|endoftext|>"]
242 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token]
243 | for text in texts]
244 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
245 |
246 | for i, tokens in enumerate(all_tokens):
247 | if len(tokens) > context_length:
248 | if truncate:
249 | tokens = tokens[:context_length]
250 | tokens[-1] = eot_token
251 | else:
252 | raise RuntimeError(
253 | f"Input {texts[i]} is too long for context length {context_length}"
254 | )
255 | result[i, :len(tokens)] = torch.tensor(tokens)
256 |
257 | return result
258 |
--------------------------------------------------------------------------------
/trainers/clip_text/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)),
13 | "bpe_simple_vocab_16e6.txt.gz")
14 |
15 |
16 | @lru_cache()
17 | def bytes_to_unicode():
18 | """
19 | Returns list of utf-8 byte and a corresponding list of unicode strings.
20 | The reversible bpe codes work on unicode strings.
21 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
22 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
23 | This is a signficant percentage of your normal, say, 32K bpe vocab.
24 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
25 | And avoids mapping to whitespace/control characters the bpe code barfs on.
26 | """
27 | bs = list(range(ord("!"),
28 | ord("~") + 1)) + list(range(
29 | ord("¡"),
30 | ord("¬") + 1)) + list(range(ord("®"),
31 | ord("ÿ") + 1))
32 | cs = bs[:]
33 | n = 0
34 | for b in range(2**8):
35 | if b not in bs:
36 | bs.append(b)
37 | cs.append(2**8 + n)
38 | n += 1
39 | cs = [chr(n) for n in cs]
40 | return dict(zip(bs, cs))
41 |
42 |
43 | def get_pairs(word):
44 | """Return set of symbol pairs in a word.
45 | Word is represented as tuple of symbols (symbols being variable-length strings).
46 | """
47 | pairs = set()
48 | prev_char = word[0]
49 | for char in word[1:]:
50 | pairs.add((prev_char, char))
51 | prev_char = char
52 | return pairs
53 |
54 |
55 | def basic_clean(text):
56 | text = ftfy.fix_text(text)
57 | text = html.unescape(html.unescape(text))
58 | return text.strip()
59 |
60 |
61 | def whitespace_clean(text):
62 | text = re.sub(r'\s+', ' ', text)
63 | text = text.strip()
64 | return text
65 |
66 |
67 | class SimpleTokenizer(object):
68 | def __init__(self, bpe_path: str = default_bpe()):
69 | self.byte_encoder = bytes_to_unicode()
70 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
71 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
72 | merges = merges[1:49152 - 256 - 2 + 1]
73 | merges = [tuple(merge.split()) for merge in merges]
74 | vocab = list(bytes_to_unicode().values())
75 | vocab = vocab + [v + '' for v in vocab]
76 | for merge in merges:
77 | vocab.append(''.join(merge))
78 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
79 | self.encoder = dict(zip(vocab, range(len(vocab))))
80 | self.decoder = {v: k for k, v in self.encoder.items()}
81 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
82 | self.cache = {
83 | '<|startoftext|>': '<|startoftext|>',
84 | '<|endoftext|>': '<|endoftext|>'
85 | }
86 | self.pat = re.compile(
87 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
88 | re.IGNORECASE)
89 |
90 | def bpe(self, token):
91 | if token in self.cache:
92 | return self.cache[token]
93 | word = tuple(token[:-1]) + (token[-1] + '', )
94 | pairs = get_pairs(word)
95 |
96 | if not pairs:
97 | return token + ''
98 |
99 | while True:
100 | bigram = min(
101 | pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
102 | if bigram not in self.bpe_ranks:
103 | break
104 | first, second = bigram
105 | new_word = []
106 | i = 0
107 | while i < len(word):
108 | try:
109 | j = word.index(first, i)
110 | new_word.extend(word[i:j])
111 | i = j
112 | except:
113 | new_word.extend(word[i:])
114 | break
115 |
116 | if word[i] == first and i < len(word) - 1 and word[
117 | i + 1] == second:
118 | new_word.append(first + second)
119 | i += 2
120 | else:
121 | new_word.append(word[i])
122 | i += 1
123 | new_word = tuple(new_word)
124 | word = new_word
125 | if len(word) == 1:
126 | break
127 | else:
128 | pairs = get_pairs(word)
129 | word = ' '.join(word)
130 | self.cache[token] = word
131 | return word
132 |
133 | def encode(self, text):
134 | bpe_tokens = []
135 | text = whitespace_clean(basic_clean(text)).lower()
136 | for token in re.findall(self.pat, text):
137 | token = ''.join(self.byte_encoder[b]
138 | for b in token.encode('utf-8'))
139 | bpe_tokens.extend(self.encoder[bpe_token]
140 | for bpe_token in self.bpe(token).split(' '))
141 | return bpe_tokens
142 |
143 | def decode(self, tokens):
144 | text = ''.join([self.decoder[token] for token in tokens])
145 | text = bytearray([self.byte_decoder[c] for c in text
146 | ]).decode('utf-8',
147 | errors="replace").replace('', ' ')
148 | return text
149 |
--------------------------------------------------------------------------------
/trainers/coop.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import functional as F
6 | from torch.cuda.amp import GradScaler, autocast
7 |
8 | from dassl.engine import TRAINER_REGISTRY, TrainerX
9 | from dassl.metrics import compute_accuracy
10 | from dassl.utils import load_pretrained_weights, load_checkpoint
11 | from dassl.optim import build_optimizer, build_lr_scheduler
12 |
13 | from clip import clip
14 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
15 |
16 | _tokenizer = _Tokenizer()
17 |
18 |
19 | def load_clip_to_cpu(cfg):
20 | backbone_name = cfg.MODEL.BACKBONE.NAME
21 | url = clip._MODELS[backbone_name]
22 | model_path = clip._download(url)
23 |
24 | try:
25 | # loading JIT archive
26 | model = torch.jit.load(model_path, map_location="cpu").eval()
27 | state_dict = None
28 |
29 | except RuntimeError:
30 | state_dict = torch.load(model_path, map_location="cpu")
31 |
32 | model = clip.build_model(state_dict or model.state_dict())
33 |
34 | return model
35 |
36 |
37 | class TextEncoder(nn.Module):
38 | def __init__(self, clip_model):
39 | super().__init__()
40 | self.transformer = clip_model.transformer
41 | self.positional_embedding = clip_model.positional_embedding
42 | self.ln_final = clip_model.ln_final
43 | self.text_projection = clip_model.text_projection
44 | self.dtype = clip_model.dtype
45 |
46 | def forward(self, prompts, tokenized_prompts):
47 | x = prompts + self.positional_embedding.type(self.dtype)
48 | x = x.permute(1, 0, 2) # NLD -> LND
49 | x = self.transformer(x)
50 | x = x.permute(1, 0, 2) # LND -> NLD
51 | x = self.ln_final(x).type(self.dtype)
52 |
53 | # x.shape = [batch_size, n_ctx, transformer.width]
54 | # take features from the eot embedding (eot_token is the highest number in each sequence)
55 | x = x[torch.arange(x.shape[0]),
56 | tokenized_prompts.argmax(dim=-1)] @ self.text_projection
57 |
58 | return x
59 |
60 |
61 | class PromptLearner(nn.Module):
62 | def __init__(self, cfg, classnames, clip_model):
63 | super().__init__()
64 | n_cls = len(classnames)
65 | n_ctx = cfg.TRAINER.COOP.N_CTX
66 | ctx_init = cfg.TRAINER.COOP.CTX_INIT
67 | dtype = clip_model.dtype
68 | ctx_dim = clip_model.ln_final.weight.shape[0]
69 | clip_imsize = clip_model.visual.input_resolution
70 | cfg_imsize = cfg.INPUT.SIZE[0]
71 | assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
72 |
73 | # if ctx_init:
74 | # # use given words to initialize context vectors
75 | # ctx_init = ctx_init.replace("_", " ")
76 | # n_ctx = len(ctx_init.split(" "))
77 | # prompt = clip.tokenize(ctx_init)
78 | # with torch.no_grad():
79 | # embedding = clip_model.token_embedding(prompt).type(dtype)
80 | # ctx_vectors = embedding[0, 1:1 + n_ctx, :]
81 | # prompt_prefix = ctx_init
82 | if ctx_init:
83 | ctx_init = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
84 | ctx_init = ctx_init.replace(" {}.", "")
85 | ctx_init = ctx_init.replace("_", " ")
86 | prompt_n_ctx = len(ctx_init.split(" "))
87 |
88 | assert n_ctx >= prompt_n_ctx, f"#tokens ({n_ctx}) should larger equal than #initial prompt tokens ({prompt_n_ctx}, {ctx_init})"
89 |
90 | prompt = clip.tokenize(ctx_init)
91 | with torch.no_grad():
92 | embedding = clip_model.token_embedding(prompt).type(dtype)
93 |
94 | ctx_vectors = torch.zeros(n_ctx, ctx_dim, dtype=dtype)
95 |
96 | ctx_vectors[n_ctx - prompt_n_ctx:, :] = embedding[0, 1:1 +
97 | prompt_n_ctx, :]
98 | prompt_prefix = " ".join(["X"] * (n_ctx - prompt_n_ctx))
99 | prompt_prefix = f"{prompt_prefix} {ctx_init}"
100 | else:
101 | # random initialization
102 | if cfg.TRAINER.COOP.CSC:
103 | print("Initializing class-specific contexts")
104 | ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype)
105 | else:
106 | print("Initializing a generic context")
107 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
108 | nn.init.normal_(ctx_vectors, std=0.02)
109 | prompt_prefix = " ".join(["X"] * n_ctx)
110 |
111 | print(f'Initial context: "{prompt_prefix}"')
112 | print(f"Number of context words (tokens): {n_ctx}")
113 |
114 | self.ctx = nn.Parameter(ctx_vectors) # to be optimized
115 |
116 | classnames = [name.replace("_", " ") for name in classnames]
117 | name_lens = [len(_tokenizer.encode(name)) for name in classnames]
118 | prompts = [prompt_prefix + " " + name + "." for name in classnames]
119 |
120 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
121 | with torch.no_grad():
122 | embedding = clip_model.token_embedding(tokenized_prompts).type(
123 | dtype)
124 |
125 | # These token vectors will be saved when in save_model(),
126 | # but they should be ignored in load_model() as we want to use
127 | # those computed using the current class names
128 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
129 | self.register_buffer("token_suffix",
130 | embedding[:, 1 + n_ctx:, :]) # CLS, EOS
131 |
132 | self.n_cls = n_cls
133 | self.n_ctx = n_ctx
134 | self.tokenized_prompts = tokenized_prompts # torch.Tensor
135 | self.name_lens = name_lens
136 | self.class_token_position = cfg.TRAINER.COOP.CLASS_TOKEN_POSITION
137 |
138 | def forward(self):
139 | ctx = self.ctx
140 | if ctx.dim() == 2:
141 | ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
142 |
143 | prefix = self.token_prefix
144 | suffix = self.token_suffix
145 |
146 | if self.class_token_position == "end":
147 | prompts = torch.cat(
148 | [
149 | prefix, # (n_cls, 1, dim)
150 | ctx, # (n_cls, n_ctx, dim)
151 | suffix, # (n_cls, *, dim)
152 | ],
153 | dim=1,
154 | )
155 |
156 | elif self.class_token_position == "middle":
157 | half_n_ctx = self.n_ctx // 2
158 | prompts = []
159 | for i in range(self.n_cls):
160 | name_len = self.name_lens[i]
161 | prefix_i = prefix[i:i + 1, :, :]
162 | class_i = suffix[i:i + 1, :name_len, :]
163 | suffix_i = suffix[i:i + 1, name_len:, :]
164 | ctx_i_half1 = ctx[i:i + 1, :half_n_ctx, :]
165 | ctx_i_half2 = ctx[i:i + 1, half_n_ctx:, :]
166 | prompt = torch.cat(
167 | [
168 | prefix_i, # (1, 1, dim)
169 | ctx_i_half1, # (1, n_ctx//2, dim)
170 | class_i, # (1, name_len, dim)
171 | ctx_i_half2, # (1, n_ctx//2, dim)
172 | suffix_i, # (1, *, dim)
173 | ],
174 | dim=1,
175 | )
176 | prompts.append(prompt)
177 | prompts = torch.cat(prompts, dim=0)
178 |
179 | elif self.class_token_position == "front":
180 | prompts = []
181 | for i in range(self.n_cls):
182 | name_len = self.name_lens[i]
183 | prefix_i = prefix[i:i + 1, :, :]
184 | class_i = suffix[i:i + 1, :name_len, :]
185 | suffix_i = suffix[i:i + 1, name_len:, :]
186 | ctx_i = ctx[i:i + 1, :, :]
187 | prompt = torch.cat(
188 | [
189 | prefix_i, # (1, 1, dim)
190 | class_i, # (1, name_len, dim)
191 | ctx_i, # (1, n_ctx, dim)
192 | suffix_i, # (1, *, dim)
193 | ],
194 | dim=1,
195 | )
196 | prompts.append(prompt)
197 | prompts = torch.cat(prompts, dim=0)
198 |
199 | else:
200 | raise ValueError
201 |
202 | return prompts
203 |
204 |
205 | CUSTOM_TEMPLATES = {
206 | # "OxfordPets": "a photo of a {}, a type of pet.",
207 | "OxfordPets": "a type of pet, a photo of a {}.",
208 | # "OxfordFlowers": "a photo of a {}, a type of flower.",
209 | "OxfordFlowers": "a type of flower, a photo of a {}.",
210 | "FGVCAircraft": "a type of aircraft, a photo of a {}.",
211 | "DescribableTextures": "a texture of {}.",
212 | "EuroSAT": "a centered satellite photo of {}.",
213 | "StanfordCars": "a photo of a {}.",
214 | # "Food101": "a photo of {}, a type of food.",
215 | "Food101": "a type of food, a photo of {}.",
216 | "SUN397": "a photo of a {}.",
217 | "Caltech101": "a photo of a {}.",
218 | "UCF101": "a photo of a person doing {}.",
219 | "ImageNet": "a photo of a {}.",
220 | "ImageNetSketch": "a photo of a {}.",
221 | "ImageNetV2": "a photo of a {}.",
222 | "ImageNetA": "a photo of a {}.",
223 | "ImageNetR": "a photo of a {}.",
224 | }
225 |
226 |
227 | class CustomCLIP(nn.Module):
228 | def __init__(self, cfg, classnames, clip_model):
229 | super().__init__()
230 | self.prompt_learner = PromptLearner(cfg, classnames, clip_model)
231 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts
232 | self.image_encoder = clip_model.visual
233 | self.text_encoder = TextEncoder(clip_model)
234 | self.logit_scale = clip_model.logit_scale
235 | self.dtype = clip_model.dtype
236 |
237 | def forward(self, image):
238 | image_features = self.image_encoder(image.type(self.dtype))
239 |
240 | prompts = self.prompt_learner()
241 | tokenized_prompts = self.tokenized_prompts
242 | text_features = self.text_encoder(prompts, tokenized_prompts)
243 |
244 | image_features = image_features / image_features.norm(dim=-1,
245 | keepdim=True)
246 | text_features = text_features / text_features.norm(dim=-1,
247 | keepdim=True)
248 |
249 | logit_scale = self.logit_scale.exp()
250 | logits = logit_scale * image_features @ text_features.t()
251 |
252 | return logits
253 |
254 |
255 | @TRAINER_REGISTRY.register()
256 | class CoOp(TrainerX):
257 | """Context Optimization (CoOp).
258 |
259 | Learning to Prompt for Vision-Language Models
260 | https://arxiv.org/abs/2109.01134
261 | """
262 | def check_cfg(self, cfg):
263 | assert cfg.TRAINER.COOP.PREC in ["fp16", "fp32", "amp"]
264 |
265 | def build_model(self):
266 | cfg = self.cfg
267 | classnames = self.dm.dataset.classnames
268 |
269 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
270 | clip_model = load_clip_to_cpu(cfg)
271 |
272 | if cfg.TRAINER.COOP.PREC == "fp32" or cfg.TRAINER.COOP.PREC == "amp":
273 | # CLIP's default precision is fp16
274 | clip_model.float()
275 |
276 | print("Building custom CLIP")
277 | self.model = CustomCLIP(cfg, classnames, clip_model)
278 |
279 | print("Turning off gradients in both the image and the text encoder")
280 | for name, param in self.model.named_parameters():
281 | if "prompt_learner" not in name:
282 | param.requires_grad_(False)
283 |
284 | if cfg.MODEL.INIT_WEIGHTS:
285 | load_pretrained_weights(self.model.prompt_learner,
286 | cfg.MODEL.INIT_WEIGHTS)
287 |
288 | self.model.to(self.device)
289 | # NOTE: only give prompt_learner to the optimizer
290 | self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)
291 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
292 | self.register_model("prompt_learner", self.model.prompt_learner,
293 | self.optim, self.sched)
294 |
295 | self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == "amp" else None
296 |
297 | # Note that multi-gpu training could be slow because CLIP's size is
298 | # big, which slows down the copy operation in DataParallel
299 | device_count = torch.cuda.device_count()
300 | if device_count > 1:
301 | print(
302 | f"Multiple GPUs detected (n_gpus={device_count}), use all of them!"
303 | )
304 | self.model = nn.DataParallel(self.model)
305 |
306 | def forward_backward(self, batch):
307 | image, label = self.parse_batch_train(batch)
308 |
309 | prec = self.cfg.TRAINER.COOP.PREC
310 | if prec == "amp":
311 | with autocast():
312 | output = self.model(image)
313 | loss = F.cross_entropy(output, label)
314 | self.optim.zero_grad()
315 | self.scaler.scale(loss).backward()
316 | self.scaler.step(self.optim)
317 | self.scaler.update()
318 | else:
319 | output = self.model(image)
320 | loss = F.cross_entropy(output, label)
321 | self.model_backward_and_update(loss)
322 |
323 | loss_summary = {
324 | "loss": loss.item(),
325 | "acc": compute_accuracy(output, label)[0].item(),
326 | }
327 |
328 | if (self.batch_idx + 1) == self.num_batches:
329 | self.update_lr()
330 |
331 | return loss_summary
332 |
333 | def parse_batch_train(self, batch):
334 | input = batch["img"]
335 | label = batch["label"]
336 | input = input.to(self.device)
337 | label = label.to(self.device)
338 | return input, label
339 |
340 | def load_model(self, directory, epoch=None):
341 | if not directory:
342 | print(
343 | "Note that load_model() is skipped as no pretrained model is given"
344 | )
345 | return
346 |
347 | names = self.get_model_names()
348 |
349 | # By default, the best model is loaded
350 | model_file = "model-best.pth.tar"
351 |
352 | if epoch is not None:
353 | model_file = "model.pth.tar-" + str(epoch)
354 |
355 | for name in names:
356 | model_path = osp.join(directory, name, model_file)
357 |
358 | if not osp.exists(model_path):
359 | raise FileNotFoundError(
360 | 'Model not found at "{}"'.format(model_path))
361 |
362 | checkpoint = load_checkpoint(model_path)
363 | state_dict = checkpoint["state_dict"]
364 | epoch = checkpoint["epoch"]
365 |
366 | # Ignore fixed token vectors
367 | if "token_prefix" in state_dict:
368 | del state_dict["token_prefix"]
369 |
370 | if "token_suffix" in state_dict:
371 | del state_dict["token_suffix"]
372 |
373 | print("Loading weights to {} "
374 | 'from "{}" (epoch = {})'.format(name, model_path, epoch))
375 | # set strict=False
376 | self._models[name].load_state_dict(state_dict, strict=False)
377 |
--------------------------------------------------------------------------------
/trainers/imagenet_templates.py:
--------------------------------------------------------------------------------
1 | # source: https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb
2 |
3 | IMAGENET_TEMPLATES = [
4 | "a bad photo of a {}.",
5 | "a photo of many {}.",
6 | "a sculpture of a {}.",
7 | "a photo of the hard to see {}.",
8 | "a low resolution photo of the {}.",
9 | "a rendering of a {}.",
10 | "graffiti of a {}.",
11 | "a bad photo of the {}.",
12 | "a cropped photo of the {}.",
13 | "a tattoo of a {}.",
14 | "the embroidered {}.",
15 | "a photo of a hard to see {}.",
16 | "a bright photo of a {}.",
17 | "a photo of a clean {}.",
18 | "a photo of a dirty {}.",
19 | "a dark photo of the {}.",
20 | "a drawing of a {}.",
21 | "a photo of my {}.",
22 | "the plastic {}.",
23 | "a photo of the cool {}.",
24 | "a close-up photo of a {}.",
25 | "a black and white photo of the {}.",
26 | "a painting of the {}.",
27 | "a painting of a {}.",
28 | "a pixelated photo of the {}.",
29 | "a sculpture of the {}.",
30 | "a bright photo of the {}.",
31 | "a cropped photo of a {}.",
32 | "a plastic {}.",
33 | "a photo of the dirty {}.",
34 | "a jpeg corrupted photo of a {}.",
35 | "a blurry photo of the {}.",
36 | "a photo of the {}.",
37 | "a good photo of the {}.",
38 | "a rendering of the {}.",
39 | "a {} in a video game.",
40 | "a photo of one {}.",
41 | "a doodle of a {}.",
42 | "a close-up photo of the {}.",
43 | "a photo of a {}.",
44 | "the origami {}.",
45 | "the {} in a video game.",
46 | "a sketch of a {}.",
47 | "a doodle of the {}.",
48 | "a origami {}.",
49 | "a low resolution photo of a {}.",
50 | "the toy {}.",
51 | "a rendition of the {}.",
52 | "a photo of the clean {}.",
53 | "a photo of a large {}.",
54 | "a rendition of a {}.",
55 | "a photo of a nice {}.",
56 | "a photo of a weird {}.",
57 | "a blurry photo of a {}.",
58 | "a cartoon {}.",
59 | "art of a {}.",
60 | "a sketch of the {}.",
61 | "a embroidered {}.",
62 | "a pixelated photo of a {}.",
63 | "itap of the {}.",
64 | "a jpeg corrupted photo of the {}.",
65 | "a good photo of a {}.",
66 | "a plushie {}.",
67 | "a photo of the nice {}.",
68 | "a photo of the small {}.",
69 | "a photo of the weird {}.",
70 | "the cartoon {}.",
71 | "art of the {}.",
72 | "a drawing of the {}.",
73 | "a photo of the large {}.",
74 | "a black and white photo of a {}.",
75 | "the plushie {}.",
76 | "a dark photo of a {}.",
77 | "itap of a {}.",
78 | "graffiti of the {}.",
79 | "a toy {}.",
80 | "itap of my {}.",
81 | "a photo of a cool {}.",
82 | "a photo of a small {}.",
83 | "a tattoo of the {}.",
84 | ]
85 |
86 | IMAGENET_TEMPLATES_SELECT = [
87 | "itap of a {}.",
88 | "a bad photo of the {}.",
89 | "a origami {}.",
90 | "a photo of the large {}.",
91 | "a {} in a video game.",
92 | "art of the {}.",
93 | "a photo of the small {}.",
94 | ]
95 |
--------------------------------------------------------------------------------
/trainers/zsclip.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from dassl.engine import TRAINER_REGISTRY, TrainerX
5 | from dassl.optim import build_optimizer, build_lr_scheduler
6 |
7 | from clip import clip
8 | from clip.model import convert_weights
9 |
10 | from .coop import load_clip_to_cpu
11 | from .imagenet_templates import IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT
12 |
13 | CUSTOM_TEMPLATES = {
14 | # "OxfordPets": "a photo of a {}, a type of pet.",
15 | "OxfordPets": "a type of pet, a photo of a {}.",
16 | # "OxfordFlowers": "a photo of a {}, a type of flower.",
17 | "OxfordFlowers": "a type of flower, a photo of a {}.",
18 | "FGVCAircraft": "a photo of a {}, a type of aircraft.",
19 | "DescribableTextures": "{} texture.",
20 | "EuroSAT": "a centered satellite photo of {}.",
21 | "StanfordCars": "a photo of a {}.",
22 | # "Food101": "a photo of {}, a type of food.",
23 | "Food101": "a type of food, a photo of {}.",
24 | "SUN397": "a photo of a {}.",
25 | "Caltech101": "a photo of a {}.",
26 | "UCF101": "a photo of a person doing {}.",
27 | "ImageNet": "a photo of a {}.",
28 | "ImageNetSketch": "a photo of a {}.",
29 | "ImageNetV2": "a photo of a {}.",
30 | "ImageNetA": "a photo of a {}.",
31 | "ImageNetR": "a photo of a {}.",
32 | }
33 |
34 |
35 | @TRAINER_REGISTRY.register()
36 | class ZeroshotCLIP(TrainerX):
37 | def build_model(self):
38 | cfg = self.cfg
39 | classnames = self.dm.dataset.classnames
40 |
41 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
42 | clip_model = load_clip_to_cpu(cfg)
43 | clip_model.to(self.device)
44 |
45 | temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
46 | prompts = [temp.format(c.replace("_", " ")) for c in classnames]
47 | print(f"Prompts: {prompts}")
48 | prompts = torch.cat([clip.tokenize(p) for p in prompts])
49 | prompts = prompts.to(self.device)
50 |
51 | with torch.no_grad():
52 | text_features = clip_model.encode_text(prompts)
53 | text_features = text_features / text_features.norm(dim=-1,
54 | keepdim=True)
55 |
56 | self.text_features = text_features
57 | self.clip_model = clip_model
58 |
59 | def model_inference(self, image):
60 | image_features = self.clip_model.encode_image(image)
61 | image_features = image_features / image_features.norm(dim=-1,
62 | keepdim=True)
63 | logit_scale = self.clip_model.logit_scale.exp()
64 | logits = logit_scale * image_features @ self.text_features.t()
65 | return logits
66 |
67 |
68 | @TRAINER_REGISTRY.register()
69 | class ZeroshotCLIP2(ZeroshotCLIP):
70 | """Prompt ensembling."""
71 |
72 | # templates = IMAGENET_TEMPLATES
73 | templates = IMAGENET_TEMPLATES_SELECT
74 |
75 | def build_model(self):
76 | cfg = self.cfg
77 | classnames = self.dm.dataset.classnames
78 |
79 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
80 | clip_model = load_clip_to_cpu(cfg)
81 | clip_model.to(self.device)
82 |
83 | for params in clip_model.parameters():
84 | params.requires_grad_(False)
85 |
86 | # add custom-made prompt
87 | if cfg.DATASET.NAME != "ImageNet":
88 | self.templates += [CUSTOM_TEMPLATES[cfg.DATASET.NAME]]
89 |
90 | num_temp = len(self.templates)
91 | print(f"Prompt ensembling (n={num_temp})")
92 |
93 | mean_text_features = 0
94 | for i, temp in enumerate(self.templates):
95 | prompts = [temp.format(c.replace("_", " ")) for c in classnames]
96 | prompts = torch.cat([clip.tokenize(p)
97 | for p in prompts]).to(self.device)
98 | text_features = clip_model.encode_text(prompts)
99 | text_features = text_features / text_features.norm(dim=-1,
100 | keepdim=True)
101 | mean_text_features = mean_text_features + text_features
102 | mean_text_features = mean_text_features / num_temp
103 | mean_text_features = mean_text_features / mean_text_features.norm(
104 | dim=-1, keepdim=True)
105 |
106 | self.text_features = mean_text_features
107 | self.clip_model = clip_model
108 |
--------------------------------------------------------------------------------