├── .gitignore ├── DATASETS.md ├── LICENSE ├── README.md ├── apt ├── OODRB │ ├── README.md │ ├── __init__.py │ ├── cifar.py │ ├── dataset.py │ ├── image_ids │ │ ├── CIFAR10-R_image_ids.txt │ │ ├── CINIC-10_image_ids.txt │ │ ├── imagenet-a_image_ids.txt │ │ ├── imagenet-r_image_ids.txt │ │ ├── imagenet-v2_image_ids.txt │ │ ├── imagenet_test_image_ids.txt │ │ └── objectnet_image_ids.txt │ └── imagenet.py ├── backbone │ └── README.md ├── clip │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── clip.py │ ├── model.py │ └── simple_tokenizer.py ├── configs │ ├── datasets │ │ ├── caltech101.yaml │ │ ├── dtd.yaml │ │ ├── eurosat.yaml │ │ ├── fgvc_aircraft.yaml │ │ ├── food101.yaml │ │ ├── imagenet.yaml │ │ ├── imagenet_a.yaml │ │ ├── imagenet_r.yaml │ │ ├── imagenet_sketch.yaml │ │ ├── imagenetv2.yaml │ │ ├── oxford_flowers.yaml │ │ ├── oxford_pets.yaml │ │ ├── stanford_cars.yaml │ │ ├── sun397.yaml │ │ └── ucf101.yaml │ └── trainers │ │ ├── APT │ │ ├── rn101.yaml │ │ ├── rn101_ep50.yaml │ │ ├── rn50.yaml │ │ ├── rn50_ctxv1.yaml │ │ ├── rn50_ep100.yaml │ │ ├── rn50_ep50.yaml │ │ ├── rn50_ep50_ctxv1.yaml │ │ ├── rn50_val.yaml │ │ ├── vit_b16.yaml │ │ ├── vit_b16_ctxv1.yaml │ │ ├── vit_b16_ep100.yaml │ │ ├── vit_b16_ep100_ctxv1.yaml │ │ ├── vit_b16_ep50.yaml │ │ ├── vit_b16_ep50_ctxv1.yaml │ │ ├── vit_b32.yaml │ │ ├── vit_b32_ep100.yaml │ │ ├── vit_b32_ep20.yaml │ │ ├── vit_b32_ep50.yaml │ │ └── vit_b32_st.yaml │ │ └── CoOp │ │ ├── rn101.yaml │ │ ├── rn101_ep50.yaml │ │ ├── rn50.yaml │ │ ├── rn50_ctxv1.yaml │ │ ├── rn50_ep100.yaml │ │ ├── rn50_ep50.yaml │ │ ├── rn50_ep50_ctxv1.yaml │ │ ├── rn50_val.yaml │ │ ├── vit_b16.yaml │ │ ├── vit_b16_ctxv1.yaml │ │ ├── vit_b16_ep100.yaml │ │ ├── vit_b16_ep100_ctxv1.yaml │ │ ├── vit_b16_ep50.yaml │ │ ├── vit_b16_ep50_ctxv1.yaml │ │ ├── vit_b32.yaml │ │ ├── vit_b32_ep20.yaml │ │ └── vit_b32_ep50.yaml ├── datasets │ ├── __init__.py │ ├── caltech101.py │ ├── dtd.py │ ├── eurosat.py │ ├── fgvc_aircraft.py │ ├── food101.py │ ├── imagenet.py │ ├── imagenet_a.py │ ├── imagenet_r.py │ ├── imagenet_sketch.py │ ├── imagenetv2.py │ ├── oxford_flowers.py │ ├── oxford_pets.py │ ├── stanford_cars.py │ ├── sun397.py │ └── ucf101.py ├── evaluate.py ├── interpret_prompt.py ├── scripts │ ├── APT.sh │ ├── CoOp.sh │ └── eval.sh ├── train.py ├── trainers │ ├── __init__.py │ ├── apt.py │ └── imagenet_templates.py └── utils.py ├── assets └── one_word_boost.png └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /DATASETS.md: -------------------------------------------------------------------------------- 1 | The instructions below are adapted from [CoOp](https://github.com/KaiyangZhou/CoOp/blob/main/DATASETS.md) with some modification to add few Out-Of-Distribution datasets at the end. It appears that the official link provided may be invalid. You can download the required data from various platforms and then follow the file structure to place the corresponding data accordingly. 2 | 3 | # How to install datasets 4 | 5 | We suggest putting all datasets under the same folder (say `$DATA`) to ease management and following the instructions below to organize datasets to avoid modifying the source code. The file structure looks like 6 | 7 | ``` 8 | $DATA/ 9 | |–– imagenet/ 10 | |–– caltech-101/ 11 | |–– oxford_pets/ 12 | |–– stanford_cars/ 13 | ``` 14 | 15 | If you have some datasets already installed somewhere else, you can create symbolic links in `$DATA/dataset_name` that point to the original data to avoid duplicate download. 16 | 17 | Datasets list: 18 | - [ImageNet](#imagenet) 19 | - [Caltech101](#caltech101) 20 | - [OxfordPets](#oxfordpets) 21 | - [StanfordCars](#stanfordcars) 22 | - [Flowers102](#flowers102) 23 | - [Food101](#food101) 24 | - [FGVCAircraft](#fgvcaircraft) 25 | - [SUN397](#sun397) 26 | - [DTD](#dtd) 27 | - [EuroSAT](#eurosat) 28 | - [UCF101](#ucf101) 29 | 30 | The instructions to prepare each dataset are detailed below. To ensure reproducibility and fair comparison for future work, we provide fixed train/val/test splits for all datasets except ImageNet where the validation set is used as test set. The fixed splits are either from the original datasets (if available) or created by us. 31 | 32 | ### ImageNet 33 | - Create a folder named `imagenet/` under `$DATA`. 34 | - Create `images/` under `imagenet/`. 35 | - Download the dataset from the [official website](https://image-net.org/index.php) and extract the training and validation sets to `$DATA/imagenet/images`. The directory structure should look like 36 | ``` 37 | imagenet/ 38 | |–– images/ 39 | | |–– train/ # contains 1,000 folders like n01440764, n01443537, etc. 40 | | |–– val/ 41 | ``` 42 | - If you had downloaded the ImageNet dataset before, you can create symbolic links to map the training and validation sets to `$DATA/imagenet/images`. 43 | - Download the `classnames.txt` to `$DATA/imagenet/` from this [link](https://drive.google.com/file/d/1-61f_ol79pViBFDG_IDlUQSwoLcn2XXF/view?usp=sharing). The class names are copied from [CLIP](https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb). 44 | 45 | ### Caltech101 46 | - Create a folder named `caltech-101/` under `$DATA`. 47 | - Download `101_ObjectCategories.tar.gz` from http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz and extract the file under `$DATA/caltech-101`. 48 | - Download `split_zhou_Caltech101.json` from this [link](https://drive.google.com/file/d/1hyarUivQE36mY6jSomru6Fjd-JzwcCzN/view?usp=sharing) and put it under `$DATA/caltech-101`. 49 | 50 | The directory structure should look like 51 | ``` 52 | caltech-101/ 53 | |–– 101_ObjectCategories/ 54 | |–– split_zhou_Caltech101.json 55 | ``` 56 | 57 | ### OxfordPets 58 | - Create a folder named `oxford_pets/` under `$DATA`. 59 | - Download the images from https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz. 60 | - Download the annotations from https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz. 61 | - Download `split_zhou_OxfordPets.json` from this [link](https://drive.google.com/file/d/1501r8Ber4nNKvmlFVQZ8SeUHTcdTTEqs/view?usp=sharing). 62 | 63 | The directory structure should look like 64 | ``` 65 | oxford_pets/ 66 | |–– images/ 67 | |–– annotations/ 68 | |–– split_zhou_OxfordPets.json 69 | ``` 70 | 71 | ### StanfordCars 72 | - Create a folder named `stanford_cars/` under `$DATA`. 73 | - Download the train images http://ai.stanford.edu/~jkrause/car196/cars_train.tgz. 74 | - Download the test images http://ai.stanford.edu/~jkrause/car196/cars_test.tgz. 75 | - Download the train labels https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz. 76 | - Download the test labels http://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat. 77 | - Download `split_zhou_StanfordCars.json` from this [link](https://drive.google.com/file/d/1ObCFbaAgVu0I-k_Au-gIUcefirdAuizT/view?usp=sharing). 78 | 79 | The directory structure should look like 80 | ``` 81 | stanford_cars/ 82 | |–– cars_test\ 83 | |–– cars_test_annos_withlabels.mat 84 | |–– cars_train\ 85 | |–– devkit\ 86 | |–– split_zhou_StanfordCars.json 87 | ``` 88 | 89 | ### Flowers102 90 | - Create a folder named `oxford_flowers/` under `$DATA`. 91 | - Download the images and labels from https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz and https://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat respectively. 92 | - Download `cat_to_name.json` from [here](https://drive.google.com/file/d/1AkcxCXeK_RCGCEC_GvmWxjcjaNhu-at0/view?usp=sharing). 93 | - Download `split_zhou_OxfordFlowers.json` from [here](https://drive.google.com/file/d/1Pp0sRXzZFZq15zVOzKjKBu4A9i01nozT/view?usp=sharing). 94 | 95 | The directory structure should look like 96 | ``` 97 | oxford_flowers/ 98 | |–– cat_to_name.json 99 | |–– imagelabels.mat 100 | |–– jpg/ 101 | |–– split_zhou_OxfordFlowers.json 102 | ``` 103 | 104 | ### Food101 105 | - Download the dataset from https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/ and extract the file `food-101.tar.gz` under `$DATA`, resulting in a folder named `$DATA/food-101/`. 106 | - Download `split_zhou_Food101.json` from [here](https://drive.google.com/file/d/1QK0tGi096I0Ba6kggatX1ee6dJFIcEJl/view?usp=sharing). 107 | 108 | The directory structure should look like 109 | ``` 110 | food-101/ 111 | |–– images/ 112 | |–– license_agreement.txt 113 | |–– meta/ 114 | |–– README.txt 115 | |–– split_zhou_Food101.json 116 | ``` 117 | 118 | ### FGVCAircraft 119 | - Download the data from https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz. 120 | - Extract `fgvc-aircraft-2013b.tar.gz` and keep only `data/`. 121 | - Move `data/` to `$DATA` and rename the folder to `fgvc_aircraft/`. 122 | 123 | The directory structure should look like 124 | ``` 125 | fgvc_aircraft/ 126 | |–– images/ 127 | |–– ... # a bunch of .txt files 128 | ``` 129 | 130 | ### SUN397 131 | - Create a folder named `sun397/` under `$DATA`. 132 | - Download the images http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz. 133 | - Download the partitions https://vision.princeton.edu/projects/2010/SUN/download/Partitions.zip. 134 | - Extract these files under `$DATA/sun397/`. 135 | - Download `split_zhou_SUN397.json` from this [link](https://drive.google.com/file/d/1y2RD81BYuiyvebdN-JymPfyWYcd8_MUq/view?usp=sharing). 136 | 137 | The directory structure should look like 138 | ``` 139 | sun397/ 140 | |–– SUN397/ 141 | |–– split_zhou_SUN397.json 142 | |–– ... # a bunch of .txt files 143 | ``` 144 | 145 | ### DTD 146 | - Download the dataset from https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz and extract it to `$DATA`. This should lead to `$DATA/dtd/`. 147 | - Download `split_zhou_DescribableTextures.json` from this [link](https://drive.google.com/file/d/1u3_QfB467jqHgNXC00UIzbLZRQCg2S7x/view?usp=sharing). 148 | 149 | The directory structure should look like 150 | ``` 151 | dtd/ 152 | |–– images/ 153 | |–– imdb/ 154 | |–– labels/ 155 | |–– split_zhou_DescribableTextures.json 156 | ``` 157 | 158 | ### EuroSAT 159 | - Create a folder named `eurosat/` under `$DATA`. 160 | - Download the dataset from http://madm.dfki.de/files/sentinel/EuroSAT.zip and extract it to `$DATA/eurosat/`. 161 | - Download `split_zhou_EuroSAT.json` from [here](https://drive.google.com/file/d/1Ip7yaCWFi0eaOFUGga0lUdVi_DDQth1o/view?usp=sharing). 162 | 163 | The directory structure should look like 164 | ``` 165 | eurosat/ 166 | |–– 2750/ 167 | |–– split_zhou_EuroSAT.json 168 | ``` 169 | 170 | ### UCF101 171 | - Create a folder named `ucf101/` under `$DATA`. 172 | - Download the zip file `UCF-101-midframes.zip` from [here](https://drive.google.com/file/d/10Jqome3vtUA2keJkNanAiFpgbyC9Hc2O/view?usp=sharing) and extract it to `$DATA/ucf101/`. This zip file contains the extracted middle video frames. 173 | - Download `split_zhou_UCF101.json` from this [link](https://drive.google.com/file/d/1I0S0q91hJfsV9Gf4xDIjgDq4AqBNJb1y/view?usp=sharing). 174 | 175 | The directory structure should look like 176 | ``` 177 | ucf101/ 178 | |–– UCF-101-midframes/ 179 | |–– split_zhou_UCF101.json 180 | ``` 181 | 182 | ### ImageNet-Sketch 183 | 184 | - Download the dataset from https://github.com/HaohanWang/ImageNet-Sketch. 185 | - Extract the dataset to `$DATA/imagenet-sketch`. 186 | - Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenet-sketch/`. 187 | 188 | The directory structure should look like 189 | 190 | ``` 191 | imagenet-sketch/ 192 | |–– images/ # contains 1,000 folders whose names have the format of n* 193 | |–– classnames.txt 194 | ``` 195 | 196 | ### ImageNet-V2 197 | 198 | - Download the dataset from https://huggingface.co/datasets/vaishaal/ImageNetV2/tree/main. 199 | - Extract the dataset and rename `$DATA/imagenetv2-matched-frequency-format-val`. 200 | 201 | The directory structure should look like 202 | 203 | ``` 204 | imagenetv2-matched-frequency-format-val/ 205 | |–– 1000 folders # named 0-999 206 | ``` 207 | 208 | ### ImageNet-R 209 | 210 | - Download the dataset from https://people.eecs.berkeley.edu/~hendrycks/imagenet-r.tar. 211 | - Extract the dataset and rename `$DATA/imagenet-r`. 212 | 213 | The directory structure should look like 214 | 215 | ``` 216 | imagenet-r/ 217 | |–– folders # named by ImageNet class ID 218 | ``` 219 | 220 | ### ObjectNet 221 | 222 | Todo: upload the processed ObjectNet dataset. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Lin Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # APT: Adversarial Prompt Tuning 2 | The official code of the paper "One Prompt Word is Enough to Boost Adversarial Robustness for Pre-trained Vision-Language Models" which is accepted by the main conference of CVPR 2024. 3 | 4 | **Abstract:** Large pre-trained Vision-Language Models (VLMs) like CLIP, despite having remarkable generalization ability, are highly vulnerable to adversarial examples. This work studies the adversarial robustness of VLMs from the novel perspective of the text prompt instead of the extensively studied model weights (frozen in this work). We first show that **the effectiveness of both adversarial attack and defense are sensitive to the used text prompt**. Inspired by this, we **propose a method to improve resilience to adversarial attacks by learning a robust text prompt for VLMs**. The proposed method, named Adversarial Prompt Tuning (APT), is effective while being both computationally and data efficient. Extensive experiments are conducted across 15 datasets and 4 data sparsity schemes (from 1-shot to full training data settings) to show APT's superiority over hand-engineered prompts and other state-of-the-art adaption methods. APT demonstrated excellent abilities in terms of the in-distribution performance and the generalization under input distribution shift and across datasets. Surprisingly, by simply adding one learned word to the prompts, APT can significantly boost the accuracy and robustness ($\epsilon=4/255$) over the hand-engineered prompts by +13\% and +8.5\% on average respectively. The improvement further increases, in our most effective setting, to +26.4\% for accuracy and +16.7\% for robustness. 5 | 6 | Arxiv: https://arxiv.org/abs/2403.01849. 7 | 8 | ![one word boost](assets/one_word_boost.png) 9 | 10 | ## Preparation 11 | 12 | ### Code 13 | 14 | This code is built on top of [CoOp](https://github.com/KaiyangZhou/CoOp) which extensively uses the toolbox [Dassl.pytorch](https://github.com/KaiyangZhou/Dassl.pytorch) so you need to install the `dassl` environment first. Simply follow the instructions described [here](https://github.com/KaiyangZhou/Dassl.pytorch#installation) to install `dassl`. After that, run `pip install -r requirements.txt` to install a few more packages (this should be done when `dassl` is activated). Then, you are ready to go. 15 | 16 | ### Data 17 | 18 | Follow [DATASETS.md](DATASETS.md) to install the datasets. After successfully set up the datasets, the data directory variables, `DATA`, in each script under `/apt/scripts` MUST be updated with the root path of those datasets. 19 | 20 | ### Pre-trained Robust CLIP Backbone 21 | 22 | We adopt as backbone the pre-trained adversarially-robust CLIP models from [TeCoA](https://github.com/cvlab-columbia/ZSRobust4FoundationModel). The used pre-trained weights are provided [here](https://emckclac-my.sharepoint.com/:f:/g/personal/k19010102_kcl_ac_uk/EmZ98eFLv71FqQyqPLvWNTkBYNAKPyx_wYEDjNPx7smKCA?e=8AB51S). To run the code, the pre-trained backbone models should be placed under the directory `/backbone`. The code currently supports two architectures: ViT-B/32 (named `vitb32`) and ResNet50 (named `rn50`). Taking an example of tuning ViT-B/32 at epsilon=4/255, the path to the checkpoint is `/apt/backbone/vitb32_eps4.pth.tar`. Note that our code can be easily adapted to load other pre-trained models as backbone. 23 | 24 | ## Adversarial Prompt Tuning 25 | 26 | The following command (being executed under the directory `/apt`) runs APT to tune text prompt using an experiment setting specified by `/apt/configs/trainers/APT/vit_b32_ep50.yaml`: 27 | 28 | ```bash 29 | bash scripts/APT.sh imagenet vit_b32_ep50 end 16 16 False 4 2.67 3 0 onfly 0 30 | ``` 31 | 32 | The above arguments correspond to in order: 33 | 34 | 1. dataset ID. The list of supported dataset ID is given by the name of dataset source code files under the directory `apt/datasets`. 35 | 2. training configuration identifier. For full specification, please refer to the corresponding file. There are other predefined configurations under `apt/configs/trainers/APT`. 36 | 3. the position of class token 37 | 4. the number of context vectors, `M` 38 | 5. the number of shots, `N`, `-1` for tuning with entire training set. 39 | 6. the variant of APT: True for Class Specific Context (CSC); False for Unified Context (UC) 40 | 7. the training perturbation budget, `\epsilon` 41 | 8. the step size of training adversary, `\alpha` 42 | 9. the number of steps for training adversary 43 | 10. seed of run. 44 | 11. the prompting strategy: "perturbed", "constant" and "onfly" 45 | 12. the step size, `\alpha`, for perturbing text prompt if prompting strategy "perturbed" used. 46 | 47 | ### Pre-trained Prompt Weights: 48 | 49 | To facilitate reproducibility, the pre-trained text prompt weights are provided [here](https://emckclac-my.sharepoint.com/:f:/g/personal/k19010102_kcl_ac_uk/EmZ98eFLv71FqQyqPLvWNTkBYNAKPyx_wYEDjNPx7smKCA). 50 | 51 | ## Adversarial Evaluation 52 | 53 | ### In-Distribution 54 | 55 | The following command evaluates the tuned text prompt against PGD attack on the test set of the same dataset as training: 56 | 57 | ```bash 58 | python evaluate.py path_to_checkpoint --cls-prompt prompter --attack pgd 59 | ``` 60 | 61 | * Replace the `path_to_checkpoint` with the real path ending by the seed-level directory, e.g., `output/imagenet/CoOpAT/vit_b32_ep50_st_16shots/nctx16_cscFalse_ctpend/eps4_alpha2.67_step3/seed0`. 62 | * By default, the same perturbation budget is used for evaluation as for training which is read from the saved configuration file. 63 | * `--cls-prompt` specifies the text prompt for classification (inference). By specifying `prompter`, the prompt is loaded from the saved, APT-tuned, weights. Otherwise, a string template is expected, e.g., "`a photo of a {}`" where `{}` is necessary and will be automatically replaced by the real class label. 64 | * `--atk-prompt` specifies the text prompt for attack, i.e., generating adversarial examples. By default, it uses the same prompt as the `--cls-prompt` unless specified otherwise. 65 | * The current code implements 4 adversarial attacks: PGD, TPGD, CW and AutoAttack. They are identified in `--attack` by `pgd`, `tpgd`, `cw` and `aa`, respectively. 66 | 67 | The evaluation result will be saved under the provided `path_to_checkpoint` directory in a name of `evaluation.yaml`. 68 | 69 | ### Zero-shot / Out-Of-Distribution 70 | 71 | The following command evaluates the tuned text prompt against PGD attack on the other dataset, OxfordFlowers in this case: 72 | 73 | ```bash 74 | python evaluate.py path_to_checkpoint --dataset OxfordFlowers --cls-prompt prompter --attack pgd 75 | ``` 76 | 77 | * `--dataset` specifies the target dataset to be evaluated on. Note that the naming of datasets can be found in the corresponding `.yaml` files [here](https://github.com/TreeLLi/APT/tree/main/apt/configs/datasets), which are different from the ones defined [here](#Adversarial Prompt Tuning). 78 | 79 | The evaluation result will be saved under the provided `path_to_checkpoint` directory in a name of `dist_shift.yaml`. 80 | 81 | ## Dependency 82 | 83 | The code is heavily built on top of the following projects: 84 | 85 | * [CoOp (IJCV 2022)](https://github.com/KaiyangZhou/CoOp) 86 | * [OODRobustBench (ICML2024 and ICLRW-DMLR 2024)](https://github.com/OODRobustBench/OODRobustBench) 87 | 88 | We sincerely appreciate their help! 89 | 90 | ## Citation 91 | 92 | ``` 93 | @inproceedings{li2024apt, 94 | title={One Prompt Word is Enough to Boost Adversarial Robustness for Pre-trained Vision-Language Models}, 95 | author={Lin Li*, Haoyan Guan*, Jianing Qiu, Michael Spratling}, 96 | booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 97 | year={2024} 98 | } 99 | ``` 100 | 101 | -------------------------------------------------------------------------------- /apt/OODRB/README.md: -------------------------------------------------------------------------------- 1 | This code is copied from OODRobustBench. -------------------------------------------------------------------------------- /apt/OODRB/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreeLLi/APT/c2e99e09329b9a2be87d02e5c7ae02092d2ac95d/apt/OODRB/__init__.py -------------------------------------------------------------------------------- /apt/OODRB/cifar.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from torch.utils.data import Dataset 5 | from torchvision.datasets import ImageFolder 6 | 7 | from .dataset import CustomImageFolder 8 | 9 | CIFAR10_CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] 10 | CIFAR10_CLS_TO_IDX = {cls:idx for idx, cls in enumerate(CIFAR10_CLASSES)} 11 | 12 | def auto_find_path(root, data_dir, file_name): 13 | data_path = os.path.join(root, file_name) 14 | if os.path.isfile(data_path): 15 | return root, data_path 16 | 17 | data_dir = os.path.join(root, data_dir) 18 | data_path = os.path.join(data_dir, file_name) 19 | if not os.path.isfile(data_path): 20 | raise Exception(f'!failed to find {file_name} under {root} or {data_dir}.') 21 | 22 | return data_dir, data_path 23 | 24 | class CIFAR10_1(Dataset): 25 | 26 | NAME = 'CIFAR10.1' 27 | DEFAULT_ROOT_DIR = 'cifar-10.1' 28 | DATA_FILES = { 29 | 'v4' : ('cifar10.1_v4_data.npy', 'cifar10.1_v4_labels.npy'), 30 | 'v6' : ('cifar10.1_v6_data.npy', 'cifar10.1_v6_labels.npy') 31 | } 32 | 33 | def __init__(self, root, version='v6', transform=None, target_transform=None): 34 | assert version in self.DATA_FILES 35 | 36 | data_fname, target_fname = self.DATA_FILES[version] 37 | root, data_path = auto_find_path(root, self.DEFAULT_ROOT_DIR, data_fname) 38 | target_path = os.path.join(root, target_fname) 39 | 40 | data = np.load(data_path) 41 | target = np.load(target_path) 42 | 43 | assert data.shape[0] == target.shape[0] 44 | true_size = 2021 if version == 'v4' else 2000 45 | assert data.shape[0] == true_size 46 | 47 | self.data, self.target = data, target 48 | self.transform, self.target_transform = transform, target_transform 49 | 50 | def __len__(self): 51 | return self.data.shape[0] 52 | 53 | def __getitem__(self, idx): 54 | data, target = self.data[idx], self.target[idx] 55 | if self.transform is not None: 56 | data = self.transform(data) 57 | 58 | if self.target_transform is not None: 59 | target = self.target_transform(target) 60 | 61 | return data, target.item() 62 | 63 | 64 | class CIFAR10_2(Dataset): 65 | NAME = 'CIFAR10.2' 66 | DEFAULT_ROOT_DIR = 'cifar-10.2' 67 | DATA_FILES = { 68 | 'train' : 'cifar102_train.npz', 69 | 'test' : 'cifar102_test.npz' 70 | } 71 | 72 | def __init__(self, root, split='test', transform=None, target_transform=None): 73 | _, data_path = auto_find_path(root, self.DEFAULT_ROOT_DIR, self.DATA_FILES[split]) 74 | data = np.load(data_path) 75 | self.images = data['images'] 76 | self.targets = data['labels'] 77 | 78 | true_size = 10000 if split == 'train' else 2000 79 | assert self.images.shape[0] == true_size 80 | 81 | self.transform = transform 82 | self.target_transform = target_transform 83 | 84 | def __len__(self): 85 | return self.images.shape[0] 86 | 87 | def __getitem__(self, idx): 88 | image, target = self.images[idx], self.targets[idx] 89 | if self.transform is not None: 90 | image = self.transform(image) 91 | 92 | if self.target_transform is not None: 93 | target = self.target_transform(target) 94 | 95 | return image, target.item() 96 | 97 | class CINIC10_(ImageFolder): 98 | SPLITS = ['train', 'val', 'test'] 99 | 100 | DATA_DIR = 'CINIC-10' 101 | 102 | MEAN = [0.47889522, 0.47227842, 0.43047404] 103 | STD = [0.24205776, 0.23828046, 0.25874835] 104 | 105 | def __init__(self, root, split='test', transform=None, target_transform=None): 106 | assert split in self.SPLITS 107 | 108 | data_dir = os.path.join(root, split) 109 | if not os.path.isdir(data_dir): 110 | root = os.path.join(root, self.DATA_DIR) 111 | data_dir = os.path.join(root, split) 112 | 113 | assert os.path.isdir(data_dir) 114 | 115 | filter_cifar10_test_data = lambda x: 'cifar10' not in x 116 | 117 | super().__init__(data_dir, 118 | transform=transform, 119 | target_transform=target_transform, 120 | is_valid_file=filter_cifar10_test_data) 121 | 122 | class CINIC10(CustomImageFolder): 123 | def __init__(self, root, transform=None, target_transform=None): 124 | super().__init__(os.path.join(root, 'CINIC-10/test'), 125 | transform=transform, 126 | target_transform=target_transform, 127 | class_to_idx=CIFAR10_CLS_TO_IDX, 128 | data_list='image_ids/CINIC-10_image_ids.txt') 129 | 130 | 131 | class CIFAR10_R(CustomImageFolder): 132 | def __init__(self, root, transform=None, target_transform=None): 133 | super().__init__(os.path.join(root, 'cifar-10-r'), 134 | transform=transform, 135 | target_transform=target_transform, 136 | class_to_idx=CIFAR10_CLS_TO_IDX, 137 | data_list='image_ids/CIFAR10-R_image_ids.txt') 138 | 139 | DATASETS = { 140 | 'cifar10.1' : CIFAR10_1, 141 | 'cifar10.2' : CIFAR10_2, 142 | 'cinic' : CINIC10, 143 | 'cifar10-r': CIFAR10_R 144 | } 145 | -------------------------------------------------------------------------------- /apt/OODRB/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is based on the code from https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py. 3 | """ 4 | import pkg_resources 5 | 6 | from torchvision.datasets.vision import VisionDataset 7 | 8 | import torch 9 | import torch.utils.data as data 10 | import torchvision.transforms as transforms 11 | 12 | from PIL import Image, ImageFile 13 | 14 | import os 15 | import os.path 16 | import sys 17 | 18 | ImageFile.LOAD_TRUNCATED_IMAGES = True 19 | 20 | def make_custom_dataset(root, path_imgs, class_to_idx): 21 | with open(pkg_resources.resource_filename(__name__, path_imgs), 'r') as f: 22 | fnames = f.readlines() 23 | images = [(os.path.join(root, 24 | c.split('\n')[0]), class_to_idx[c.split('/')[0]]) 25 | for c in fnames] 26 | 27 | return images 28 | 29 | 30 | class CustomDatasetFolder(VisionDataset): 31 | """A generic data loader where the samples are arranged in this way: :: 32 | root/class_x/xxx.ext 33 | root/class_x/xxy.ext 34 | root/class_x/xxz.ext 35 | root/class_y/123.ext 36 | root/class_y/nsdf3.ext 37 | root/class_y/asd932_.ext 38 | Args: 39 | root (string): Root directory path. 40 | loader (callable): A function to load a sample given its path. 41 | extensions (tuple[string]): A list of allowed extensions. 42 | both extensions and is_valid_file should not be passed. 43 | transform (callable, optional): A function/transform that takes in 44 | a sample and returns a transformed version. 45 | E.g, ``transforms.RandomCrop`` for images. 46 | target_transform (callable, optional): A function/transform that takes 47 | in the target and transforms it. 48 | is_valid_file (callable, optional): A function that takes path of an Image file 49 | and check if the file is a valid_file (used to check of corrupt files) 50 | both extensions and is_valid_file should not be passed. 51 | Attributes: 52 | classes (list): List of the class names. 53 | class_to_idx (dict): Dict with items (class_name, class_index). 54 | samples (list): List of (sample path, class_index) tuples 55 | targets (list): The class_index value for each image in the dataset 56 | """ 57 | 58 | def __init__(self, 59 | root, 60 | loader, 61 | extensions=None, 62 | transform=None, 63 | target_transform=None, 64 | data_list=None, 65 | class_to_idx=None, 66 | is_valid_file=None): 67 | super(CustomDatasetFolder, self).__init__(root) 68 | self.transform = transform 69 | self.target_transform = target_transform 70 | classes, _class_to_idx = self._find_classes(self.root) 71 | class_to_idx = _class_to_idx if class_to_idx is None else class_to_idx 72 | 73 | samples = make_custom_dataset( 74 | self.root, data_list, 75 | class_to_idx) 76 | if len(samples) == 0: 77 | raise (RuntimeError("Found 0 files in subfolders of: " + 78 | self.root + "\n" 79 | "Supported extensions are: " + 80 | ",".join(extensions))) 81 | 82 | self.loader = loader 83 | self.extensions = extensions 84 | 85 | self.classes = classes 86 | self.class_to_idx = class_to_idx 87 | self.samples = samples 88 | self.targets = [s[1] for s in samples] 89 | 90 | def _find_classes(self, dir): 91 | """ 92 | Finds the class folders in a dataset. 93 | Args: 94 | dir (string): Root directory path. 95 | Returns: 96 | tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. 97 | Ensures: 98 | No class is a subdirectory of another. 99 | """ 100 | if sys.version_info >= (3, 5): 101 | # Faster and available in Python 3.5 and above 102 | classes = [d.name for d in os.scandir(dir) if d.is_dir()] 103 | else: 104 | classes = [ 105 | d for d in os.listdir(dir) 106 | if os.path.isdir(os.path.join(dir, d)) 107 | ] 108 | classes.sort() 109 | class_to_idx = {classes[i]: i for i in range(len(classes))} 110 | return classes, class_to_idx 111 | 112 | def __getitem__(self, index): 113 | """ 114 | Args: 115 | index (int): Index 116 | Returns: 117 | tuple: (sample, target) where target is class_index of the target class. 118 | """ 119 | path, target = self.samples[index] 120 | sample = self.loader(path) 121 | if self.transform is not None: 122 | sample = self.transform(sample) 123 | if self.target_transform is not None: 124 | target = self.target_transform(target) 125 | return sample, target, path 126 | 127 | def __len__(self): 128 | return len(self.samples) 129 | 130 | 131 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', 132 | '.tiff', '.webp') 133 | 134 | 135 | def pil_loader(path): 136 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 137 | with open(path, 'rb') as f: 138 | img = Image.open(f) 139 | return img.convert('RGB') 140 | 141 | 142 | def accimage_loader(path): 143 | import accimage 144 | try: 145 | return accimage.Image(path) 146 | except IOError: 147 | # Potentially a decoding problem, fall back to PIL.Image 148 | return pil_loader(path) 149 | 150 | 151 | def default_loader(path): 152 | from torchvision import get_image_backend 153 | if get_image_backend() == 'accimage': 154 | return accimage_loader(path) 155 | else: 156 | return pil_loader(path) 157 | 158 | 159 | class CustomImageFolder(CustomDatasetFolder): 160 | """A generic data loader where the images are arranged in this way: :: 161 | root/dog/xxx.png 162 | root/dog/xxy.png 163 | root/dog/xxz.png 164 | root/cat/123.png 165 | root/cat/nsdf3.png 166 | root/cat/asd932_.png 167 | Args: 168 | root (string): Root directory path. 169 | transform (callable, optional): A function/transform that takes in an PIL image 170 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 171 | target_transform (callable, optional): A function/transform that takes in the 172 | target and transforms it. 173 | loader (callable, optional): A function to load an image given its path. 174 | is_valid_file (callable, optional): A function that takes path of an Image file 175 | and check if the file is a valid_file (used to check of corrupt files) 176 | Attributes: 177 | classes (list): List of the class names. 178 | class_to_idx (dict): Dict with items (class_name, class_index). 179 | imgs (list): List of (image path, class_index) tuples 180 | """ 181 | 182 | def __init__(self, 183 | root, 184 | transform=None, 185 | target_transform=None, 186 | loader=default_loader, 187 | data_list=None, 188 | class_to_idx=None, 189 | is_valid_file=None): 190 | super(CustomImageFolder, 191 | self).__init__(root, 192 | loader, 193 | IMG_EXTENSIONS if is_valid_file is None else None, 194 | transform=transform, 195 | target_transform=target_transform, 196 | data_list=data_list, 197 | class_to_idx=class_to_idx, 198 | is_valid_file=is_valid_file) 199 | 200 | self.imgs = self.samples 201 | -------------------------------------------------------------------------------- /apt/OODRB/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .dataset import CustomImageFolder 4 | 5 | from addict import Dict 6 | 7 | DATASETS = Dict() 8 | 9 | DATASETS[None].data_dir = 'imagenet' 10 | DATASETS[None].data_list = 'image_ids/imagenet_test_image_ids.txt' 11 | 12 | DATASETS.R.data_dir = 'imagenet-r' 13 | DATASETS.R.data_list = 'image_ids/imagenet-r_image_ids.txt' 14 | 15 | DATASETS.A.data_dir = 'imagenet-a' 16 | DATASETS.A.data_list = 'image_ids/imagenet-a_image_ids.txt' 17 | 18 | DATASETS.v2.data_dir = 'imagenetv2-matched-frequency-format-val' 19 | DATASETS.v2.data_list = 'image_ids/imagenet-v2_image_ids.txt' 20 | 21 | DATASETS.ON.data_dir = 'objectnet' 22 | DATASETS.ON.data_list = 'image_ids/objectnet_image_ids.txt' 23 | 24 | 25 | class ImageNet(CustomImageFolder): 26 | VARIANTS = list(DATASETS.keys()) 27 | VARIANTS.remove(None) 28 | 29 | def __init__(self, root, variant=None, split=None, transform=None, target_transform=None): 30 | data_dir = os.path.join(root, DATASETS[variant].data_dir) 31 | split = split if variant == None else None # no split for non-base variants 32 | if split is not None: data_dir = os.path.join(data_dir, split) 33 | assert os.path.isdir(data_dir) 34 | 35 | if variant in ['R', 'A', 'ON']: 36 | # remap targets to ImageNet class idx 37 | class_to_idx = ImageNet(root, split='val').class_to_idx 38 | elif variant == 'v2': 39 | class_to_idx = {str(i):i for i in range(1000)} 40 | else: 41 | class_to_idx = None 42 | 43 | super().__init__(data_dir, 44 | transform=transform, 45 | target_transform=target_transform, 46 | class_to_idx=class_to_idx, 47 | data_list=DATASETS[variant].data_list) 48 | -------------------------------------------------------------------------------- /apt/backbone/README.md: -------------------------------------------------------------------------------- 1 | # Put pre-trained robust backbone here. 2 | 3 | We adopt as backbone the pre-trained adversarially-robust CLIP models from [TeCoA](https://github.com/cvlab-columbia/ZSRobust4FoundationModel). The used pre-trained weights are provided [here](https://emckclac-my.sharepoint.com/:f:/g/personal/k19010102_kcl_ac_uk/EmZ98eFLv71FqQyqPLvWNTkBYNAKPyx_wYEDjNPx7smKCA?e=8AB51S). To run the code, the pre-trained backbone models should be placed under this directory. The code currently supports two architectures: ViT-B/32 (named `vitb32`) and ResNet50 (named `rn50`). Taking an example of tuning ViT-B/32 at epsilon=4/255, the name of checkpoint is `vitb32_eps4.pth.tar`. Note that our code can be easily adapted to load other pre-trained models as backbone. -------------------------------------------------------------------------------- /apt/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /apt/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreeLLi/APT/c2e99e09329b9a2be87d02e5c7ae02092d2ac95d/apt/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /apt/clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | try: 16 | from torchvision.transforms import InterpolationMode 17 | BICUBIC = InterpolationMode.BICUBIC 18 | except ImportError: 19 | BICUBIC = Image.BICUBIC 20 | 21 | 22 | if torch.__version__.split(".") < ["1", "7", "1"]: 23 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 24 | 25 | 26 | __all__ = ["available_models", "load", "tokenize"] 27 | _tokenizer = _Tokenizer() 28 | 29 | _MODELS = { 30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 34 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 35 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 36 | } 37 | 38 | 39 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 40 | os.makedirs(root, exist_ok=True) 41 | filename = os.path.basename(url) 42 | 43 | expected_sha256 = url.split("/")[-2] 44 | download_target = os.path.join(root, filename) 45 | 46 | if os.path.exists(download_target) and not os.path.isfile(download_target): 47 | raise RuntimeError(f"{download_target} exists and is not a regular file") 48 | 49 | if os.path.isfile(download_target): 50 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 51 | return download_target 52 | else: 53 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 54 | 55 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 56 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 57 | while True: 58 | buffer = source.read(8192) 59 | if not buffer: 60 | break 61 | 62 | output.write(buffer) 63 | loop.update(len(buffer)) 64 | 65 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 66 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 67 | 68 | return download_target 69 | 70 | 71 | def _transform(n_px): 72 | return Compose([ 73 | Resize(n_px, interpolation=BICUBIC), 74 | CenterCrop(n_px), 75 | lambda image: image.convert("RGB"), 76 | ToTensor(), 77 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 78 | ]) 79 | 80 | 81 | def available_models() -> List[str]: 82 | """Returns the names of available CLIP models""" 83 | return list(_MODELS.keys()) 84 | 85 | 86 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False): 87 | """Load a CLIP model 88 | 89 | Parameters 90 | ---------- 91 | name : str 92 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 93 | 94 | device : Union[str, torch.device] 95 | The device to put the loaded model 96 | 97 | jit : bool 98 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 99 | 100 | Returns 101 | ------- 102 | model : torch.nn.Module 103 | The CLIP model 104 | 105 | preprocess : Callable[[PIL.Image], torch.Tensor] 106 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 107 | """ 108 | if name in _MODELS: 109 | model_path = _download(_MODELS[name]) 110 | elif os.path.isfile(name): 111 | model_path = name 112 | else: 113 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 114 | 115 | try: 116 | # loading JIT archive 117 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 118 | state_dict = None 119 | except RuntimeError: 120 | # loading saved state dict 121 | if jit: 122 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 123 | jit = False 124 | state_dict = torch.load(model_path, map_location="cpu") 125 | 126 | if not jit: 127 | model = build_model(state_dict or model.state_dict()).to(device) 128 | if str(device) == "cpu": 129 | model.float() 130 | return model, _transform(model.visual.input_resolution) 131 | 132 | # patch the device names 133 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 134 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 135 | 136 | def patch_device(module): 137 | try: 138 | graphs = [module.graph] if hasattr(module, "graph") else [] 139 | except RuntimeError: 140 | graphs = [] 141 | 142 | if hasattr(module, "forward1"): 143 | graphs.append(module.forward1.graph) 144 | 145 | for graph in graphs: 146 | for node in graph.findAllNodes("prim::Constant"): 147 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 148 | node.copyAttributes(device_node) 149 | 150 | model.apply(patch_device) 151 | patch_device(model.encode_image) 152 | patch_device(model.encode_text) 153 | 154 | # patch dtype to float32 on CPU 155 | if str(device) == "cpu": 156 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 157 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 158 | float_node = float_input.node() 159 | 160 | def patch_float(module): 161 | try: 162 | graphs = [module.graph] if hasattr(module, "graph") else [] 163 | except RuntimeError: 164 | graphs = [] 165 | 166 | if hasattr(module, "forward1"): 167 | graphs.append(module.forward1.graph) 168 | 169 | for graph in graphs: 170 | for node in graph.findAllNodes("aten::to"): 171 | inputs = list(node.inputs()) 172 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 173 | if inputs[i].node()["value"] == 5: 174 | inputs[i].node().copyAttributes(float_node) 175 | 176 | model.apply(patch_float) 177 | patch_float(model.encode_image) 178 | patch_float(model.encode_text) 179 | 180 | model.float() 181 | 182 | return model, _transform(model.input_resolution.item()) 183 | 184 | 185 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 186 | """ 187 | Returns the tokenized representation of given input string(s) 188 | 189 | Parameters 190 | ---------- 191 | texts : Union[str, List[str]] 192 | An input string or a list of input strings to tokenize 193 | 194 | context_length : int 195 | The context length to use; all CLIP models use 77 as the context length 196 | 197 | truncate: bool 198 | Whether to truncate the text in case its encoding is longer than the context length 199 | 200 | Returns 201 | ------- 202 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 203 | """ 204 | if isinstance(texts, str): 205 | texts = [texts] 206 | 207 | sot_token = _tokenizer.encoder["<|startoftext|>"] 208 | eot_token = _tokenizer.encoder["<|endoftext|>"] 209 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 210 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 211 | 212 | for i, tokens in enumerate(all_tokens): 213 | if len(tokens) > context_length: 214 | if truncate: 215 | tokens = tokens[:context_length] 216 | tokens[-1] = eot_token 217 | else: 218 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 219 | result[i, :len(tokens)] = torch.tensor(tokens) 220 | 221 | return result 222 | -------------------------------------------------------------------------------- /apt/clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | 20 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 24 | 25 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 26 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 27 | 28 | self.relu = nn.ReLU(inplace=True) 29 | self.downsample = None 30 | self.stride = stride 31 | 32 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 33 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 34 | self.downsample = nn.Sequential(OrderedDict([ 35 | ("-1", nn.AvgPool2d(stride)), 36 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 37 | ("1", nn.BatchNorm2d(planes * self.expansion)) 38 | ])) 39 | 40 | def forward(self, x: torch.Tensor): 41 | identity = x 42 | 43 | out = self.relu(self.bn1(self.conv1(x))) 44 | out = self.relu(self.bn2(self.conv2(out))) 45 | out = self.avgpool(out) 46 | out = self.bn3(self.conv3(out)) 47 | 48 | if self.downsample is not None: 49 | identity = self.downsample(x) 50 | 51 | out += identity 52 | out = self.relu(out) 53 | return out 54 | 55 | 56 | class AttentionPool2d(nn.Module): 57 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 58 | super().__init__() 59 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 60 | self.k_proj = nn.Linear(embed_dim, embed_dim) 61 | self.q_proj = nn.Linear(embed_dim, embed_dim) 62 | self.v_proj = nn.Linear(embed_dim, embed_dim) 63 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 64 | self.num_heads = num_heads 65 | 66 | def forward(self, x): 67 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 68 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 69 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 70 | x, _ = F.multi_head_attention_forward( 71 | query=x, key=x, value=x, 72 | embed_dim_to_check=x.shape[-1], 73 | num_heads=self.num_heads, 74 | q_proj_weight=self.q_proj.weight, 75 | k_proj_weight=self.k_proj.weight, 76 | v_proj_weight=self.v_proj.weight, 77 | in_proj_weight=None, 78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 79 | bias_k=None, 80 | bias_v=None, 81 | add_zero_attn=False, 82 | dropout_p=0, 83 | out_proj_weight=self.c_proj.weight, 84 | out_proj_bias=self.c_proj.bias, 85 | use_separate_proj_weight=True, 86 | training=self.training, 87 | need_weights=False 88 | ) 89 | 90 | return x[0] 91 | 92 | 93 | class ModifiedResNet(nn.Module): 94 | """ 95 | A ResNet class that is similar to torchvision's but contains the following changes: 96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 98 | - The final pooling layer is a QKV attention instead of an average pool 99 | """ 100 | 101 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 102 | super().__init__() 103 | self.output_dim = output_dim 104 | self.input_resolution = input_resolution 105 | 106 | # the 3-layer stem 107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 108 | self.bn1 = nn.BatchNorm2d(width // 2) 109 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 110 | self.bn2 = nn.BatchNorm2d(width // 2) 111 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 112 | self.bn3 = nn.BatchNorm2d(width) 113 | self.avgpool = nn.AvgPool2d(2) 114 | self.relu = nn.ReLU(inplace=True) 115 | 116 | # residual layers 117 | self._inplanes = width # this is a *mutable* variable used during construction 118 | self.layer1 = self._make_layer(width, layers[0]) 119 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 120 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 121 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 122 | 123 | embed_dim = width * 32 # the ResNet feature dimension 124 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 125 | 126 | def _make_layer(self, planes, blocks, stride=1): 127 | layers = [Bottleneck(self._inplanes, planes, stride)] 128 | 129 | self._inplanes = planes * Bottleneck.expansion 130 | for _ in range(1, blocks): 131 | layers.append(Bottleneck(self._inplanes, planes)) 132 | 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | def stem(x): 137 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 138 | x = self.relu(bn(conv(x))) 139 | x = self.avgpool(x) 140 | return x 141 | 142 | x = x.type(self.conv1.weight.dtype) 143 | x = stem(x) 144 | x = self.layer1(x) 145 | x = self.layer2(x) 146 | x = self.layer3(x) 147 | x = self.layer4(x) 148 | x = self.attnpool(x) 149 | 150 | return x 151 | 152 | 153 | class LayerNorm(nn.LayerNorm): 154 | """Subclass torch's LayerNorm to handle fp16.""" 155 | 156 | def forward(self, x: torch.Tensor): 157 | orig_type = x.dtype 158 | ret = super().forward(x.type(torch.float32)) 159 | return ret.type(orig_type) 160 | 161 | 162 | class QuickGELU(nn.Module): 163 | def forward(self, x: torch.Tensor): 164 | return x * torch.sigmoid(1.702 * x) 165 | 166 | 167 | class ResidualAttentionBlock(nn.Module): 168 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 169 | super().__init__() 170 | 171 | self.attn = nn.MultiheadAttention(d_model, n_head) 172 | self.ln_1 = LayerNorm(d_model) 173 | self.mlp = nn.Sequential(OrderedDict([ 174 | ("c_fc", nn.Linear(d_model, d_model * 4)), 175 | ("gelu", QuickGELU()), 176 | ("c_proj", nn.Linear(d_model * 4, d_model)) 177 | ])) 178 | self.ln_2 = LayerNorm(d_model) 179 | self.attn_mask = attn_mask 180 | 181 | def attention(self, x: torch.Tensor): 182 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 183 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 184 | 185 | def forward(self, x: torch.Tensor): 186 | x = x + self.attention(self.ln_1(x)) 187 | x = x + self.mlp(self.ln_2(x)) 188 | return x 189 | 190 | 191 | class Transformer(nn.Module): 192 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 193 | super().__init__() 194 | self.width = width 195 | self.layers = layers 196 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 197 | 198 | def forward(self, x: torch.Tensor): 199 | return self.resblocks(x) 200 | 201 | 202 | class VisionTransformer(nn.Module): 203 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 204 | super().__init__() 205 | self.input_resolution = input_resolution 206 | self.output_dim = output_dim 207 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 208 | 209 | scale = width ** -0.5 210 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 211 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 212 | self.ln_pre = LayerNorm(width) 213 | 214 | self.transformer = Transformer(width, layers, heads) 215 | 216 | self.ln_post = LayerNorm(width) 217 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 218 | 219 | def forward(self, x: torch.Tensor): 220 | x = self.conv1(x) # shape = [*, width, grid, grid] 221 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 222 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 223 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 224 | x = x + self.positional_embedding.to(x.dtype) 225 | x = self.ln_pre(x) 226 | 227 | x = x.permute(1, 0, 2) # NLD -> LND 228 | x = self.transformer(x) 229 | x = x.permute(1, 0, 2) # LND -> NLD 230 | 231 | x = self.ln_post(x[:, 0, :]) 232 | 233 | if self.proj is not None: 234 | x = x @ self.proj 235 | 236 | return x 237 | 238 | 239 | class CLIP(nn.Module): 240 | def __init__(self, 241 | embed_dim: int, 242 | # vision 243 | image_resolution: int, 244 | vision_layers: Union[Tuple[int, int, int, int], int], 245 | vision_width: int, 246 | vision_patch_size: int, 247 | # text 248 | context_length: int, 249 | vocab_size: int, 250 | transformer_width: int, 251 | transformer_heads: int, 252 | transformer_layers: int 253 | ): 254 | super().__init__() 255 | 256 | self.context_length = context_length 257 | 258 | if isinstance(vision_layers, (tuple, list)): 259 | vision_heads = vision_width * 32 // 64 260 | self.visual = ModifiedResNet( 261 | layers=vision_layers, 262 | output_dim=embed_dim, 263 | heads=vision_heads, 264 | input_resolution=image_resolution, 265 | width=vision_width 266 | ) 267 | else: 268 | vision_heads = vision_width // 64 269 | self.visual = VisionTransformer( 270 | input_resolution=image_resolution, 271 | patch_size=vision_patch_size, 272 | width=vision_width, 273 | layers=vision_layers, 274 | heads=vision_heads, 275 | output_dim=embed_dim 276 | ) 277 | 278 | self.transformer = Transformer( 279 | width=transformer_width, 280 | layers=transformer_layers, 281 | heads=transformer_heads, 282 | attn_mask=self.build_attention_mask() 283 | ) 284 | 285 | self.vocab_size = vocab_size 286 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 287 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 288 | self.ln_final = LayerNorm(transformer_width) 289 | 290 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 291 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 292 | 293 | self.initialize_parameters() 294 | 295 | def initialize_parameters(self): 296 | nn.init.normal_(self.token_embedding.weight, std=0.02) 297 | nn.init.normal_(self.positional_embedding, std=0.01) 298 | 299 | if isinstance(self.visual, ModifiedResNet): 300 | if self.visual.attnpool is not None: 301 | std = self.visual.attnpool.c_proj.in_features ** -0.5 302 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 303 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 304 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 305 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 306 | 307 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 308 | for name, param in resnet_block.named_parameters(): 309 | if name.endswith("bn3.weight"): 310 | nn.init.zeros_(param) 311 | 312 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 313 | attn_std = self.transformer.width ** -0.5 314 | fc_std = (2 * self.transformer.width) ** -0.5 315 | for block in self.transformer.resblocks: 316 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 317 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 318 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 319 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 320 | 321 | if self.text_projection is not None: 322 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 323 | 324 | def build_attention_mask(self): 325 | # lazily create causal attention mask, with full attention between the vision tokens 326 | # pytorch uses additive attention mask; fill with -inf 327 | mask = torch.empty(self.context_length, self.context_length) 328 | mask.fill_(float("-inf")) 329 | mask.triu_(1) # zero out the lower diagonal 330 | return mask 331 | 332 | @property 333 | def dtype(self): 334 | return self.visual.conv1.weight.dtype 335 | 336 | def encode_image(self, image): 337 | return self.visual(image.type(self.dtype)) 338 | 339 | def encode_text(self, text): 340 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 341 | 342 | x = x + self.positional_embedding.type(self.dtype) 343 | x = x.permute(1, 0, 2) # NLD -> LND 344 | x = self.transformer(x) 345 | x = x.permute(1, 0, 2) # LND -> NLD 346 | x = self.ln_final(x).type(self.dtype) 347 | 348 | # x.shape = [batch_size, n_ctx, transformer.width] 349 | # take features from the eot embedding (eot_token is the highest number in each sequence) 350 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 351 | 352 | return x 353 | 354 | def forward(self, image, text): 355 | image_features = self.encode_image(image) 356 | text_features = self.encode_text(text) 357 | 358 | # normalized features 359 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 360 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 361 | 362 | # cosine similarity as logits 363 | logit_scale = self.logit_scale.exp() 364 | logits_per_image = logit_scale * image_features @ text_features.t() 365 | logits_per_text = logit_scale * text_features @ image_features.t() 366 | 367 | # shape = [global_batch_size, global_batch_size] 368 | return logits_per_image, logits_per_text 369 | 370 | 371 | def convert_weights(model: nn.Module): 372 | """Convert applicable model parameters to fp16""" 373 | 374 | def _convert_weights_to_fp16(l): 375 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 376 | l.weight.data = l.weight.data.half() 377 | if l.bias is not None: 378 | l.bias.data = l.bias.data.half() 379 | 380 | if isinstance(l, nn.MultiheadAttention): 381 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 382 | tensor = getattr(l, attr) 383 | if tensor is not None: 384 | tensor.data = tensor.data.half() 385 | 386 | for name in ["text_projection", "proj"]: 387 | if hasattr(l, name): 388 | attr = getattr(l, name) 389 | if attr is not None: 390 | attr.data = attr.data.half() 391 | 392 | model.apply(_convert_weights_to_fp16) 393 | 394 | 395 | def build_model(state_dict: dict): 396 | vit = "visual.proj" in state_dict 397 | 398 | if vit: 399 | vision_width = state_dict["visual.conv1.weight"].shape[0] 400 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 401 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 402 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 403 | image_resolution = vision_patch_size * grid_size 404 | else: 405 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 406 | vision_layers = tuple(counts) 407 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 408 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 409 | vision_patch_size = None 410 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 411 | image_resolution = output_width * 32 412 | 413 | embed_dim = state_dict["text_projection"].shape[1] 414 | context_length = state_dict["positional_embedding"].shape[0] 415 | vocab_size = state_dict["token_embedding.weight"].shape[0] 416 | transformer_width = state_dict["ln_final.weight"].shape[0] 417 | transformer_heads = transformer_width // 64 418 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 419 | 420 | model = CLIP( 421 | embed_dim, 422 | image_resolution, vision_layers, vision_width, vision_patch_size, 423 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 424 | ) 425 | 426 | for key in ["input_resolution", "context_length", "vocab_size"]: 427 | if key in state_dict: 428 | del state_dict[key] 429 | 430 | convert_weights(model) 431 | model.load_state_dict(state_dict) 432 | return model.eval() 433 | -------------------------------------------------------------------------------- /apt/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /apt/configs/datasets/caltech101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "Caltech101" 3 | -------------------------------------------------------------------------------- /apt/configs/datasets/dtd.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "DescribableTextures" 3 | -------------------------------------------------------------------------------- /apt/configs/datasets/eurosat.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "EuroSAT" 3 | -------------------------------------------------------------------------------- /apt/configs/datasets/fgvc_aircraft.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "FGVCAircraft" 3 | -------------------------------------------------------------------------------- /apt/configs/datasets/food101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "Food101" 3 | -------------------------------------------------------------------------------- /apt/configs/datasets/imagenet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNet" 3 | -------------------------------------------------------------------------------- /apt/configs/datasets/imagenet_a.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetA" 3 | -------------------------------------------------------------------------------- /apt/configs/datasets/imagenet_r.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetR" 3 | -------------------------------------------------------------------------------- /apt/configs/datasets/imagenet_sketch.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetSketch" 3 | -------------------------------------------------------------------------------- /apt/configs/datasets/imagenetv2.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetV2" 3 | -------------------------------------------------------------------------------- /apt/configs/datasets/oxford_flowers.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "OxfordFlowers" -------------------------------------------------------------------------------- /apt/configs/datasets/oxford_pets.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "OxfordPets" -------------------------------------------------------------------------------- /apt/configs/datasets/stanford_cars.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "StanfordCars" 3 | -------------------------------------------------------------------------------- /apt/configs/datasets/sun397.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "SUN397" 3 | -------------------------------------------------------------------------------- /apt/configs/datasets/ucf101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "UCF101" 3 | -------------------------------------------------------------------------------- /apt/configs/trainers/APT/rn101.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN101" -------------------------------------------------------------------------------- /apt/configs/trainers/APT/rn101_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN101" -------------------------------------------------------------------------------- /apt/configs/trainers/APT/rn50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | ROBUST: true -------------------------------------------------------------------------------- /apt/configs/trainers/APT/rn50_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" 34 | -------------------------------------------------------------------------------- /apt/configs/trainers/APT/rn50_ep100.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | ROBUST: true -------------------------------------------------------------------------------- /apt/configs/trainers/APT/rn50_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | ROBUST: true -------------------------------------------------------------------------------- /apt/configs/trainers/APT/rn50_ep50_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" -------------------------------------------------------------------------------- /apt/configs/trainers/APT/rn50_val.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 200 4 | TEST: 5 | BATCH_SIZE: 200 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | MODEL: 16 | BACKBONE: 17 | NAME: "RN50" -------------------------------------------------------------------------------- /apt/configs/trainers/APT/vit_b16.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /apt/configs/trainers/APT/vit_b16_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" 34 | -------------------------------------------------------------------------------- /apt/configs/trainers/APT/vit_b16_ep100.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /apt/configs/trainers/APT/vit_b16_ep100_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" 34 | -------------------------------------------------------------------------------- /apt/configs/trainers/APT/vit_b16_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /apt/configs/trainers/APT/vit_b16_ep50_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" 34 | -------------------------------------------------------------------------------- /apt/configs/trainers/APT/vit_b32.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/32" 30 | ROBUST: true -------------------------------------------------------------------------------- /apt/configs/trainers/APT/vit_b32_ep100.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/32" 30 | ROBUST: true -------------------------------------------------------------------------------- /apt/configs/trainers/APT/vit_b32_ep20.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 20 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/32" 30 | ROBUST: true -------------------------------------------------------------------------------- /apt/configs/trainers/APT/vit_b32_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/32" 30 | ROBUST: true -------------------------------------------------------------------------------- /apt/configs/trainers/APT/vit_b32_st.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/32" 30 | ROBUST: false -------------------------------------------------------------------------------- /apt/configs/trainers/CoOp/rn101.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN101" -------------------------------------------------------------------------------- /apt/configs/trainers/CoOp/rn101_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN101" -------------------------------------------------------------------------------- /apt/configs/trainers/CoOp/rn50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" -------------------------------------------------------------------------------- /apt/configs/trainers/CoOp/rn50_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" 34 | -------------------------------------------------------------------------------- /apt/configs/trainers/CoOp/rn50_ep100.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | -------------------------------------------------------------------------------- /apt/configs/trainers/CoOp/rn50_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" -------------------------------------------------------------------------------- /apt/configs/trainers/CoOp/rn50_ep50_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" -------------------------------------------------------------------------------- /apt/configs/trainers/CoOp/rn50_val.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 200 4 | TEST: 5 | BATCH_SIZE: 200 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | MODEL: 16 | BACKBONE: 17 | NAME: "RN50" -------------------------------------------------------------------------------- /apt/configs/trainers/CoOp/vit_b16.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /apt/configs/trainers/CoOp/vit_b16_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" 34 | -------------------------------------------------------------------------------- /apt/configs/trainers/CoOp/vit_b16_ep100.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /apt/configs/trainers/CoOp/vit_b16_ep100_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" 34 | -------------------------------------------------------------------------------- /apt/configs/trainers/CoOp/vit_b16_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /apt/configs/trainers/CoOp/vit_b16_ep50_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" 34 | -------------------------------------------------------------------------------- /apt/configs/trainers/CoOp/vit_b32.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/32" -------------------------------------------------------------------------------- /apt/configs/trainers/CoOp/vit_b32_ep20.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 20 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/32" -------------------------------------------------------------------------------- /apt/configs/trainers/CoOp/vit_b32_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/32" -------------------------------------------------------------------------------- /apt/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreeLLi/APT/c2e99e09329b9a2be87d02e5c7ae02092d2ac95d/apt/datasets/__init__.py -------------------------------------------------------------------------------- /apt/datasets/caltech101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | from .dtd import DescribableTextures as DTD 9 | 10 | IGNORED = ["BACKGROUND_Google", "Faces_easy"] 11 | NEW_CNAMES = { 12 | "airplanes": "airplane", 13 | "Faces": "face", 14 | "Leopards": "leopard", 15 | "Motorbikes": "motorbike", 16 | } 17 | 18 | 19 | @DATASET_REGISTRY.register() 20 | class Caltech101(DatasetBase): 21 | 22 | dataset_dir = "caltech-101" 23 | 24 | def __init__(self, cfg): 25 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 26 | self.dataset_dir = os.path.join(root, self.dataset_dir) 27 | self.image_dir = os.path.join(self.dataset_dir, "101_ObjectCategories") 28 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_Caltech101.json") 29 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 30 | mkdir_if_missing(self.split_fewshot_dir) 31 | 32 | if os.path.exists(self.split_path): 33 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 34 | else: 35 | train, val, test = DTD.read_and_split_data(self.image_dir, ignored=IGNORED, new_cnames=NEW_CNAMES) 36 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 37 | 38 | num_shots = cfg.DATASET.NUM_SHOTS 39 | if num_shots >= 1: 40 | seed = cfg.SEED 41 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 42 | 43 | if os.path.exists(preprocessed): 44 | print(f"Loading preprocessed few-shot data from {preprocessed}") 45 | with open(preprocessed, "rb") as file: 46 | data = pickle.load(file) 47 | train, val = data["train"], data["val"] 48 | else: 49 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 50 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 51 | data = {"train": train, "val": val} 52 | print(f"Saving preprocessed few-shot data to {preprocessed}") 53 | with open(preprocessed, "wb") as file: 54 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 55 | 56 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 57 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 58 | 59 | super().__init__(train_x=train, val=val, test=test) 60 | -------------------------------------------------------------------------------- /apt/datasets/dtd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | 5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 6 | from dassl.utils import listdir_nohidden, mkdir_if_missing 7 | 8 | from .oxford_pets import OxfordPets 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class DescribableTextures(DatasetBase): 13 | 14 | dataset_dir = "dtd" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "images") 20 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_DescribableTextures.json") 21 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 22 | mkdir_if_missing(self.split_fewshot_dir) 23 | 24 | if os.path.exists(self.split_path): 25 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 26 | else: 27 | train, val, test = self.read_and_split_data(self.image_dir) 28 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 29 | 30 | num_shots = cfg.DATASET.NUM_SHOTS 31 | if num_shots >= 1: 32 | seed = cfg.SEED 33 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 34 | 35 | if os.path.exists(preprocessed): 36 | print(f"Loading preprocessed few-shot data from {preprocessed}") 37 | with open(preprocessed, "rb") as file: 38 | data = pickle.load(file) 39 | train, val = data["train"], data["val"] 40 | else: 41 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 42 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 43 | data = {"train": train, "val": val} 44 | print(f"Saving preprocessed few-shot data to {preprocessed}") 45 | with open(preprocessed, "wb") as file: 46 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 47 | 48 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 49 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 50 | 51 | super().__init__(train_x=train, val=val, test=test) 52 | 53 | @staticmethod 54 | def read_and_split_data(image_dir, p_trn=0.5, p_val=0.2, ignored=[], new_cnames=None): 55 | # The data are supposed to be organized into the following structure 56 | # ============= 57 | # images/ 58 | # dog/ 59 | # cat/ 60 | # horse/ 61 | # ============= 62 | categories = listdir_nohidden(image_dir) 63 | categories = [c for c in categories if c not in ignored] 64 | categories.sort() 65 | 66 | p_tst = 1 - p_trn - p_val 67 | print(f"Splitting into {p_trn:.0%} train, {p_val:.0%} val, and {p_tst:.0%} test") 68 | 69 | def _collate(ims, y, c): 70 | items = [] 71 | for im in ims: 72 | item = Datum(impath=im, label=y, classname=c) # is already 0-based 73 | items.append(item) 74 | return items 75 | 76 | train, val, test = [], [], [] 77 | for label, category in enumerate(categories): 78 | category_dir = os.path.join(image_dir, category) 79 | images = listdir_nohidden(category_dir) 80 | images = [os.path.join(category_dir, im) for im in images] 81 | random.shuffle(images) 82 | n_total = len(images) 83 | n_train = round(n_total * p_trn) 84 | n_val = round(n_total * p_val) 85 | n_test = n_total - n_train - n_val 86 | assert n_train > 0 and n_val > 0 and n_test > 0 87 | 88 | if new_cnames is not None and category in new_cnames: 89 | category = new_cnames[category] 90 | 91 | train.extend(_collate(images[:n_train], label, category)) 92 | val.extend(_collate(images[n_train : n_train + n_val], label, category)) 93 | test.extend(_collate(images[n_train + n_val :], label, category)) 94 | 95 | return train, val, test 96 | -------------------------------------------------------------------------------- /apt/datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | from .dtd import DescribableTextures as DTD 9 | 10 | NEW_CNAMES = { 11 | "AnnualCrop": "Annual Crop Land", 12 | "Forest": "Forest", 13 | "HerbaceousVegetation": "Herbaceous Vegetation Land", 14 | "Highway": "Highway or Road", 15 | "Industrial": "Industrial Buildings", 16 | "Pasture": "Pasture Land", 17 | "PermanentCrop": "Permanent Crop Land", 18 | "Residential": "Residential Buildings", 19 | "River": "River", 20 | "SeaLake": "Sea or Lake", 21 | } 22 | 23 | 24 | @DATASET_REGISTRY.register() 25 | class EuroSAT(DatasetBase): 26 | 27 | dataset_dir = "eurosat" 28 | 29 | def __init__(self, cfg): 30 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 31 | self.dataset_dir = os.path.join(root, self.dataset_dir) 32 | self.image_dir = os.path.join(self.dataset_dir, "2750") 33 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_EuroSAT.json") 34 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 35 | mkdir_if_missing(self.split_fewshot_dir) 36 | 37 | if os.path.exists(self.split_path): 38 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 39 | else: 40 | train, val, test = DTD.read_and_split_data(self.image_dir, new_cnames=NEW_CNAMES) 41 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 42 | 43 | num_shots = cfg.DATASET.NUM_SHOTS 44 | if num_shots >= 1: 45 | seed = cfg.SEED 46 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 47 | 48 | if os.path.exists(preprocessed): 49 | print(f"Loading preprocessed few-shot data from {preprocessed}") 50 | with open(preprocessed, "rb") as file: 51 | data = pickle.load(file) 52 | train, val = data["train"], data["val"] 53 | else: 54 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 55 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 56 | data = {"train": train, "val": val} 57 | print(f"Saving preprocessed few-shot data to {preprocessed}") 58 | with open(preprocessed, "wb") as file: 59 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 60 | 61 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 62 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 63 | 64 | super().__init__(train_x=train, val=val, test=test) 65 | 66 | def update_classname(self, dataset_old): 67 | dataset_new = [] 68 | for item_old in dataset_old: 69 | cname_old = item_old.classname 70 | cname_new = NEW_CLASSNAMES[cname_old] 71 | item_new = Datum(impath=item_old.impath, label=item_old.label, classname=cname_new) 72 | dataset_new.append(item_new) 73 | return dataset_new 74 | -------------------------------------------------------------------------------- /apt/datasets/fgvc_aircraft.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class FGVCAircraft(DatasetBase): 12 | 13 | dataset_dir = "fgvc_aircraft" 14 | 15 | def __init__(self, cfg): 16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 17 | self.dataset_dir = os.path.join(root, self.dataset_dir) 18 | self.image_dir = os.path.join(self.dataset_dir, "images") 19 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 20 | mkdir_if_missing(self.split_fewshot_dir) 21 | 22 | classnames = [] 23 | with open(os.path.join(self.dataset_dir, "variants.txt"), "r") as f: 24 | lines = f.readlines() 25 | for line in lines: 26 | classnames.append(line.strip()) 27 | cname2lab = {c: i for i, c in enumerate(classnames)} 28 | 29 | train = self.read_data(cname2lab, "images_variant_train.txt") 30 | val = self.read_data(cname2lab, "images_variant_val.txt") 31 | test = self.read_data(cname2lab, "images_variant_test.txt") 32 | 33 | num_shots = cfg.DATASET.NUM_SHOTS 34 | if num_shots >= 1: 35 | seed = cfg.SEED 36 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 37 | 38 | if os.path.exists(preprocessed): 39 | print(f"Loading preprocessed few-shot data from {preprocessed}") 40 | with open(preprocessed, "rb") as file: 41 | data = pickle.load(file) 42 | train, val = data["train"], data["val"] 43 | else: 44 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 45 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 46 | data = {"train": train, "val": val} 47 | print(f"Saving preprocessed few-shot data to {preprocessed}") 48 | with open(preprocessed, "wb") as file: 49 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 50 | 51 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 52 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 53 | 54 | super().__init__(train_x=train, val=val, test=test) 55 | 56 | def read_data(self, cname2lab, split_file): 57 | filepath = os.path.join(self.dataset_dir, split_file) 58 | items = [] 59 | 60 | with open(filepath, "r") as f: 61 | lines = f.readlines() 62 | for line in lines: 63 | line = line.strip().split(" ") 64 | imname = line[0] + ".jpg" 65 | classname = " ".join(line[1:]) 66 | impath = os.path.join(self.image_dir, imname) 67 | label = cname2lab[classname] 68 | item = Datum(impath=impath, label=label, classname=classname) 69 | items.append(item) 70 | 71 | return items 72 | -------------------------------------------------------------------------------- /apt/datasets/food101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | from .dtd import DescribableTextures as DTD 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class Food101(DatasetBase): 13 | 14 | dataset_dir = "food-101" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "images") 20 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_Food101.json") 21 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 22 | mkdir_if_missing(self.split_fewshot_dir) 23 | 24 | if os.path.exists(self.split_path): 25 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 26 | else: 27 | train, val, test = DTD.read_and_split_data(self.image_dir) 28 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 29 | 30 | num_shots = cfg.DATASET.NUM_SHOTS 31 | if num_shots >= 1: 32 | seed = cfg.SEED 33 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 34 | 35 | if os.path.exists(preprocessed): 36 | print(f"Loading preprocessed few-shot data from {preprocessed}") 37 | with open(preprocessed, "rb") as file: 38 | data = pickle.load(file) 39 | train, val = data["train"], data["val"] 40 | else: 41 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 42 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 43 | data = {"train": train, "val": val} 44 | print(f"Saving preprocessed few-shot data to {preprocessed}") 45 | with open(preprocessed, "wb") as file: 46 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 47 | 48 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 49 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 50 | 51 | super().__init__(train_x=train, val=val, test=test) 52 | -------------------------------------------------------------------------------- /apt/datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from collections import OrderedDict 4 | 5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 6 | from dassl.utils import listdir_nohidden, mkdir_if_missing 7 | 8 | from .oxford_pets import OxfordPets 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class ImageNet(DatasetBase): 13 | 14 | dataset_dir = "imagenet" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "images") 20 | self.preprocessed = os.path.join(self.dataset_dir, "preprocessed.pkl") 21 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 22 | mkdir_if_missing(self.split_fewshot_dir) 23 | 24 | if os.path.exists(self.preprocessed): 25 | with open(self.preprocessed, "rb") as f: 26 | preprocessed = pickle.load(f) 27 | train = preprocessed["train"] 28 | test = preprocessed["test"] 29 | else: 30 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 31 | classnames = self.read_classnames(text_file) 32 | train = self.read_data(classnames, "train") 33 | # Follow standard practice to perform evaluation on the val set 34 | # Also used as the val set (so evaluate the last-step model) 35 | test = self.read_data(classnames, "val") 36 | 37 | preprocessed = {"train": train, "test": test} 38 | with open(self.preprocessed, "wb") as f: 39 | pickle.dump(preprocessed, f, protocol=pickle.HIGHEST_PROTOCOL) 40 | 41 | num_shots = cfg.DATASET.NUM_SHOTS 42 | if num_shots >= 1: 43 | seed = cfg.SEED 44 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 45 | 46 | if os.path.exists(preprocessed): 47 | print(f"Loading preprocessed few-shot data from {preprocessed}") 48 | with open(preprocessed, "rb") as file: 49 | data = pickle.load(file) 50 | train = data["train"] 51 | else: 52 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 53 | data = {"train": train} 54 | print(f"Saving preprocessed few-shot data to {preprocessed}") 55 | with open(preprocessed, "wb") as file: 56 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 57 | 58 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 59 | train, test = OxfordPets.subsample_classes(train, test, subsample=subsample) 60 | 61 | super().__init__(train_x=train, val=test, test=test) 62 | 63 | @staticmethod 64 | def read_classnames(text_file): 65 | """Return a dictionary containing 66 | key-value pairs of : . 67 | """ 68 | classnames = OrderedDict() 69 | with open(text_file, "r") as f: 70 | lines = f.readlines() 71 | for line in lines: 72 | line = line.strip().split(" ") 73 | folder = line[0] 74 | classname = " ".join(line[1:]) 75 | classnames[folder] = classname 76 | return classnames 77 | 78 | def read_data(self, classnames, split_dir): 79 | split_dir = os.path.join(self.image_dir, split_dir) 80 | folders = sorted(f.name for f in os.scandir(split_dir) if f.is_dir()) 81 | items = [] 82 | 83 | for label, folder in enumerate(folders): 84 | imnames = listdir_nohidden(os.path.join(split_dir, folder)) 85 | classname = classnames[folder] 86 | for imname in imnames: 87 | impath = os.path.join(split_dir, folder, imname) 88 | item = Datum(impath=impath, label=label, classname=classname) 89 | items.append(item) 90 | 91 | return items 92 | -------------------------------------------------------------------------------- /apt/datasets/imagenet_a.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | TO_BE_IGNORED = ["README.txt"] 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class ImageNetA(DatasetBase): 13 | """ImageNet-A(dversarial). 14 | 15 | This dataset is used for testing only. 16 | """ 17 | 18 | dataset_dir = "imagenet-adversarial" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.image_dir = os.path.join(self.dataset_dir, "imagenet-a") 24 | 25 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 26 | classnames = ImageNet.read_classnames(text_file) 27 | 28 | data = self.read_data(classnames) 29 | 30 | super().__init__(train_x=data, test=data) 31 | 32 | def read_data(self, classnames): 33 | image_dir = self.image_dir 34 | folders = listdir_nohidden(image_dir, sort=True) 35 | folders = [f for f in folders if f not in TO_BE_IGNORED] 36 | items = [] 37 | 38 | for label, folder in enumerate(folders): 39 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 40 | classname = classnames[folder] 41 | for imname in imnames: 42 | impath = os.path.join(image_dir, folder, imname) 43 | item = Datum(impath=impath, label=label, classname=classname) 44 | items.append(item) 45 | 46 | return items 47 | -------------------------------------------------------------------------------- /apt/datasets/imagenet_r.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | TO_BE_IGNORED = ["README.txt"] 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class ImageNetR(DatasetBase): 13 | """ImageNet-R(endition). 14 | 15 | This dataset is used for testing only. 16 | """ 17 | 18 | dataset_dir = "imagenet-rendition" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.image_dir = os.path.join(self.dataset_dir, "imagenet-r") 24 | 25 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 26 | classnames = ImageNet.read_classnames(text_file) 27 | 28 | data = self.read_data(classnames) 29 | 30 | super().__init__(train_x=data, test=data) 31 | 32 | def read_data(self, classnames): 33 | image_dir = self.image_dir 34 | folders = listdir_nohidden(image_dir, sort=True) 35 | folders = [f for f in folders if f not in TO_BE_IGNORED] 36 | items = [] 37 | 38 | for label, folder in enumerate(folders): 39 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 40 | classname = classnames[folder] 41 | for imname in imnames: 42 | impath = os.path.join(image_dir, folder, imname) 43 | item = Datum(impath=impath, label=label, classname=classname) 44 | items.append(item) 45 | 46 | return items 47 | -------------------------------------------------------------------------------- /apt/datasets/imagenet_sketch.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class ImageNetSketch(DatasetBase): 11 | """ImageNet-Sketch. 12 | 13 | This dataset is used for testing only. 14 | """ 15 | 16 | dataset_dir = "imagenet-sketch" 17 | 18 | def __init__(self, cfg): 19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 20 | self.dataset_dir = os.path.join(root, self.dataset_dir) 21 | self.image_dir = os.path.join(self.dataset_dir, "images") 22 | 23 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 24 | classnames = ImageNet.read_classnames(text_file) 25 | 26 | data = self.read_data(classnames) 27 | 28 | super().__init__(train_x=data, test=data) 29 | 30 | def read_data(self, classnames): 31 | image_dir = self.image_dir 32 | folders = listdir_nohidden(image_dir, sort=True) 33 | items = [] 34 | 35 | for label, folder in enumerate(folders): 36 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 37 | classname = classnames[folder] 38 | for imname in imnames: 39 | impath = os.path.join(image_dir, folder, imname) 40 | item = Datum(impath=impath, label=label, classname=classname) 41 | items.append(item) 42 | 43 | return items 44 | -------------------------------------------------------------------------------- /apt/datasets/imagenetv2.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class ImageNetV2(DatasetBase): 11 | """ImageNetV2. 12 | 13 | This dataset is used for testing only. 14 | """ 15 | 16 | dataset_dir = "imagenetv2" 17 | 18 | def __init__(self, cfg): 19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 20 | self.dataset_dir = os.path.join(root, self.dataset_dir) 21 | image_dir = "imagenetv2-matched-frequency-format-val" 22 | self.image_dir = os.path.join(self.dataset_dir, image_dir) 23 | 24 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 25 | classnames = ImageNet.read_classnames(text_file) 26 | 27 | data = self.read_data(classnames) 28 | 29 | super().__init__(train_x=data, test=data) 30 | 31 | def read_data(self, classnames): 32 | image_dir = self.image_dir 33 | folders = list(classnames.keys()) 34 | items = [] 35 | 36 | for label in range(1000): 37 | class_dir = os.path.join(image_dir, str(label)) 38 | imnames = listdir_nohidden(class_dir) 39 | folder = folders[label] 40 | classname = classnames[folder] 41 | for imname in imnames: 42 | impath = os.path.join(class_dir, imname) 43 | item = Datum(impath=impath, label=label, classname=classname) 44 | items.append(item) 45 | 46 | return items 47 | -------------------------------------------------------------------------------- /apt/datasets/oxford_flowers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | from scipy.io import loadmat 5 | from collections import defaultdict 6 | 7 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 8 | from dassl.utils import read_json, mkdir_if_missing 9 | 10 | from .oxford_pets import OxfordPets 11 | 12 | 13 | @DATASET_REGISTRY.register() 14 | class OxfordFlowers(DatasetBase): 15 | 16 | dataset_dir = "oxford_flowers" 17 | 18 | def __init__(self, cfg): 19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 20 | self.dataset_dir = os.path.join(root, self.dataset_dir) 21 | self.image_dir = os.path.join(self.dataset_dir, "jpg") 22 | self.label_file = os.path.join(self.dataset_dir, "imagelabels.mat") 23 | self.lab2cname_file = os.path.join(self.dataset_dir, "cat_to_name.json") 24 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordFlowers.json") 25 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 26 | mkdir_if_missing(self.split_fewshot_dir) 27 | 28 | if os.path.exists(self.split_path): 29 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 30 | else: 31 | train, val, test = self.read_data() 32 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 33 | 34 | num_shots = cfg.DATASET.NUM_SHOTS 35 | if num_shots >= 1: 36 | seed = cfg.SEED 37 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 38 | 39 | if os.path.exists(preprocessed): 40 | print(f"Loading preprocessed few-shot data from {preprocessed}") 41 | with open(preprocessed, "rb") as file: 42 | data = pickle.load(file) 43 | train, val = data["train"], data["val"] 44 | else: 45 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 46 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 47 | data = {"train": train, "val": val} 48 | print(f"Saving preprocessed few-shot data to {preprocessed}") 49 | with open(preprocessed, "wb") as file: 50 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 51 | 52 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 53 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 54 | 55 | super().__init__(train_x=train, val=val, test=test) 56 | 57 | def read_data(self): 58 | tracker = defaultdict(list) 59 | label_file = loadmat(self.label_file)["labels"][0] 60 | for i, label in enumerate(label_file): 61 | imname = f"image_{str(i + 1).zfill(5)}.jpg" 62 | impath = os.path.join(self.image_dir, imname) 63 | label = int(label) 64 | tracker[label].append(impath) 65 | 66 | print("Splitting data into 50% train, 20% val, and 30% test") 67 | 68 | def _collate(ims, y, c): 69 | items = [] 70 | for im in ims: 71 | item = Datum(impath=im, label=y - 1, classname=c) # convert to 0-based label 72 | items.append(item) 73 | return items 74 | 75 | lab2cname = read_json(self.lab2cname_file) 76 | train, val, test = [], [], [] 77 | for label, impaths in tracker.items(): 78 | random.shuffle(impaths) 79 | n_total = len(impaths) 80 | n_train = round(n_total * 0.5) 81 | n_val = round(n_total * 0.2) 82 | n_test = n_total - n_train - n_val 83 | assert n_train > 0 and n_val > 0 and n_test > 0 84 | cname = lab2cname[str(label)] 85 | train.extend(_collate(impaths[:n_train], label, cname)) 86 | val.extend(_collate(impaths[n_train : n_train + n_val], label, cname)) 87 | test.extend(_collate(impaths[n_train + n_val :], label, cname)) 88 | 89 | return train, val, test 90 | -------------------------------------------------------------------------------- /apt/datasets/oxford_pets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import math 4 | import random 5 | from collections import defaultdict 6 | 7 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 8 | from dassl.utils import read_json, write_json, mkdir_if_missing 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class OxfordPets(DatasetBase): 13 | 14 | dataset_dir = "oxford_pets" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "images") 20 | self.anno_dir = os.path.join(self.dataset_dir, "annotations") 21 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordPets.json") 22 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 23 | mkdir_if_missing(self.split_fewshot_dir) 24 | 25 | if os.path.exists(self.split_path): 26 | train, val, test = self.read_split(self.split_path, self.image_dir) 27 | else: 28 | trainval = self.read_data(split_file="trainval.txt") 29 | test = self.read_data(split_file="test.txt") 30 | train, val = self.split_trainval(trainval) 31 | self.save_split(train, val, test, self.split_path, self.image_dir) 32 | 33 | num_shots = cfg.DATASET.NUM_SHOTS 34 | if num_shots >= 1: 35 | seed = cfg.SEED 36 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 37 | 38 | if os.path.exists(preprocessed): 39 | print(f"Loading preprocessed few-shot data from {preprocessed}") 40 | with open(preprocessed, "rb") as file: 41 | data = pickle.load(file) 42 | train, val = data["train"], data["val"] 43 | else: 44 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 45 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 46 | data = {"train": train, "val": val} 47 | print(f"Saving preprocessed few-shot data to {preprocessed}") 48 | with open(preprocessed, "wb") as file: 49 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 50 | 51 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 52 | train, val, test = self.subsample_classes(train, val, test, subsample=subsample) 53 | 54 | super().__init__(train_x=train, val=val, test=test) 55 | 56 | def read_data(self, split_file): 57 | filepath = os.path.join(self.anno_dir, split_file) 58 | items = [] 59 | 60 | with open(filepath, "r") as f: 61 | lines = f.readlines() 62 | for line in lines: 63 | line = line.strip() 64 | imname, label, species, _ = line.split(" ") 65 | breed = imname.split("_")[:-1] 66 | breed = "_".join(breed) 67 | breed = breed.lower() 68 | imname += ".jpg" 69 | impath = os.path.join(self.image_dir, imname) 70 | label = int(label) - 1 # convert to 0-based index 71 | item = Datum(impath=impath, label=label, classname=breed) 72 | items.append(item) 73 | 74 | return items 75 | 76 | @staticmethod 77 | def split_trainval(trainval, p_val=0.2): 78 | p_trn = 1 - p_val 79 | print(f"Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val") 80 | tracker = defaultdict(list) 81 | for idx, item in enumerate(trainval): 82 | label = item.label 83 | tracker[label].append(idx) 84 | 85 | train, val = [], [] 86 | for label, idxs in tracker.items(): 87 | n_val = round(len(idxs) * p_val) 88 | assert n_val > 0 89 | random.shuffle(idxs) 90 | for n, idx in enumerate(idxs): 91 | item = trainval[idx] 92 | if n < n_val: 93 | val.append(item) 94 | else: 95 | train.append(item) 96 | 97 | return train, val 98 | 99 | @staticmethod 100 | def save_split(train, val, test, filepath, path_prefix): 101 | def _extract(items): 102 | out = [] 103 | for item in items: 104 | impath = item.impath 105 | label = item.label 106 | classname = item.classname 107 | impath = impath.replace(path_prefix, "") 108 | if impath.startswith("/"): 109 | impath = impath[1:] 110 | out.append((impath, label, classname)) 111 | return out 112 | 113 | train = _extract(train) 114 | val = _extract(val) 115 | test = _extract(test) 116 | 117 | split = {"train": train, "val": val, "test": test} 118 | 119 | write_json(split, filepath) 120 | print(f"Saved split to {filepath}") 121 | 122 | @staticmethod 123 | def read_split(filepath, path_prefix): 124 | def _convert(items): 125 | out = [] 126 | for impath, label, classname in items: 127 | impath = os.path.join(path_prefix, impath) 128 | item = Datum(impath=impath, label=int(label), classname=classname) 129 | out.append(item) 130 | return out 131 | 132 | print(f"Reading split from {filepath}") 133 | split = read_json(filepath) 134 | train = _convert(split["train"]) 135 | val = _convert(split["val"]) 136 | test = _convert(split["test"]) 137 | 138 | return train, val, test 139 | 140 | @staticmethod 141 | def subsample_classes(*args, subsample="all"): 142 | """Divide classes into two groups. The first group 143 | represents base classes while the second group represents 144 | new classes. 145 | 146 | Args: 147 | args: a list of datasets, e.g. train, val and test. 148 | subsample (str): what classes to subsample. 149 | """ 150 | assert subsample in ["all", "base", "new"] 151 | 152 | if subsample == "all": 153 | return args 154 | 155 | dataset = args[0] 156 | labels = set() 157 | for item in dataset: 158 | labels.add(item.label) 159 | labels = list(labels) 160 | labels.sort() 161 | n = len(labels) 162 | # Divide classes into two halves 163 | m = math.ceil(n / 2) 164 | 165 | print(f"SUBSAMPLE {subsample.upper()} CLASSES!") 166 | if subsample == "base": 167 | selected = labels[:m] # take the first half 168 | else: 169 | selected = labels[m:] # take the second half 170 | relabeler = {y: y_new for y_new, y in enumerate(selected)} 171 | 172 | output = [] 173 | for dataset in args: 174 | dataset_new = [] 175 | for item in dataset: 176 | if item.label not in selected: 177 | continue 178 | item_new = Datum( 179 | impath=item.impath, 180 | label=relabeler[item.label], 181 | classname=item.classname 182 | ) 183 | dataset_new.append(item_new) 184 | output.append(dataset_new) 185 | 186 | return output 187 | -------------------------------------------------------------------------------- /apt/datasets/stanford_cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from scipy.io import loadmat 4 | 5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 6 | from dassl.utils import mkdir_if_missing 7 | 8 | from .oxford_pets import OxfordPets 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class StanfordCars(DatasetBase): 13 | 14 | dataset_dir = "stanford_cars" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_StanfordCars.json") 20 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 21 | mkdir_if_missing(self.split_fewshot_dir) 22 | 23 | if os.path.exists(self.split_path): 24 | train, val, test = OxfordPets.read_split(self.split_path, self.dataset_dir) 25 | else: 26 | trainval_file = os.path.join(self.dataset_dir, "devkit", "cars_train_annos.mat") 27 | test_file = os.path.join(self.dataset_dir, "cars_test_annos_withlabels.mat") 28 | meta_file = os.path.join(self.dataset_dir, "devkit", "cars_meta.mat") 29 | trainval = self.read_data("cars_train", trainval_file, meta_file) 30 | test = self.read_data("cars_test", test_file, meta_file) 31 | train, val = OxfordPets.split_trainval(trainval) 32 | OxfordPets.save_split(train, val, test, self.split_path, self.dataset_dir) 33 | 34 | num_shots = cfg.DATASET.NUM_SHOTS 35 | if num_shots >= 1: 36 | seed = cfg.SEED 37 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 38 | 39 | if os.path.exists(preprocessed): 40 | print(f"Loading preprocessed few-shot data from {preprocessed}") 41 | with open(preprocessed, "rb") as file: 42 | data = pickle.load(file) 43 | train, val = data["train"], data["val"] 44 | else: 45 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 46 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 47 | data = {"train": train, "val": val} 48 | print(f"Saving preprocessed few-shot data to {preprocessed}") 49 | with open(preprocessed, "wb") as file: 50 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 51 | 52 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 53 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 54 | 55 | super().__init__(train_x=train, val=val, test=test) 56 | 57 | def read_data(self, image_dir, anno_file, meta_file): 58 | anno_file = loadmat(anno_file)["annotations"][0] 59 | meta_file = loadmat(meta_file)["class_names"][0] 60 | items = [] 61 | 62 | for i in range(len(anno_file)): 63 | imname = anno_file[i]["fname"][0] 64 | impath = os.path.join(self.dataset_dir, image_dir, imname) 65 | label = anno_file[i]["class"][0, 0] 66 | label = int(label) - 1 # convert to 0-based index 67 | classname = meta_file[label][0] 68 | names = classname.split(" ") 69 | year = names.pop(-1) 70 | names.insert(0, year) 71 | classname = " ".join(names) 72 | item = Datum(impath=impath, label=label, classname=classname) 73 | items.append(item) 74 | 75 | return items 76 | -------------------------------------------------------------------------------- /apt/datasets/sun397.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class SUN397(DatasetBase): 12 | 13 | dataset_dir = "sun397" 14 | 15 | def __init__(self, cfg): 16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 17 | self.dataset_dir = os.path.join(root, self.dataset_dir) 18 | self.image_dir = os.path.join(self.dataset_dir, "SUN397") 19 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_SUN397.json") 20 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 21 | mkdir_if_missing(self.split_fewshot_dir) 22 | 23 | if os.path.exists(self.split_path): 24 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 25 | else: 26 | classnames = [] 27 | with open(os.path.join(self.dataset_dir, "ClassName.txt"), "r") as f: 28 | lines = f.readlines() 29 | for line in lines: 30 | line = line.strip()[1:] # remove / 31 | classnames.append(line) 32 | cname2lab = {c: i for i, c in enumerate(classnames)} 33 | trainval = self.read_data(cname2lab, "Training_01.txt") 34 | test = self.read_data(cname2lab, "Testing_01.txt") 35 | train, val = OxfordPets.split_trainval(trainval) 36 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 37 | 38 | num_shots = cfg.DATASET.NUM_SHOTS 39 | if num_shots >= 1: 40 | seed = cfg.SEED 41 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 42 | 43 | if os.path.exists(preprocessed): 44 | print(f"Loading preprocessed few-shot data from {preprocessed}") 45 | with open(preprocessed, "rb") as file: 46 | data = pickle.load(file) 47 | train, val = data["train"], data["val"] 48 | else: 49 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 50 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 51 | data = {"train": train, "val": val} 52 | print(f"Saving preprocessed few-shot data to {preprocessed}") 53 | with open(preprocessed, "wb") as file: 54 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 55 | 56 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 57 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 58 | 59 | super().__init__(train_x=train, val=val, test=test) 60 | 61 | def read_data(self, cname2lab, text_file): 62 | text_file = os.path.join(self.dataset_dir, text_file) 63 | items = [] 64 | 65 | with open(text_file, "r") as f: 66 | lines = f.readlines() 67 | for line in lines: 68 | imname = line.strip()[1:] # remove / 69 | classname = os.path.dirname(imname) 70 | label = cname2lab[classname] 71 | impath = os.path.join(self.image_dir, imname) 72 | 73 | names = classname.split("/")[1:] # remove 1st letter 74 | names = names[::-1] # put words like indoor/outdoor at first 75 | classname = " ".join(names) 76 | 77 | item = Datum(impath=impath, label=label, classname=classname) 78 | items.append(item) 79 | 80 | return items 81 | -------------------------------------------------------------------------------- /apt/datasets/ucf101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import re 4 | 5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 6 | from dassl.utils import mkdir_if_missing 7 | 8 | from .oxford_pets import OxfordPets 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class UCF101(DatasetBase): 13 | 14 | dataset_dir = "ucf101" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "UCF-101-midframes") 20 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_UCF101.json") 21 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 22 | mkdir_if_missing(self.split_fewshot_dir) 23 | 24 | if os.path.exists(self.split_path): 25 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 26 | else: 27 | cname2lab = {} 28 | filepath = os.path.join(self.dataset_dir, "ucfTrainTestlist/classInd.txt") 29 | with open(filepath, "r") as f: 30 | lines = f.readlines() 31 | for line in lines: 32 | label, classname = line.strip().split(" ") 33 | label = int(label) - 1 # conver to 0-based index 34 | cname2lab[classname] = label 35 | 36 | trainval = self.read_data(cname2lab, "ucfTrainTestlist/trainlist01.txt") 37 | test = self.read_data(cname2lab, "ucfTrainTestlist/testlist01.txt") 38 | train, val = OxfordPets.split_trainval(trainval) 39 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 40 | 41 | num_shots = cfg.DATASET.NUM_SHOTS 42 | if num_shots >= 1: 43 | seed = cfg.SEED 44 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 45 | 46 | if os.path.exists(preprocessed): 47 | print(f"Loading preprocessed few-shot data from {preprocessed}") 48 | with open(preprocessed, "rb") as file: 49 | data = pickle.load(file) 50 | train, val = data["train"], data["val"] 51 | else: 52 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 53 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 54 | data = {"train": train, "val": val} 55 | print(f"Saving preprocessed few-shot data to {preprocessed}") 56 | with open(preprocessed, "wb") as file: 57 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 58 | 59 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 60 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 61 | 62 | super().__init__(train_x=train, val=val, test=test) 63 | 64 | def read_data(self, cname2lab, text_file): 65 | text_file = os.path.join(self.dataset_dir, text_file) 66 | items = [] 67 | 68 | with open(text_file, "r") as f: 69 | lines = f.readlines() 70 | for line in lines: 71 | line = line.strip().split(" ")[0] # trainlist: filename, label 72 | action, filename = line.split("/") 73 | label = cname2lab[action] 74 | 75 | elements = re.findall("[A-Z][^A-Z]*", action) 76 | renamed_action = "_".join(elements) 77 | 78 | filename = filename.replace(".avi", ".jpg") 79 | impath = os.path.join(self.image_dir, renamed_action, filename) 80 | 81 | item = Datum(impath=impath, label=label, classname=renamed_action) 82 | items.append(item) 83 | 84 | return items 85 | -------------------------------------------------------------------------------- /apt/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from warnings import warn 4 | from yacs.config import CfgNode 5 | import yaml 6 | import argparse 7 | from tqdm import tqdm 8 | 9 | from statistics import mean 10 | 11 | from torchvision import transforms 12 | from torchvision.datasets import * 13 | 14 | import torch.nn as nn 15 | from collections import OrderedDict 16 | from typing import Tuple, TypeVar 17 | from torch import Tensor 18 | from torch.autograd import grad, Variable 19 | 20 | from addict import Dict 21 | 22 | from dassl.data import DataManager 23 | 24 | import datasets.oxford_pets 25 | import datasets.oxford_flowers 26 | import datasets.fgvc_aircraft 27 | import datasets.dtd 28 | import datasets.eurosat 29 | import datasets.stanford_cars 30 | import datasets.food101 31 | import datasets.sun397 32 | import datasets.caltech101 33 | import datasets.ucf101 34 | import datasets.imagenet 35 | 36 | 37 | from torchattacks import PGD, TPGD 38 | from autoattack import AutoAttack 39 | 40 | from utils import * 41 | 42 | 43 | def CWLoss(output, target, confidence=0): 44 | """ 45 | CW loss (Marging loss). 46 | """ 47 | num_classes = output.shape[-1] 48 | target = target.data 49 | target_onehot = torch.zeros(target.size() + (num_classes,)) 50 | target_onehot = target_onehot.cuda() 51 | target_onehot.scatter_(1, target.unsqueeze(1), 1.) 52 | target_var = Variable(target_onehot, requires_grad=False) 53 | real = (target_var * output).sum(1) 54 | other = ((1. - target_var) * output - target_var * 10000.).max(1)[0] 55 | loss = - torch.clamp(real - other + confidence, min=0.) 56 | loss = torch.sum(loss) 57 | return loss 58 | 59 | def input_grad(imgs, targets, model, criterion): 60 | output = model(imgs) 61 | loss = criterion(output, targets) 62 | ig = grad(loss, imgs)[0] 63 | return ig 64 | 65 | def perturb(imgs, targets, model, criterion, eps, eps_step, pert=None, ig=None): 66 | adv = imgs.requires_grad_(True) if pert is None else torch.clamp(imgs+pert, 0, 1).requires_grad_(True) 67 | ig = input_grad(adv, targets, model, criterion) if ig is None else ig 68 | if pert is None: 69 | pert = eps_step*torch.sign(ig) 70 | else: 71 | pert += eps_step*torch.sign(ig) 72 | pert.clamp_(-eps, eps) 73 | adv = torch.clamp(imgs+pert, 0, 1) 74 | pert = adv-imgs 75 | return adv.detach(), pert.detach() 76 | 77 | def pgd(imgs, targets, model, criterion, eps, eps_step, max_iter, pert=None, ig=None): 78 | for i in range(max_iter): 79 | adv, pert = perturb(imgs, targets, model, criterion, eps, eps_step, pert, ig) 80 | ig = None 81 | return adv, pert 82 | 83 | 84 | 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument('experiment') 87 | parser.add_argument('-cp','--cls-prompt', default='a photo of a {}') 88 | parser.add_argument('-ap','--atk-prompt', default=None) 89 | parser.add_argument('--best-checkpoint', action='store_true') 90 | 91 | parser.add_argument('--attack', default='pgd') 92 | parser.add_argument('--dataset', default=None) 93 | parser.add_argument('-lp', '--linear-probe', action='store_true') 94 | 95 | 96 | if __name__ == '__main__': 97 | args = parser.parse_args() 98 | 99 | cfg = CfgNode() 100 | cfg.set_new_allowed(True) 101 | cfg_path = os.path.join(args.experiment, 'cfg.yaml') 102 | cfg.merge_from_file(cfg_path) 103 | 104 | train_dataset = cfg.DATASET.NAME 105 | 106 | if args.dataset: 107 | if args.dataset in ['ImageNetR', 'ImageNetA', 'ON']: 108 | cfg.DATASET.NAME = 'ImageNet' 109 | else: 110 | cfg.DATASET.NAME = args.dataset 111 | save_path = os.path.join(cfg.OUTPUT_DIR, 'dist_shift.yaml') 112 | else: 113 | save_path = os.path.join(cfg.OUTPUT_DIR, 'evaluation.yaml') 114 | if os.path.isfile(save_path): 115 | with open(save_path, 'r') as f: 116 | result = Dict(yaml.safe_load(f)) 117 | 118 | result = result if args.dataset is None or args.dataset==train_dataset else result[args.dataset] 119 | tune = 'linear_probe' if args.linear_probe else args.cls_prompt 120 | if result[tune][args.attack] != {}: 121 | print(f'eval result already exists at: {save_path}') 122 | exit() 123 | 124 | dm = DataManager(cfg) 125 | classes = dm.dataset.classnames 126 | loader = dm.test_loader 127 | num_classes = dm.num_classes 128 | 129 | if args.dataset in ['ImageNetR', 'ImageNetA', 'ON'] or (train_dataset == 'ImageNet' and args.dataset is None and args.attack == 'aa'): 130 | from OODRB.imagenet import ImageNet 131 | if args.dataset == 'ImageNetV2': 132 | shift = 'v2' 133 | elif args.dataset == 'ImageNetA': 134 | shift = 'A' 135 | elif args.dataset == 'ImageNetR': 136 | shift = 'R' 137 | elif args.dataset == 'ON': 138 | shift = 'ON' 139 | else: 140 | shift = None 141 | num_classes = 1000 142 | dataset = ImageNet(cfg.DATASET.ROOT, 143 | shift, 144 | 'val', 145 | transform=loader.dataset.transform) 146 | if args.attack == 'aa': 147 | dataset = torch.utils.data.Subset(dataset, list(range(5000))) 148 | loader = torch.utils.data.DataLoader(dataset, 149 | batch_size=100, 150 | shuffle=False, 151 | num_workers=4, 152 | pin_memory=True) 153 | 154 | model, _ = clip.load(cfg.MODEL.BACKBONE.NAME, device='cpu') 155 | 156 | # load pretrained adversarially robust backbone models 157 | ckp_name = 'vitb32' if cfg.MODEL.BACKBONE.NAME == 'ViT-B/32' else 'rn50' 158 | eps = int(cfg.AT.EPS * 255) 159 | ckp_name += f'_eps{eps}.pth.tar' 160 | ckp = torch.load(os.path.join('backbone', ckp_name)) 161 | model.visual.load_state_dict(ckp['vision_encoder_state_dict']) 162 | 163 | if 'prompter' in (args.cls_prompt, args.atk_prompt): 164 | prompter_path = os.path.join(cfg.OUTPUT_DIR, 'prompt_learner/') 165 | 166 | assert os.path.isdir(prompter_path) 167 | if args.best_checkpoint: 168 | prompter_path += 'best.pth.tar' 169 | else: 170 | ckp = [fname for fname in os.listdir(prompter_path) if 'model.pth.tar' in fname][0] 171 | prompter_path += ckp 172 | 173 | classify_prompt = prompter_path if args.cls_prompt == 'prompter' else args.cls_prompt 174 | attack_prompt = prompter_path if args.atk_prompt == 'prompter' else args.atk_prompt 175 | 176 | if args.linear_probe: 177 | from adv_lp import LinearProbe 178 | model = LinearProbe(model, 512, num_classes, False) 179 | ckp = torch.load(os.path.join(cfg.OUTPUT_DIR, 'linear_probe/linear.pth.tar')) 180 | model.linear.load_state_dict(ckp) 181 | else: 182 | model = CustomCLIP(model, 183 | classes, 184 | cls_prompt=classify_prompt, 185 | atk_prompt=attack_prompt, 186 | cfg=cfg) 187 | 188 | model = model.cuda() 189 | model.eval() 190 | 191 | meters = Dict() 192 | meters.acc = AverageMeter('Clean Acc@1', ':6.2f') 193 | meters.rob = AverageMeter('Robust Acc@1', ':6.2f') 194 | 195 | progress = ProgressMeter( 196 | len(loader), 197 | [meters.acc, meters.rob], 198 | prefix=cfg.DATASET.NAME) 199 | 200 | eps = cfg.AT.EPS 201 | alpha = eps / 4.0 202 | steps = 100 203 | 204 | if args.attack == 'aa': 205 | attack = AutoAttack(model, 206 | norm='Linf', 207 | eps=eps, 208 | version='standard', 209 | verbose=False) 210 | elif args.attack == 'pgd': 211 | attack = PGD(model, eps=eps, alpha=alpha, steps=steps) 212 | elif args.attack == 'tpgd': 213 | attack = TPGD(model, eps=eps, alpha=alpha, steps=steps) 214 | 215 | for i, data in enumerate(loader, start=1): 216 | try: 217 | # few-shot data loader from Dassl 218 | imgs, tgts = data['img'], data['label'] 219 | except: 220 | imgs, tgts = data[:2] 221 | imgs, tgts = imgs.cuda(), tgts.cuda() 222 | bs = imgs.size(0) 223 | 224 | with torch.no_grad(): 225 | output = model(imgs) 226 | 227 | acc = accuracy(output, tgts) 228 | meters.acc.update(acc[0].item(), bs) 229 | 230 | model.mode = 'attack' 231 | if args.attack == 'aa': 232 | adv = attack.run_standard_evaluation(imgs, tgts, bs=bs) 233 | elif args.attack in ['pgd', 'tpgd']: 234 | adv = attack(imgs, tgts) 235 | else: 236 | adv, _ = pgd(imgs, tgts, model, CWLoss, eps, alpha, steps) 237 | 238 | model.mode = 'classification' 239 | 240 | # Calculate features 241 | with torch.no_grad(): 242 | output = model(adv) 243 | 244 | rob = accuracy(output, tgts) 245 | meters.rob.update(rob[0].item(), bs) 246 | 247 | if i == 1 or i % 10 == 0 or i == len(loader): 248 | progress.display(i) 249 | 250 | # save result 251 | if os.path.isfile(save_path): 252 | with open(save_path, 'r') as f: 253 | result = Dict(yaml.safe_load(f)) 254 | else: 255 | result = Dict() 256 | 257 | _result = result if args.dataset is None or args.dataset==train_dataset else result[args.dataset] 258 | tune = 'linear_probe' if args.linear_probe else args.cls_prompt 259 | _result[tune].clean = meters.acc.avg 260 | _result[tune][args.attack] = meters.rob.avg 261 | 262 | with open(save_path, 'w+') as f: 263 | yaml.dump(result.to_dict(), f) 264 | 265 | print(f'result saved at: {save_path}') 266 | -------------------------------------------------------------------------------- /apt/interpret_prompt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import torch 5 | 6 | from clip.simple_tokenizer import SimpleTokenizer 7 | from clip import clip 8 | 9 | 10 | def load_clip_to_cpu(backbone_name="RN50"): 11 | url = clip._MODELS[backbone_name] 12 | model_path = clip._download(url) 13 | 14 | try: 15 | # loading JIT archive 16 | model = torch.jit.load(model_path, map_location="cpu").eval() 17 | state_dict = None 18 | 19 | except RuntimeError: 20 | state_dict = torch.load(model_path, map_location="cpu") 21 | 22 | model = clip.build_model(state_dict or model.state_dict()) 23 | 24 | return model 25 | 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument("fpath", type=str, help="Path to the learned prompt") 29 | parser.add_argument("topk", type=int, help="Select top-k similar words") 30 | args = parser.parse_args() 31 | 32 | fpath = args.fpath 33 | topk = args.topk 34 | 35 | assert os.path.exists(fpath) 36 | 37 | print(f"Return the top-{topk} matched words") 38 | 39 | tokenizer = SimpleTokenizer() 40 | clip_model = load_clip_to_cpu() 41 | token_embedding = clip_model.token_embedding.weight 42 | print(f"Size of token embedding: {token_embedding.shape}") 43 | 44 | prompt_learner = torch.load(fpath, map_location="cpu")["state_dict"] 45 | ctx = prompt_learner["ctx"] 46 | ctx = ctx.float() 47 | print(f"Size of context: {ctx.shape}") 48 | 49 | if ctx.dim() == 2: 50 | # Generic context 51 | distance = torch.cdist(ctx, token_embedding) 52 | print(f"Size of distance matrix: {distance.shape}") 53 | sorted_idxs = torch.argsort(distance, dim=1) 54 | sorted_idxs = sorted_idxs[:, :topk] 55 | 56 | for m, idxs in enumerate(sorted_idxs): 57 | words = [tokenizer.decoder[idx.item()] for idx in idxs] 58 | dist = [f"{distance[m, idx].item():.4f}" for idx in idxs] 59 | print(f"{m+1}: {words} {dist}") 60 | 61 | elif ctx.dim() == 3: 62 | # Class-specific context 63 | raise NotImplementedError 64 | -------------------------------------------------------------------------------- /apt/scripts/APT.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # custom config 4 | DATA=TODO_replace_with_data_root_path 5 | TRAINER=APT 6 | 7 | DATASET=$1 8 | CFG=$2 # config file 9 | CTP=$3 # class token position (end or middle) 10 | NCTX=$4 # number of context tokens 11 | SHOTS=$5 # number of shots (1, 2, 4, 8, 16) 12 | CSC=$6 # class-specific context (False or True) 13 | EPS=$7 # epsilon for AT 14 | ALPHA=$8 # alpha or step size for AT 15 | STEPS=$9 # number of steps for AT 16 | SEED=${10} 17 | ATP=${11} 18 | PALPHA=${12} 19 | 20 | 21 | if [ ${ATP} == 'perturbed' ] 22 | then 23 | DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/eps${EPS}_alpha${ALPHA}_step${STEPS}_${ATP}_${PALPHA}/seed${SEED} 24 | elif [ ${ATP} == 'constant' ] 25 | then 26 | DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/eps${EPS}_alpha${ALPHA}_step${STEPS}_${ATP}/seed${SEED} 27 | else 28 | DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/eps${EPS}_alpha${ALPHA}_step${STEPS}/seed${SEED} 29 | fi 30 | 31 | 32 | if [ -d "$DIR" ]; then 33 | echo "Oops! The results exist at ${DIR} (so skip this job)" 34 | else 35 | python train.py \ 36 | --root ${DATA} \ 37 | --trainer ${TRAINER} \ 38 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 39 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 40 | --output-dir ${DIR} \ 41 | --eps ${EPS} \ 42 | --alpha ${ALPHA} \ 43 | --steps ${STEPS} \ 44 | --adv-prompt ${ATP} \ 45 | --prompt-alpha ${PALPHA} \ 46 | TRAINER.COOP.N_CTX ${NCTX} \ 47 | TRAINER.COOP.CSC ${CSC} \ 48 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ 49 | DATASET.NUM_SHOTS ${SHOTS} 50 | fi 51 | -------------------------------------------------------------------------------- /apt/scripts/CoOp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # custom config 4 | DATA=TODO_replace_with_data_root_path 5 | TRAINER=CoOp 6 | 7 | DATASET=$1 8 | CFG=$2 # config file 9 | CTP=$3 # class token position (end or middle) 10 | NCTX=$4 # number of context tokens 11 | SHOTS=$5 # number of shots (1, 2, 4, 8, 16) 12 | CSC=$6 # class-specific context (False or True) 13 | EPS=$7 # epsilon for AT 14 | ALPHA=$8 # alpha or step size for AT 15 | STEPS=$9 # number of steps for AT 16 | 17 | 18 | DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/eps${EPS}_alpha${ALPHA}_step${STEPS}/seed${SEED} 19 | 20 | if [ -d "$DIR" ]; then 21 | echo "Oops! The results exist at ${DIR} (so skip this job)" 22 | else 23 | python train.py \ 24 | --root ${DATA} \ 25 | --trainer ${TRAINER} \ 26 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 28 | --output-dir ${DIR} \ 29 | --eps ${EPS}\ 30 | --alpha ${ALPHA}\ 31 | --steps ${STEPS}\ 32 | TRAINER.COOP.N_CTX ${NCTX} \ 33 | TRAINER.COOP.CSC ${CSC} \ 34 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ 35 | DATASET.NUM_SHOTS ${SHOTS} 36 | fi 37 | -------------------------------------------------------------------------------- /apt/scripts/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # custom config 4 | DATA=TODO_replace_with_data_root_path 5 | TRAINER=APT 6 | SHOTS=16 7 | NCTX=16 8 | CSC=False 9 | CTP=end 10 | 11 | DATASET=$1 12 | CFG=$2 13 | 14 | # --seed ${SEED} \ 15 | 16 | 17 | python train.py \ 18 | --root ${DATA} \ 19 | --trainer ${TRAINER} \ 20 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 21 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 22 | --output-dir output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/${DATASET}/seed${SEED} \ 23 | --model-dir output/imagenet/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/eps1_alpha0.67_step3/seed${SEED} \ 24 | --load-epoch 50 \ 25 | --eval-only \ 26 | TRAINER.COOP.N_CTX ${NCTX} \ 27 | TRAINER.COOP.CSC ${CSC} \ 28 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} 29 | -------------------------------------------------------------------------------- /apt/train.py: -------------------------------------------------------------------------------- 1 | import yaml, os 2 | import argparse 3 | import torch 4 | from yacs.config import CfgNode as CN 5 | 6 | from dassl.utils import setup_logger, set_random_seed, collect_env_info 7 | from dassl.config import get_cfg_default 8 | from dassl.engine import build_trainer 9 | 10 | # custom 11 | import datasets.oxford_pets 12 | import datasets.oxford_flowers 13 | import datasets.fgvc_aircraft 14 | import datasets.dtd 15 | import datasets.eurosat 16 | import datasets.stanford_cars 17 | import datasets.food101 18 | import datasets.sun397 19 | import datasets.caltech101 20 | import datasets.ucf101 21 | import datasets.imagenet 22 | 23 | import trainers.apt 24 | 25 | def print_args(args, cfg): 26 | print("***************") 27 | print("** Arguments **") 28 | print("***************") 29 | optkeys = list(args.__dict__.keys()) 30 | optkeys.sort() 31 | for key in optkeys: 32 | print("{}: {}".format(key, args.__dict__[key])) 33 | print("************") 34 | print("** Config **") 35 | print("************") 36 | print(cfg) 37 | 38 | 39 | def reset_cfg(cfg, args): 40 | if args.root: 41 | cfg.DATASET.ROOT = args.root 42 | 43 | if args.output_dir: 44 | cfg.OUTPUT_DIR = args.output_dir 45 | 46 | if args.resume: 47 | cfg.RESUME = args.resume 48 | 49 | if args.seed: 50 | cfg.SEED = args.seed 51 | 52 | if args.source_domains: 53 | cfg.DATASET.SOURCE_DOMAINS = args.source_domains 54 | 55 | if args.target_domains: 56 | cfg.DATASET.TARGET_DOMAINS = args.target_domains 57 | 58 | if args.transforms: 59 | cfg.INPUT.TRANSFORMS = args.transforms 60 | 61 | if args.trainer: 62 | cfg.TRAINER.NAME = args.trainer 63 | 64 | if args.backbone: 65 | cfg.MODEL.BACKBONE.NAME = args.backbone 66 | 67 | if args.head: 68 | cfg.MODEL.HEAD.NAME = args.head 69 | 70 | cfg.AT = CN() 71 | cfg.AT.EPS = args.eps / 255.0 72 | cfg.AT.ALPHA = args.alpha / 255.0 73 | cfg.AT.STEPS = args.steps 74 | 75 | cfg.AT.PROMPT = args.adv_prompt 76 | cfg.AT.PALPHA = args.prompt_alpha 77 | 78 | def extend_cfg(cfg): 79 | """ 80 | Add new config variables. 81 | 82 | E.g. 83 | from yacs.config import CfgNode as CN 84 | cfg.TRAINER.MY_MODEL = CN() 85 | cfg.TRAINER.MY_MODEL.PARAM_A = 1. 86 | cfg.TRAINER.MY_MODEL.PARAM_B = 0.5 87 | cfg.TRAINER.MY_MODEL.PARAM_C = False 88 | """ 89 | 90 | cfg.TRAINER.COOP = CN() 91 | cfg.TRAINER.COOP.N_CTX = 16 # number of context vectors 92 | cfg.TRAINER.COOP.CSC = False # class-specific context 93 | cfg.TRAINER.COOP.CTX_INIT = "" # initialization words 94 | cfg.TRAINER.COOP.PREC = "fp16" # fp16, fp32, amp 95 | cfg.TRAINER.COOP.CLASS_TOKEN_POSITION = "end" # 'middle' or 'end' or 'front' 96 | 97 | cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new 98 | 99 | cfg.MODEL.BACKBONE.ROBUST = True 100 | 101 | def setup_cfg(args): 102 | cfg = get_cfg_default() 103 | extend_cfg(cfg) 104 | 105 | # 1. From the dataset config file 106 | if args.dataset_config_file: 107 | cfg.merge_from_file(args.dataset_config_file) 108 | 109 | # 2. From the method config file 110 | if args.config_file: 111 | cfg.merge_from_file(args.config_file) 112 | 113 | # 3. From input arguments 114 | reset_cfg(cfg, args) 115 | 116 | # 4. From optional input arguments 117 | cfg.merge_from_list(args.opts) 118 | 119 | cfg.freeze() 120 | 121 | return cfg 122 | 123 | 124 | def main(args): 125 | cfg = setup_cfg(args) 126 | if cfg.SEED >= 0: 127 | print("Setting fixed seed: {}".format(cfg.SEED)) 128 | set_random_seed(cfg.SEED) 129 | setup_logger(cfg.OUTPUT_DIR) 130 | 131 | if torch.cuda.is_available() and cfg.USE_CUDA: 132 | torch.backends.cudnn.benchmark = True 133 | 134 | # print_args(args, cfg) 135 | # print("Collecting env info ...") 136 | # print("** System info **\n{}\n".format(collect_env_info())) 137 | 138 | trainer = build_trainer(cfg) 139 | 140 | with open(os.path.join(cfg.OUTPUT_DIR, 'cfg.yaml'), 'w+') as f: 141 | f.write(cfg.dump()) 142 | 143 | if args.eval_only: 144 | trainer.load_model(args.model_dir, epoch=args.load_epoch) 145 | trainer.test() 146 | return 147 | 148 | if not args.no_train: 149 | trainer.train() 150 | 151 | 152 | if __name__ == "__main__": 153 | parser = argparse.ArgumentParser() 154 | parser.add_argument("--root", type=str, default="", help="path to dataset") 155 | parser.add_argument("--output-dir", type=str, default="", help="output directory") 156 | parser.add_argument( 157 | "--resume", 158 | type=str, 159 | default="", 160 | help="checkpoint directory (from which the training resumes)", 161 | ) 162 | parser.add_argument( 163 | "--seed", type=int, default=-1, help="only positive value enables a fixed seed" 164 | ) 165 | parser.add_argument( 166 | "--source-domains", type=str, nargs="+", help="source domains for DA/DG" 167 | ) 168 | parser.add_argument( 169 | "--target-domains", type=str, nargs="+", help="target domains for DA/DG" 170 | ) 171 | parser.add_argument( 172 | "--transforms", type=str, nargs="+", help="data augmentation methods" 173 | ) 174 | parser.add_argument( 175 | "--config-file", type=str, default="", help="path to config file" 176 | ) 177 | parser.add_argument( 178 | "--dataset-config-file", 179 | type=str, 180 | default="", 181 | help="path to config file for dataset setup", 182 | ) 183 | parser.add_argument("--trainer", type=str, default="", help="name of trainer") 184 | parser.add_argument("--backbone", type=str, default="", help="name of CNN backbone") 185 | parser.add_argument("--head", type=str, default="", help="name of head") 186 | parser.add_argument("--eval-only", action="store_true", help="evaluation only") 187 | parser.add_argument("--eps", type=float, default=1) 188 | parser.add_argument("--alpha", type=float, default=2.0/3) 189 | parser.add_argument("--steps", type=int, default=3) 190 | parser.add_argument("--adv-prompt", type=str, default='onfly') 191 | parser.add_argument("--prompt-alpha", type=float, default=None) 192 | parser.add_argument( 193 | "--model-dir", 194 | type=str, 195 | default="", 196 | help="load model from this directory for eval-only mode", 197 | ) 198 | parser.add_argument( 199 | "--load-epoch", type=int, help="load model weights at this epoch for evaluation" 200 | ) 201 | parser.add_argument( 202 | "--no-train", action="store_true", help="do not call trainer.train()" 203 | ) 204 | parser.add_argument( 205 | "opts", 206 | default=None, 207 | nargs=argparse.REMAINDER, 208 | help="modify config options using the command-line", 209 | ) 210 | args = parser.parse_args() 211 | main(args) 212 | -------------------------------------------------------------------------------- /apt/trainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreeLLi/APT/c2e99e09329b9a2be87d02e5c7ae02092d2ac95d/apt/trainers/__init__.py -------------------------------------------------------------------------------- /apt/trainers/apt.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from tqdm import tqdm 3 | import copy 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import functional as F 8 | from torch.cuda.amp import GradScaler, autocast 9 | 10 | from dassl.engine import TRAINER_REGISTRY, TrainerX 11 | from dassl.metrics import compute_accuracy 12 | from dassl.utils import load_pretrained_weights, load_checkpoint 13 | from dassl.optim import build_optimizer, build_lr_scheduler 14 | 15 | from clip import clip 16 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 17 | 18 | from torchattacks import PGD 19 | 20 | 21 | _tokenizer = _Tokenizer() 22 | 23 | 24 | def load_clip_to_cpu(cfg): 25 | backbone_name = cfg.MODEL.BACKBONE.NAME 26 | url = clip._MODELS[backbone_name] 27 | model_path = clip._download(url) 28 | 29 | try: 30 | # loading JIT archive 31 | model = torch.jit.load(model_path, map_location="cpu").eval() 32 | state_dict = None 33 | 34 | except RuntimeError: 35 | state_dict = torch.load(model_path, map_location="cpu") 36 | 37 | model = clip.build_model(state_dict or model.state_dict()) 38 | 39 | return model 40 | 41 | 42 | class TextEncoder(nn.Module): 43 | def __init__(self, clip_model): 44 | super().__init__() 45 | self.transformer = clip_model.transformer 46 | self.positional_embedding = clip_model.positional_embedding 47 | self.ln_final = clip_model.ln_final 48 | self.text_projection = clip_model.text_projection 49 | self.dtype = clip_model.dtype 50 | 51 | def forward(self, prompts, tokenized_prompts): 52 | x = prompts + self.positional_embedding.type(self.dtype) 53 | x = x.permute(1, 0, 2) # NLD -> LND 54 | x = self.transformer(x) 55 | x = x.permute(1, 0, 2) # LND -> NLD 56 | x = self.ln_final(x).type(self.dtype) 57 | 58 | # x.shape = [batch_size, n_ctx, transformer.width] 59 | # take features from the eot embedding (eot_token is the highest number in each sequence) 60 | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection 61 | 62 | return x 63 | 64 | 65 | class PromptLearner(nn.Module): 66 | def __init__(self, cfg, classnames, clip_model): 67 | super().__init__() 68 | n_cls = len(classnames) 69 | n_ctx = cfg.TRAINER.COOP.N_CTX 70 | ctx_init = cfg.TRAINER.COOP.CTX_INIT 71 | dtype = clip_model.dtype 72 | ctx_dim = clip_model.ln_final.weight.shape[0] 73 | clip_imsize = clip_model.visual.input_resolution 74 | cfg_imsize = cfg.INPUT.SIZE[0] 75 | assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})" 76 | 77 | if ctx_init: 78 | # use given words to initialize context vectors 79 | ctx_init = ctx_init.replace("_", " ") 80 | n_ctx = len(ctx_init.split(" ")) 81 | prompt = clip.tokenize(ctx_init) 82 | with torch.no_grad(): 83 | embedding = clip_model.token_embedding(prompt).type(dtype) 84 | ctx_vectors = embedding[0, 1 : 1 + n_ctx, :] 85 | prompt_prefix = ctx_init 86 | 87 | else: 88 | # random initialization 89 | if cfg.TRAINER.COOP.CSC: 90 | print("Initializing class-specific contexts") 91 | ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype) 92 | else: 93 | print("Initializing a generic context") 94 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) 95 | nn.init.normal_(ctx_vectors, std=0.02) 96 | prompt_prefix = " ".join(["X"] * n_ctx) 97 | 98 | print(f'Initial context: "{prompt_prefix}"') 99 | print(f"Number of context words (tokens): {n_ctx}") 100 | 101 | self.ctx = nn.Parameter(ctx_vectors) # to be optimized 102 | 103 | classnames = [name.replace("_", " ") for name in classnames] 104 | name_lens = [len(_tokenizer.encode(name)) for name in classnames] 105 | prompts = [prompt_prefix + " " + name + "." for name in classnames] 106 | 107 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) 108 | with torch.no_grad(): 109 | embedding = clip_model.token_embedding(tokenized_prompts).type(dtype) 110 | 111 | # These token vectors will be saved when in save_model(), 112 | # but they should be ignored in load_model() as we want to use 113 | # those computed using the current class names 114 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS 115 | self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS 116 | 117 | self.n_cls = n_cls 118 | self.n_ctx = n_ctx 119 | self.tokenized_prompts = tokenized_prompts # torch.Tensor 120 | self.name_lens = name_lens 121 | self.class_token_position = cfg.TRAINER.COOP.CLASS_TOKEN_POSITION 122 | 123 | def forward(self): 124 | ctx = self.ctx 125 | if ctx.dim() == 2: 126 | ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) 127 | 128 | prefix = self.token_prefix 129 | suffix = self.token_suffix 130 | 131 | if self.class_token_position == "end": 132 | prompts = torch.cat( 133 | [ 134 | prefix, # (n_cls, 1, dim) 135 | ctx, # (n_cls, n_ctx, dim) 136 | suffix, # (n_cls, *, dim) 137 | ], 138 | dim=1, 139 | ) 140 | 141 | elif self.class_token_position == "middle": 142 | half_n_ctx = self.n_ctx // 2 143 | prompts = [] 144 | for i in range(self.n_cls): 145 | name_len = self.name_lens[i] 146 | prefix_i = prefix[i : i + 1, :, :] 147 | class_i = suffix[i : i + 1, :name_len, :] 148 | suffix_i = suffix[i : i + 1, name_len:, :] 149 | ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :] 150 | ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :] 151 | prompt = torch.cat( 152 | [ 153 | prefix_i, # (1, 1, dim) 154 | ctx_i_half1, # (1, n_ctx//2, dim) 155 | class_i, # (1, name_len, dim) 156 | ctx_i_half2, # (1, n_ctx//2, dim) 157 | suffix_i, # (1, *, dim) 158 | ], 159 | dim=1, 160 | ) 161 | prompts.append(prompt) 162 | prompts = torch.cat(prompts, dim=0) 163 | 164 | elif self.class_token_position == "front": 165 | prompts = [] 166 | for i in range(self.n_cls): 167 | name_len = self.name_lens[i] 168 | prefix_i = prefix[i : i + 1, :, :] 169 | class_i = suffix[i : i + 1, :name_len, :] 170 | suffix_i = suffix[i : i + 1, name_len:, :] 171 | ctx_i = ctx[i : i + 1, :, :] 172 | prompt = torch.cat( 173 | [ 174 | prefix_i, # (1, 1, dim) 175 | class_i, # (1, name_len, dim) 176 | ctx_i, # (1, n_ctx, dim) 177 | suffix_i, # (1, *, dim) 178 | ], 179 | dim=1, 180 | ) 181 | prompts.append(prompt) 182 | prompts = torch.cat(prompts, dim=0) 183 | 184 | else: 185 | raise ValueError 186 | 187 | return prompts 188 | 189 | class ImageNormalizer(nn.Module): 190 | 191 | def __init__(self, mean, std): 192 | super(ImageNormalizer, self).__init__() 193 | 194 | self.register_buffer('mean', torch.as_tensor(mean).view(1, 3, 1, 1)) 195 | self.register_buffer('std', torch.as_tensor(std).view(1, 3, 1, 1)) 196 | 197 | def forward(self, input): 198 | return (input - self.mean) / self.std 199 | 200 | def __repr__(self): 201 | return f'ImageNormalizer(mean={self.mean.squeeze()}, std={self.std.squeeze()})' # type: ignore 202 | 203 | class CustomCLIP(nn.Module): 204 | def __init__(self, cfg, classnames, clip_model, mode='updated', device='cuda'): 205 | super().__init__() 206 | self.prompt_learner = PromptLearner(cfg, classnames, clip_model) 207 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts 208 | self.image_encoder = clip_model.visual 209 | self.text_encoder = TextEncoder(clip_model) 210 | self.logit_scale = clip_model.logit_scale 211 | self.dtype = clip_model.dtype 212 | self.device = device 213 | self.mode = mode 214 | if mode == 'constant': 215 | prompts = torch.cat([clip.tokenize('a photo of a {}'.format(c)) 216 | for c in classnames]).cuda() 217 | clip_model = clip_model.cuda() 218 | self.text_features = clip_model.encode_text(prompts).detach().clone() 219 | 220 | self.normalize = ImageNormalizer(cfg.INPUT.PIXEL_MEAN, 221 | cfg.INPUT.PIXEL_STD).to(device) 222 | 223 | num_params = sum(p.numel() for p in self.prompt_learner.parameters()) 224 | print('params: ', num_params) 225 | 226 | def forward(self, image, attack=False): 227 | image_features = self.image_encoder(self.normalize(image).type(self.dtype)) 228 | 229 | if self.mode == 'constant' and attack: 230 | text_features = self.text_features.detach().clone() 231 | else: 232 | prompts = self.prompt_learner() 233 | tokenized_prompts = self.tokenized_prompts 234 | text_features = self.text_encoder(prompts, tokenized_prompts) 235 | 236 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 237 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 238 | 239 | logit_scale = self.logit_scale.exp() 240 | logits = logit_scale * image_features @ text_features.t() 241 | 242 | return logits 243 | 244 | 245 | @TRAINER_REGISTRY.register() 246 | class CoOp(TrainerX): 247 | """Context Optimization (CoOp). 248 | 249 | Learning to Prompt for Vision-Language Models 250 | https://arxiv.org/abs/2109.01134 251 | """ 252 | 253 | def check_cfg(self, cfg): 254 | assert cfg.TRAINER.COOP.PREC in ["fp16", "fp32", "amp"] 255 | 256 | def build_model(self): 257 | cfg = self.cfg 258 | classnames = self.dm.dataset.classnames 259 | 260 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 261 | clip_model = load_clip_to_cpu(cfg)#.to(self.device) 262 | 263 | if cfg.MODEL.BACKBONE.ROBUST: 264 | ckp_name = 'vitb32' if cfg.MODEL.BACKBONE.NAME == 'ViT-B/32' else 'rn50' 265 | eps = int(cfg.AT.EPS * 255) 266 | ckp_name += f'_eps{eps}.pth.tar' 267 | ckp = torch.load(osp.join('backbone', ckp_name)) 268 | clip_model.visual.load_state_dict(ckp['vision_encoder_state_dict']) 269 | 270 | if cfg.TRAINER.COOP.PREC == "fp32" or cfg.TRAINER.COOP.PREC == "amp": 271 | # CLIP's default precision is fp16 272 | clip_model.float() 273 | 274 | print("Building custom CLIP") 275 | self.model = CustomCLIP(cfg, classnames, clip_model, cfg.AT.PROMPT, self.device) 276 | 277 | print("Turning off gradients in both the image and the text encoder") 278 | for name, param in self.model.named_parameters(): 279 | if "prompt_learner" not in name: 280 | param.requires_grad_(False) 281 | 282 | if cfg.MODEL.INIT_WEIGHTS: 283 | load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS) 284 | 285 | self.model.to(self.device) 286 | # NOTE: only give prompt_learner to the optimizer 287 | self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM) 288 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) 289 | self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched) 290 | 291 | self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == "amp" else None 292 | 293 | # Note that multi-gpu training could be slow because CLIP's size is 294 | # big, which slows down the copy operation in DataParallel 295 | device_count = torch.cuda.device_count() 296 | if device_count > 1: 297 | print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!") 298 | self.model = nn.DataParallel(self.model) 299 | 300 | def forward_backward(self, batch): 301 | image, label = self.parse_batch_train(batch) 302 | 303 | prec = self.cfg.TRAINER.COOP.PREC 304 | if prec == "amp": 305 | with autocast(): 306 | output = self.model(image) 307 | loss = F.cross_entropy(output, label) 308 | self.optim.zero_grad() 309 | self.scaler.scale(loss).backward() 310 | self.scaler.step(self.optim) 311 | self.scaler.update() 312 | else: 313 | output = self.model(image) 314 | loss = F.cross_entropy(output, label) 315 | self.model_backward_and_update(loss) 316 | 317 | loss_summary = { 318 | "loss": loss.item(), 319 | "acc": compute_accuracy(output, label)[0].item(), 320 | } 321 | 322 | if (self.batch_idx + 1) == self.num_batches: 323 | self.update_lr() 324 | 325 | return loss_summary 326 | 327 | def parse_batch_train(self, batch): 328 | input = batch["img"] 329 | label = batch["label"] 330 | input = input.to(self.device) 331 | label = label.to(self.device) 332 | return input, label 333 | 334 | def load_model(self, directory, epoch=None): 335 | if not directory: 336 | print("Note that load_model() is skipped as no pretrained model is given") 337 | return 338 | 339 | names = self.get_model_names() 340 | 341 | # By default, the best model is loaded 342 | model_file = "model-best.pth.tar" 343 | 344 | if epoch is not None: 345 | model_file = "model.pth.tar-" + str(epoch) 346 | 347 | for name in names: 348 | model_path = osp.join(directory, name, model_file) 349 | 350 | if not osp.exists(model_path): 351 | raise FileNotFoundError('Model not found at "{}"'.format(model_path)) 352 | 353 | checkpoint = load_checkpoint(model_path) 354 | state_dict = checkpoint["state_dict"] 355 | epoch = checkpoint["epoch"] 356 | 357 | # Ignore fixed token vectors 358 | if "token_prefix" in state_dict: 359 | del state_dict["token_prefix"] 360 | 361 | if "token_suffix" in state_dict: 362 | del state_dict["token_suffix"] 363 | 364 | print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch)) 365 | # set strict=False 366 | self._models[name].load_state_dict(state_dict, strict=False) 367 | 368 | 369 | 370 | @TRAINER_REGISTRY.register() 371 | class APT(CoOp): 372 | """Context Optimization (CoOp) for Adversarial Training. 373 | 374 | """ 375 | 376 | def forward_backward(self, batch): 377 | image, label = self.parse_batch_train(batch) 378 | eps=self.cfg.AT.EPS 379 | alpha=self.cfg.AT.ALPHA 380 | steps=self.cfg.AT.STEPS 381 | 382 | attack = PGD(self.model, 383 | eps=self.cfg.AT.EPS, 384 | alpha=self.cfg.AT.ALPHA, 385 | steps=self.cfg.AT.STEPS) 386 | 387 | prompter_optim = torch.optim.SGD(self.model.prompt_learner.parameters(), 388 | lr=self.cfg.AT.PALPHA, 389 | weight_decay=0, 390 | momentum=0) 391 | 392 | prec = self.cfg.TRAINER.COOP.PREC 393 | if prec == "amp": 394 | with autocast(): 395 | output = self.model(image) 396 | loss = F.cross_entropy(output, label) 397 | self.optim.zero_grad() 398 | self.scaler.scale(loss).backward() 399 | self.scaler.step(self.optim) 400 | self.scaler.update() 401 | elif self.cfg.AT.PROMPT == 'perturbed': 402 | state = copy.deepcopy(self.model.prompt_learner.state_dict()) 403 | delta = torch.zeros_like(image).uniform_(-eps, eps) 404 | for _ in range(steps): 405 | adv = torch.clamp(image+delta, 0, 1).requires_grad_(True) 406 | output = self.model(adv) 407 | loss = -F.cross_entropy(output, label) 408 | prompter_optim.zero_grad() 409 | loss.backward() 410 | delta -= alpha * torch.sign(adv.grad) 411 | delta = torch.clamp(delta, -eps, eps).detach() 412 | prompter_optim.step() 413 | 414 | self.model.prompt_learner.load_state_dict(state) 415 | prompter_optim.zero_grad() 416 | adv = torch.clamp(image+delta, 0, 1).detach() 417 | output = self.model(adv) 418 | loss = F.cross_entropy(output, label) 419 | self.model_backward_and_update(loss) 420 | else: 421 | delta = torch.zeros_like(image).uniform_(-eps, eps) 422 | for _ in range(steps): 423 | adv = torch.clamp(image+delta, 0, 1).requires_grad_(True) 424 | output = self.model(adv, attack=True) 425 | loss = -F.cross_entropy(output, label) 426 | loss.backward() 427 | delta -= alpha * torch.sign(adv.grad) 428 | delta = torch.clamp(delta, -eps, eps).detach() 429 | 430 | adv = torch.clamp(image+delta, 0, 1).detach() 431 | output = self.model(adv) 432 | loss = F.cross_entropy(output, label) 433 | self.model_backward_and_update(loss) 434 | 435 | loss_summary = { 436 | "loss": loss.item(), 437 | "acc": compute_accuracy(output, label)[0].item(), 438 | } 439 | 440 | if (self.batch_idx + 1) == self.num_batches: 441 | self.update_lr() 442 | 443 | return loss_summary 444 | 445 | def test(self, split=None): 446 | """A generic testing pipeline.""" 447 | self.set_model_mode("eval") 448 | self.evaluator.reset() 449 | 450 | attack = PGD(self.model, 451 | eps=self.cfg.AT.EPS, 452 | alpha=self.cfg.AT.ALPHA, 453 | steps=10) 454 | 455 | if split is None: 456 | split = self.cfg.TEST.SPLIT 457 | 458 | if split == "val" and self.val_loader is not None: 459 | data_loader = self.val_loader 460 | else: 461 | split = "test" # in case val_loader is None 462 | data_loader = self.test_loader 463 | 464 | print(f"Evaluate on the *{split}* set") 465 | 466 | for batch_idx, batch in enumerate(tqdm(data_loader)): 467 | input, label = self.parse_batch_test(batch) 468 | # input = attack(input, label).detach() 469 | with torch.no_grad(): 470 | output = self.model_inference(input).detach() 471 | self.evaluator.process(output, label) 472 | 473 | results = self.evaluator.evaluate() 474 | 475 | for k, v in results.items(): 476 | tag = f"{split}/{k}" 477 | self.write_scalar(tag, v, self.epoch) 478 | 479 | return list(results.values())[0] 480 | -------------------------------------------------------------------------------- /apt/trainers/imagenet_templates.py: -------------------------------------------------------------------------------- 1 | # source: https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb 2 | 3 | IMAGENET_TEMPLATES = [ 4 | "a bad photo of a {}.", 5 | "a photo of many {}.", 6 | "a sculpture of a {}.", 7 | "a photo of the hard to see {}.", 8 | "a low resolution photo of the {}.", 9 | "a rendering of a {}.", 10 | "graffiti of a {}.", 11 | "a bad photo of the {}.", 12 | "a cropped photo of the {}.", 13 | "a tattoo of a {}.", 14 | "the embroidered {}.", 15 | "a photo of a hard to see {}.", 16 | "a bright photo of a {}.", 17 | "a photo of a clean {}.", 18 | "a photo of a dirty {}.", 19 | "a dark photo of the {}.", 20 | "a drawing of a {}.", 21 | "a photo of my {}.", 22 | "the plastic {}.", 23 | "a photo of the cool {}.", 24 | "a close-up photo of a {}.", 25 | "a black and white photo of the {}.", 26 | "a painting of the {}.", 27 | "a painting of a {}.", 28 | "a pixelated photo of the {}.", 29 | "a sculpture of the {}.", 30 | "a bright photo of the {}.", 31 | "a cropped photo of a {}.", 32 | "a plastic {}.", 33 | "a photo of the dirty {}.", 34 | "a jpeg corrupted photo of a {}.", 35 | "a blurry photo of the {}.", 36 | "a photo of the {}.", 37 | "a good photo of the {}.", 38 | "a rendering of the {}.", 39 | "a {} in a video game.", 40 | "a photo of one {}.", 41 | "a doodle of a {}.", 42 | "a close-up photo of the {}.", 43 | "a photo of a {}.", 44 | "the origami {}.", 45 | "the {} in a video game.", 46 | "a sketch of a {}.", 47 | "a doodle of the {}.", 48 | "a origami {}.", 49 | "a low resolution photo of a {}.", 50 | "the toy {}.", 51 | "a rendition of the {}.", 52 | "a photo of the clean {}.", 53 | "a photo of a large {}.", 54 | "a rendition of a {}.", 55 | "a photo of a nice {}.", 56 | "a photo of a weird {}.", 57 | "a blurry photo of a {}.", 58 | "a cartoon {}.", 59 | "art of a {}.", 60 | "a sketch of the {}.", 61 | "a embroidered {}.", 62 | "a pixelated photo of a {}.", 63 | "itap of the {}.", 64 | "a jpeg corrupted photo of the {}.", 65 | "a good photo of a {}.", 66 | "a plushie {}.", 67 | "a photo of the nice {}.", 68 | "a photo of the small {}.", 69 | "a photo of the weird {}.", 70 | "the cartoon {}.", 71 | "art of the {}.", 72 | "a drawing of the {}.", 73 | "a photo of the large {}.", 74 | "a black and white photo of a {}.", 75 | "the plushie {}.", 76 | "a dark photo of a {}.", 77 | "itap of a {}.", 78 | "graffiti of the {}.", 79 | "a toy {}.", 80 | "itap of my {}.", 81 | "a photo of a cool {}.", 82 | "a photo of a small {}.", 83 | "a tattoo of the {}.", 84 | ] 85 | 86 | IMAGENET_TEMPLATES_SELECT = [ 87 | "itap of a {}.", 88 | "a bad photo of the {}.", 89 | "a origami {}.", 90 | "a photo of the large {}.", 91 | "a {} in a video game.", 92 | "art of the {}.", 93 | "a photo of the small {}.", 94 | ] 95 | -------------------------------------------------------------------------------- /apt/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from collections import OrderedDict 5 | from typing import Tuple, TypeVar 6 | from torch import Tensor 7 | 8 | from clip import clip 9 | from trainers.apt import PromptLearner, TextEncoder 10 | 11 | 12 | mu = (0.48145466, 0.4578275, 0.40821073) 13 | std = (0.26862954, 0.26130258, 0.27577711) 14 | 15 | class ImageNormalizer(nn.Module): 16 | 17 | def __init__(self, mean: Tuple[float, float, float], 18 | std: Tuple[float, float, float]) -> None: 19 | super(ImageNormalizer, self).__init__() 20 | 21 | self.register_buffer('mean', torch.as_tensor(mean).view(1, 3, 1, 1)) 22 | self.register_buffer('std', torch.as_tensor(std).view(1, 3, 1, 1)) 23 | 24 | def forward(self, input: Tensor) -> Tensor: 25 | return (input - self.mean) / self.std 26 | 27 | def __repr__(self): 28 | return f'ImageNormalizer(mean={self.mean.squeeze()}, std={self.std.squeeze()})' # type: ignore 29 | 30 | 31 | class CustomCLIP(nn.Module): 32 | def __init__(self, 33 | model, 34 | classnames, 35 | cls_prompt='a photo of a {}', 36 | atk_prompt=None, 37 | cfg=None): 38 | super().__init__() 39 | 40 | self.cfg = cfg 41 | self.logit_scale = model.logit_scale 42 | self.classnames = classnames 43 | self.model = model 44 | self.mode = 'classification' 45 | 46 | self.normalize = ImageNormalizer(mu, std).cuda() 47 | 48 | self.set_prompts(cls_prompt, atk_prompt) 49 | 50 | def _prompt_text_features(self, prompt): 51 | if '{}' in prompt: 52 | # manual prompt template 53 | prompts = torch.cat([clip.tokenize(prompt.format(c)) 54 | for c in self.classnames]) 55 | self.model = self.model 56 | text_features = self.model.encode_text(prompts) 57 | else: 58 | # optimized prompt vector 59 | prompter_ckp = prompt 60 | assert os.path.isfile(prompter_ckp) 61 | prompter = PromptLearner(self.cfg, self.classnames, self.model) 62 | 63 | state_dict = torch.load(prompter_ckp)["state_dict"] 64 | 65 | # Ignore fixed token vectors 66 | if "token_prefix" in state_dict: 67 | del state_dict["token_prefix"] 68 | 69 | if "token_suffix" in state_dict: 70 | del state_dict["token_suffix"] 71 | 72 | prompter.load_state_dict(state_dict, strict=False) 73 | text_encoder = TextEncoder(self.model) 74 | prompts = prompter() 75 | text_features = text_encoder(prompts, prompter.tokenized_prompts) 76 | 77 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 78 | return text_features.detach() 79 | 80 | def set_prompts(self, cls_prompt, atk_prompt=None): 81 | print(f'classification prompt: {cls_prompt}') 82 | self.cls_tfeatures = self._prompt_text_features(cls_prompt).cuda() 83 | 84 | if atk_prompt is None or cls_prompt == atk_prompt: 85 | print(f'attack prompt: {cls_prompt}') 86 | self.atk_tfeatures = self.cls_tfeatures 87 | else: 88 | print(f'attack prompt: {atk_prompt}') 89 | self.atk_tfeatures = self._prompt_text_features(atk_prompt).cuda() 90 | 91 | def forward(self, image): 92 | image_features = self.model.encode_image(self.normalize(image)) 93 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 94 | 95 | logit_scale = self.logit_scale.exp() 96 | 97 | text_features = self.cls_tfeatures if self.mode == 'classification' else self.atk_tfeatures 98 | logits = logit_scale * image_features @ text_features.t() 99 | 100 | return logits 101 | 102 | 103 | class AverageMeter(object): 104 | """Computes and stores the average and current value""" 105 | def __init__(self, name, fmt=':f'): 106 | self.name = name 107 | self.fmt = fmt 108 | self.reset() 109 | 110 | def reset(self): 111 | self.val = 0 112 | self.avg = 0 113 | self.sum = 0 114 | self.count = 0 115 | 116 | def update(self, val, n=1): 117 | self.val = val 118 | self.sum += val * n 119 | self.count += n 120 | self.avg = self.sum / self.count 121 | 122 | def __str__(self): 123 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 124 | return fmtstr.format(**self.__dict__) 125 | 126 | 127 | class ProgressMeter(object): 128 | def __init__(self, num_batches, meters, prefix=""): 129 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 130 | self.meters = meters 131 | self.prefix = prefix 132 | 133 | def display(self, batch): 134 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 135 | entries += [str(meter) for meter in self.meters] 136 | print('\t'.join(entries)) 137 | 138 | def _get_batch_fmtstr(self, num_batches): 139 | num_digits = len(str(num_batches // 1)) 140 | fmt = '{:' + str(num_digits) + 'd}' 141 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 142 | 143 | def accuracy(output, target, topk=(1,)): 144 | """Computes the accuracy over the k top predictions for the specified values of k""" 145 | with torch.no_grad(): 146 | maxk = max(topk) 147 | batch_size = target.size(0) 148 | 149 | _, pred = output.topk(maxk, 1, True, True) 150 | pred = pred.t() 151 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 152 | 153 | res = [] 154 | for k in topk: 155 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 156 | res.append(correct_k.mul_(100.0 / batch_size)) 157 | return res 158 | -------------------------------------------------------------------------------- /assets/one_word_boost.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreeLLi/APT/c2e99e09329b9a2be87d02e5c7ae02092d2ac95d/assets/one_word_boost.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy 2 | regex 3 | tqdm 4 | torchattacks 5 | git+https://github.com/fra31/auto-attack 6 | --------------------------------------------------------------------------------