├── .gitignore ├── CITATION.cff ├── LICENSE ├── README.md ├── data └── dataset │ ├── birds_525.py │ ├── car.py │ ├── cub.py │ ├── dog.py │ ├── pet.py │ └── utils.py ├── demo.ipynb ├── engine ├── optimizer.py └── trainer.py ├── env_setup.sh ├── experiment ├── build_loader.py ├── build_model.py ├── config │ └── prompt_cam │ │ ├── dino │ │ ├── birds_525 │ │ │ └── args.yaml │ │ ├── car │ │ │ └── args.yaml │ │ ├── cub │ │ │ └── args.yaml │ │ ├── dog │ │ │ └── args.yaml │ │ └── pet │ │ │ └── args.yaml │ │ └── dinov2 │ │ ├── cub │ │ └── args.yaml │ │ ├── dog │ │ └── args.yaml │ │ └── pet │ │ └── args.yaml ├── run.py └── visualize_run.py ├── main.py ├── model ├── attention.py ├── block.py ├── mlp.py ├── patch_embed.py ├── utils.py ├── vision_transformer.py └── vpt.py ├── samples ├── Baltimore_Oriole.jpg ├── Brewer_Blackbird.jpg ├── Orchard_Oriole.jpg ├── Scott_Oriole.jpg ├── red_winged_blackbird.jpg ├── rusty_Blackbird.jpg ├── sample_image.png ├── trait_manipulation.jpg └── yellow_headed_blackbird.jpg ├── utils ├── distributed.py ├── file_io.py ├── global_var.py ├── log_utils.py ├── misc.py ├── setup_logging.py └── visual_utils.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | checkpoints/ 3 | *.pyc 4 | **/__pycache__/ 5 | output/ 6 | pretrained_weights/ 7 | data/images/ 8 | visualization/ 9 | data/annotations/ -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | # This CITATION.cff file was generated with cffinit. 2 | # Visit https://bit.ly/cffinit to generate yours today! 3 | 4 | cff-version: 1.2.0 5 | title: >- 6 | Prompt-CAM: Making Vision Transformers Interpretable for 7 | Fine-Grained Analysis 8 | message: >- 9 | If you use this software, please cite it using the 10 | metadata from this file. 11 | type: software 12 | authors: 13 | - given-names: Arpita 14 | family-names: Chowdhury 15 | - given-names: Dipanjyoti 16 | family-names: Paul 17 | - given-names: Zheda 18 | family-names: Mai 19 | - given-names: Jianyang 20 | family-names: Gu 21 | - given-names: Ziheng 22 | family-names: Zhang 23 | - given-names: Kazi Sajeed 24 | family-names: Mehrab 25 | - given-names: Elizabeth G. 26 | family-names: Campolongo 27 | - given-names: Daniel 28 | family-names: Rubenstein 29 | - given-names: Charles V. 30 | family-names: Stewart 31 | - given-names: Anuj 32 | family-names: Karpatne 33 | - given-names: Tanya 34 | family-names: Berger-Wolf 35 | - given-names: Yu 36 | family-names: Su 37 | - given-names: Wei-Lun 38 | family-names: Chao 39 | identifiers: 40 | - type: url 41 | value: 'https://arxiv.org/pdf/2501.09333' 42 | repository-code: 'https://github.com/Imageomics/Prompt_CAM' 43 | abstract: >- 44 | We present a simple usage of pre-trained Vision 45 | Transformers (ViTs) for fine-grained analysis, aiming to 46 | identify and localize the traits that distinguish visually 47 | similar categories, such as different bird species or dog 48 | breeds. 49 | keywords: 50 | - explainable-ai 51 | - interpretable-ai 52 | - imageomics 53 | - fine-grained-classification 54 | - vision-transformer 55 | - interpretable 56 | license: MIT 57 | commit: Prompt-CAM 58 | version: 1.0.0 59 | date-released: '2025-03-24' 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Arpita Chowdhury 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # :mag: Prompt-CAM: Making Vision Transformers Interpretable for Fine-Grained Analysis(CVPR'25) 2 | 3 | This is an official implementation for [PROMPT-CAM: Making Vision Transformers Interpretable for Fine-Grained Analysis](https://arxiv.org/pdf/2501.09333) (CVPR'25) 4 | 5 | Introducing **Prompt-CAM**, a $${\textcolor{red}{\text{simple yet effective}}}$$ **interpretable transformer** that requires no architectural modifications to pre-trained ViTs, we just have to inject **class-specific prompts** into any ViT to make them interpretable. 6 | 7 | Prompt CAM lets us explore: 8 | - 🧠 What the model thinks is important for each class? 9 | - ✨ Which traits are shared between two bird species? 10 | - 🎨 How different classes ‘see’ the same image differently! 11 | 12 |

13 | 14 |

15 | 16 | ## Quick Start: Try out the demo 17 | 🔍 Ever wondered what traits stand out when a model looks at an image of one class but searches with another class in mind? 🤔 18 | Witness the important traits of different class through the lens of Prompt-CAM with our interactive demos! 19 | 20 | 👉 Try our demo **without installing anything** in Gooogle Colab [![](https://img.shields.io/badge/Google_Colab-blue)](https://colab.research.google.com/drive/1co1P5LXSVb-g0hqv8Selfjq4WGxSpIFe?usp=sharing) 21 | 22 | 👉 Try our demo locally in [![](https://img.shields.io/badge/notebook-orange 23 | )](demo.ipynb) 24 | - Setup the [envoiroment](#environment-setup) 25 | - download the pre-trained model from link below! 26 | - run the demo. 27 | 28 | 29 | 👉 You can extend this code base to include: [New datasets](#to-add-a-new-dataset) and [New backbones](#to-add-a-new-backbone) 30 | 31 | 32 | 33 | ## Environment Setup 34 | ```bash 35 | conda create -n prompt_cam python=3.7 36 | conda activate prompt_cam 37 | source env_setup.sh 38 | ``` 39 | 40 | 41 | ## Data Preparation 42 | You can put all the data in a folder and pass the path to `--data_path` argument. 43 | 44 | The structure of `data/images/`should be organized as follows: 45 | 46 | ``` 47 | cub/ 48 | ├── train/ 49 | │ ├── 001.Black_footed_Albatross/ 50 | │ │ ├── image_1.jpg 51 | │ │ ├── image_2.jpg 52 | │ │ └── ... 53 | │ ├── 002.Laysan_Albatross/ 54 | │ │ ├── image_1.jpg 55 | │ │ ├── image_2.jpg 56 | │ │ └── ... 57 | │ └── ... 58 | └── val/ 59 | ├── 001.Black_footed_Albatross/ 60 | │ ├── image_1.jpg 61 | │ ├── image_2.jpg 62 | │ └── ... 63 | ├── 002.Laysan_Albatross/ 64 | │ ├── image_1.jpg 65 | ``` 66 | 67 |
68 | Prepare CUB dataset 69 | 70 | ## CUB 71 | 72 | - Download prepared dataset 73 | - From [![](https://img.shields.io/badge/google_drive-yellow)](https://drive.google.com/drive/folders/1X3ikQEk_D7cKcyCnxbF3kJTsZ0LZfvVO?usp=sharing) 74 | - `Or` Prepare the dataset by yourself 75 | - You can download the CUB dataset from [the original website](https://www.vision.caltech.edu/datasets/cub_200_2011/) and put it in the `data/images/` folder. 76 | - You can use the dataset's provided train/val split to create the train/val splits and have their class numbers as the `prefix` of the respective image folder names(starting from 1). 77 | - The code will automatically create train and val annotation files in the `data/annotations/` folder for each dataset if not provided. 78 | 79 |
80 |
81 | Prepare Oxford Pet dataset 82 | 83 | ## Pet Dataset 84 | - Download prepared dataset 85 | - From [![](https://img.shields.io/badge/google_drive-yellow 86 | )](https://drive.google.com/drive/folders/1X3ikQEk_D7cKcyCnxbF3kJTsZ0LZfvVO?usp=sharing) 87 |
88 | 89 | **To add new dataset, see [Extensions](#extensions)** 90 | 91 | ## Results + Checkpoints: 92 | - Download from the links below and put it in the `checkpoints/{model}/{dataset}/` folder. 93 | 94 | Backbone | Dataset | Prompt-CAM(Acc top%1) | Checkpoint Link| 95 | --- | --- | --- | --- | 96 | dino | cub (CUB)| 73.2 | [url](https://drive.google.com/drive/folders/1UmHdGx4OtWCQ1GhHCrBArQeeX14FqwyY?usp=sharing) | 97 | dino | car (Stanford Cars) | 83.2 | [url](https://drive.google.com/drive/folders/1UmHdGx4OtWCQ1GhHCrBArQeeX14FqwyY?usp=sharing) | 98 | dino | dog (Stanford Dogs) | 81.1 |[url](https://drive.google.com/drive/folders/1UmHdGx4OtWCQ1GhHCrBArQeeX14FqwyY?usp=sharing) | 99 | dino | pet (Oxford Pet) | 91.3 | [url](https://drive.google.com/drive/folders/1UmHdGx4OtWCQ1GhHCrBArQeeX14FqwyY?usp=sharing) | 100 | dino | birds_525 (Birds-525) | 98.8 | [url](https://drive.google.com/drive/folders/1UmHdGx4OtWCQ1GhHCrBArQeeX14FqwyY?usp=sharing) | 101 | 102 | Backbone | Dataset | Prompt-CAM(Acc top%1) | Checkpoint Link| 103 | --- | --- | --- | --- | 104 | dinov2 | cub (CUB) | 74.1 | [url](https://drive.google.com/drive/folders/1UmHdGx4OtWCQ1GhHCrBArQeeX14FqwyY?usp=sharing) | 105 | dinov2 | dog (Stanford Dogs) | 81.3| [url](https://drive.google.com/drive/folders/1UmHdGx4OtWCQ1GhHCrBArQeeX14FqwyY?usp=sharing) | 106 | dinov2 | pet (Oxford Pet) | 92.7 | [url](https://drive.google.com/drive/folders/1UmHdGx4OtWCQ1GhHCrBArQeeX14FqwyY?usp=sharing) | 107 | 108 | ## Evaluation and Visualization 109 | - download the checkpoint from url in the [Table](#results--checkpoints) above and put it in the `checkpoints/{model}/{dataset}/` folder. 110 | 111 | For example, to visualize the attention map of the DINO model on the class `024.Red_faced_Cormorant` of CUB dataset, put the checkpoint in `checkpoints/dino/cub/` folder and run the following command: 112 | 113 | ```python 114 | CUDA_VISIBLE_DEVICES=0 python visualize.py --config ./experiment/config/prompt_cam/dino/cub/args.yaml --checkpoint ./checkpoints/dino/cub/model.pt --vis_cls 23 115 | ``` 116 | 117 | - The output will be saved in the `visualization/dino/cub/class_23/` folder. 118 | - Inside the individual image folder, there will be `top_traits` heatmaps for the target class concatenated if the prediction is correct. Otherwise, all the traits will be concatenated. (the prediction is for the respective image can be found `concatenated_prediction_{predicted_class}.jpg`). 119 |
120 | Visualization Configuration Meaning 121 | 122 | - `config`: path to the config file. 123 | - `checkpoint`: path to the checkpoint file. 124 | - `vis_cls`: class number to visualize. (default: 23) 125 | - `vis_attn`: set to True to visualize the attention map. (default: True) 126 | - `top_traits`: number of traits to visualize. (default: 4) 127 | - `nmbr_samples`: number of images from the `vis_cls to visualize. (default: 10) 128 | - `vis_outdir`: output directory. (default: visualization/) 129 |
130 | 131 | 132 | 133 | ## :fire: Training 134 | 135 | ### :one: Pretrained weights 136 | --- 137 | 138 | Download the pretrained weights from the following links and put them in the `pretrained_weights` folder. 139 | 1. [ViT-B-DINO](https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth) rename it as `dino_vitbase16_pretrain.pth` 140 | 2. [ViT-B-DINOV2](https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth) rename it as `dinov2_vitb14_pretrain.pth` 141 | ### :two: Load dataset 142 | --- 143 | 144 | See [Data Preparation](#data-preparation) above. 145 | ### :three: Start training 146 | --- 147 | 148 | 👉 To train the model on the `CUB dataset` using the `DINO` model, run the following command: 149 | ```python 150 | CUDA_VISIBLE_DEVICES=0 python main.py --config ./experiment/config/prompt_cam/dino/cub/args.yaml 151 | ``` 152 | The checkpoint will be saved in the `output/vit_base_patch16_dino/cub/` folder. Copy the checkpoint `model.pt` to the `checkpoints/dino/cub/` folder. 153 | 154 | --- 155 | 156 | 👉 To train the model on the `Oxford Pet dataset` using the `DINO` model, run the following command: 157 | ```python 158 | CUDA_VISIBLE_DEVICES=0 python main.py --config ./experiment/config/prompt_cam/dino/pet/args.yaml 159 | ``` 160 | The checkpoint will be saved in the `output/vit_base_patch14_dino/pet/` folder. Copy the checkpoint `model.pt` to the `checkpoints/dino/pet/` folder. 161 | 162 | --- 163 | 164 | 👉 To train the model on the `Oxford Pet dataset` using the `DINOv2` model, run the following command: 165 | ```python 166 | CUDA_VISIBLE_DEVICES=0 python main.py --config ./experiment/config/prompt_cam/dinov2/pet/args.yaml 167 | ``` 168 | 169 | The checkpoint will be saved in the `output/vit_base_patch14_dinov2/pet/` folder. Copy the checkpoint `model.pt` to the `checkpoints/dinov2/pet/` folder. 170 | 171 | --- 172 | 173 | ### :four: :mag: Visualize the attention map 174 | --- 175 | 176 | See [Visualization](#evaluation-and-visualization) above. 177 | 178 | ## Extensions 179 | ### To add a new dataset 180 | 1. Prepare dataset using above [instructions](#data-preparation). 181 | 2. add a new dataset file in `/data/dataset`. [ Look at the existing dataset files for reference.](data/dataset/cub.py) 182 | 3. modify [build_loader.py](experiment/build_loader.py) to include the new dataset. 183 | 4. create a new config file in `experiment/config/prompt_cam/{model}/{dataset}/args.yaml` 184 | - See `experiment/config/prompt_cam/dino/cub/args.yaml` for reference and what to modify. 185 | 186 | ### To add a new backbone 187 | - modify `get_base_model()` in [build_model.py](experiment/build_model.py). 188 | - register the new backbone in [vision_transformer.py](model/vision_transformer.py) by creating a new function. 189 | - add another option in `--pretrained_weights` and `--model` in `setup_parser()` function of [main.py](main.py) to include the new backbone. 190 | 191 | 192 | # Citation [![Paper](https://img.shields.io/badge/paper-2501.09333-blue)](https://arxiv.org/pdf/2501.09333) 193 | If you find this repository useful, please consider citing our work :pencil: and giving a star :star2: : 194 | ``` 195 | @article{chowdhury2025prompt, 196 | title={Prompt-CAM: A Simpler Interpretable Transformer for Fine-Grained Analysis}, 197 | author={Chowdhury, Arpita and Paul, Dipanjyoti and Mai, Zheda and Gu, Jianyang and Zhang, Ziheng and Mehrab, Kazi Sajeed and Campolongo, Elizabeth G and Rubenstein, Daniel and Stewart, Charles V and Karpatne, Anuj and others}, 198 | journal={arXiv preprint arXiv:2501.09333}, 199 | year={2025} 200 | } 201 | ``` 202 | ### Acknowledgement 203 | 204 | - VPT: https://github.com/KMnP/vpt 205 | - PETL_VISION: https://github.com/OSU-MLB/PETL_Vision 206 | 207 | Thanks for their wonderful works. 208 | 209 | 🛠 create an issue for any contributions. 210 | -------------------------------------------------------------------------------- /data/dataset/birds_525.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets.folder import ImageFolder, default_loader 2 | from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 3 | import torchvision.transforms as transforms 4 | from timm.data.transforms import str_to_interp_mode 5 | from data.dataset.utils import add_samples, create_annotation_file, get_transformation 6 | import os 7 | 8 | class Birds_525_Dataset(ImageFolder): ## For new data ## Change this:: the name of the class to match the dataset name 9 | def __init__(self, root, data_list, transform=None): 10 | self.data_root = root 11 | self.loader = default_loader 12 | self.transform = transform 13 | self.target_transform = None 14 | self.samples = [] 15 | 16 | add_samples(self.samples, data_list, root) 17 | 18 | 19 | def get_birds_525(params, mode='trainval_combined'): ## For new data ## Change this:: the name of the dataset name 20 | params.class_num = 525 ## For new data ## Change this:: the number of classes in the dataset 21 | mean, std = IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 22 | 23 | transform_train = get_transformation('train', mean, std) 24 | transform_val = get_transformation('val', mean, std) 25 | 26 | if mode == 'trainval_combined': 27 | train_data_list = f'data/annotations/birds_525/{params.data}_combine.txt' ## For new data ## Change this:: the name of the data in the path 28 | 29 | if not os.path.exists(train_data_list): 30 | create_annotation_file(params.data_path, ['train'],train_data_list) 31 | return Birds_525_Dataset(params.data_path, train_data_list, transform_train) ## For new data ## Change this:: the class to call 32 | 33 | elif mode == 'test': 34 | test_data_list = f'data/annotations/birds_525/test.txt' ## For new data ## Change this:: the name of the data in the path 35 | if not os.path.exists(test_data_list): 36 | create_annotation_file(params.data_path,['val'],test_data_list) 37 | return Birds_525_Dataset(params.data_path, test_data_list, transform_val) ## For new data ## Change this:: the class to call 38 | else: 39 | raise NotImplementedError 40 | 41 | 42 | -------------------------------------------------------------------------------- /data/dataset/car.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets.folder import ImageFolder, default_loader 2 | from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 3 | import torchvision.transforms as transforms 4 | from timm.data.transforms import str_to_interp_mode 5 | from data.dataset.utils import add_samples, create_annotation_file, get_transformation 6 | import os 7 | 8 | class CarDataset(ImageFolder): 9 | def __init__(self, root, data_list, transform=None): 10 | self.data_root = root 11 | self.loader = default_loader 12 | self.transform = transform 13 | self.target_transform = None 14 | self.samples = [] 15 | 16 | add_samples(self.samples, data_list, root) 17 | 18 | 19 | def get_car(params, mode='trainval_combined'): 20 | params.class_num = 196 21 | mean, std = IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 22 | 23 | transform_train = get_transformation('train', mean, std) 24 | transform_val = get_transformation('val', mean, std) 25 | 26 | if mode == 'trainval_combined': 27 | train_data_list = f'data/annotations/car/{params.data}_combine.txt' 28 | 29 | if not os.path.exists(train_data_list): 30 | create_annotation_file(params.data_path, ['train'],train_data_list) 31 | return CarDataset(params.data_path, train_data_list, transform_train) 32 | 33 | elif mode == 'test': 34 | test_data_list = f'data/annotations/car/test.txt' 35 | if not os.path.exists(test_data_list): 36 | create_annotation_file(params.data_path,['val'],test_data_list) 37 | return CarDataset(params.data_path, test_data_list, transform_val) 38 | else: 39 | raise NotImplementedError 40 | 41 | 42 | -------------------------------------------------------------------------------- /data/dataset/cub.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets.folder import ImageFolder, default_loader 2 | from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 3 | import torchvision.transforms as transforms 4 | from timm.data.transforms import str_to_interp_mode 5 | from data.dataset.utils import add_samples, create_annotation_file, get_transformation 6 | import os 7 | 8 | class CubDataset(ImageFolder): ## For new data ## Change this:: the name of the class to match the dataset name 9 | def __init__(self, root, data_list, transform=None): 10 | self.data_root = root 11 | self.loader = default_loader 12 | self.transform = transform 13 | self.target_transform = None 14 | self.samples = [] 15 | 16 | add_samples(self.samples, data_list, root) 17 | 18 | 19 | def get_cub(params, mode='trainval_combined'): ## For new data ## Change this:: the name of the dataset name 20 | params.class_num = 200 ## For new data ## Change this:: the number of classes in the dataset 21 | mean, std = IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 22 | 23 | transform_train = get_transformation('train', mean, std) 24 | transform_val = get_transformation('val', mean, std) 25 | 26 | if mode == 'trainval_combined': 27 | train_data_list = f'data/annotations/cub/{params.data}_combine.txt' ## For new data ## Change this:: the name of the data in the path 28 | 29 | if not os.path.exists(train_data_list): 30 | create_annotation_file(params.data_path, ['train'],train_data_list) 31 | return CubDataset(params.data_path, train_data_list, transform_train) ## For new data ## Change this:: the class to call 32 | 33 | elif mode == 'test': 34 | test_data_list = f'data/annotations/cub/test.txt' ## For new data ## Change this:: the name of the data in the path 35 | if not os.path.exists(test_data_list): 36 | create_annotation_file(params.data_path,['val'],test_data_list) 37 | return CubDataset(params.data_path, test_data_list, transform_val) ## For new data ## Change this:: the class to call 38 | else: 39 | raise NotImplementedError 40 | 41 | 42 | -------------------------------------------------------------------------------- /data/dataset/dog.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets.folder import ImageFolder, default_loader 2 | from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 3 | import torchvision.transforms as transforms 4 | from timm.data.transforms import str_to_interp_mode 5 | from data.dataset.utils import add_samples, create_annotation_file, get_transformation 6 | import os 7 | 8 | class DogDataset(ImageFolder): 9 | def __init__(self, root, data_list, transform=None): 10 | self.data_root = root 11 | self.loader = default_loader 12 | self.transform = transform 13 | self.target_transform = None 14 | self.samples = [] 15 | 16 | add_samples(self.samples, data_list, root) 17 | 18 | 19 | def get_dog(params, mode='trainval_combined'): 20 | params.class_num = 120 21 | mean, std = IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 22 | 23 | transform_train = get_transformation('train', mean, std) 24 | transform_val = get_transformation('val', mean, std) 25 | 26 | if mode == 'trainval_combined': 27 | train_data_list = f'data/annotations/dog/{params.data}_combine.txt' 28 | 29 | if not os.path.exists(train_data_list): 30 | create_annotation_file(params.data_path, ['train'],train_data_list) 31 | return DogDataset(params.data_path, train_data_list, transform_train) 32 | 33 | elif mode == 'test': 34 | test_data_list = f'data/annotations/dog/test.txt' 35 | if not os.path.exists(test_data_list): 36 | create_annotation_file(params.data_path,['val'],test_data_list) 37 | return DogDataset(params.data_path, test_data_list, transform_val) 38 | else: 39 | raise NotImplementedError 40 | 41 | 42 | -------------------------------------------------------------------------------- /data/dataset/pet.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets.folder import ImageFolder, default_loader 2 | from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 3 | import torchvision.transforms as transforms 4 | from timm.data.transforms import str_to_interp_mode 5 | from data.dataset.utils import add_samples, create_annotation_file, get_transformation 6 | import os 7 | 8 | class PetDataset(ImageFolder): 9 | def __init__(self, root, data_list, transform=None): 10 | self.data_root = root 11 | self.loader = default_loader 12 | self.transform = transform 13 | self.target_transform = None 14 | self.samples = [] 15 | 16 | add_samples(self.samples, data_list, root) 17 | 18 | 19 | def get_pet(params, mode='trainval_combined'): 20 | params.class_num = 37 21 | mean, std = IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 22 | 23 | transform_train = get_transformation('train', mean, std) 24 | transform_val = get_transformation('val', mean, std) 25 | 26 | if mode == 'trainval_combined': 27 | train_data_list = f'data/annotations/pet/{params.data}_combine.txt' 28 | 29 | if not os.path.exists(train_data_list): 30 | create_annotation_file(params.data_path, ['train'],train_data_list) 31 | return PetDataset(params.data_path, train_data_list, transform_train) 32 | 33 | elif mode == 'test': 34 | test_data_list = f'data/annotations/pet/test.txt' 35 | if not os.path.exists(test_data_list): 36 | create_annotation_file(params.data_path,['val'],test_data_list) 37 | return PetDataset(params.data_path, test_data_list, transform_val) 38 | else: 39 | raise NotImplementedError 40 | 41 | 42 | -------------------------------------------------------------------------------- /data/dataset/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import torchvision as tv 4 | 5 | import torch.utils.data as data 6 | from PIL import Image 7 | from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 8 | 9 | def add_samples(sample_list, list_path, root): 10 | with open(list_path, 'r') as f: 11 | for line in f: 12 | img_name = line.rsplit(' ', 1)[0] 13 | label = int(line.rsplit(' ', 1)[1]) 14 | sample_list.append((os.path.join(root, img_name), label)) 15 | 16 | def default_loader(path): 17 | return Image.open(path).convert('RGB') 18 | 19 | 20 | def default_flist_reader(flist): 21 | """ 22 | flist format: impath label\nimpath label\n ...(same to caffe's filelist) 23 | """ 24 | imlist = [] 25 | with open(flist, 'r') as rf: 26 | for line in rf.readlines(): 27 | impath, imlabel = line.strip().split() 28 | imlist.append((impath, int(imlabel))) 29 | 30 | return imlist 31 | 32 | 33 | class ImageFilelist(data.Dataset): 34 | def __init__(self, root, flist, name, transform=None, target_transform=None, 35 | flist_reader=default_flist_reader, loader=default_loader): 36 | self.root = root 37 | self.name = name 38 | self.imlist = flist_reader(flist) 39 | self.transform = transform 40 | self.target_transform = target_transform 41 | self.loader = loader 42 | 43 | def __getitem__(self, index): 44 | impath, target = self.imlist[index] 45 | img = self.loader(os.path.join(self.root, impath)) 46 | if self.transform is not None: 47 | img = self.transform(img) 48 | if self.target_transform is not None: 49 | target = self.target_transform(target) 50 | 51 | return img, target 52 | 53 | def __len__(self): 54 | return len(self.imlist) 55 | 56 | def get_transformation(mode, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD): 57 | if mode == 'train': 58 | return tv.transforms.Compose([ 59 | tv.transforms.Resize((240,240)), 60 | tv.transforms.RandomCrop((224,224)), 61 | tv.transforms.RandomHorizontalFlip(), 62 | tv.transforms.ToTensor(), 63 | tv.transforms.Normalize(mean, std), 64 | ]) 65 | elif mode == 'val' or mode == 'test': 66 | return tv.transforms.Compose([ 67 | tv.transforms.Resize((224,224)), 68 | tv.transforms.ToTensor(), 69 | tv.transforms.Normalize(mean, std), 70 | ]) 71 | else: 72 | raise ValueError("Invalid mode. Use 'train' or 'val' or 'test'.") 73 | 74 | def create_annotation_file(data_path, mode, output_file): 75 | """ 76 | Creates a file listing images and their corresponding labels based on the directory structure. 77 | 78 | Args: 79 | data_path: The root directory containing the dataset (e.g., dataset_name/). 80 | output_file: The path to the file where the image list will be written. 81 | """ 82 | 83 | # Create the output directory if it doesn't exist 84 | output_dir = os.path.dirname(output_file) # Get the directory part of the path 85 | if output_dir and not os.path.exists(output_dir): 86 | os.makedirs(output_dir) 87 | 88 | image_list = [] 89 | label_map = {} # Store species names and corresponding numerical labels 90 | 91 | # Assign numerical labels to species names in the order they're encountered 92 | label_counter = 0 93 | 94 | for split in mode: # Iterate through train and val splits 95 | split_path = os.path.join(data_path, split) 96 | if not os.path.exists(split_path): 97 | continue # Skip if split does not exist 98 | 99 | for species_dir in sorted(os.listdir(split_path)): # Iterate through species directories 100 | if not os.path.isdir(os.path.join(split_path, species_dir)): 101 | continue # Skip if not a directory 102 | species_name = species_dir.split('.', 1)[1] if '.' in species_dir else species_dir # extract species name 103 | if species_name not in label_map: 104 | label_map[species_name] = label_counter 105 | label_counter += 1 106 | label = label_map[species_name] 107 | 108 | image_dir = os.path.join(split_path, species_dir) 109 | for image_file in os.listdir(image_dir): # Iterate through image files 110 | if os.path.isfile(os.path.join(image_dir, image_file)): # ensure it is a file 111 | image_path = os.path.join(split, species_dir, image_file) # Relative path from dataset root 112 | image_list.append(f"{image_path} {label}") 113 | 114 | # Write the image list to the output file 115 | with open(output_file, "w") as f: 116 | f.write("\n".join(image_list)) 117 | -------------------------------------------------------------------------------- /engine/optimizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | In VPT repo, they handle bias term very carefully when weight decay > 0 for different optimizers. 3 | They also used a special implementation of Adamw from huggingface with weight decay fix. Check their repo's optimizer.py 4 | for additional information. 5 | We skip them for now as AdaptFormer and ConvPass skip this as either 6 | 7 | """ 8 | 9 | import torch.optim as optim 10 | from typing import Any, Callable, Iterable, List, Tuple, Optional 11 | 12 | from utils.setup_logging import get_logger 13 | 14 | logger = get_logger("Prompt_CAM") 15 | 16 | 17 | def make_optimizer(tune_parameters, params): 18 | if params.optimizer == 'adam': 19 | optimizer = optim.Adam( 20 | tune_parameters, 21 | lr=params.lr, 22 | weight_decay=params.wd, 23 | ) 24 | 25 | elif params.optimizer == 'adamw': 26 | optimizer = optim.AdamW( 27 | tune_parameters, 28 | lr=params.lr, 29 | weight_decay=params.wd, 30 | ) 31 | else: 32 | optimizer = optim.SGD( 33 | tune_parameters, 34 | lr=params.lr, 35 | weight_decay=params.wd, 36 | momentum=params.momentum, 37 | ) 38 | return optimizer 39 | 40 | 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /engine/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import torch.nn as nn 4 | from timm.scheduler.cosine_lr import CosineLRScheduler 5 | from collections import OrderedDict 6 | from engine.optimizer import make_optimizer 7 | from utils.misc import AverageMeter, EarlyStop 8 | from utils.setup_logging import get_logger 9 | from timm.utils import accuracy, update_summary 10 | import numpy as np 11 | logger = get_logger("Prompt_CAM") 12 | torch.backends.cudnn.benchmark = False 13 | 14 | class Trainer(): 15 | """ 16 | a trainer with below logics: 17 | 18 | 1. Build optimizer, scheduler 19 | 2. Load checkpoints if provided 20 | 3. Train and eval at each epoch 21 | """ 22 | 23 | def __init__( 24 | self, 25 | model, tune_parameters, params 26 | ) -> None: 27 | self.params = params 28 | self.model = model 29 | self.device = params.device 30 | self.cls_criterion = nn.CrossEntropyLoss() 31 | 32 | if 'test_data' not in params: 33 | # solver related 34 | logger.info("\tSetting up the optimizer...") 35 | self.optimizer = make_optimizer(tune_parameters, params) 36 | self.scheduler = CosineLRScheduler(self.optimizer, t_initial=params.epoch, 37 | warmup_t=params.warmup_epoch, lr_min=params.lr_min, 38 | warmup_lr_init=params.warmup_lr_init) 39 | self.total_epoch = self.params.epoch 40 | if self.params.early_patience > 0: 41 | self.early_stop_check = EarlyStop(self.params.early_patience) 42 | 43 | def forward_one_batch(self, samples, targets, is_train): 44 | """Train a single (full) epoch on the model using the given 45 | data loader. 46 | 47 | Args: 48 | samples 49 | targets 50 | is_train: bool 51 | Returns: 52 | loss 53 | outputs: output logits 54 | """ 55 | # move data to device 56 | samples = samples.to(self.device, non_blocking=True) # (batchsize, 2048) 57 | targets = targets.to(self.device, non_blocking=True) # (batchsize, ) 58 | 59 | # forward 60 | with torch.set_grad_enabled(is_train): 61 | outputs,_ = self.model(samples) # (batchsize, num_cls) 62 | if self.params.train_type == 'prompt_cam': 63 | outputs = outputs.squeeze(-1) 64 | loss = self.cls_criterion(outputs, targets) 65 | 66 | if loss == float('inf'): 67 | logger.info( 68 | "encountered infinite loss, skip gradient updating for this batch!" 69 | ) 70 | return -1, -1, (-1, -1) 71 | elif torch.isnan(loss).any(): 72 | logger.info( 73 | "encountered nan loss, skip gradient updating for this batch!" 74 | ) 75 | return -1, -1, (-1, -1) 76 | 77 | acc1, acc5 = accuracy(outputs, targets, topk=(1, 5)) 78 | # =======backward and optim step only if in training phase... ========= 79 | if is_train: 80 | self.optimizer.zero_grad() 81 | loss.backward() 82 | #for name, param in self.model.named_parameters(): 83 | # if param.grad is not None: 84 | # print(f"After backward: {name} has grad with mean {param.grad.mean().item()}") 85 | 86 | self.optimizer.step() 87 | 88 | return loss, outputs, (acc1, acc5) 89 | 90 | def train_one_epoch(self, epoch, loader): 91 | loss_m = AverageMeter() 92 | top1_m = AverageMeter() 93 | top5_m = AverageMeter() 94 | lr = self.scheduler._get_lr(epoch) 95 | logger.info( 96 | "Training {} / {} epoch, with learning rate {}".format( 97 | epoch + 1, self.total_epoch, lr 98 | ) 99 | ) 100 | # Enable training mode 101 | self.model.train() 102 | 103 | num_updates = epoch * len(loader) 104 | for idx, (samples, targets) in enumerate(loader): 105 | train_loss, _, (acc1, acc5) = self.forward_one_batch(samples, targets, True) 106 | if not isinstance(train_loss, int): 107 | loss_m.update(train_loss.item(), samples.shape[0]) 108 | top1_m.update(acc1.item(), samples.shape[0]) 109 | top5_m.update(acc5.item(), samples.shape[0]) 110 | del train_loss, acc1, acc5, _, samples, targets 111 | num_updates += 1 112 | self.scheduler.step_update(num_updates=num_updates, metric=loss_m.avg) 113 | 114 | logger.info( 115 | "Epoch {} / {}: ".format(epoch + 1, self.total_epoch) 116 | + "average train loss: {:.2f}, ".format(loss_m.avg) 117 | + "average train top1: {:.2f} ".format(top1_m.avg) 118 | + "average train top5: {:.2f}".format(top5_m.avg)) 119 | 120 | return OrderedDict( 121 | [('loss', round(loss_m.avg, 2)), ('top1', round(top1_m.avg, 2)), ('top5', round(top5_m.avg, 2))]) 122 | 123 | def train_classifier(self, train_loader, val_loader, test_loader): 124 | """ 125 | Train a classifier using epoch 126 | """ 127 | 128 | for epoch in range(self.total_epoch): 129 | train_metrics = self.train_one_epoch(epoch, train_loader) 130 | 131 | if (epoch % self.params.eval_freq == 0) or epoch == self.total_epoch - 1: 132 | if test_loader is not None: 133 | eval_metrics = self.eval_classifier( 134 | test_loader, "test") 135 | elif val_loader is not None: 136 | eval_metrics = self.eval_classifier(val_loader, "val") 137 | else: 138 | raise Exception('Both val and test loaders are missing. ') 139 | 140 | if self.params.early_patience > 0: 141 | stop, save_model = self.early_stop_check.early_stop(eval_metrics) 142 | if save_model and self.params.store_ckp: 143 | torch.save({'model_state_dict': self.model.state_dict()}, 144 | os.path.join(self.params.output_dir, 'model.pt')) 145 | if stop: 146 | return train_metrics, self.early_stop_check.max_metrics, eval_metrics 147 | if self.params.debug: 148 | update_summary( 149 | epoch, train_metrics, eval_metrics, os.path.join(self.params.output_dir, 'summary.csv'), 150 | write_header=epoch == 0) 151 | self.scheduler.step(epoch) 152 | 153 | if self.params.store_ckp and not os.path.isfile(os.path.join(self.params.output_dir, 'model.pt')): 154 | torch.save({'model_state_dict': self.model.state_dict()}, os.path.join(self.params.output_dir, 'model.pt')) 155 | return train_metrics, self.early_stop_check.max_metrics, eval_metrics 156 | 157 | @torch.no_grad() 158 | def eval_classifier(self, loader, prefix): 159 | """evaluate classifier""" 160 | 161 | loss_m = AverageMeter() 162 | top1_m = AverageMeter() 163 | top5_m = AverageMeter() 164 | 165 | # Enable eval mode 166 | self.model.eval() 167 | 168 | with torch.no_grad(): 169 | for batch_idx, (samples, targets) in enumerate(loader): 170 | loss, outputs, (acc1, acc5) = self.forward_one_batch(samples, targets, False) 171 | if not isinstance(loss, int): 172 | loss_m.update(loss.item(), samples.shape[0]) 173 | top1_m.update(acc1.item(), samples.shape[0]) 174 | top5_m.update(acc5.item(), samples.shape[0]) 175 | del loss, outputs, acc1, acc5 176 | logger.info( 177 | f"Inference ({prefix}):" 178 | + "average loss: {:.2f}, ".format(loss_m.avg) 179 | + "average top1: {:.2f} ".format(top1_m.avg) 180 | + "average top5: {:.2f}".format(top5_m.avg)) 181 | return OrderedDict( 182 | [('loss', round(loss_m.avg, 2)), ('top1', round(top1_m.avg, 2)), ('top5', round(top5_m.avg, 2))]) 183 | 184 | def load_weight(self): 185 | self.model.load_state_dict(torch.load(self.params.output_dir + '/model.pt')['model_state_dict']) 186 | 187 | @torch.no_grad() 188 | def collect_logits(self, loader): 189 | self.model.eval() 190 | all_logits = [] 191 | gt = [] 192 | with torch.no_grad(): 193 | for batch_idx, (samples, targets) in enumerate(loader): 194 | loss, outputs, (acc1, acc5) = self.forward_one_batch(samples, targets, False) 195 | all_logits.append(outputs.cpu().detach().numpy()) 196 | gt.append(targets.cpu().detach().numpy()) 197 | return np.concatenate(all_logits, axis=0), np.concatenate(gt, axis=0) -------------------------------------------------------------------------------- /env_setup.sh: -------------------------------------------------------------------------------- 1 | 2 | pip install torch==1.13.0+cu117 torchvision==0.14.0+cu117 --extra-index-url https://download.pytorch.org/whl/cu117 --no-cache-dir 3 | 4 | ## timm 5 | pip install timm==0.9.12 --no-cache-dir 6 | # 7 | ###VTAB 8 | pip install tensorflow==2.11.0 --no-cache-dir 9 | # specifying tfds versions is important to reproduce our results 10 | pip install tfds-nightly==4.4.0.dev202201080107 --no-cache-dir 11 | pip install tensorflow-addons==0.19.0 --no-cache-dir 12 | pip install opencv-python --no-cache-dir 13 | 14 | ## CLIP 15 | pip install git+https://github.com/openai/CLIP.git --no-cache-dir 16 | 17 | ####utils 18 | pip install dotwiz --no-cache-dir 19 | pip install pyyaml --no-cache-dir 20 | pip install tabulate --no-cache-dir 21 | pip install termcolor --no-cache-dir 22 | pip install iopath --no-cache-dir 23 | pip install scikit-learn --no-cache-dir 24 | 25 | pip install ftfy regex tqdm --no-cache-dir 26 | pip install pandas --no-cache-dir 27 | pip install matplotlib --no-cache-dir 28 | pip install ipykernel --no-cache-dir 29 | -------------------------------------------------------------------------------- /experiment/build_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from data.dataset.cub import get_cub 3 | from data.dataset.dog import get_dog 4 | from data.dataset.pet import get_pet 5 | from data.dataset.car import get_car 6 | from data.dataset.birds_525 import get_birds_525 7 | 8 | 9 | def get_dataset(data, params, logger): 10 | dataset_train, dataset_val, dataset_test = None, None, None 11 | 12 | if data.startswith("cub"): 13 | logger.info("Loading CUB data ...") 14 | if params.final_run: 15 | logger.info("Loading training data (final training data for cub)...") 16 | dataset_train = get_cub(params, 'trainval_combined') 17 | dataset_test = get_cub(params, 'test') 18 | else: 19 | raise NotImplementedError 20 | elif data.startswith("dog"): 21 | logger.info("Loading Standford Dogs data ...") 22 | if params.final_run: 23 | logger.info("Loading training data (final training data for dog)...") 24 | dataset_train = get_dog(params, 'trainval_combined') 25 | dataset_test = get_dog(params, 'test') 26 | else: 27 | raise NotImplementedError 28 | elif data.startswith("pet"): 29 | logger.info("Loading Oxford Pet data ...") 30 | if params.final_run: 31 | logger.info("Loading training data (final training data for pet)...") 32 | dataset_train = get_pet(params, 'trainval_combined') 33 | dataset_test = get_pet(params, 'test') 34 | else: 35 | raise NotImplementedError 36 | elif data.startswith("car"): 37 | logger.info("Loading Stanford Car data ...") 38 | if params.final_run: 39 | logger.info("Loading training data (final training data for car)...") 40 | dataset_train = get_car(params, 'trainval_combined') 41 | dataset_test = get_car(params, 'test') 42 | else: 43 | raise NotImplementedError 44 | elif data.startswith("birds_525"): 45 | logger.info("Loading Birds 525 data ...") 46 | if params.final_run: 47 | logger.info("Loading training data (final training data for birds_525)...") 48 | dataset_train = get_birds_525(params, 'trainval_combined') 49 | dataset_test = get_birds_525(params, 'test') 50 | else: 51 | raise NotImplementedError 52 | else: 53 | raise Exception("Dataset '{}' not supported".format(data)) 54 | return dataset_train, dataset_val, dataset_test 55 | 56 | 57 | def get_loader(params, logger): 58 | if 'test_data' in params: 59 | dataset_train, dataset_val, dataset_test = get_dataset(params.test_data, params, logger) 60 | else: 61 | dataset_train, dataset_val, dataset_test = get_dataset(params.data, params, logger) 62 | 63 | if isinstance(dataset_train, list): 64 | train_loader, val_loader, test_loader = [], [], [] 65 | for i in range(len(dataset_train)): 66 | tmp_train, tmp_val, tmp_test = gen_loader(params, dataset_train[i], dataset_val[i], None) 67 | train_loader.append(tmp_train) 68 | val_loader.append(tmp_val) 69 | test_loader.append(tmp_test) 70 | else: 71 | train_loader, val_loader, test_loader = gen_loader(params, dataset_train, dataset_val, dataset_test) 72 | 73 | logger.info("Finish setup loaders") 74 | return train_loader, val_loader, test_loader 75 | 76 | 77 | def gen_loader(params, dataset_train, dataset_val, dataset_test): 78 | train_loader, val_loader, test_loader = None, None, None 79 | if params.debug: 80 | num_workers = 1 81 | else: 82 | num_workers = 4 83 | if dataset_train is not None: 84 | train_loader = torch.utils.data.DataLoader( 85 | dataset_train, 86 | batch_size=params.batch_size, 87 | shuffle=True, 88 | num_workers=num_workers, 89 | pin_memory=True, 90 | drop_last=True 91 | ) 92 | if dataset_val is not None: 93 | val_loader = torch.utils.data.DataLoader( 94 | dataset_val, 95 | batch_size=params.test_batch_size, 96 | shuffle=False, 97 | num_workers=num_workers, 98 | pin_memory=True 99 | ) 100 | if dataset_test is not None: 101 | test_loader = torch.utils.data.DataLoader( 102 | dataset_test, 103 | batch_size=params.test_batch_size, 104 | shuffle=False, 105 | num_workers=num_workers, 106 | pin_memory=True 107 | 108 | ) 109 | return train_loader, val_loader, test_loader 110 | -------------------------------------------------------------------------------- /experiment/build_model.py: -------------------------------------------------------------------------------- 1 | from tkinter.constants import RAISED 2 | 3 | import timm 4 | import torch 5 | from model.vision_transformer import VisionTransformerPETL 6 | from utils.log_utils import log_model_info 7 | from timm.data import resolve_data_config 8 | from utils.setup_logging import get_logger 9 | 10 | logger = get_logger("Prompt_CAM") 11 | 12 | TUNE_MODULES = ['vpt'] 13 | def get_model(params,visualize=False): 14 | params.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | print(f"Using device: {params.device}") 16 | 17 | model = get_base_model(params,visualize=visualize) 18 | 19 | ########## 20 | tune_parameters = [] 21 | if params.debug: 22 | logger.info("Trainable params:") 23 | 24 | for name, parameter in model.named_parameters(): 25 | if any(m in name for m in TUNE_MODULES): 26 | parameter.requires_grad = True 27 | tune_parameters.append(parameter) 28 | if params.debug: 29 | logger.info("\t{}, {}, {}".format(name, parameter.numel(), parameter.shape)) 30 | else: 31 | parameter.requires_grad = False 32 | 33 | model_grad_params_no_head = log_model_info(model, logger) 34 | 35 | if not visualize: 36 | model = model.cuda(device=params.device) 37 | return model, tune_parameters, model_grad_params_no_head 38 | 39 | 40 | def get_base_model(params,visualize=False): 41 | if params.pretrained_weights == "vit_base_patch16_224_in21k": 42 | params.patch_size = 16 43 | model = timm.create_model("vit_base_patch16_224_in21k_petl", drop_path_rate=params.drop_path_rate, 44 | pretrained=False, params=params) 45 | if not visualize: 46 | model.load_pretrained( 47 | 'pretrained_weights/ViT-B_16_in21k.npz') 48 | model.reset_classifier(params.class_num) 49 | elif params.pretrained_weights == "vit_base_mae": 50 | model = timm.create_model("vit_base_patch16_224_in21k_petl", drop_path_rate=params.drop_path_rate, 51 | pretrained=False, 52 | params=params) 53 | if not visualize: 54 | model.load_pretrained( 55 | 'pretrained_weights/mae_pretrain_vit_base.pth') 56 | model.reset_classifier(params.class_num) 57 | elif params.pretrained_weights == "vit_base_patch14_dinov2": 58 | params.patch_size = 14 59 | model = timm.create_model("vit_base_patch14_dinov2_petl", drop_path_rate=params.drop_path_rate, 60 | pretrained=False, 61 | params=params) 62 | if not visualize: 63 | model.load_pretrained( 64 | 'pretrained_weights/dinov2_vitb14_pretrain.pth') 65 | model.reset_classifier(params.class_num) 66 | elif params.pretrained_weights == "vit_base_patch16_dino": 67 | model = timm.create_model("vit_base_patch16_dino_petl", drop_path_rate=params.drop_path_rate, 68 | pretrained=False, 69 | params=params) 70 | if not visualize: 71 | model.load_pretrained( 72 | 'pretrained_weights/dino_vitbase16_pretrain.pth') 73 | model.reset_classifier(params.class_num) 74 | elif params.pretrained_weights == 'vit_base_patch16_clip_224': 75 | params.patch_size = 16 76 | model = timm.create_model("vit_base_patch16_clip_224_petl", drop_path_rate=params.drop_path_rate, 77 | pretrained=False, 78 | params=params) 79 | if not visualize: 80 | model.load_pretrained( 81 | 'pretrained_weights/ViT-B_16_clip.bin') 82 | 83 | fc = init_imagenet_clip(params.device) 84 | proj = get_clip_proj(params.device) 85 | model.head = torch.nn.Sequential(*[proj, fc]) 86 | else: 87 | raise NotImplementedError 88 | 89 | # data_config = resolve_data_config(vars(params), model=model, verbose=False) 90 | 91 | return model 92 | -------------------------------------------------------------------------------- /experiment/config/prompt_cam/dino/birds_525/args.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 16 2 | crop_size: 224 3 | data: birds_525 4 | data_path: ./data/images/birds_525 5 | debug: true 6 | drop_path_rate: 0.1 7 | early_patience: 101 8 | epoch: 100 9 | eval_freq: 5 10 | final_acc_hp: true 11 | final_run: true 12 | full: false 13 | gpu_num: 1 14 | lr: 0.01 15 | lr_min: 1.0e-06 16 | model: dino 17 | momentum: 0.9 18 | normalized: true 19 | optimizer: sgd 20 | pretrained_weights: vit_base_patch16_dino 21 | random_seed: 42 22 | store_ckp: true 23 | test_batch_size: 32 24 | train_type: prompt_cam 25 | vpt_dropout: 0.1 26 | vpt_layer: null 27 | vpt_mode: null 28 | vpt_num: 525 29 | warmup_epoch: 20 30 | warmup_lr_init: 0 31 | wd: 0.001 32 | -------------------------------------------------------------------------------- /experiment/config/prompt_cam/dino/car/args.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 16 2 | crop_size: 224 3 | data: car 4 | data_path: ./data/images/car 5 | debug: true 6 | drop_path_rate: 0.1 7 | early_patience: 101 8 | epoch: 100 9 | eval_freq: 10 10 | final_acc_hp: true 11 | final_run: true 12 | full: false 13 | gpu_num: 1 14 | lr: 0.01 15 | lr_min: 1.0e-06 16 | model: dino 17 | momentum: 0.9 18 | normalized: true 19 | optimizer: sgd 20 | pretrained_weights: vit_base_patch16_dino 21 | random_seed: 42 22 | store_ckp: true 23 | test_batch_size: 32 24 | train_type: prompt_cam 25 | vpt_dropout: 0.1 26 | vpt_layer: null 27 | vpt_mode: null 28 | vpt_num: 196 29 | warmup_epoch: 20 30 | warmup_lr_init: 0 31 | wd: 0.001 32 | -------------------------------------------------------------------------------- /experiment/config/prompt_cam/dino/cub/args.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 16 2 | crop_size: 224 3 | data: cub #Change this for new dataset 4 | data_path: ./data/images/cub #Change this for new dataset 5 | debug: true 6 | drop_path_rate: 0.1 7 | early_patience: 101 8 | epoch: 130 #Tune this for new dataset 9 | eval_freq: 10 10 | final_acc_hp: true 11 | final_run: true 12 | full: false 13 | gpu_num: 1 14 | lr: 0.005 #Tune this for new dataset 15 | lr_min: 1.0e-06 16 | model: dino 17 | momentum: 0.9 18 | normalized: true 19 | optimizer: sgd 20 | pretrained_weights: vit_base_patch16_dino 21 | random_seed: 42 22 | store_ckp: true 23 | test_batch_size: 32 24 | train_type: prompt_cam 25 | vpt_dropout: 0.1 26 | vpt_layer: null 27 | vpt_mode: null 28 | vpt_num: 200 #Change this for new dataset 29 | warmup_epoch: 20 30 | warmup_lr_init: 0 31 | wd: 0.001 32 | -------------------------------------------------------------------------------- /experiment/config/prompt_cam/dino/dog/args.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 16 2 | crop_size: 224 3 | data: dog 4 | data_path: ./data/images/dog 5 | debug: true 6 | drop_path_rate: 0.1 7 | early_patience: 101 8 | epoch: 100 9 | eval_freq: 10 10 | final_acc_hp: true 11 | final_run: true 12 | full: false 13 | gpu_num: 1 14 | lr: 0.005 15 | lr_min: 1.0e-06 16 | model: dino 17 | momentum: 0.9 18 | normalized: true 19 | optimizer: sgd 20 | pretrained_weights: vit_base_patch16_dino 21 | random_seed: 42 22 | store_ckp: true 23 | test_batch_size: 32 24 | train_type: prompt_cam 25 | vpt_dropout: 0.1 26 | vpt_layer: null 27 | vpt_mode: null 28 | vpt_num: 120 29 | warmup_epoch: 20 30 | warmup_lr_init: 0 31 | wd: 0.001 32 | -------------------------------------------------------------------------------- /experiment/config/prompt_cam/dino/pet/args.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 16 2 | crop_size: 224 3 | data: pet 4 | data_path: ./data/images/pet 5 | debug: true 6 | drop_path_rate: 0.1 7 | early_patience: 101 8 | epoch: 100 9 | eval_freq: 10 10 | final_acc_hp: true 11 | final_run: true 12 | full: false 13 | gpu_num: 1 14 | lr: 0.005 15 | lr_min: 1.0e-06 16 | model: dino 17 | momentum: 0.9 18 | normalized: true 19 | optimizer: sgd 20 | pretrained_weights: vit_base_patch16_dino 21 | random_seed: 42 22 | store_ckp: true 23 | test_batch_size: 32 24 | train_type: prompt_cam 25 | vpt_dropout: 0.1 26 | vpt_layer: null 27 | vpt_mode: null 28 | vpt_num: 37 29 | warmup_epoch: 20 30 | warmup_lr_init: 0 31 | wd: 0.001 32 | -------------------------------------------------------------------------------- /experiment/config/prompt_cam/dinov2/cub/args.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 16 2 | crop_size: 224 3 | data: cub 4 | data_path: ./data/images/cub 5 | debug: true 6 | drop_path_rate: 0.0 7 | early_patience: 101 8 | epoch: 130 9 | eval_freq: 10 10 | final_acc_hp: true 11 | final_run: true 12 | full: false 13 | gpu_num: 1 14 | lr: 0.001 15 | lr_min: 1.0e-06 16 | model: dinov2 17 | momentum: 0.9 18 | normalized: true 19 | optimizer: sgd 20 | pretrained_weights: vit_base_patch14_dinov2 21 | random_seed: 42 22 | store_ckp: true 23 | test_batch_size: 32 24 | train_type: prompt_cam 25 | vpt_dropout: 0.0 26 | vpt_layer: null 27 | vpt_mode: null 28 | vpt_num: 200 29 | warmup_epoch: 20 30 | warmup_lr_init: 1.0e-06 31 | wd: 0.001 32 | -------------------------------------------------------------------------------- /experiment/config/prompt_cam/dinov2/dog/args.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 16 2 | crop_size: 224 3 | data: dog 4 | data_path: ./data/images/dog 5 | debug: true 6 | drop_path_rate: 0.0 7 | early_patience: 101 8 | epoch: 100 9 | eval_freq: 10 10 | final_acc_hp: true 11 | final_run: true 12 | full: false 13 | gpu_num: 1 14 | lr: 0.005 15 | lr_min: 1.0e-06 16 | model: dinov2 17 | momentum: 0.9 18 | normalized: true 19 | optimizer: sgd 20 | pretrained_weights: vit_base_patch14_dinov2 21 | random_seed: 42 22 | store_ckp: true 23 | test_batch_size: 32 24 | train_type: prompt_cam 25 | vpt_dropout: 0.0 26 | vpt_layer: null 27 | vpt_mode: null 28 | vpt_num: 120 29 | warmup_epoch: 20 30 | warmup_lr_init: 1.0e-06 31 | wd: 0.001 32 | -------------------------------------------------------------------------------- /experiment/config/prompt_cam/dinov2/pet/args.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 16 2 | crop_size: 224 3 | data: pet 4 | data_path: ./data/images/pet 5 | debug: true 6 | drop_path_rate: 0.0 7 | early_patience: 101 8 | epoch: 100 9 | eval_freq: 10 10 | final_acc_hp: true 11 | final_run: true 12 | full: false 13 | gpu_num: 1 14 | lr: 0.005 15 | lr_min: 1.0e-06 16 | model: dinov2 17 | momentum: 0.9 18 | normalized: true 19 | optimizer: sgd 20 | pretrained_weights: vit_base_patch14_dinov2 21 | random_seed: 42 22 | store_ckp: true 23 | test_batch_size: 32 24 | train_type: prompt_cam 25 | vpt_dropout: 0.0 26 | vpt_layer: null 27 | vpt_mode: null 28 | vpt_num: 37 29 | warmup_epoch: 20 30 | warmup_lr_init: 1.0e-06 31 | wd: 0.001 32 | -------------------------------------------------------------------------------- /experiment/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | from experiment.build_model import get_model 3 | from experiment.build_loader import get_loader 4 | from utils.global_var import OUTPUT_DIR, TUNE_DIR, TUNE_DIR_TEST 5 | from engine.trainer import Trainer 6 | from timm.utils import get_outdir 7 | from utils.log_utils import logging_env_setup 8 | from utils.misc import method_name 9 | from datetime import datetime 10 | import yaml 11 | import torch 12 | from utils.misc import set_seed 13 | from collections import OrderedDict 14 | from statistics import mean 15 | import json 16 | import time 17 | import csv 18 | import numpy as np 19 | from utils.setup_logging import get_logger 20 | 21 | logger = get_logger("Prompt_CAM") 22 | 23 | 24 | def train(params, train_loader, val_loader, test_loader): 25 | model, tune_parameters, model_grad_params_no_head = get_model(params) 26 | trainer = Trainer(model, tune_parameters, params) 27 | train_metrics, best_eval_metrics, eval_metrics = trainer.train_classifier(train_loader, val_loader, test_loader) 28 | return train_metrics, best_eval_metrics, eval_metrics, model_grad_params_no_head, trainer.model 29 | 30 | def basic_run(params): 31 | if torch.cuda.is_available(): 32 | torch.cuda.empty_cache() 33 | data_name = params.data.split("-")[-1] 34 | dataset_name = params.data.split("-")[0] 35 | method = method_name(params) 36 | start_time = datetime.now().strftime("%Y-%m-%d-%H:%M") 37 | output_dir = os.path.join(OUTPUT_DIR, params.pretrained_weights, dataset_name, method, data_name, start_time) 38 | params.output_dir = get_outdir(output_dir) 39 | params_text = yaml.safe_dump(params.__dict__, default_flow_style=False) 40 | with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: 41 | f.write(params_text) 42 | logging_env_setup(params) 43 | logger.info(f'Start loading {data_name}') 44 | train_loader, val_loader, test_loader = get_loader(params, logger) 45 | 46 | train(params, train_loader, val_loader, test_loader) 47 | 48 | 49 | def update_output_dir(default_params, test): 50 | logger.info(f'start running {default_params.method_name}') 51 | if torch.cuda.is_available(): 52 | torch.cuda.empty_cache() 53 | data_name = default_params.data.split("-")[-1] 54 | dataset_name = default_params.data.split("-")[0] 55 | method = default_params.method_name 56 | if test: 57 | output_dir = os.path.join(TUNE_DIR_TEST, default_params.experiment_name, dataset_name, data_name, method) 58 | else: 59 | output_dir = os.path.join(TUNE_DIR, default_params.experiment_name, dataset_name, data_name, method) 60 | default_params.output_dir = output_dir 61 | 62 | logging_env_setup(default_params) 63 | return output_dir, data_name 64 | 65 | 66 | 67 | def evaluate(default_params): 68 | _, _, test_loader = get_loader(default_params, logger) 69 | if 'eval' in default_params.test_data: 70 | result_name = f'{default_params.test_data.split("_")[1]}_result.json' 71 | else: 72 | result_name = f'{default_params.test_data}_result.json' 73 | if not os.path.isfile(os.path.join(default_params.output_dir, result_name)): 74 | if not os.path.isfile(os.path.join(default_params.output_dir, 'final_result.json')): 75 | logger.info('no final_result.json, the model is not fine-tuned, show model zero shot performance') 76 | best_tune = () 77 | result_name = 'zero_shot_' + result_name 78 | else: 79 | result = json.load(open(os.path.join(default_params.output_dir, 'final_result.json'))) 80 | best_tune = result['best_tune'] 81 | default_params.update(best_tune) 82 | 83 | model, tune_parameters, model_grad_params_no_head = get_model(default_params) 84 | trainer = Trainer(model, tune_parameters, default_params) 85 | if not os.path.isfile(os.path.join(default_params.output_dir, 'model.pt')): 86 | assert not os.path.isfile(os.path.join(default_params.output_dir, 'final_result.json')) 87 | logger.info('no model.pt, shows zero shot performance') 88 | else: 89 | trainer.load_weight() 90 | eval_metrics = trainer.eval_classifier(test_loader, 'test') 91 | json.dump( 92 | {"avg_acc": eval_metrics['top1'], "inserted_parameters": model_grad_params_no_head, 93 | 'best_tune': best_tune}, 94 | open(os.path.join(default_params.output_dir, result_name), 'w')) 95 | else: 96 | logger.info(f'finish {result_name} for {default_params.method_name}') 97 | return 98 | 99 | 100 | def result_tracker(first_col, train_metrics, eval_metrics, best_eval_metrics, filename, write_header=False, first_col_name='param_set', 101 | eval_name='val_'): 102 | rowd = OrderedDict([(first_col_name, first_col)]) 103 | rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) 104 | rowd.update([(eval_name + k, v) for k, v in eval_metrics.items()]) 105 | rowd.update([(eval_name + "best_" + k, v) for k, v in best_eval_metrics.items()]) 106 | with open(filename, mode='a') as cf: 107 | dw = csv.DictWriter(cf, fieldnames=rowd.keys()) 108 | if write_header: 109 | dw.writeheader() 110 | dw.writerow(rowd) 111 | -------------------------------------------------------------------------------- /experiment/visualize_run.py: -------------------------------------------------------------------------------- 1 | import os 2 | from experiment.build_model import get_model 3 | from experiment.build_loader import get_loader 4 | from timm.utils import get_outdir,accuracy 5 | from utils.log_utils import logging_env_setup 6 | from utils.misc import AverageMeter 7 | import torch 8 | import numpy as np 9 | from utils.setup_logging import get_logger 10 | 11 | from utils.visual_utils import combine_images,create_overlay_images 12 | 13 | 14 | logger = get_logger("Prompt_CAM") 15 | 16 | 17 | def basic_vis(params): 18 | 19 | if torch.cuda.is_available(): 20 | torch.cuda.empty_cache() 21 | dataset_name = params.data.split("-")[0] 22 | top_traits = f"top_traits_{params.top_traits}" 23 | output_dir = os.path.join(params.vis_outdir, params.model, dataset_name, "class_"+str(params.vis_cls),top_traits) 24 | params.output_dir = get_outdir(output_dir) 25 | logging_env_setup(params) 26 | 27 | 28 | logger.info(f'Start loading test data: {dataset_name}') 29 | _, _, test_loader = get_loader(params, logger) 30 | 31 | logger.info(f'Start loading model: {params.model}') 32 | model, _ , _ = get_model(params) 33 | model.load_state_dict(torch.load(params.checkpoint)['model_state_dict']) 34 | logger.info (f'Model loaded from {params.checkpoint}') 35 | 36 | top1_m = AverageMeter() 37 | 38 | model.eval() 39 | 40 | 41 | params.test_batch_size= 1 42 | _, _, test_loader = get_loader(params, logger) 43 | 44 | 45 | smpl_count = 0 46 | with torch.no_grad(): 47 | for batch_idx, (samples, targets) in enumerate(test_loader): 48 | # move data to device 49 | samples = samples.to(params.device, non_blocking=True) # (batchsize, channel, height, width) 50 | targets = targets.to(params.device, non_blocking=True) # (batchsize, ) 51 | 52 | if targets[0].item() == params.vis_cls: 53 | smpl_count += 1 54 | 55 | outputs, attn_map = model(samples) 56 | predicted_class = torch.argmax(outputs, dim=1).item() 57 | 58 | if predicted_class == targets[0].item(): 59 | logger.info(f"Predicted class: {predicted_class}, Target class: {targets[0].item()}") 60 | prune_attn_heads(model,samples,targets, predicted_class,smpl_count, params) 61 | else: 62 | attn_map = attn_map[:, :, targets[0].item(), (params.vpt_num+1):] 63 | create_overlay_images(samples, 64 | model.patch_size, 65 | attn_map, 66 | f'{params.output_dir}/img_{smpl_count}') 67 | 68 | combine_images(path=f'{params.output_dir}/img_{smpl_count}', pred_class=predicted_class) 69 | 70 | 71 | if smpl_count == params.nmbr_samples: 72 | break 73 | 74 | 75 | 76 | 77 | #TODO: ADD Later 78 | with torch.no_grad(): 79 | for batch_idx, (samples, targets) in enumerate(test_loader): 80 | # move data to device 81 | samples = samples.to(params.device, non_blocking=True) # (batchsize, 2048) 82 | targets = targets.to(params.device, non_blocking=True) # (batchsize, ) 83 | 84 | outputs,_ = model(samples) 85 | acc1,_= accuracy(outputs.squeeze(-1), targets, topk=(1,5)) 86 | top1_m.update(acc1.item(), samples.shape[0]) 87 | 88 | del outputs, acc1 89 | 90 | logger.info("Evaluate: average top1: {:.2f}".format(top1_m.avg)) 91 | 92 | 93 | def prune_attn_heads(model,inputs,target, prediction,smpl_count, params): 94 | remaining_head_list = list(range(model.num_heads)) 95 | pruned_head_index = None 96 | blur_head_lst = [] 97 | 98 | while len(remaining_head_list) > 0 and len(remaining_head_list) > params.top_traits: 99 | highest_score=-1e8 100 | remaining_head_scores= [] 101 | 102 | for head_idx in remaining_head_list: 103 | output,_ = model(inputs, 104 | blur_head_lst=blur_head_lst+[head_idx], 105 | target_cls=prediction) 106 | 107 | probabilities = torch.softmax(output.squeeze(-1), dim=-1) 108 | 109 | remaining_head_scores.append(probabilities[0,prediction].item()) 110 | 111 | if remaining_head_scores[-1] > highest_score: 112 | highest_score=remaining_head_scores[-1] 113 | pruned_head_index=head_idx 114 | 115 | if pruned_head_index is not None: 116 | blur_head_lst.append(pruned_head_index) 117 | remaining_head_list.remove(pruned_head_index) 118 | print(f'best head to prune is {pruned_head_index+1} with score {highest_score}') 119 | 120 | sorted_remaining_heads = [head for _, head in sorted(zip(remaining_head_scores, remaining_head_list))] 121 | 122 | _,attn_map=model(inputs, 123 | blur_head_lst=blur_head_lst, 124 | target_cls=prediction) 125 | attn_map = attn_map[:, sorted_remaining_heads, prediction, (params.vpt_num+1):] 126 | create_overlay_images(inputs, 127 | model.patch_size, 128 | attn_map, 129 | f'{params.output_dir}/img_{smpl_count}') 130 | 131 | combine_images(path=f'{params.output_dir}/img_{smpl_count}', pred_class=prediction) 132 | 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from experiment.run import basic_run 3 | from utils.setup_logging import get_logger 4 | from utils.misc import set_seed,load_yaml,override_args_with_yaml 5 | import time 6 | 7 | logger = get_logger("Prompt_CAM") 8 | 9 | 10 | def main(): 11 | args = setup_parser().parse_args() 12 | 13 | if args.config: 14 | yaml_config = load_yaml(args.config) 15 | if yaml_config: 16 | args = override_args_with_yaml(args, yaml_config) 17 | 18 | set_seed(args.random_seed) 19 | start = time.time() 20 | args.vis_attn= False 21 | basic_run(args) 22 | end = time.time() 23 | logger.info(f'----------- Total Run time : {(end - start) / 60} mins-----------') 24 | 25 | 26 | def setup_parser(): 27 | parser = argparse.ArgumentParser(description='Prompt_CAM') 28 | 29 | ########################Pretrained Model######################### 30 | parser.add_argument('--pretrained_weights', type=str, default='vit_base_patch16_224_in21k', 31 | choices=['vit_base_patch16_224_in21k', 'vit_base_mae', 'vit_base_patch14_dinov2','vit_base_patch16_dino', 32 | 'vit_base_patch16_clip_224'], 33 | help='pretrained weights name') 34 | parser.add_argument('--drop_path_rate', default=0.1, 35 | type=float, 36 | help='Drop Path Rate (default: %(default)s)') 37 | parser.add_argument('--model', type=str, default='dinov2', choices=['vit', 'dino', 'dinov2'], 38 | help='pretrained model name') 39 | 40 | parser.add_argument('--train_type', type=str, default='vpt', choices=['vpt', 'prompt_cam', 'linear'], 41 | help='pretrained model name') 42 | 43 | ########################Optimizer Scheduler######################### 44 | parser.add_argument('--optimizer', default='sgd', choices=['sgd', 'adam', 'adamw'], 45 | help='Optimizer (default: %(default)s)') 46 | parser.add_argument('--lr', default=0.005, 47 | type=float, 48 | help='Learning rate (default: %(default)s)') 49 | parser.add_argument('--epoch', default=100, 50 | type=int, 51 | help='The number of total epochs used. (default: %(default)s)') 52 | parser.add_argument('--warmup_epoch', default=20, 53 | type=int, 54 | help='warnup epoch in scheduler. (default: %(default)s)') 55 | parser.add_argument('--lr_min', type=float, default=1e-5, 56 | help='lr_min for scheduler (default: %(default)s)') 57 | parser.add_argument('--warmup_lr_init', type=float, default=1e-6, 58 | help='warmup_lr_init for scheduler (default: %(default)s)') 59 | parser.add_argument('--batch_size', default=16, 60 | type=int, 61 | help='Batch size (default: %(default)s)') 62 | parser.add_argument('--test_batch_size', default=32, 63 | type=int, 64 | help='Test batch size (default: %(default)s)') 65 | parser.add_argument('--wd', type=float, default=0.001, 66 | help='weight_decay (default: %(default)s)') 67 | parser.add_argument('--momentum', type=float, default=0.9, 68 | help='momentum used in sgd (default: %(default)s)') 69 | parser.add_argument('--early_patience', type=int, default=101, 70 | help='early stop patience (default: %(default)s)') 71 | 72 | ########################Data######################### 73 | parser.add_argument('--data', default="processed_vtab-dtd", 74 | help='data name. (default: %(default)s)') 75 | parser.add_argument('--data_path', default="data_folder/vtab_processed", 76 | help='Path to the dataset. (default: %(default)s)') 77 | parser.add_argument('--crop_size', default=224, 78 | type=int, 79 | help='Crop size of the input image (default: %(default)s)') 80 | parser.add_argument('--final_run', action='store_false', 81 | help='If final_run is True, use train+val as train data else, use train only') 82 | parser.add_argument('--normalized', action='store_false', 83 | help='If imagees are normalized using ImageNet mean and variance ') 84 | 85 | ########################VPT######################### 86 | parser.add_argument('--vpt_mode', type=str, default=None, choices=['deep', 'shallow'], 87 | help='VPT mode, deep or shallow') 88 | parser.add_argument('--vpt_num', default=10, type=int, 89 | help='Number of prompts (default: %(default)s)') 90 | parser.add_argument('--vpt_layer', default=None, type=int, 91 | help='Number of layers to add prompt, start from the last layer (default: %(default)s)') 92 | parser.add_argument('--vpt_dropout', default=0.1, type=float, 93 | help='VPT dropout rate for deep mode. (default: %(default)s)') 94 | 95 | 96 | ########################full######################### 97 | parser.add_argument('--full', action='store_true', 98 | help='whether turn on full finetune') 99 | 100 | ########################Misc######################### 101 | parser.add_argument('--gpu_num', default=1, 102 | type=int, 103 | help='Number of GPU (default: %(default)s)') 104 | parser.add_argument('--debug', action='store_false', 105 | help='Debug mode to show more information (default: %(default)s)') 106 | parser.add_argument('--random_seed', default=42, 107 | type=int, 108 | help='Random seed (default: %(default)s)') 109 | parser.add_argument('--eval_freq', default=10, 110 | type=int, 111 | help='eval frequency(epoch) testset (default: %(default)s)') 112 | parser.add_argument('--store_ckp', action='store_true', 113 | help='whether store checkpoint') 114 | parser.add_argument('--final_acc_hp', action='store_false', 115 | help='if true, use the best acc during all epochs as criteria to select HP, if false, use the acc at final epoch as criteria to select HP ') 116 | 117 | ######################## YAML Config ######################### 118 | parser.add_argument('--config', type=str, default=None, help='Path to YAML config file') 119 | 120 | 121 | return parser 122 | 123 | 124 | if __name__ == '__main__': 125 | main() 126 | -------------------------------------------------------------------------------- /model/attention.py: -------------------------------------------------------------------------------- 1 | from torch.jit import Final 2 | import torch.nn as nn 3 | from timm.layers import use_fused_attn 4 | import torch 5 | import torch.nn.functional as F 6 | from typing import Tuple 7 | 8 | 9 | class AttentionPETL(nn.Module): 10 | fused_attn: Final[bool] 11 | 12 | def __init__( 13 | self, 14 | dim: int, 15 | num_heads: int = 8, 16 | qkv_bias: bool = False, 17 | qk_norm: bool = False, 18 | attn_drop: float = 0., 19 | proj_drop: float = 0., 20 | norm_layer: nn.Module = nn.LayerNorm, 21 | params=None, 22 | ) -> None: 23 | super().__init__() 24 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 25 | self.num_heads = num_heads 26 | self.head_dim = dim // num_heads 27 | self.scale = self.head_dim ** -0.5 28 | self.fused_attn = use_fused_attn() 29 | 30 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 31 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 32 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 33 | self.attn_drop = nn.Dropout(attn_drop) 34 | self.proj = nn.Linear(dim, dim) 35 | self.proj_drop = nn.Dropout(proj_drop) 36 | 37 | ############# Added module ############# 38 | self.params = params 39 | ############# Added module end ############# 40 | 41 | def forward(self, x: torch.Tensor, block_idx, blur_head_lst=[],target_cls=-1) -> Tuple[torch.Tensor,torch.Tensor]: 42 | B, N, C = x.shape 43 | ############# Added module ############# 44 | qkv = self.qkv(x) 45 | ############# Added module end ############# 46 | 47 | qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 48 | q, k, v = qkv.unbind(0) 49 | 50 | q, k = self.q_norm(q), self.k_norm(k) 51 | 52 | q = q * self.scale 53 | attn = q @ k.transpose(-2, -1) 54 | 55 | ############# Added module ############# 56 | if len(blur_head_lst)!=0: 57 | attn[:, blur_head_lst, target_cls, :] = 0 58 | 59 | ############# Added module end ############# 60 | 61 | attn = attn.softmax(dim=-1) 62 | attn = self.attn_drop(attn) 63 | x = attn @ v 64 | 65 | x = x.transpose(1, 2).reshape(B, N, C) 66 | proj = self.proj(x) 67 | x = self.proj_drop(proj) 68 | return x,attn 69 | -------------------------------------------------------------------------------- /model/block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from timm.layers import DropPath 3 | from timm.models.vision_transformer import LayerScale 4 | from timm.layers.trace_utils import _assert 5 | import torch 6 | from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union, List 7 | from model.mlp import MlpPETL 8 | from model.attention import AttentionPETL 9 | 10 | class BlockPETL(nn.Module): 11 | def __init__( 12 | self, 13 | dim: int, 14 | num_heads: int, 15 | mlp_ratio: float = 4., 16 | qkv_bias: bool = False, 17 | qk_norm: bool = False, 18 | proj_drop: float = 0., 19 | attn_drop: float = 0., 20 | init_values: Optional[float] = None, 21 | drop_path: float = 0., 22 | act_layer: nn.Module = nn.GELU, 23 | norm_layer: nn.Module = nn.LayerNorm, 24 | mlp_layer: nn.Module = MlpPETL, 25 | params=None 26 | ) -> None: 27 | super().__init__() 28 | self.norm1 = norm_layer(dim) 29 | self.attn = AttentionPETL( 30 | dim, 31 | num_heads=num_heads, 32 | qkv_bias=qkv_bias, 33 | qk_norm=qk_norm, 34 | attn_drop=attn_drop, 35 | proj_drop=proj_drop, 36 | norm_layer=norm_layer, 37 | ############# Added module ############# 38 | params=params 39 | ############# Added module end ############# 40 | ) 41 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 42 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 43 | 44 | self.norm2 = norm_layer(dim) 45 | self.mlp = mlp_layer( 46 | in_features=dim, 47 | hidden_features=int(dim * mlp_ratio), 48 | act_layer=act_layer, 49 | drop=proj_drop 50 | ) 51 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 52 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 53 | 54 | ############# Added module ############# 55 | self.params = params 56 | ############# Added module end ############# 57 | 58 | def forward(self, x: torch.Tensor, idx, blur_head_lst=[], target_cls=-1) -> Tuple[torch.Tensor,torch.Tensor]: 59 | output, attn_map = self.attn(self.norm1(x), idx , blur_head_lst=blur_head_lst, target_cls=target_cls) 60 | x = x + self.drop_path1(self.ls1(output)) 61 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) 62 | return x,attn_map 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /model/mlp.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from timm.layers.helpers import to_2tuple 3 | from functools import partial 4 | 5 | 6 | class MlpPETL(nn.Module): 7 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 8 | """ 9 | def __init__( 10 | self, 11 | in_features, 12 | hidden_features=None, 13 | out_features=None, 14 | act_layer=nn.GELU, 15 | norm_layer=None, 16 | bias=True, 17 | drop=0., 18 | use_conv=False 19 | ): 20 | super().__init__() 21 | 22 | out_features = out_features or in_features 23 | hidden_features = hidden_features or in_features 24 | bias = to_2tuple(bias) 25 | drop_probs = to_2tuple(drop) 26 | linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear 27 | 28 | self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) 29 | self.act = act_layer() 30 | self.drop1 = nn.Dropout(drop_probs[0]) 31 | self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() 32 | self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) 33 | self.drop2 = nn.Dropout(drop_probs[1]) 34 | 35 | 36 | def forward(self, x): 37 | h = self.fc1(x) 38 | x = self.act(h) 39 | x = self.drop1(x) 40 | x = self.norm(x) 41 | h = self.fc2(x) 42 | x = self.drop2(h) 43 | return x -------------------------------------------------------------------------------- /model/patch_embed.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Optional, Tuple, Union 2 | 3 | import torch 4 | from torch import nn as nn 5 | import torch.nn.functional as F 6 | 7 | from timm.layers.format import Format, nchw_to 8 | from timm.layers.helpers import to_2tuple 9 | from timm.layers.trace_utils import _assert 10 | 11 | 12 | 13 | class PatchEmbedPETL(nn.Module): 14 | """ 2D Image to Patch Embedding 15 | """ 16 | output_fmt: Format 17 | dynamic_img_pad: torch.jit.Final[bool] 18 | 19 | def __init__( 20 | self, 21 | img_size: Optional[int] = 224, 22 | patch_size: int = 16, 23 | in_chans: int = 3, 24 | embed_dim: int = 768, 25 | norm_layer: Optional[Callable] = None, 26 | flatten: bool = True, 27 | output_fmt: Optional[str] = None, 28 | bias: bool = True, 29 | strict_img_size: bool = True, 30 | dynamic_img_pad: bool = False, 31 | params = None 32 | ): 33 | super().__init__() 34 | self.patch_size = to_2tuple(patch_size) 35 | if img_size is not None: 36 | self.img_size = to_2tuple(img_size) 37 | self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)]) 38 | self.num_patches = self.grid_size[0] * self.grid_size[1] 39 | else: 40 | self.img_size = None 41 | self.grid_size = None 42 | self.num_patches = None 43 | 44 | if output_fmt is not None: 45 | self.flatten = False 46 | self.output_fmt = Format(output_fmt) 47 | else: 48 | # flatten spatial dim and transpose to channels last, kept for bwd compat 49 | self.flatten = flatten 50 | self.output_fmt = Format.NCHW 51 | self.strict_img_size = strict_img_size 52 | self.dynamic_img_pad = dynamic_img_pad 53 | 54 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) 55 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 56 | 57 | ############# Added module ############# 58 | self.params = params 59 | self.norm_layer = norm_layer 60 | ############# Added module end ############# 61 | 62 | def forward(self, x): 63 | B, C, H, W = x.shape 64 | if self.img_size is not None: 65 | if self.strict_img_size: 66 | _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).") 67 | _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).") 68 | elif not self.dynamic_img_pad: 69 | _assert( 70 | H % self.patch_size[0] == 0, 71 | f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})." 72 | ) 73 | _assert( 74 | W % self.patch_size[1] == 0, 75 | f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})." 76 | ) 77 | if self.dynamic_img_pad: 78 | pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] 79 | pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] 80 | x = F.pad(x, (0, pad_w, 0, pad_h)) 81 | x = self.proj(x) 82 | if self.flatten: 83 | x = x.flatten(2).transpose(1, 2) # NCHW -> NLC 84 | elif self.output_fmt != Format.NCHW: 85 | x = nchw_to(x, self.output_fmt) 86 | x = self.norm(x) 87 | return x 88 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def init_weight(down, up, option): 7 | with torch.no_grad(): 8 | if option == 'lora_kaiming': 9 | nn.init.kaiming_uniform_(down.weight, a=math.sqrt(5)) 10 | nn.init.zeros_(up.weight) 11 | nn.init.zeros_(down.bias) 12 | nn.init.zeros_(up.bias) 13 | elif option == 'lora_xavier': 14 | nn.init.xavier_uniform_(down.weight) 15 | nn.init.zeros_(up.weight) 16 | nn.init.zeros_(down.bias) 17 | nn.init.zeros_(up.bias) 18 | elif option == 'xavier': 19 | nn.init.xavier_uniform_(down.weight) 20 | nn.init.xavier_uniform_(up.weight) 21 | nn.init.normal_(down.bias, std=1e-6) 22 | nn.init.normal_(up.bias, std=1e-6) 23 | elif option == 'zero': 24 | nn.init.zeros_(down.weight) 25 | nn.init.zeros_(up.bias) 26 | nn.init.zeros_(down.weight) 27 | nn.init.zeros_(up.bias) 28 | else: 29 | raise NotImplementedError -------------------------------------------------------------------------------- /model/vision_transformer.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union, List 3 | 4 | try: 5 | from typing import Literal 6 | except ImportError: 7 | from typing_extensions import Literal 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.utils.checkpoint 12 | from torch.jit import Final 13 | from timm.layers import DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \ 14 | trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \ 15 | get_act_layer, get_norm_layer, LayerType 16 | from timm.models._builder import build_model_with_cfg 17 | from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv 18 | from timm.models._registry import generate_default_cfgs, register_model, register_model_deprecations 19 | from timm.models.vision_transformer import VisionTransformer 20 | from timm.models.vision_transformer import LayerScale, init_weights_vit_timm, get_init_weights_vit, \ 21 | _load_weights, checkpoint_filter_fn 22 | ## added for petl 23 | from utils.setup_logging import get_logger 24 | from model.block import BlockPETL 25 | from model.patch_embed import PatchEmbedPETL 26 | from model.mlp import MlpPETL 27 | from model.vpt import VPT 28 | 29 | logger = get_logger("Prompt_CAM") 30 | 31 | 32 | class VisionTransformerPETL(VisionTransformer): 33 | """ Vision Transformer 34 | 35 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 36 | - https://arxiv.org/abs/2010.11929 37 | """ 38 | dynamic_img_size: Final[bool] 39 | 40 | def __init__( 41 | self, 42 | img_size: Union[int, Tuple[int, int]] = 224, 43 | patch_size: Union[int, Tuple[int, int]] = 16, 44 | in_chans: int = 3, 45 | num_classes: int = 1000, 46 | global_pool: Literal['', 'avg', 'token', 'map'] = 'token', 47 | embed_dim: int = 768, 48 | depth: int = 12, 49 | num_heads: int = 12, 50 | mlp_ratio: float = 4., 51 | qkv_bias: bool = True, 52 | qk_norm: bool = False, 53 | init_values: Optional[float] = None, 54 | class_token: bool = True, 55 | no_embed_class: bool = False, 56 | reg_tokens: int = 0, 57 | pre_norm: bool = False, 58 | fc_norm: Optional[bool] = None, 59 | dynamic_img_size: bool = False, 60 | dynamic_img_pad: bool = False, 61 | drop_rate: float = 0., 62 | pos_drop_rate: float = 0., 63 | patch_drop_rate: float = 0., 64 | proj_drop_rate: float = 0., 65 | attn_drop_rate: float = 0., 66 | drop_path_rate: float = 0., 67 | weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '', 68 | embed_layer: Callable = PatchEmbedPETL, 69 | norm_layer: Optional[LayerType] = None, 70 | act_layer: Optional[LayerType] = None, 71 | block_fn: Type[nn.Module] = BlockPETL, 72 | mlp_layer: Type[nn.Module] = MlpPETL, 73 | params=None 74 | ) -> None: 75 | """ 76 | Args: 77 | img_size: Input image size. 78 | patch_size: Patch size. 79 | in_chans: Number of image input channels. 80 | num_classes: Mumber of classes for classification head. 81 | global_pool: Type of global pooling for final sequence (default: 'token'). 82 | embed_dim: Transformer embedding dimension. 83 | depth: Depth of transformer. 84 | num_heads: Number of attention heads. 85 | mlp_ratio: Ratio of mlp hidden dim to embedding dim. 86 | qkv_bias: Enable bias for qkv projections if True. 87 | init_values: Layer-scale init values (layer-scale enabled if not None). 88 | class_token: Use class token. 89 | no_embed_class: Don't include position embeddings for class (or reg) tokens. 90 | reg_tokens: Number of register tokens. 91 | fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'. 92 | drop_rate: Head dropout rate. 93 | pos_drop_rate: Position embedding dropout rate. 94 | attn_drop_rate: Attention dropout rate. 95 | drop_path_rate: Stochastic depth rate. 96 | weight_init: Weight initialization scheme. 97 | embed_layer: Patch embedding layer. 98 | norm_layer: Normalization layer. 99 | act_layer: MLP activation layer. 100 | block_fn: Transformer block layer. 101 | """ 102 | super().__init__() 103 | assert global_pool in ('', 'avg', 'token', 'map') 104 | assert class_token or global_pool != 'token' 105 | use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm 106 | norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) 107 | act_layer = get_act_layer(act_layer) or nn.GELU 108 | 109 | self.num_classes = num_classes 110 | self.global_pool = global_pool 111 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 112 | self.num_prefix_tokens = 1 if class_token else 0 113 | self.num_prefix_tokens += reg_tokens 114 | self.num_reg_tokens = reg_tokens 115 | self.has_class_token = class_token 116 | self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg) 117 | self.dynamic_img_size = dynamic_img_size 118 | self.grad_checkpointing = False 119 | 120 | embed_args = {} 121 | if dynamic_img_size: 122 | # flatten deferred until after pos embed 123 | embed_args.update(dict(strict_img_size=False, output_fmt='NHWC')) 124 | self.patch_embed = embed_layer( 125 | img_size=img_size, 126 | patch_size=patch_size, 127 | in_chans=in_chans, 128 | embed_dim=embed_dim, 129 | bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) 130 | dynamic_img_pad=dynamic_img_pad, 131 | params=params, 132 | **embed_args, 133 | ) 134 | num_patches = self.patch_embed.num_patches 135 | 136 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None 137 | self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None 138 | embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens 139 | self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) 140 | self.pos_drop = nn.Dropout(p=pos_drop_rate) 141 | if patch_drop_rate > 0: 142 | self.patch_drop = PatchDropout( 143 | patch_drop_rate, 144 | num_prefix_tokens=self.num_prefix_tokens, 145 | ) 146 | else: 147 | self.patch_drop = nn.Identity() 148 | self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() 149 | 150 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 151 | 152 | ############# Added module start ############# 153 | self.patch_size = patch_size 154 | self.params = params 155 | if self.params.train_type in ['vpt','prompt_cam']: 156 | self.vpt = VPT(params, depth, patch_size, embed_dim) 157 | ############# Added module end ############# 158 | 159 | self.blocks = nn.Sequential(*[ 160 | block_fn( 161 | dim=embed_dim, 162 | num_heads=num_heads, 163 | mlp_ratio=mlp_ratio, 164 | qkv_bias=qkv_bias, 165 | qk_norm=qk_norm, 166 | init_values=init_values, 167 | proj_drop=proj_drop_rate, 168 | attn_drop=attn_drop_rate, 169 | drop_path=dpr[i], 170 | norm_layer=norm_layer, 171 | act_layer=act_layer, 172 | mlp_layer=mlp_layer, 173 | ############# Added module start ############# 174 | params=params 175 | ############# Added module end ############# 176 | ) 177 | for i in range(depth)]) 178 | self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() 179 | 180 | # Classifier Head 181 | if global_pool == 'map': 182 | self.attn_pool = AttentionPoolLatent( 183 | self.embed_dim, 184 | num_heads=num_heads, 185 | mlp_ratio=mlp_ratio, 186 | norm_layer=norm_layer, 187 | ) 188 | else: 189 | self.attn_pool = None 190 | self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() 191 | self.head_drop = nn.Dropout(drop_rate) 192 | self.num_heads= num_heads 193 | 194 | ############# Added module start ############# 195 | if self.params.train_type == 'vpt' or self.params.train_type == 'linear': 196 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 197 | elif self.params.train_type == 'prompt_cam': 198 | self.head = nn.Linear(self.embed_dim, 1) 199 | ############# Added module end ############# 200 | 201 | 202 | if weight_init != 'skip': 203 | self.init_weights(weight_init) 204 | 205 | @torch.jit.ignore() 206 | def load_pretrained(self, checkpoint_path: str, prefix: str = '') -> None: 207 | _load_weights_PETL(self, checkpoint_path, prefix) 208 | 209 | def forward_features(self, x: torch.Tensor, blur_head_lst=[], target_cls=-1) -> Tuple[torch.Tensor,torch.Tensor]: 210 | pcam_outputs = None 211 | x = self.patch_embed(x) 212 | x = self._pos_embed(x) 213 | x = self.patch_drop(x) 214 | x = self.norm_pre(x) 215 | 216 | if self.grad_checkpointing and not torch.jit.is_scripting(): 217 | x = checkpoint_seq(self.blocks, x) 218 | else: 219 | ############# Added module ############# 220 | for idx, block in enumerate(self.blocks): 221 | if self.params.train_type in ['vpt','prompt_cam']: 222 | prompt = self.vpt.retrieve_prompt(idx, x.shape[0]) 223 | if prompt is not None: 224 | x = torch.cat([prompt, x], dim=1) 225 | 226 | # forward block 227 | if idx == len(self.blocks) - 1: 228 | x,attn_map = block(x, idx, blur_head_lst=blur_head_lst, target_cls=target_cls) 229 | else: 230 | x,_ = block(x, idx) 231 | 232 | if self.params.vpt_mode and prompt is not None: 233 | x = x[:, self.params.vpt_num:, :] 234 | elif self.params.train_type == 'prompt_cam': 235 | pcam_outputs = x 236 | x = x[:, self.params.vpt_num:, :] 237 | 238 | if self.params.train_type == 'prompt_cam': 239 | x = pcam_outputs 240 | ############# Added module end ############# 241 | x = self.norm(x) 242 | return x,attn_map 243 | 244 | def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: 245 | #import pdb;pdb.set_trace() 246 | if self.params.train_type == 'prompt_cam': 247 | output_feature = x[:,:self.params.vpt_num] 248 | else: 249 | if self.attn_pool is not None: 250 | output_feature = self.attn_pool(x) 251 | elif self.global_pool == 'avg': 252 | output_feature = x[:, self.num_prefix_tokens:].mean(dim=1) 253 | elif self.global_pool: 254 | output_feature = x[:, 0] # class token 255 | else: 256 | output_feature = x 257 | 258 | output_feature = self.fc_norm(output_feature) 259 | output_feature = self.head_drop(output_feature) 260 | 261 | return output_feature if pre_logits else self.head(output_feature) 262 | 263 | def forward(self, x: torch.Tensor, blur_head_lst=[], target_cls=-1) -> Tuple[torch.Tensor,torch.Tensor]: 264 | attn_maps = None 265 | if self.params.vis_attn: 266 | x,attn_maps = self.forward_features(x, 267 | blur_head_lst=blur_head_lst, 268 | target_cls=target_cls) 269 | else: 270 | x,_ = self.forward_features(x) 271 | 272 | x = self.forward_head(x) 273 | return x, attn_maps 274 | 275 | def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): 276 | self.num_classes = num_classes 277 | if global_pool is not None: 278 | assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map') 279 | if global_pool == 'map' and self.attn_pool is None: 280 | assert False, "Cannot currently add attention pooling in reset_classifier()." 281 | elif global_pool != 'map' and self.attn_pool is not None: 282 | self.attn_pool = None # remove attention pooling 283 | self.global_pool = global_pool 284 | ############# Added module ############# 285 | if self.params.train_type == 'vpt' or self.params.train_type == 'linear': 286 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 287 | else: 288 | self.head = nn.Linear(self.embed_dim, 1) 289 | ############# Added module end ############# 290 | #Original 291 | # self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 292 | 293 | 294 | @torch.no_grad() 295 | def _load_weights_PETL(model: VisionTransformerPETL, checkpoint_path: str, prefix: str = ''): 296 | if checkpoint_path.endswith('.npz'): 297 | _load_weights(model, checkpoint_path, prefix) 298 | elif checkpoint_path.endswith('.pth') or checkpoint_path.endswith('.bin'): 299 | _load_weights_pth(model, checkpoint_path, checkpoint_filter_fn) 300 | 301 | 302 | def _load_weights_pth(model, checkpoint_path, filter_fn=checkpoint_filter_fn): 303 | """ Load weights from .pth checkpoints 304 | """ 305 | state_dict = torch.load(checkpoint_path, map_location='cpu') 306 | if filter_fn is not None: 307 | state_dict = filter_fn(state_dict, model) 308 | if 'head.weight' in state_dict: 309 | state_dict.pop('head.weight', None) 310 | if 'head.bias' in state_dict: 311 | state_dict.pop('head.bias', None) 312 | model.load_state_dict(state_dict, strict=False) 313 | 314 | 315 | def _create_vision_transformer_petl(variant: str, pretrained: bool = False, **kwargs): 316 | if kwargs.get('features_only', None): 317 | raise RuntimeError('features_only not implemented for Vision Transformer models.') 318 | 319 | if 'flexi' in variant: 320 | # Google FlexiViT pretrained models have a strong preference for bilinear patch / embed 321 | # interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation. 322 | _filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False) 323 | else: 324 | _filter_fn = checkpoint_filter_fn 325 | 326 | # attn pool (currently only in siglip) params removed if pool disabled, is there a better soln? 327 | strict = True 328 | if 'siglip' in variant and kwargs.get('global_pool', None) != 'map': 329 | strict = False 330 | 331 | return build_model_with_cfg( 332 | VisionTransformerPETL, 333 | variant, 334 | pretrained, 335 | pretrained_filter_fn=checkpoint_filter_fn, 336 | pretrained_strict=strict, 337 | **kwargs, 338 | ) 339 | 340 | 341 | @register_model 342 | def vit_base_patch14_dinov2_petl(pretrained: bool = False, **kwargs): 343 | """ ViT-B/14 for DINOv2 344 | change img_size to 224 345 | """ 346 | model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, img_size=224) 347 | model = _create_vision_transformer_petl( 348 | 'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) 349 | return model 350 | 351 | @register_model 352 | def vit_base_patch16_dino_petl(pretrained: bool = False, **kwargs): 353 | """ ViT-B/16 for DINO 354 | """ 355 | model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12) 356 | model = _create_vision_transformer_petl( 357 | 'vit_base_patch16_224.dino', pretrained=pretrained, **dict(model_args, **kwargs)) 358 | return model 359 | 360 | @register_model 361 | def vit_base_patch16_224_in21k_petl(pretrained=False, **kwargs): 362 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 363 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 364 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 365 | """ 366 | model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12) 367 | model = _create_vision_transformer_petl( 368 | 'vit_base_patch16_224_in21k', pretrained=pretrained, **dict(model_args, **kwargs)) 369 | return model 370 | 371 | 372 | @register_model 373 | def vit_base_patch16_clip_224_petl(pretrained: bool = False, **kwargs) -> VisionTransformer: 374 | """ ViT-B/16 CLIP image tower 375 | """ 376 | model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, 377 | act_layer='quick_gelu') 378 | model = _create_vision_transformer_petl( 379 | 'vit_base_patch16_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs)) 380 | return model 381 | -------------------------------------------------------------------------------- /model/vpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import reduce 4 | from operator import mul 5 | import math 6 | from torch.nn.modules.utils import _pair 7 | 8 | 9 | class VPT(nn.Module): 10 | 11 | def __init__(self, params, depth, patch_size, embed_dim): 12 | super().__init__() 13 | self.params = params 14 | self.depth = depth 15 | if params.vpt_mode == 'shallow': 16 | prompt_layer = 1 17 | elif params.vpt_mode == 'deep': 18 | if params.vpt_layer: 19 | prompt_layer = params.vpt_layer 20 | else: 21 | prompt_layer = depth 22 | elif params.train_type == 'prompt_cam': 23 | prompt_layer = depth 24 | else: 25 | raise ValueError 26 | val = math.sqrt(6. / float(3 * reduce(mul, _pair(patch_size), 1) + embed_dim)) 27 | self.prompt_embeddings = nn.Parameter(torch.zeros( 28 | prompt_layer, params.vpt_num, embed_dim)) 29 | # xavier_uniform initialization 30 | nn.init.uniform_(self.prompt_embeddings.data, -val, val) 31 | self.prompt_dropout = nn.Dropout(params.vpt_dropout) 32 | 33 | 34 | def retrieve_prompt(self, index, batch_size): 35 | if self.params.vpt_layer: 36 | index = index - (self.depth - self.params.vpt_layer) 37 | if index < 0: 38 | return None 39 | if index < len(self.prompt_embeddings): 40 | return self.prompt_dropout(self.prompt_embeddings[index]).expand(batch_size, -1, -1) 41 | else: 42 | return None 43 | -------------------------------------------------------------------------------- /samples/Baltimore_Oriole.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Imageomics/Prompt_CAM/2990095277bdaa7bbdc77a079693243925e8a855/samples/Baltimore_Oriole.jpg -------------------------------------------------------------------------------- /samples/Brewer_Blackbird.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Imageomics/Prompt_CAM/2990095277bdaa7bbdc77a079693243925e8a855/samples/Brewer_Blackbird.jpg -------------------------------------------------------------------------------- /samples/Orchard_Oriole.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Imageomics/Prompt_CAM/2990095277bdaa7bbdc77a079693243925e8a855/samples/Orchard_Oriole.jpg -------------------------------------------------------------------------------- /samples/Scott_Oriole.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Imageomics/Prompt_CAM/2990095277bdaa7bbdc77a079693243925e8a855/samples/Scott_Oriole.jpg -------------------------------------------------------------------------------- /samples/red_winged_blackbird.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Imageomics/Prompt_CAM/2990095277bdaa7bbdc77a079693243925e8a855/samples/red_winged_blackbird.jpg -------------------------------------------------------------------------------- /samples/rusty_Blackbird.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Imageomics/Prompt_CAM/2990095277bdaa7bbdc77a079693243925e8a855/samples/rusty_Blackbird.jpg -------------------------------------------------------------------------------- /samples/sample_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Imageomics/Prompt_CAM/2990095277bdaa7bbdc77a079693243925e8a855/samples/sample_image.png -------------------------------------------------------------------------------- /samples/trait_manipulation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Imageomics/Prompt_CAM/2990095277bdaa7bbdc77a079693243925e8a855/samples/trait_manipulation.jpg -------------------------------------------------------------------------------- /samples/yellow_headed_blackbird.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Imageomics/Prompt_CAM/2990095277bdaa7bbdc77a079693243925e8a855/samples/yellow_headed_blackbird.jpg -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Distributed helpers.""" 4 | 5 | import torch 6 | import torch.distributed as dist 7 | _LOCAL_PROCESS_GROUP = None 8 | 9 | 10 | def get_world_size() -> int: 11 | if not dist.is_available(): 12 | return 1 13 | if not dist.is_initialized(): 14 | return 1 15 | return dist.get_world_size() 16 | 17 | 18 | def get_rank() -> int: 19 | if not dist.is_available(): 20 | return 0 21 | if not dist.is_initialized(): 22 | return 0 23 | return dist.get_rank() 24 | 25 | 26 | def is_master_process(num_gpus=8): 27 | """ 28 | Determines if the current process is the master process. 29 | """ 30 | if torch.distributed.is_initialized(): 31 | return dist.get_rank() % num_gpus == 0 32 | else: 33 | return True 34 | 35 | 36 | def run( 37 | local_rank, 38 | num_proc, 39 | func, 40 | init_method, 41 | shard_id, 42 | num_shards, 43 | backend, 44 | cfg, 45 | args, 46 | ): 47 | """ 48 | Runs a function from a child process. 49 | Args: 50 | local_rank (int): rank of the current process on the current machine. 51 | num_proc (int): number of processes per machine. 52 | func (function): function to execute on each of the process. 53 | init_method (string): method to initialize the distributed training. 54 | TCP initialization: equiring a network address reachable from all 55 | processes followed by the port. 56 | Shared file-system initialization: makes use of a file system that 57 | is shared and visible from all machines. The URL should start with 58 | file:// and contain a path to a non-existent file on a shared file 59 | system. 60 | shard_id (int): the rank of the current machine. 61 | num_shards (int): number of overall machines for the distributed 62 | training job. 63 | backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are 64 | supports, each with different capabilities. Details can be found 65 | here: 66 | https://pytorch.org/docs/stable/distributed.html 67 | cfg (CfgNode): configs. Details can be found in 68 | loco/config/defaults.py 69 | """ 70 | # Initialize the process group. 71 | # shard_id = get_rank() 72 | world_size = num_proc * num_shards 73 | rank = shard_id * num_proc + local_rank 74 | 75 | try: 76 | torch.distributed.init_process_group( 77 | backend=backend, 78 | init_method=init_method, 79 | world_size=world_size, 80 | rank=rank, 81 | ) 82 | except Exception as e: 83 | raise e 84 | 85 | torch.cuda.set_device(local_rank) 86 | func(cfg, args) 87 | 88 | 89 | def destroy_process_group(): 90 | """Destroys the default process group.""" 91 | torch.distributed.destroy_process_group() 92 | 93 | 94 | def scaled_all_reduce(cfg, tensors): 95 | """Performs the scaled all_reduce operation on the provided tensors. 96 | 97 | The input tensors are modified in-place. Currently supports only the sum 98 | reduction operator. The reduced values are scaled by the inverse size of 99 | the process group (equivalent to cfg.NUM_GPUS). 100 | """ 101 | # Queue the reductions 102 | reductions = [] 103 | for tensor in tensors: 104 | reduction = torch.distributed.all_reduce(tensor, async_op=True) 105 | reductions.append(reduction) 106 | # Wait for reductions to finish 107 | for reduction in reductions: 108 | reduction.wait() 109 | # Scale the results 110 | for tensor in tensors: 111 | tensor.mul_(1.0 / cfg.NUM_GPUS / cfg.NUM_SHARDS) 112 | return tensors 113 | 114 | 115 | def cat_all_gather(tensors): 116 | """Performs the concatenated all_gather operation on the provided tensors. 117 | """ 118 | tensors_gather = [ 119 | torch.ones_like(tensors) 120 | for _ in range(torch.distributed.get_world_size()) 121 | ] 122 | torch.distributed.all_gather(tensors_gather, tensors, async_op=False) 123 | 124 | output = torch.cat(tensors_gather, dim=0) 125 | return output 126 | 127 | 128 | def local_cat_all_gather(tensors): 129 | """Performs the concatenated all_gather operation on the provided tensors. 130 | """ 131 | tensors_gather = [ 132 | torch.ones_like(tensors) 133 | for _ in range(get_local_size()) 134 | ] 135 | torch.distributed.all_gather( 136 | tensors_gather, 137 | tensors, 138 | async_op=False, 139 | group=_LOCAL_PROCESS_GROUP, 140 | ) 141 | output = torch.cat(tensors_gather, dim=0) 142 | return output 143 | 144 | 145 | def get_local_size(): 146 | """ 147 | Returns: 148 | The size of the per-machine process group, 149 | i.e. the number of processes per machine. 150 | """ 151 | if not dist.is_available(): 152 | return 1 153 | if not dist.is_initialized(): 154 | return 1 155 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 156 | 157 | 158 | def get_local_rank(): 159 | """ 160 | Returns: 161 | The rank of the current process within the local (per-machine) process group. 162 | """ 163 | if not dist.is_available(): 164 | return 0 165 | if not dist.is_initialized(): 166 | return 0 167 | assert _LOCAL_PROCESS_GROUP is not None 168 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 169 | -------------------------------------------------------------------------------- /utils/file_io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Project specific pathmanagers for a project as recommended by Detectron2 5 | """ 6 | from iopath.common.file_io import PathManager as PathManagerBase 7 | from iopath.common.file_io import HTTPURLHandler 8 | 9 | 10 | PathManager = PathManagerBase() 11 | PathManager.register_handler(HTTPURLHandler()) 12 | -------------------------------------------------------------------------------- /utils/global_var.py: -------------------------------------------------------------------------------- 1 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ 2 | OPENAI_CLIP_MEAN, OPENAI_CLIP_STD 3 | 4 | TFDS_DATASETS = { 5 | 'caltech101': 102, 6 | 'cifar(num_classes=100)': 100, 7 | 'dtd': 47, 8 | 'oxford_flowers102': 102, 9 | 'oxford_iiit_pet': 37, 10 | 'patch_camelyon': 2, 11 | 'sun397': 397, 12 | 'svhn': 10, 13 | 'resisc45': 45, 14 | 'eurosat': 10, 15 | 'dmlab': 6, 16 | 'kitti(task="closest_vehicle_distance")': 4, 17 | 'smallnorb(predicted_attribute="label_azimuth")': 18, 18 | 'smallnorb(predicted_attribute="label_elevation")': 9, 19 | 'dsprites(predicted_attribute="label_x_position",num_classes=16)': 16, 20 | 'dsprites(predicted_attribute="label_orientation",num_classes=16)': 16, 21 | 'clevr(task="closest_object_distance")': 6, 22 | 'clevr(task="count_all")': 8, 23 | 'diabetic_retinopathy(config="btgraham-300")': 5 24 | } 25 | 26 | VTAB_DATASETS = {'caltech101': 102, 'clevr_count': 8, 'dmlab': 6, 'dsprites_ori': 16, 'eurosat': 10, 'oxford_flowers102': 102, 27 | 'patch_camelyon': 2, 28 | 'smallnorb_azi': 18, 'svhn': 10, 'cifar': 100, 'clevr_dist': 6, 'dsprites_loc': 16, 'dtd': 47, 29 | 'kitti': 4, 'oxford_iiit_pet': 37, 'resisc45': 45, 30 | 'smallnorb_ele': 9, 'sun397': 397, 'diabetic_retinopathy': 5} 31 | 32 | MEAN_STD_DICT = { 33 | 'vit_base_patch16_224_in21k': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 34 | 'vit_base_patch14_dinov2': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 35 | 'vit_base_mae': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 36 | 'vit_base_patch16_clip_224': (OPENAI_CLIP_MEAN, OPENAI_CLIP_STD) 37 | } 38 | 39 | CIFAR_VAL_SPLIT = { 40 | 'cifar100-500': 0.2, 41 | 'cifar100-1k': 0.2, 42 | 'cifar100-5k': 0.1, 43 | 'cifar100-10k': 0.1, 44 | 'cifar100-full': 0.1 45 | } 46 | 47 | RETINOPATHY_VAL_SPLIT = { 48 | 'retinopathy-500': 0.2, 49 | 'retinopathy-5k': 0.1, 50 | 'retinopathy-10k': 0.1, 51 | 'retinopathy-full': 0.1 52 | } 53 | 54 | RESISC_VAL_SPLIT = { 55 | 'resisc-225': 0.2, 56 | 'resisc-450': 0.2, 57 | 'resisc-900': 0.1, 58 | 'resisc-2250': 0.1, 59 | 'resisc-4500': 0.1, 60 | 'resisc-9000': 0.1, 61 | 'resisc-full': 0.1 62 | 63 | } 64 | 65 | 66 | OUTPUT_DIR = "./output" 67 | TUNE_DIR = "./tune_output" 68 | TUNE_DIR_TEST = "./tune_output_test" 69 | -------------------------------------------------------------------------------- /utils/log_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pprint 3 | import sys 4 | from collections import defaultdict 5 | import torch 6 | from tabulate import tabulate 7 | 8 | from utils.distributed import get_rank, get_world_size 9 | from utils.setup_logging import setup_logging 10 | 11 | 12 | def get_env_module(): 13 | var_name = "ENV_MODULE" 14 | return var_name, os.environ.get(var_name, "") 15 | 16 | 17 | def collect_torch_env() -> str: 18 | try: 19 | import torch.__config__ 20 | 21 | return torch.__config__.show() 22 | except ImportError: 23 | # compatible with older versions of pytorch 24 | from torch.utils.collect_env import get_pretty_env_info 25 | 26 | return get_pretty_env_info() 27 | 28 | 29 | def collect_env_info(): 30 | data = [] 31 | data.append(("Python", sys.version.replace("\n", ""))) 32 | data.append(get_env_module()) 33 | data.append(("PyTorch", torch.__version__)) 34 | data.append(("PyTorch Debug Build", torch.version.debug)) 35 | 36 | has_cuda = torch.cuda.is_available() 37 | data.append(("CUDA available", has_cuda)) 38 | if has_cuda: 39 | data.append(("CUDA ID", os.environ["CUDA_VISIBLE_DEVICES"])) 40 | devices = defaultdict(list) 41 | for k in range(torch.cuda.device_count()): 42 | devices[torch.cuda.get_device_name(k)].append(str(k)) 43 | for name, devids in devices.items(): 44 | data.append(("GPU " + ",".join(devids), name)) 45 | 46 | env_str = tabulate(data) + "\n" 47 | env_str += collect_torch_env() 48 | return env_str 49 | 50 | 51 | def logging_env_setup(params) -> None: 52 | logger = setup_logging( 53 | params.gpu_num, get_world_size(), params.output_dir, name="Prompt_CAM") 54 | 55 | # Log basic information about environment, cmdline arguments, and config 56 | rank = get_rank() 57 | logger.info( 58 | f"Rank of current process: {rank}. World size: {get_world_size()}") 59 | logger.info("Environment info:\n" + collect_env_info()) 60 | 61 | 62 | # Show the config 63 | logger.info("Training with config:") 64 | logger.info(pprint.pformat(params)) 65 | 66 | def log_model_info(model, logger, verbose=False): 67 | """Logs model info""" 68 | if verbose: 69 | logger.info(f"Classification Model:\n{model}") 70 | model_total_params = sum(p.numel() for p in model.parameters()) 71 | model_grad_params = sum( 72 | p.numel() for p in model.parameters() if p.requires_grad) 73 | model_grad_params_no_head = sum(p.numel() for n, p in model.named_parameters() if p.requires_grad and 'head' not in n) 74 | logger.info("Total Parameters: {0}\t Gradient Parameters: {1}\t Gradient Parameters No Head: {2}".format( 75 | model_total_params, model_grad_params, model_grad_params_no_head)) 76 | logger.info(f"total tuned percent:{(model_grad_params/model_total_params*100):.2f} %") 77 | logger.info(f"total tuned percent no head:{(model_grad_params_no_head / model_total_params * 100):.2f} %") 78 | 79 | # Print the names of parameters that require gradient updates 80 | fine_tuned_params = [n for n, p in model.named_parameters() if p.requires_grad] 81 | 82 | logger.info("Fine-tuned Parameters:") 83 | for param_name in fine_tuned_params: 84 | logger.info(param_name) 85 | 86 | return model_grad_params_no_head 87 | 88 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import time 5 | import yaml 6 | from dotwiz import DotWiz 7 | 8 | class AverageMeter(object): 9 | """Computes and stores the average and current value""" 10 | def __init__(self, name=None, fmt=':f'): 11 | self.name = name 12 | self.fmt = fmt 13 | self.reset() 14 | 15 | def reset(self): 16 | self.val = 0 17 | self.avg = 0 18 | self.sum = 0 19 | self.count = 0 20 | 21 | def update(self, val, n=1): 22 | self.val = val 23 | self.sum += val * n 24 | self.count += n 25 | self.avg = self.sum / self.count 26 | 27 | def __str__(self): 28 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 29 | return fmtstr.format(**self.__dict__) 30 | 31 | 32 | def method_name(params): 33 | name = '' 34 | if params.train_type == 'prompt_cam': 35 | name += 'pcam_' 36 | name += params.train_type + '_' 37 | name += str(params.vpt_num) + '_' 38 | elif params.vpt_mode: 39 | name += 'vpt_' 40 | name += params.vpt_mode + '_' 41 | name += str(params.vpt_num) + '_' 42 | name += str(params.vpt_layer) + '_' 43 | #####if nothing, linear 44 | if name == '': 45 | name += 'linear' + '_' 46 | name += params.optimizer 47 | return name 48 | 49 | 50 | def set_seed(random_seed=42): 51 | np.random.seed(random_seed) 52 | random.seed(random_seed) 53 | torch.manual_seed(random_seed) 54 | if torch.cuda.is_available(): 55 | torch.cuda.manual_seed(random_seed) 56 | torch.cuda.manual_seed_all(random_seed) 57 | torch.backends.cudnn.deterministic = True 58 | torch.backends.cudnn.benchmark = False 59 | 60 | 61 | @torch.no_grad() 62 | def throughput(model,img_size=224,bs=1): 63 | with torch.no_grad(): 64 | x = torch.randn(bs, 3, img_size, img_size).cuda() 65 | batch_size=x.shape[0] 66 | # model=create_model('vit_base_patch16_224_in21k', checkpoint_path='./ViT-B_16.npz', drop_path_rate=0.1) 67 | model.eval() 68 | for i in range(50): 69 | model(x) 70 | torch.cuda.synchronize() 71 | print(f"throughput averaged with 30 times") 72 | tic1 = time.time() 73 | for i in range(30): 74 | model(x) 75 | torch.cuda.synchronize() 76 | tic2 = time.time() 77 | print(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}") 78 | MB = 1024.0 * 1024.0 79 | print('memory:', torch.cuda.max_memory_allocated() / MB) 80 | 81 | def load_yaml(path): 82 | with open(path, 'r') as stream: 83 | try: 84 | return DotWiz(yaml.load(stream, Loader=yaml.FullLoader)) 85 | except yaml.YAMLError as exc: 86 | print(exc) 87 | 88 | def override_args_with_yaml(args, yaml_config): 89 | """Override argparse args with values from YAML if they exist.""" 90 | for key, value in yaml_config.items(): 91 | if hasattr(args, key): 92 | setattr(args, key, value) 93 | return args 94 | 95 | def load_vis_args_with_yaml(args,yaml_config_path,checkpoint_path): 96 | """Create args with yaml for notebook""" 97 | yaml_config = load_yaml(yaml_config_path) 98 | for key, value in yaml_config.items(): 99 | setattr(args, key, value) 100 | 101 | set_seed(args.random_seed) 102 | args.checkpoint = checkpoint_path 103 | args.test_batch_size= 1 104 | return args 105 | 106 | class EarlyStop: 107 | def __init__(self, patience=1, min_delta=0): 108 | self.patience = patience 109 | self.min_delta = min_delta 110 | self.counter = 0 111 | self.max_metrics = None 112 | 113 | def early_stop(self, eval_metrics): 114 | ''' 115 | 116 | :param val_acc: 117 | :return: bool(if early stop), bool(if save model) 118 | ''' 119 | if self.max_metrics is None: 120 | self.max_metrics = eval_metrics 121 | if eval_metrics['top1'] > self.max_metrics['top1']: 122 | self.max_metrics = eval_metrics 123 | self.counter = 0 124 | return False, True 125 | elif eval_metrics['top1'] < (self.max_metrics['top1'] - self.min_delta): 126 | self.counter += 1 127 | if self.counter >= self.patience: 128 | return True, False 129 | return False, False 130 | -------------------------------------------------------------------------------- /utils/setup_logging.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Logging.""" 4 | 5 | import builtins 6 | import functools 7 | import logging 8 | import sys 9 | import os 10 | from termcolor import colored 11 | 12 | from .distributed import is_master_process 13 | from .file_io import PathManager 14 | 15 | # Show filename and line number in logs 16 | _FORMAT = "[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s" 17 | 18 | 19 | def _suppress_print(): 20 | """Suppresses printing from the current process.""" 21 | 22 | def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False): 23 | pass 24 | 25 | builtins.print = print_pass 26 | 27 | 28 | # cache the opened file object, so that different calls to `setup_logger` 29 | # with the same file name can safely write to the same file. 30 | @functools.lru_cache(maxsize=None) 31 | def _cached_log_stream(filename): 32 | return PathManager.open(filename, "a") 33 | 34 | 35 | @functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers # noqa 36 | def setup_logging( 37 | num_gpu, num_shards, output="", name="Prompt_CAM", color=True): 38 | """Sets up the logging.""" 39 | # Enable logging only for the master process 40 | if is_master_process(num_gpu): 41 | # Clear the root logger to prevent any existing logging config 42 | # (e.g. set by another module) from messing with our setup 43 | logging.root.handlers = [] 44 | # Configure logging 45 | logging.basicConfig( 46 | level=logging.INFO, format=_FORMAT, stream=sys.stdout 47 | ) 48 | else: 49 | _suppress_print() 50 | 51 | if name is None: 52 | name = __name__ 53 | logger = logging.getLogger(name) 54 | # remove any lingering handler 55 | logger.handlers.clear() 56 | 57 | logger.setLevel(logging.INFO) 58 | logger.propagate = False 59 | 60 | plain_formatter = logging.Formatter( 61 | "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s", 62 | datefmt="%m/%d %H:%M:%S", 63 | ) 64 | if color: 65 | formatter = _ColorfulFormatter( 66 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 67 | datefmt="%m/%d %H:%M:%S", 68 | root_name=name, 69 | abbrev_name=str(name), 70 | ) 71 | else: 72 | formatter = plain_formatter 73 | 74 | if is_master_process(num_gpu): 75 | ch = logging.StreamHandler(stream=sys.stdout) 76 | ch.setLevel(logging.DEBUG) 77 | ch.setFormatter(formatter) 78 | logger.addHandler(ch) 79 | 80 | if is_master_process(num_gpu * num_shards): 81 | if len(output) > 0: 82 | if output.endswith(".txt") or output.endswith(".log"): 83 | filename = output 84 | else: 85 | filename = os.path.join(output, "logs.txt") 86 | 87 | PathManager.mkdirs(os.path.dirname(filename)) 88 | 89 | fh = logging.StreamHandler(_cached_log_stream(filename)) 90 | fh.setLevel(logging.DEBUG) 91 | fh.setFormatter(plain_formatter) 92 | logger.addHandler(fh) 93 | return logger 94 | 95 | 96 | def setup_single_logging(name, output=""): 97 | """Sets up the logging.""" 98 | # Enable logging only for the master process 99 | # Clear the root logger to prevent any existing logging config 100 | # (e.g. set by another module) from messing with our setup 101 | logging.root.handlers = [] 102 | # Configure logging 103 | logging.basicConfig( 104 | level=logging.INFO, format=_FORMAT, stream=sys.stdout 105 | ) 106 | 107 | if len(name) == 0: 108 | name = __name__ 109 | logger = logging.getLogger(name) 110 | logger.setLevel(logging.INFO) 111 | logger.propagate = False 112 | 113 | plain_formatter = logging.Formatter( 114 | "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s", 115 | datefmt="%m/%d %H:%M:%S", 116 | ) 117 | formatter = _ColorfulFormatter( 118 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 119 | datefmt="%m/%d %H:%M:%S", 120 | root_name=name, 121 | abbrev_name=str(name), 122 | ) 123 | 124 | ch = logging.StreamHandler(stream=sys.stdout) 125 | ch.setLevel(logging.DEBUG) 126 | ch.setFormatter(formatter) 127 | logger.addHandler(ch) 128 | 129 | if len(output) > 0: 130 | if output.endswith(".txt") or output.endswith(".log"): 131 | filename = output 132 | else: 133 | filename = os.path.join(output, "logs.txt") 134 | 135 | PathManager.mkdirs(os.path.dirname(filename)) 136 | 137 | fh = logging.StreamHandler(_cached_log_stream(filename)) 138 | fh.setLevel(logging.DEBUG) 139 | fh.setFormatter(plain_formatter) 140 | logger.addHandler(fh) 141 | 142 | return logger 143 | 144 | 145 | def get_logger(name): 146 | """Retrieves the logger.""" 147 | return logging.getLogger(name) 148 | 149 | 150 | 151 | class _ColorfulFormatter(logging.Formatter): 152 | # from detectron2 153 | def __init__(self, *args, **kwargs): 154 | self._root_name = kwargs.pop("root_name") + "." 155 | self._abbrev_name = kwargs.pop("abbrev_name", "") 156 | if len(self._abbrev_name): 157 | self._abbrev_name = self._abbrev_name + "." 158 | super(_ColorfulFormatter, self).__init__(*args, **kwargs) 159 | 160 | def formatMessage(self, record: logging.LogRecord) -> str: 161 | record.name = record.name.replace(self._root_name, self._abbrev_name) 162 | log = super(_ColorfulFormatter, self).formatMessage(record) 163 | if record.levelno == logging.WARNING: 164 | prefix = colored("WARNING", "red", attrs=["blink"]) 165 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 166 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 167 | else: 168 | return log 169 | return prefix + " " + log 170 | -------------------------------------------------------------------------------- /utils/visual_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import warnings 5 | 6 | import numpy as np 7 | import random 8 | import cv2 9 | import shutil 10 | 11 | 12 | from time import sleep 13 | from random import randint 14 | 15 | from PIL import Image 16 | 17 | import torch.nn as nn 18 | import matplotlib.pyplot as plt 19 | import torch.nn.functional as F 20 | 21 | parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 22 | sys.path.append(parent_dir) 23 | 24 | from data.dataset.utils import get_transformation 25 | 26 | 27 | warnings.filterwarnings("ignore") 28 | 29 | def combine_images(path, pred_class,resize_dim=(200,200)): 30 | images = [os.path.join(path, image) for image in os.listdir(path) if image.endswith('.jpg')] 31 | images.sort(key=lambda x: int(os.path.basename(x).split('.')[0])) 32 | imgs = [Image.open(image).resize(resize_dim) for image in images] 33 | 34 | widths, heights = zip(*(img.size for img in imgs)) 35 | 36 | total_width = sum(widths) 37 | max_height = max(heights) 38 | merged_image = Image.new('RGB', (total_width, max_height)) 39 | 40 | x_offset = 0 41 | for img in imgs: 42 | merged_image.paste(img, (x_offset, 0)) 43 | x_offset += img.width 44 | merged_image.save(path + "/" + "concatenated_prediction_"+str(pred_class)+".jpg") 45 | 46 | for image in images: 47 | #print(image) 48 | os.remove(image) 49 | 50 | def SuperImposeHeatmap(attention, input_image): 51 | alpha = 0.5 52 | 53 | attention_resized = cv2.resize(attention, (input_image.shape[1], input_image.shape[0]), interpolation=cv2.INTER_CUBIC) 54 | 55 | # Check if the attention map is already normalized 56 | min_val = attention_resized.min() 57 | max_val = attention_resized.max() 58 | attention_normalized = (attention_resized - min_val) / (max_val - min_val) 59 | 60 | 61 | # Apply Gaussian blur for smoothing 62 | attention_normalized = cv2.GaussianBlur(attention_normalized, (9, 9), 0) 63 | 64 | # Convert to heatmap 65 | heatmap = (attention_normalized * 255).astype(np.uint8) 66 | heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) 67 | 68 | # Superimpose the heatmap on the original image 69 | result = (input_image * alpha + heatmap * (1 - alpha)).astype(np.uint8) 70 | return result 71 | 72 | def create_overlay_images(X,patch_size,attentions,output_folder): 73 | w_featmap = X.shape[-1] // patch_size 74 | attentions =attentions[0] 75 | 76 | nh = attentions.shape[0] 77 | if os.path.exists(output_folder): 78 | shutil.rmtree(output_folder) 79 | os.makedirs(output_folder, exist_ok=True) 80 | 81 | image = X[0].detach().cpu().numpy().transpose(1, 2, 0) # Shape (H, W, C) 82 | image = (image - image.min()) / (image.max() - image.min()) 83 | image = (image * 255).astype(np.uint8) 84 | 85 | cv2.imwrite(os.path.join(output_folder, "0.jpg"), cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) 86 | 87 | for head in range(nh): 88 | attention_map = attentions[head].reshape(w_featmap, w_featmap).detach().cpu().numpy() 89 | result_image = SuperImposeHeatmap(attention_map, image) 90 | 91 | # Save the overlayed image 92 | save_path = os.path.join(output_folder, f"{head+1}.jpg") 93 | cv2.imwrite(save_path, result_image) 94 | 95 | 96 | def prune_and_plot_ranked_heads(model,inputs,target, params): 97 | if params.top_traits<1 and params.top_traits>model.num_heads: 98 | raise notImplementedError("top_traits must be greater than 0 and less than the number of heads") 99 | 100 | remaining_head_list = list(range(model.num_heads)) 101 | pruned_head_index = None 102 | blur_head_lst = [] 103 | blur_head_probs = [] 104 | 105 | # Determine the ranking of heads by iteratively finding the head that, when blurred, 106 | # gives the highest probability for the target. 107 | while len(remaining_head_list) > 0: 108 | highest_score=-1e8 109 | remaining_head_scores= [] 110 | 111 | for head_idx in remaining_head_list: 112 | output,_ = model(inputs, 113 | blur_head_lst=blur_head_lst+[head_idx], 114 | target_cls=target) 115 | 116 | probabilities = torch.softmax(output.squeeze(-1), dim=-1) 117 | 118 | remaining_head_scores.append(probabilities[0,target].item()) 119 | 120 | if remaining_head_scores[-1] > highest_score: 121 | highest_score=remaining_head_scores[-1] 122 | pruned_head_index=head_idx 123 | 124 | if pruned_head_index is not None: 125 | blur_head_lst.append(pruned_head_index) 126 | remaining_head_list.remove(pruned_head_index) 127 | blur_head_probs.append(highest_score) 128 | 129 | #### Convert the image for overlaying the attention maps 130 | image = inputs[0].detach().cpu().numpy().transpose(1, 2, 0) # Shape (H, W, C) 131 | image = (image - image.min()) / (image.max() - image.min()) 132 | image = (image * 255).astype(np.uint8) 133 | w_featmap = inputs.shape[-1] // model.patch_size 134 | 135 | ### Go through all the attention maps and overlay them on the image 136 | _,attn_maps = model(inputs) 137 | attn_maps = attn_maps[:, :, target, (params.vpt_num+1):] 138 | overlayed_attn_maps = [] 139 | 140 | for head in range(model.num_heads): 141 | attention_map = attn_maps[0,head].reshape(w_featmap, w_featmap).detach().cpu().numpy() 142 | overlayed_attn_maps.append(cv2.cvtColor(SuperImposeHeatmap(attention_map, image),cv2.COLOR_BGR2RGB)) 143 | 144 | #### Create captions for all the plot 145 | captions = ['input image'] 146 | for head in range(model.num_heads): 147 | captions.append(f'rank {head+1}') 148 | 149 | print(f'Head # (from most important to least important):{blur_head_lst[::-1]}') 150 | #### Create plot with overlayed attention maps 151 | for bl_idx in range(1,len(blur_head_lst)+1): 152 | if bl_idx == params.top_traits+1: 153 | break 154 | current_blr_head_lst = blur_head_lst[:-bl_idx] 155 | remaining_head_list = remaining_head_list+ [blur_head_lst[-bl_idx]] 156 | 157 | current_images = [image] 158 | for r_idx in remaining_head_list: 159 | current_images+=[overlayed_attn_maps[r_idx]] 160 | 161 | # Create a grid of images 162 | grid_size = (1,len(current_images)) 163 | fig, axes = plt.subplots(*grid_size, figsize=(10, 10)) 164 | for i, ax in enumerate(axes.flatten()): 165 | ax.imshow(current_images[i]) 166 | ax.axis('off') 167 | ax.set_title(captions[i]) 168 | 169 | plt.tight_layout() 170 | plt.show() 171 | 172 | 173 | def load_image(image_path): 174 | image = Image.open(image_path) 175 | image = image.convert("RGB") 176 | transformation = get_transformation(mode='test') 177 | transformed_image = transformation(image) 178 | return torch.unsqueeze(transformed_image, 0) -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from experiment.visualize_run import basic_vis 3 | from utils.setup_logging import get_logger 4 | from utils.misc import set_seed,load_yaml 5 | import time 6 | 7 | logger = get_logger('Prompt_CAM') 8 | 9 | 10 | def main(): 11 | args = setup_parser().parse_args() 12 | 13 | if args.config: 14 | yaml_config = load_yaml(args.config) 15 | for key, value in yaml_config.items(): 16 | setattr(args, key, value) 17 | 18 | set_seed(args.random_seed) 19 | start = time.time() 20 | basic_vis(args) 21 | end = time.time() 22 | logger.info(f'----------- Total Run time : {(end - start) / 60} mins-----------') 23 | 24 | def setup_parser(): 25 | parser = argparse.ArgumentParser(description='Prompt_CAM') 26 | ######################## YAML Config ######################### 27 | parser.add_argument('--config', type=str, default=None, help='Path to YAML config file') 28 | 29 | ####################### Model ######################### 30 | parser.add_argument('--checkpoint', default=None, type=str, help='Path to the model checkpoint') 31 | 32 | ####################### Visualization Configuration ######################### 33 | parser.add_argument('--vis_attn', default=True, type=bool, help='visualize the attention map') 34 | parser.add_argument('--vis_cls', default=23, type=int, help='Class in the current Dataset to visualize') 35 | parser.add_argument('--nmbr_samples', default=10, type=int, help='Number of samples to visualize') 36 | parser.add_argument('--top_traits', default=4, type=int, help='Number of top traits per sample to visualize') 37 | 38 | parser.add_argument('--vis_outdir', default="./visualization", type=str, help='Output directory for visualization') 39 | ########################Misc######################### 40 | parser.add_argument('--gpu_num', default=1, 41 | type=int, 42 | help='Number of GPU (default: %(default)s)') 43 | parser.add_argument('--random_seed', default=42, 44 | type=int, 45 | help='Random seed (default: %(default)s)') 46 | 47 | return parser 48 | 49 | 50 | if __name__ == '__main__': 51 | main() 52 | --------------------------------------------------------------------------------