├── .gitignore
├── CITATION.cff
├── LICENSE
├── README.md
├── data
└── dataset
│ ├── birds_525.py
│ ├── car.py
│ ├── cub.py
│ ├── dog.py
│ ├── pet.py
│ └── utils.py
├── demo.ipynb
├── engine
├── optimizer.py
└── trainer.py
├── env_setup.sh
├── experiment
├── build_loader.py
├── build_model.py
├── config
│ └── prompt_cam
│ │ ├── dino
│ │ ├── birds_525
│ │ │ └── args.yaml
│ │ ├── car
│ │ │ └── args.yaml
│ │ ├── cub
│ │ │ └── args.yaml
│ │ ├── dog
│ │ │ └── args.yaml
│ │ └── pet
│ │ │ └── args.yaml
│ │ └── dinov2
│ │ ├── cub
│ │ └── args.yaml
│ │ ├── dog
│ │ └── args.yaml
│ │ └── pet
│ │ └── args.yaml
├── run.py
└── visualize_run.py
├── main.py
├── model
├── attention.py
├── block.py
├── mlp.py
├── patch_embed.py
├── utils.py
├── vision_transformer.py
└── vpt.py
├── samples
├── Baltimore_Oriole.jpg
├── Brewer_Blackbird.jpg
├── Orchard_Oriole.jpg
├── Scott_Oriole.jpg
├── red_winged_blackbird.jpg
├── rusty_Blackbird.jpg
├── sample_image.png
├── trait_manipulation.jpg
└── yellow_headed_blackbird.jpg
├── utils
├── distributed.py
├── file_io.py
├── global_var.py
├── log_utils.py
├── misc.py
├── setup_logging.py
└── visual_utils.py
└── visualize.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode/
2 | checkpoints/
3 | *.pyc
4 | **/__pycache__/
5 | output/
6 | pretrained_weights/
7 | data/images/
8 | visualization/
9 | data/annotations/
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | # This CITATION.cff file was generated with cffinit.
2 | # Visit https://bit.ly/cffinit to generate yours today!
3 |
4 | cff-version: 1.2.0
5 | title: >-
6 | Prompt-CAM: Making Vision Transformers Interpretable for
7 | Fine-Grained Analysis
8 | message: >-
9 | If you use this software, please cite it using the
10 | metadata from this file.
11 | type: software
12 | authors:
13 | - given-names: Arpita
14 | family-names: Chowdhury
15 | - given-names: Dipanjyoti
16 | family-names: Paul
17 | - given-names: Zheda
18 | family-names: Mai
19 | - given-names: Jianyang
20 | family-names: Gu
21 | - given-names: Ziheng
22 | family-names: Zhang
23 | - given-names: Kazi Sajeed
24 | family-names: Mehrab
25 | - given-names: Elizabeth G.
26 | family-names: Campolongo
27 | - given-names: Daniel
28 | family-names: Rubenstein
29 | - given-names: Charles V.
30 | family-names: Stewart
31 | - given-names: Anuj
32 | family-names: Karpatne
33 | - given-names: Tanya
34 | family-names: Berger-Wolf
35 | - given-names: Yu
36 | family-names: Su
37 | - given-names: Wei-Lun
38 | family-names: Chao
39 | identifiers:
40 | - type: url
41 | value: 'https://arxiv.org/pdf/2501.09333'
42 | repository-code: 'https://github.com/Imageomics/Prompt_CAM'
43 | abstract: >-
44 | We present a simple usage of pre-trained Vision
45 | Transformers (ViTs) for fine-grained analysis, aiming to
46 | identify and localize the traits that distinguish visually
47 | similar categories, such as different bird species or dog
48 | breeds.
49 | keywords:
50 | - explainable-ai
51 | - interpretable-ai
52 | - imageomics
53 | - fine-grained-classification
54 | - vision-transformer
55 | - interpretable
56 | license: MIT
57 | commit: Prompt-CAM
58 | version: 1.0.0
59 | date-released: '2025-03-24'
60 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Arpita Chowdhury
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # :mag: Prompt-CAM: Making Vision Transformers Interpretable for Fine-Grained Analysis(CVPR'25)
2 |
3 | This is an official implementation for [PROMPT-CAM: Making Vision Transformers Interpretable for Fine-Grained Analysis](https://arxiv.org/pdf/2501.09333) (CVPR'25)
4 |
5 | Introducing **Prompt-CAM**, a $${\textcolor{red}{\text{simple yet effective}}}$$ **interpretable transformer** that requires no architectural modifications to pre-trained ViTs, we just have to inject **class-specific prompts** into any ViT to make them interpretable.
6 |
7 | Prompt CAM lets us explore:
8 | - 🧠 What the model thinks is important for each class?
9 | - ✨ Which traits are shared between two bird species?
10 | - 🎨 How different classes ‘see’ the same image differently!
11 |
12 |
13 |
14 |
15 |
16 | ## Quick Start: Try out the demo
17 | 🔍 Ever wondered what traits stand out when a model looks at an image of one class but searches with another class in mind? 🤔
18 | Witness the important traits of different class through the lens of Prompt-CAM with our interactive demos!
19 |
20 | 👉 Try our demo **without installing anything** in Gooogle Colab [](https://colab.research.google.com/drive/1co1P5LXSVb-g0hqv8Selfjq4WGxSpIFe?usp=sharing)
21 |
22 | 👉 Try our demo locally in [](demo.ipynb)
24 | - Setup the [envoiroment](#environment-setup)
25 | - download the pre-trained model from link below!
26 | - run the demo.
27 |
28 |
29 | 👉 You can extend this code base to include: [New datasets](#to-add-a-new-dataset) and [New backbones](#to-add-a-new-backbone)
30 |
31 |
32 |
33 | ## Environment Setup
34 | ```bash
35 | conda create -n prompt_cam python=3.7
36 | conda activate prompt_cam
37 | source env_setup.sh
38 | ```
39 |
40 |
41 | ## Data Preparation
42 | You can put all the data in a folder and pass the path to `--data_path` argument.
43 |
44 | The structure of `data/images/`should be organized as follows:
45 |
46 | ```
47 | cub/
48 | ├── train/
49 | │ ├── 001.Black_footed_Albatross/
50 | │ │ ├── image_1.jpg
51 | │ │ ├── image_2.jpg
52 | │ │ └── ...
53 | │ ├── 002.Laysan_Albatross/
54 | │ │ ├── image_1.jpg
55 | │ │ ├── image_2.jpg
56 | │ │ └── ...
57 | │ └── ...
58 | └── val/
59 | ├── 001.Black_footed_Albatross/
60 | │ ├── image_1.jpg
61 | │ ├── image_2.jpg
62 | │ └── ...
63 | ├── 002.Laysan_Albatross/
64 | │ ├── image_1.jpg
65 | ```
66 |
67 |
68 | Prepare CUB dataset
69 |
70 | ## CUB
71 |
72 | - Download prepared dataset
73 | - From [](https://drive.google.com/drive/folders/1X3ikQEk_D7cKcyCnxbF3kJTsZ0LZfvVO?usp=sharing)
74 | - `Or` Prepare the dataset by yourself
75 | - You can download the CUB dataset from [the original website](https://www.vision.caltech.edu/datasets/cub_200_2011/) and put it in the `data/images/` folder.
76 | - You can use the dataset's provided train/val split to create the train/val splits and have their class numbers as the `prefix` of the respective image folder names(starting from 1).
77 | - The code will automatically create train and val annotation files in the `data/annotations/` folder for each dataset if not provided.
78 |
79 |
80 |
81 | Prepare Oxford Pet dataset
82 |
83 | ## Pet Dataset
84 | - Download prepared dataset
85 | - From [](https://drive.google.com/drive/folders/1X3ikQEk_D7cKcyCnxbF3kJTsZ0LZfvVO?usp=sharing)
87 |
88 |
89 | **To add new dataset, see [Extensions](#extensions)**
90 |
91 | ## Results + Checkpoints:
92 | - Download from the links below and put it in the `checkpoints/{model}/{dataset}/` folder.
93 |
94 | Backbone | Dataset | Prompt-CAM(Acc top%1) | Checkpoint Link|
95 | --- | --- | --- | --- |
96 | dino | cub (CUB)| 73.2 | [url](https://drive.google.com/drive/folders/1UmHdGx4OtWCQ1GhHCrBArQeeX14FqwyY?usp=sharing) |
97 | dino | car (Stanford Cars) | 83.2 | [url](https://drive.google.com/drive/folders/1UmHdGx4OtWCQ1GhHCrBArQeeX14FqwyY?usp=sharing) |
98 | dino | dog (Stanford Dogs) | 81.1 |[url](https://drive.google.com/drive/folders/1UmHdGx4OtWCQ1GhHCrBArQeeX14FqwyY?usp=sharing) |
99 | dino | pet (Oxford Pet) | 91.3 | [url](https://drive.google.com/drive/folders/1UmHdGx4OtWCQ1GhHCrBArQeeX14FqwyY?usp=sharing) |
100 | dino | birds_525 (Birds-525) | 98.8 | [url](https://drive.google.com/drive/folders/1UmHdGx4OtWCQ1GhHCrBArQeeX14FqwyY?usp=sharing) |
101 |
102 | Backbone | Dataset | Prompt-CAM(Acc top%1) | Checkpoint Link|
103 | --- | --- | --- | --- |
104 | dinov2 | cub (CUB) | 74.1 | [url](https://drive.google.com/drive/folders/1UmHdGx4OtWCQ1GhHCrBArQeeX14FqwyY?usp=sharing) |
105 | dinov2 | dog (Stanford Dogs) | 81.3| [url](https://drive.google.com/drive/folders/1UmHdGx4OtWCQ1GhHCrBArQeeX14FqwyY?usp=sharing) |
106 | dinov2 | pet (Oxford Pet) | 92.7 | [url](https://drive.google.com/drive/folders/1UmHdGx4OtWCQ1GhHCrBArQeeX14FqwyY?usp=sharing) |
107 |
108 | ## Evaluation and Visualization
109 | - download the checkpoint from url in the [Table](#results--checkpoints) above and put it in the `checkpoints/{model}/{dataset}/` folder.
110 |
111 | For example, to visualize the attention map of the DINO model on the class `024.Red_faced_Cormorant` of CUB dataset, put the checkpoint in `checkpoints/dino/cub/` folder and run the following command:
112 |
113 | ```python
114 | CUDA_VISIBLE_DEVICES=0 python visualize.py --config ./experiment/config/prompt_cam/dino/cub/args.yaml --checkpoint ./checkpoints/dino/cub/model.pt --vis_cls 23
115 | ```
116 |
117 | - The output will be saved in the `visualization/dino/cub/class_23/` folder.
118 | - Inside the individual image folder, there will be `top_traits` heatmaps for the target class concatenated if the prediction is correct. Otherwise, all the traits will be concatenated. (the prediction is for the respective image can be found `concatenated_prediction_{predicted_class}.jpg`).
119 |
120 | Visualization Configuration Meaning
121 |
122 | - `config`: path to the config file.
123 | - `checkpoint`: path to the checkpoint file.
124 | - `vis_cls`: class number to visualize. (default: 23)
125 | - `vis_attn`: set to True to visualize the attention map. (default: True)
126 | - `top_traits`: number of traits to visualize. (default: 4)
127 | - `nmbr_samples`: number of images from the `vis_cls to visualize. (default: 10)
128 | - `vis_outdir`: output directory. (default: visualization/)
129 |
130 |
131 |
132 |
133 | ## :fire: Training
134 |
135 | ### :one: Pretrained weights
136 | ---
137 |
138 | Download the pretrained weights from the following links and put them in the `pretrained_weights` folder.
139 | 1. [ViT-B-DINO](https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth) rename it as `dino_vitbase16_pretrain.pth`
140 | 2. [ViT-B-DINOV2](https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth) rename it as `dinov2_vitb14_pretrain.pth`
141 | ### :two: Load dataset
142 | ---
143 |
144 | See [Data Preparation](#data-preparation) above.
145 | ### :three: Start training
146 | ---
147 |
148 | 👉 To train the model on the `CUB dataset` using the `DINO` model, run the following command:
149 | ```python
150 | CUDA_VISIBLE_DEVICES=0 python main.py --config ./experiment/config/prompt_cam/dino/cub/args.yaml
151 | ```
152 | The checkpoint will be saved in the `output/vit_base_patch16_dino/cub/` folder. Copy the checkpoint `model.pt` to the `checkpoints/dino/cub/` folder.
153 |
154 | ---
155 |
156 | 👉 To train the model on the `Oxford Pet dataset` using the `DINO` model, run the following command:
157 | ```python
158 | CUDA_VISIBLE_DEVICES=0 python main.py --config ./experiment/config/prompt_cam/dino/pet/args.yaml
159 | ```
160 | The checkpoint will be saved in the `output/vit_base_patch14_dino/pet/` folder. Copy the checkpoint `model.pt` to the `checkpoints/dino/pet/` folder.
161 |
162 | ---
163 |
164 | 👉 To train the model on the `Oxford Pet dataset` using the `DINOv2` model, run the following command:
165 | ```python
166 | CUDA_VISIBLE_DEVICES=0 python main.py --config ./experiment/config/prompt_cam/dinov2/pet/args.yaml
167 | ```
168 |
169 | The checkpoint will be saved in the `output/vit_base_patch14_dinov2/pet/` folder. Copy the checkpoint `model.pt` to the `checkpoints/dinov2/pet/` folder.
170 |
171 | ---
172 |
173 | ### :four: :mag: Visualize the attention map
174 | ---
175 |
176 | See [Visualization](#evaluation-and-visualization) above.
177 |
178 | ## Extensions
179 | ### To add a new dataset
180 | 1. Prepare dataset using above [instructions](#data-preparation).
181 | 2. add a new dataset file in `/data/dataset`. [ Look at the existing dataset files for reference.](data/dataset/cub.py)
182 | 3. modify [build_loader.py](experiment/build_loader.py) to include the new dataset.
183 | 4. create a new config file in `experiment/config/prompt_cam/{model}/{dataset}/args.yaml`
184 | - See `experiment/config/prompt_cam/dino/cub/args.yaml` for reference and what to modify.
185 |
186 | ### To add a new backbone
187 | - modify `get_base_model()` in [build_model.py](experiment/build_model.py).
188 | - register the new backbone in [vision_transformer.py](model/vision_transformer.py) by creating a new function.
189 | - add another option in `--pretrained_weights` and `--model` in `setup_parser()` function of [main.py](main.py) to include the new backbone.
190 |
191 |
192 | # Citation [](https://arxiv.org/pdf/2501.09333)
193 | If you find this repository useful, please consider citing our work :pencil: and giving a star :star2: :
194 | ```
195 | @article{chowdhury2025prompt,
196 | title={Prompt-CAM: A Simpler Interpretable Transformer for Fine-Grained Analysis},
197 | author={Chowdhury, Arpita and Paul, Dipanjyoti and Mai, Zheda and Gu, Jianyang and Zhang, Ziheng and Mehrab, Kazi Sajeed and Campolongo, Elizabeth G and Rubenstein, Daniel and Stewart, Charles V and Karpatne, Anuj and others},
198 | journal={arXiv preprint arXiv:2501.09333},
199 | year={2025}
200 | }
201 | ```
202 | ### Acknowledgement
203 |
204 | - VPT: https://github.com/KMnP/vpt
205 | - PETL_VISION: https://github.com/OSU-MLB/PETL_Vision
206 |
207 | Thanks for their wonderful works.
208 |
209 | 🛠 create an issue for any contributions.
210 |
--------------------------------------------------------------------------------
/data/dataset/birds_525.py:
--------------------------------------------------------------------------------
1 | from torchvision.datasets.folder import ImageFolder, default_loader
2 | from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
3 | import torchvision.transforms as transforms
4 | from timm.data.transforms import str_to_interp_mode
5 | from data.dataset.utils import add_samples, create_annotation_file, get_transformation
6 | import os
7 |
8 | class Birds_525_Dataset(ImageFolder): ## For new data ## Change this:: the name of the class to match the dataset name
9 | def __init__(self, root, data_list, transform=None):
10 | self.data_root = root
11 | self.loader = default_loader
12 | self.transform = transform
13 | self.target_transform = None
14 | self.samples = []
15 |
16 | add_samples(self.samples, data_list, root)
17 |
18 |
19 | def get_birds_525(params, mode='trainval_combined'): ## For new data ## Change this:: the name of the dataset name
20 | params.class_num = 525 ## For new data ## Change this:: the number of classes in the dataset
21 | mean, std = IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
22 |
23 | transform_train = get_transformation('train', mean, std)
24 | transform_val = get_transformation('val', mean, std)
25 |
26 | if mode == 'trainval_combined':
27 | train_data_list = f'data/annotations/birds_525/{params.data}_combine.txt' ## For new data ## Change this:: the name of the data in the path
28 |
29 | if not os.path.exists(train_data_list):
30 | create_annotation_file(params.data_path, ['train'],train_data_list)
31 | return Birds_525_Dataset(params.data_path, train_data_list, transform_train) ## For new data ## Change this:: the class to call
32 |
33 | elif mode == 'test':
34 | test_data_list = f'data/annotations/birds_525/test.txt' ## For new data ## Change this:: the name of the data in the path
35 | if not os.path.exists(test_data_list):
36 | create_annotation_file(params.data_path,['val'],test_data_list)
37 | return Birds_525_Dataset(params.data_path, test_data_list, transform_val) ## For new data ## Change this:: the class to call
38 | else:
39 | raise NotImplementedError
40 |
41 |
42 |
--------------------------------------------------------------------------------
/data/dataset/car.py:
--------------------------------------------------------------------------------
1 | from torchvision.datasets.folder import ImageFolder, default_loader
2 | from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
3 | import torchvision.transforms as transforms
4 | from timm.data.transforms import str_to_interp_mode
5 | from data.dataset.utils import add_samples, create_annotation_file, get_transformation
6 | import os
7 |
8 | class CarDataset(ImageFolder):
9 | def __init__(self, root, data_list, transform=None):
10 | self.data_root = root
11 | self.loader = default_loader
12 | self.transform = transform
13 | self.target_transform = None
14 | self.samples = []
15 |
16 | add_samples(self.samples, data_list, root)
17 |
18 |
19 | def get_car(params, mode='trainval_combined'):
20 | params.class_num = 196
21 | mean, std = IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
22 |
23 | transform_train = get_transformation('train', mean, std)
24 | transform_val = get_transformation('val', mean, std)
25 |
26 | if mode == 'trainval_combined':
27 | train_data_list = f'data/annotations/car/{params.data}_combine.txt'
28 |
29 | if not os.path.exists(train_data_list):
30 | create_annotation_file(params.data_path, ['train'],train_data_list)
31 | return CarDataset(params.data_path, train_data_list, transform_train)
32 |
33 | elif mode == 'test':
34 | test_data_list = f'data/annotations/car/test.txt'
35 | if not os.path.exists(test_data_list):
36 | create_annotation_file(params.data_path,['val'],test_data_list)
37 | return CarDataset(params.data_path, test_data_list, transform_val)
38 | else:
39 | raise NotImplementedError
40 |
41 |
42 |
--------------------------------------------------------------------------------
/data/dataset/cub.py:
--------------------------------------------------------------------------------
1 | from torchvision.datasets.folder import ImageFolder, default_loader
2 | from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
3 | import torchvision.transforms as transforms
4 | from timm.data.transforms import str_to_interp_mode
5 | from data.dataset.utils import add_samples, create_annotation_file, get_transformation
6 | import os
7 |
8 | class CubDataset(ImageFolder): ## For new data ## Change this:: the name of the class to match the dataset name
9 | def __init__(self, root, data_list, transform=None):
10 | self.data_root = root
11 | self.loader = default_loader
12 | self.transform = transform
13 | self.target_transform = None
14 | self.samples = []
15 |
16 | add_samples(self.samples, data_list, root)
17 |
18 |
19 | def get_cub(params, mode='trainval_combined'): ## For new data ## Change this:: the name of the dataset name
20 | params.class_num = 200 ## For new data ## Change this:: the number of classes in the dataset
21 | mean, std = IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
22 |
23 | transform_train = get_transformation('train', mean, std)
24 | transform_val = get_transformation('val', mean, std)
25 |
26 | if mode == 'trainval_combined':
27 | train_data_list = f'data/annotations/cub/{params.data}_combine.txt' ## For new data ## Change this:: the name of the data in the path
28 |
29 | if not os.path.exists(train_data_list):
30 | create_annotation_file(params.data_path, ['train'],train_data_list)
31 | return CubDataset(params.data_path, train_data_list, transform_train) ## For new data ## Change this:: the class to call
32 |
33 | elif mode == 'test':
34 | test_data_list = f'data/annotations/cub/test.txt' ## For new data ## Change this:: the name of the data in the path
35 | if not os.path.exists(test_data_list):
36 | create_annotation_file(params.data_path,['val'],test_data_list)
37 | return CubDataset(params.data_path, test_data_list, transform_val) ## For new data ## Change this:: the class to call
38 | else:
39 | raise NotImplementedError
40 |
41 |
42 |
--------------------------------------------------------------------------------
/data/dataset/dog.py:
--------------------------------------------------------------------------------
1 | from torchvision.datasets.folder import ImageFolder, default_loader
2 | from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
3 | import torchvision.transforms as transforms
4 | from timm.data.transforms import str_to_interp_mode
5 | from data.dataset.utils import add_samples, create_annotation_file, get_transformation
6 | import os
7 |
8 | class DogDataset(ImageFolder):
9 | def __init__(self, root, data_list, transform=None):
10 | self.data_root = root
11 | self.loader = default_loader
12 | self.transform = transform
13 | self.target_transform = None
14 | self.samples = []
15 |
16 | add_samples(self.samples, data_list, root)
17 |
18 |
19 | def get_dog(params, mode='trainval_combined'):
20 | params.class_num = 120
21 | mean, std = IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
22 |
23 | transform_train = get_transformation('train', mean, std)
24 | transform_val = get_transformation('val', mean, std)
25 |
26 | if mode == 'trainval_combined':
27 | train_data_list = f'data/annotations/dog/{params.data}_combine.txt'
28 |
29 | if not os.path.exists(train_data_list):
30 | create_annotation_file(params.data_path, ['train'],train_data_list)
31 | return DogDataset(params.data_path, train_data_list, transform_train)
32 |
33 | elif mode == 'test':
34 | test_data_list = f'data/annotations/dog/test.txt'
35 | if not os.path.exists(test_data_list):
36 | create_annotation_file(params.data_path,['val'],test_data_list)
37 | return DogDataset(params.data_path, test_data_list, transform_val)
38 | else:
39 | raise NotImplementedError
40 |
41 |
42 |
--------------------------------------------------------------------------------
/data/dataset/pet.py:
--------------------------------------------------------------------------------
1 | from torchvision.datasets.folder import ImageFolder, default_loader
2 | from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
3 | import torchvision.transforms as transforms
4 | from timm.data.transforms import str_to_interp_mode
5 | from data.dataset.utils import add_samples, create_annotation_file, get_transformation
6 | import os
7 |
8 | class PetDataset(ImageFolder):
9 | def __init__(self, root, data_list, transform=None):
10 | self.data_root = root
11 | self.loader = default_loader
12 | self.transform = transform
13 | self.target_transform = None
14 | self.samples = []
15 |
16 | add_samples(self.samples, data_list, root)
17 |
18 |
19 | def get_pet(params, mode='trainval_combined'):
20 | params.class_num = 37
21 | mean, std = IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
22 |
23 | transform_train = get_transformation('train', mean, std)
24 | transform_val = get_transformation('val', mean, std)
25 |
26 | if mode == 'trainval_combined':
27 | train_data_list = f'data/annotations/pet/{params.data}_combine.txt'
28 |
29 | if not os.path.exists(train_data_list):
30 | create_annotation_file(params.data_path, ['train'],train_data_list)
31 | return PetDataset(params.data_path, train_data_list, transform_train)
32 |
33 | elif mode == 'test':
34 | test_data_list = f'data/annotations/pet/test.txt'
35 | if not os.path.exists(test_data_list):
36 | create_annotation_file(params.data_path,['val'],test_data_list)
37 | return PetDataset(params.data_path, test_data_list, transform_val)
38 | else:
39 | raise NotImplementedError
40 |
41 |
42 |
--------------------------------------------------------------------------------
/data/dataset/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import torchvision as tv
4 |
5 | import torch.utils.data as data
6 | from PIL import Image
7 | from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
8 |
9 | def add_samples(sample_list, list_path, root):
10 | with open(list_path, 'r') as f:
11 | for line in f:
12 | img_name = line.rsplit(' ', 1)[0]
13 | label = int(line.rsplit(' ', 1)[1])
14 | sample_list.append((os.path.join(root, img_name), label))
15 |
16 | def default_loader(path):
17 | return Image.open(path).convert('RGB')
18 |
19 |
20 | def default_flist_reader(flist):
21 | """
22 | flist format: impath label\nimpath label\n ...(same to caffe's filelist)
23 | """
24 | imlist = []
25 | with open(flist, 'r') as rf:
26 | for line in rf.readlines():
27 | impath, imlabel = line.strip().split()
28 | imlist.append((impath, int(imlabel)))
29 |
30 | return imlist
31 |
32 |
33 | class ImageFilelist(data.Dataset):
34 | def __init__(self, root, flist, name, transform=None, target_transform=None,
35 | flist_reader=default_flist_reader, loader=default_loader):
36 | self.root = root
37 | self.name = name
38 | self.imlist = flist_reader(flist)
39 | self.transform = transform
40 | self.target_transform = target_transform
41 | self.loader = loader
42 |
43 | def __getitem__(self, index):
44 | impath, target = self.imlist[index]
45 | img = self.loader(os.path.join(self.root, impath))
46 | if self.transform is not None:
47 | img = self.transform(img)
48 | if self.target_transform is not None:
49 | target = self.target_transform(target)
50 |
51 | return img, target
52 |
53 | def __len__(self):
54 | return len(self.imlist)
55 |
56 | def get_transformation(mode, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD):
57 | if mode == 'train':
58 | return tv.transforms.Compose([
59 | tv.transforms.Resize((240,240)),
60 | tv.transforms.RandomCrop((224,224)),
61 | tv.transforms.RandomHorizontalFlip(),
62 | tv.transforms.ToTensor(),
63 | tv.transforms.Normalize(mean, std),
64 | ])
65 | elif mode == 'val' or mode == 'test':
66 | return tv.transforms.Compose([
67 | tv.transforms.Resize((224,224)),
68 | tv.transforms.ToTensor(),
69 | tv.transforms.Normalize(mean, std),
70 | ])
71 | else:
72 | raise ValueError("Invalid mode. Use 'train' or 'val' or 'test'.")
73 |
74 | def create_annotation_file(data_path, mode, output_file):
75 | """
76 | Creates a file listing images and their corresponding labels based on the directory structure.
77 |
78 | Args:
79 | data_path: The root directory containing the dataset (e.g., dataset_name/).
80 | output_file: The path to the file where the image list will be written.
81 | """
82 |
83 | # Create the output directory if it doesn't exist
84 | output_dir = os.path.dirname(output_file) # Get the directory part of the path
85 | if output_dir and not os.path.exists(output_dir):
86 | os.makedirs(output_dir)
87 |
88 | image_list = []
89 | label_map = {} # Store species names and corresponding numerical labels
90 |
91 | # Assign numerical labels to species names in the order they're encountered
92 | label_counter = 0
93 |
94 | for split in mode: # Iterate through train and val splits
95 | split_path = os.path.join(data_path, split)
96 | if not os.path.exists(split_path):
97 | continue # Skip if split does not exist
98 |
99 | for species_dir in sorted(os.listdir(split_path)): # Iterate through species directories
100 | if not os.path.isdir(os.path.join(split_path, species_dir)):
101 | continue # Skip if not a directory
102 | species_name = species_dir.split('.', 1)[1] if '.' in species_dir else species_dir # extract species name
103 | if species_name not in label_map:
104 | label_map[species_name] = label_counter
105 | label_counter += 1
106 | label = label_map[species_name]
107 |
108 | image_dir = os.path.join(split_path, species_dir)
109 | for image_file in os.listdir(image_dir): # Iterate through image files
110 | if os.path.isfile(os.path.join(image_dir, image_file)): # ensure it is a file
111 | image_path = os.path.join(split, species_dir, image_file) # Relative path from dataset root
112 | image_list.append(f"{image_path} {label}")
113 |
114 | # Write the image list to the output file
115 | with open(output_file, "w") as f:
116 | f.write("\n".join(image_list))
117 |
--------------------------------------------------------------------------------
/engine/optimizer.py:
--------------------------------------------------------------------------------
1 | """
2 | In VPT repo, they handle bias term very carefully when weight decay > 0 for different optimizers.
3 | They also used a special implementation of Adamw from huggingface with weight decay fix. Check their repo's optimizer.py
4 | for additional information.
5 | We skip them for now as AdaptFormer and ConvPass skip this as either
6 |
7 | """
8 |
9 | import torch.optim as optim
10 | from typing import Any, Callable, Iterable, List, Tuple, Optional
11 |
12 | from utils.setup_logging import get_logger
13 |
14 | logger = get_logger("Prompt_CAM")
15 |
16 |
17 | def make_optimizer(tune_parameters, params):
18 | if params.optimizer == 'adam':
19 | optimizer = optim.Adam(
20 | tune_parameters,
21 | lr=params.lr,
22 | weight_decay=params.wd,
23 | )
24 |
25 | elif params.optimizer == 'adamw':
26 | optimizer = optim.AdamW(
27 | tune_parameters,
28 | lr=params.lr,
29 | weight_decay=params.wd,
30 | )
31 | else:
32 | optimizer = optim.SGD(
33 | tune_parameters,
34 | lr=params.lr,
35 | weight_decay=params.wd,
36 | momentum=params.momentum,
37 | )
38 | return optimizer
39 |
40 |
41 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/engine/trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import torch.nn as nn
4 | from timm.scheduler.cosine_lr import CosineLRScheduler
5 | from collections import OrderedDict
6 | from engine.optimizer import make_optimizer
7 | from utils.misc import AverageMeter, EarlyStop
8 | from utils.setup_logging import get_logger
9 | from timm.utils import accuracy, update_summary
10 | import numpy as np
11 | logger = get_logger("Prompt_CAM")
12 | torch.backends.cudnn.benchmark = False
13 |
14 | class Trainer():
15 | """
16 | a trainer with below logics:
17 |
18 | 1. Build optimizer, scheduler
19 | 2. Load checkpoints if provided
20 | 3. Train and eval at each epoch
21 | """
22 |
23 | def __init__(
24 | self,
25 | model, tune_parameters, params
26 | ) -> None:
27 | self.params = params
28 | self.model = model
29 | self.device = params.device
30 | self.cls_criterion = nn.CrossEntropyLoss()
31 |
32 | if 'test_data' not in params:
33 | # solver related
34 | logger.info("\tSetting up the optimizer...")
35 | self.optimizer = make_optimizer(tune_parameters, params)
36 | self.scheduler = CosineLRScheduler(self.optimizer, t_initial=params.epoch,
37 | warmup_t=params.warmup_epoch, lr_min=params.lr_min,
38 | warmup_lr_init=params.warmup_lr_init)
39 | self.total_epoch = self.params.epoch
40 | if self.params.early_patience > 0:
41 | self.early_stop_check = EarlyStop(self.params.early_patience)
42 |
43 | def forward_one_batch(self, samples, targets, is_train):
44 | """Train a single (full) epoch on the model using the given
45 | data loader.
46 |
47 | Args:
48 | samples
49 | targets
50 | is_train: bool
51 | Returns:
52 | loss
53 | outputs: output logits
54 | """
55 | # move data to device
56 | samples = samples.to(self.device, non_blocking=True) # (batchsize, 2048)
57 | targets = targets.to(self.device, non_blocking=True) # (batchsize, )
58 |
59 | # forward
60 | with torch.set_grad_enabled(is_train):
61 | outputs,_ = self.model(samples) # (batchsize, num_cls)
62 | if self.params.train_type == 'prompt_cam':
63 | outputs = outputs.squeeze(-1)
64 | loss = self.cls_criterion(outputs, targets)
65 |
66 | if loss == float('inf'):
67 | logger.info(
68 | "encountered infinite loss, skip gradient updating for this batch!"
69 | )
70 | return -1, -1, (-1, -1)
71 | elif torch.isnan(loss).any():
72 | logger.info(
73 | "encountered nan loss, skip gradient updating for this batch!"
74 | )
75 | return -1, -1, (-1, -1)
76 |
77 | acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
78 | # =======backward and optim step only if in training phase... =========
79 | if is_train:
80 | self.optimizer.zero_grad()
81 | loss.backward()
82 | #for name, param in self.model.named_parameters():
83 | # if param.grad is not None:
84 | # print(f"After backward: {name} has grad with mean {param.grad.mean().item()}")
85 |
86 | self.optimizer.step()
87 |
88 | return loss, outputs, (acc1, acc5)
89 |
90 | def train_one_epoch(self, epoch, loader):
91 | loss_m = AverageMeter()
92 | top1_m = AverageMeter()
93 | top5_m = AverageMeter()
94 | lr = self.scheduler._get_lr(epoch)
95 | logger.info(
96 | "Training {} / {} epoch, with learning rate {}".format(
97 | epoch + 1, self.total_epoch, lr
98 | )
99 | )
100 | # Enable training mode
101 | self.model.train()
102 |
103 | num_updates = epoch * len(loader)
104 | for idx, (samples, targets) in enumerate(loader):
105 | train_loss, _, (acc1, acc5) = self.forward_one_batch(samples, targets, True)
106 | if not isinstance(train_loss, int):
107 | loss_m.update(train_loss.item(), samples.shape[0])
108 | top1_m.update(acc1.item(), samples.shape[0])
109 | top5_m.update(acc5.item(), samples.shape[0])
110 | del train_loss, acc1, acc5, _, samples, targets
111 | num_updates += 1
112 | self.scheduler.step_update(num_updates=num_updates, metric=loss_m.avg)
113 |
114 | logger.info(
115 | "Epoch {} / {}: ".format(epoch + 1, self.total_epoch)
116 | + "average train loss: {:.2f}, ".format(loss_m.avg)
117 | + "average train top1: {:.2f} ".format(top1_m.avg)
118 | + "average train top5: {:.2f}".format(top5_m.avg))
119 |
120 | return OrderedDict(
121 | [('loss', round(loss_m.avg, 2)), ('top1', round(top1_m.avg, 2)), ('top5', round(top5_m.avg, 2))])
122 |
123 | def train_classifier(self, train_loader, val_loader, test_loader):
124 | """
125 | Train a classifier using epoch
126 | """
127 |
128 | for epoch in range(self.total_epoch):
129 | train_metrics = self.train_one_epoch(epoch, train_loader)
130 |
131 | if (epoch % self.params.eval_freq == 0) or epoch == self.total_epoch - 1:
132 | if test_loader is not None:
133 | eval_metrics = self.eval_classifier(
134 | test_loader, "test")
135 | elif val_loader is not None:
136 | eval_metrics = self.eval_classifier(val_loader, "val")
137 | else:
138 | raise Exception('Both val and test loaders are missing. ')
139 |
140 | if self.params.early_patience > 0:
141 | stop, save_model = self.early_stop_check.early_stop(eval_metrics)
142 | if save_model and self.params.store_ckp:
143 | torch.save({'model_state_dict': self.model.state_dict()},
144 | os.path.join(self.params.output_dir, 'model.pt'))
145 | if stop:
146 | return train_metrics, self.early_stop_check.max_metrics, eval_metrics
147 | if self.params.debug:
148 | update_summary(
149 | epoch, train_metrics, eval_metrics, os.path.join(self.params.output_dir, 'summary.csv'),
150 | write_header=epoch == 0)
151 | self.scheduler.step(epoch)
152 |
153 | if self.params.store_ckp and not os.path.isfile(os.path.join(self.params.output_dir, 'model.pt')):
154 | torch.save({'model_state_dict': self.model.state_dict()}, os.path.join(self.params.output_dir, 'model.pt'))
155 | return train_metrics, self.early_stop_check.max_metrics, eval_metrics
156 |
157 | @torch.no_grad()
158 | def eval_classifier(self, loader, prefix):
159 | """evaluate classifier"""
160 |
161 | loss_m = AverageMeter()
162 | top1_m = AverageMeter()
163 | top5_m = AverageMeter()
164 |
165 | # Enable eval mode
166 | self.model.eval()
167 |
168 | with torch.no_grad():
169 | for batch_idx, (samples, targets) in enumerate(loader):
170 | loss, outputs, (acc1, acc5) = self.forward_one_batch(samples, targets, False)
171 | if not isinstance(loss, int):
172 | loss_m.update(loss.item(), samples.shape[0])
173 | top1_m.update(acc1.item(), samples.shape[0])
174 | top5_m.update(acc5.item(), samples.shape[0])
175 | del loss, outputs, acc1, acc5
176 | logger.info(
177 | f"Inference ({prefix}):"
178 | + "average loss: {:.2f}, ".format(loss_m.avg)
179 | + "average top1: {:.2f} ".format(top1_m.avg)
180 | + "average top5: {:.2f}".format(top5_m.avg))
181 | return OrderedDict(
182 | [('loss', round(loss_m.avg, 2)), ('top1', round(top1_m.avg, 2)), ('top5', round(top5_m.avg, 2))])
183 |
184 | def load_weight(self):
185 | self.model.load_state_dict(torch.load(self.params.output_dir + '/model.pt')['model_state_dict'])
186 |
187 | @torch.no_grad()
188 | def collect_logits(self, loader):
189 | self.model.eval()
190 | all_logits = []
191 | gt = []
192 | with torch.no_grad():
193 | for batch_idx, (samples, targets) in enumerate(loader):
194 | loss, outputs, (acc1, acc5) = self.forward_one_batch(samples, targets, False)
195 | all_logits.append(outputs.cpu().detach().numpy())
196 | gt.append(targets.cpu().detach().numpy())
197 | return np.concatenate(all_logits, axis=0), np.concatenate(gt, axis=0)
--------------------------------------------------------------------------------
/env_setup.sh:
--------------------------------------------------------------------------------
1 |
2 | pip install torch==1.13.0+cu117 torchvision==0.14.0+cu117 --extra-index-url https://download.pytorch.org/whl/cu117 --no-cache-dir
3 |
4 | ## timm
5 | pip install timm==0.9.12 --no-cache-dir
6 | #
7 | ###VTAB
8 | pip install tensorflow==2.11.0 --no-cache-dir
9 | # specifying tfds versions is important to reproduce our results
10 | pip install tfds-nightly==4.4.0.dev202201080107 --no-cache-dir
11 | pip install tensorflow-addons==0.19.0 --no-cache-dir
12 | pip install opencv-python --no-cache-dir
13 |
14 | ## CLIP
15 | pip install git+https://github.com/openai/CLIP.git --no-cache-dir
16 |
17 | ####utils
18 | pip install dotwiz --no-cache-dir
19 | pip install pyyaml --no-cache-dir
20 | pip install tabulate --no-cache-dir
21 | pip install termcolor --no-cache-dir
22 | pip install iopath --no-cache-dir
23 | pip install scikit-learn --no-cache-dir
24 |
25 | pip install ftfy regex tqdm --no-cache-dir
26 | pip install pandas --no-cache-dir
27 | pip install matplotlib --no-cache-dir
28 | pip install ipykernel --no-cache-dir
29 |
--------------------------------------------------------------------------------
/experiment/build_loader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from data.dataset.cub import get_cub
3 | from data.dataset.dog import get_dog
4 | from data.dataset.pet import get_pet
5 | from data.dataset.car import get_car
6 | from data.dataset.birds_525 import get_birds_525
7 |
8 |
9 | def get_dataset(data, params, logger):
10 | dataset_train, dataset_val, dataset_test = None, None, None
11 |
12 | if data.startswith("cub"):
13 | logger.info("Loading CUB data ...")
14 | if params.final_run:
15 | logger.info("Loading training data (final training data for cub)...")
16 | dataset_train = get_cub(params, 'trainval_combined')
17 | dataset_test = get_cub(params, 'test')
18 | else:
19 | raise NotImplementedError
20 | elif data.startswith("dog"):
21 | logger.info("Loading Standford Dogs data ...")
22 | if params.final_run:
23 | logger.info("Loading training data (final training data for dog)...")
24 | dataset_train = get_dog(params, 'trainval_combined')
25 | dataset_test = get_dog(params, 'test')
26 | else:
27 | raise NotImplementedError
28 | elif data.startswith("pet"):
29 | logger.info("Loading Oxford Pet data ...")
30 | if params.final_run:
31 | logger.info("Loading training data (final training data for pet)...")
32 | dataset_train = get_pet(params, 'trainval_combined')
33 | dataset_test = get_pet(params, 'test')
34 | else:
35 | raise NotImplementedError
36 | elif data.startswith("car"):
37 | logger.info("Loading Stanford Car data ...")
38 | if params.final_run:
39 | logger.info("Loading training data (final training data for car)...")
40 | dataset_train = get_car(params, 'trainval_combined')
41 | dataset_test = get_car(params, 'test')
42 | else:
43 | raise NotImplementedError
44 | elif data.startswith("birds_525"):
45 | logger.info("Loading Birds 525 data ...")
46 | if params.final_run:
47 | logger.info("Loading training data (final training data for birds_525)...")
48 | dataset_train = get_birds_525(params, 'trainval_combined')
49 | dataset_test = get_birds_525(params, 'test')
50 | else:
51 | raise NotImplementedError
52 | else:
53 | raise Exception("Dataset '{}' not supported".format(data))
54 | return dataset_train, dataset_val, dataset_test
55 |
56 |
57 | def get_loader(params, logger):
58 | if 'test_data' in params:
59 | dataset_train, dataset_val, dataset_test = get_dataset(params.test_data, params, logger)
60 | else:
61 | dataset_train, dataset_val, dataset_test = get_dataset(params.data, params, logger)
62 |
63 | if isinstance(dataset_train, list):
64 | train_loader, val_loader, test_loader = [], [], []
65 | for i in range(len(dataset_train)):
66 | tmp_train, tmp_val, tmp_test = gen_loader(params, dataset_train[i], dataset_val[i], None)
67 | train_loader.append(tmp_train)
68 | val_loader.append(tmp_val)
69 | test_loader.append(tmp_test)
70 | else:
71 | train_loader, val_loader, test_loader = gen_loader(params, dataset_train, dataset_val, dataset_test)
72 |
73 | logger.info("Finish setup loaders")
74 | return train_loader, val_loader, test_loader
75 |
76 |
77 | def gen_loader(params, dataset_train, dataset_val, dataset_test):
78 | train_loader, val_loader, test_loader = None, None, None
79 | if params.debug:
80 | num_workers = 1
81 | else:
82 | num_workers = 4
83 | if dataset_train is not None:
84 | train_loader = torch.utils.data.DataLoader(
85 | dataset_train,
86 | batch_size=params.batch_size,
87 | shuffle=True,
88 | num_workers=num_workers,
89 | pin_memory=True,
90 | drop_last=True
91 | )
92 | if dataset_val is not None:
93 | val_loader = torch.utils.data.DataLoader(
94 | dataset_val,
95 | batch_size=params.test_batch_size,
96 | shuffle=False,
97 | num_workers=num_workers,
98 | pin_memory=True
99 | )
100 | if dataset_test is not None:
101 | test_loader = torch.utils.data.DataLoader(
102 | dataset_test,
103 | batch_size=params.test_batch_size,
104 | shuffle=False,
105 | num_workers=num_workers,
106 | pin_memory=True
107 |
108 | )
109 | return train_loader, val_loader, test_loader
110 |
--------------------------------------------------------------------------------
/experiment/build_model.py:
--------------------------------------------------------------------------------
1 | from tkinter.constants import RAISED
2 |
3 | import timm
4 | import torch
5 | from model.vision_transformer import VisionTransformerPETL
6 | from utils.log_utils import log_model_info
7 | from timm.data import resolve_data_config
8 | from utils.setup_logging import get_logger
9 |
10 | logger = get_logger("Prompt_CAM")
11 |
12 | TUNE_MODULES = ['vpt']
13 | def get_model(params,visualize=False):
14 | params.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15 | print(f"Using device: {params.device}")
16 |
17 | model = get_base_model(params,visualize=visualize)
18 |
19 | ##########
20 | tune_parameters = []
21 | if params.debug:
22 | logger.info("Trainable params:")
23 |
24 | for name, parameter in model.named_parameters():
25 | if any(m in name for m in TUNE_MODULES):
26 | parameter.requires_grad = True
27 | tune_parameters.append(parameter)
28 | if params.debug:
29 | logger.info("\t{}, {}, {}".format(name, parameter.numel(), parameter.shape))
30 | else:
31 | parameter.requires_grad = False
32 |
33 | model_grad_params_no_head = log_model_info(model, logger)
34 |
35 | if not visualize:
36 | model = model.cuda(device=params.device)
37 | return model, tune_parameters, model_grad_params_no_head
38 |
39 |
40 | def get_base_model(params,visualize=False):
41 | if params.pretrained_weights == "vit_base_patch16_224_in21k":
42 | params.patch_size = 16
43 | model = timm.create_model("vit_base_patch16_224_in21k_petl", drop_path_rate=params.drop_path_rate,
44 | pretrained=False, params=params)
45 | if not visualize:
46 | model.load_pretrained(
47 | 'pretrained_weights/ViT-B_16_in21k.npz')
48 | model.reset_classifier(params.class_num)
49 | elif params.pretrained_weights == "vit_base_mae":
50 | model = timm.create_model("vit_base_patch16_224_in21k_petl", drop_path_rate=params.drop_path_rate,
51 | pretrained=False,
52 | params=params)
53 | if not visualize:
54 | model.load_pretrained(
55 | 'pretrained_weights/mae_pretrain_vit_base.pth')
56 | model.reset_classifier(params.class_num)
57 | elif params.pretrained_weights == "vit_base_patch14_dinov2":
58 | params.patch_size = 14
59 | model = timm.create_model("vit_base_patch14_dinov2_petl", drop_path_rate=params.drop_path_rate,
60 | pretrained=False,
61 | params=params)
62 | if not visualize:
63 | model.load_pretrained(
64 | 'pretrained_weights/dinov2_vitb14_pretrain.pth')
65 | model.reset_classifier(params.class_num)
66 | elif params.pretrained_weights == "vit_base_patch16_dino":
67 | model = timm.create_model("vit_base_patch16_dino_petl", drop_path_rate=params.drop_path_rate,
68 | pretrained=False,
69 | params=params)
70 | if not visualize:
71 | model.load_pretrained(
72 | 'pretrained_weights/dino_vitbase16_pretrain.pth')
73 | model.reset_classifier(params.class_num)
74 | elif params.pretrained_weights == 'vit_base_patch16_clip_224':
75 | params.patch_size = 16
76 | model = timm.create_model("vit_base_patch16_clip_224_petl", drop_path_rate=params.drop_path_rate,
77 | pretrained=False,
78 | params=params)
79 | if not visualize:
80 | model.load_pretrained(
81 | 'pretrained_weights/ViT-B_16_clip.bin')
82 |
83 | fc = init_imagenet_clip(params.device)
84 | proj = get_clip_proj(params.device)
85 | model.head = torch.nn.Sequential(*[proj, fc])
86 | else:
87 | raise NotImplementedError
88 |
89 | # data_config = resolve_data_config(vars(params), model=model, verbose=False)
90 |
91 | return model
92 |
--------------------------------------------------------------------------------
/experiment/config/prompt_cam/dino/birds_525/args.yaml:
--------------------------------------------------------------------------------
1 | batch_size: 16
2 | crop_size: 224
3 | data: birds_525
4 | data_path: ./data/images/birds_525
5 | debug: true
6 | drop_path_rate: 0.1
7 | early_patience: 101
8 | epoch: 100
9 | eval_freq: 5
10 | final_acc_hp: true
11 | final_run: true
12 | full: false
13 | gpu_num: 1
14 | lr: 0.01
15 | lr_min: 1.0e-06
16 | model: dino
17 | momentum: 0.9
18 | normalized: true
19 | optimizer: sgd
20 | pretrained_weights: vit_base_patch16_dino
21 | random_seed: 42
22 | store_ckp: true
23 | test_batch_size: 32
24 | train_type: prompt_cam
25 | vpt_dropout: 0.1
26 | vpt_layer: null
27 | vpt_mode: null
28 | vpt_num: 525
29 | warmup_epoch: 20
30 | warmup_lr_init: 0
31 | wd: 0.001
32 |
--------------------------------------------------------------------------------
/experiment/config/prompt_cam/dino/car/args.yaml:
--------------------------------------------------------------------------------
1 | batch_size: 16
2 | crop_size: 224
3 | data: car
4 | data_path: ./data/images/car
5 | debug: true
6 | drop_path_rate: 0.1
7 | early_patience: 101
8 | epoch: 100
9 | eval_freq: 10
10 | final_acc_hp: true
11 | final_run: true
12 | full: false
13 | gpu_num: 1
14 | lr: 0.01
15 | lr_min: 1.0e-06
16 | model: dino
17 | momentum: 0.9
18 | normalized: true
19 | optimizer: sgd
20 | pretrained_weights: vit_base_patch16_dino
21 | random_seed: 42
22 | store_ckp: true
23 | test_batch_size: 32
24 | train_type: prompt_cam
25 | vpt_dropout: 0.1
26 | vpt_layer: null
27 | vpt_mode: null
28 | vpt_num: 196
29 | warmup_epoch: 20
30 | warmup_lr_init: 0
31 | wd: 0.001
32 |
--------------------------------------------------------------------------------
/experiment/config/prompt_cam/dino/cub/args.yaml:
--------------------------------------------------------------------------------
1 | batch_size: 16
2 | crop_size: 224
3 | data: cub #Change this for new dataset
4 | data_path: ./data/images/cub #Change this for new dataset
5 | debug: true
6 | drop_path_rate: 0.1
7 | early_patience: 101
8 | epoch: 130 #Tune this for new dataset
9 | eval_freq: 10
10 | final_acc_hp: true
11 | final_run: true
12 | full: false
13 | gpu_num: 1
14 | lr: 0.005 #Tune this for new dataset
15 | lr_min: 1.0e-06
16 | model: dino
17 | momentum: 0.9
18 | normalized: true
19 | optimizer: sgd
20 | pretrained_weights: vit_base_patch16_dino
21 | random_seed: 42
22 | store_ckp: true
23 | test_batch_size: 32
24 | train_type: prompt_cam
25 | vpt_dropout: 0.1
26 | vpt_layer: null
27 | vpt_mode: null
28 | vpt_num: 200 #Change this for new dataset
29 | warmup_epoch: 20
30 | warmup_lr_init: 0
31 | wd: 0.001
32 |
--------------------------------------------------------------------------------
/experiment/config/prompt_cam/dino/dog/args.yaml:
--------------------------------------------------------------------------------
1 | batch_size: 16
2 | crop_size: 224
3 | data: dog
4 | data_path: ./data/images/dog
5 | debug: true
6 | drop_path_rate: 0.1
7 | early_patience: 101
8 | epoch: 100
9 | eval_freq: 10
10 | final_acc_hp: true
11 | final_run: true
12 | full: false
13 | gpu_num: 1
14 | lr: 0.005
15 | lr_min: 1.0e-06
16 | model: dino
17 | momentum: 0.9
18 | normalized: true
19 | optimizer: sgd
20 | pretrained_weights: vit_base_patch16_dino
21 | random_seed: 42
22 | store_ckp: true
23 | test_batch_size: 32
24 | train_type: prompt_cam
25 | vpt_dropout: 0.1
26 | vpt_layer: null
27 | vpt_mode: null
28 | vpt_num: 120
29 | warmup_epoch: 20
30 | warmup_lr_init: 0
31 | wd: 0.001
32 |
--------------------------------------------------------------------------------
/experiment/config/prompt_cam/dino/pet/args.yaml:
--------------------------------------------------------------------------------
1 | batch_size: 16
2 | crop_size: 224
3 | data: pet
4 | data_path: ./data/images/pet
5 | debug: true
6 | drop_path_rate: 0.1
7 | early_patience: 101
8 | epoch: 100
9 | eval_freq: 10
10 | final_acc_hp: true
11 | final_run: true
12 | full: false
13 | gpu_num: 1
14 | lr: 0.005
15 | lr_min: 1.0e-06
16 | model: dino
17 | momentum: 0.9
18 | normalized: true
19 | optimizer: sgd
20 | pretrained_weights: vit_base_patch16_dino
21 | random_seed: 42
22 | store_ckp: true
23 | test_batch_size: 32
24 | train_type: prompt_cam
25 | vpt_dropout: 0.1
26 | vpt_layer: null
27 | vpt_mode: null
28 | vpt_num: 37
29 | warmup_epoch: 20
30 | warmup_lr_init: 0
31 | wd: 0.001
32 |
--------------------------------------------------------------------------------
/experiment/config/prompt_cam/dinov2/cub/args.yaml:
--------------------------------------------------------------------------------
1 | batch_size: 16
2 | crop_size: 224
3 | data: cub
4 | data_path: ./data/images/cub
5 | debug: true
6 | drop_path_rate: 0.0
7 | early_patience: 101
8 | epoch: 130
9 | eval_freq: 10
10 | final_acc_hp: true
11 | final_run: true
12 | full: false
13 | gpu_num: 1
14 | lr: 0.001
15 | lr_min: 1.0e-06
16 | model: dinov2
17 | momentum: 0.9
18 | normalized: true
19 | optimizer: sgd
20 | pretrained_weights: vit_base_patch14_dinov2
21 | random_seed: 42
22 | store_ckp: true
23 | test_batch_size: 32
24 | train_type: prompt_cam
25 | vpt_dropout: 0.0
26 | vpt_layer: null
27 | vpt_mode: null
28 | vpt_num: 200
29 | warmup_epoch: 20
30 | warmup_lr_init: 1.0e-06
31 | wd: 0.001
32 |
--------------------------------------------------------------------------------
/experiment/config/prompt_cam/dinov2/dog/args.yaml:
--------------------------------------------------------------------------------
1 | batch_size: 16
2 | crop_size: 224
3 | data: dog
4 | data_path: ./data/images/dog
5 | debug: true
6 | drop_path_rate: 0.0
7 | early_patience: 101
8 | epoch: 100
9 | eval_freq: 10
10 | final_acc_hp: true
11 | final_run: true
12 | full: false
13 | gpu_num: 1
14 | lr: 0.005
15 | lr_min: 1.0e-06
16 | model: dinov2
17 | momentum: 0.9
18 | normalized: true
19 | optimizer: sgd
20 | pretrained_weights: vit_base_patch14_dinov2
21 | random_seed: 42
22 | store_ckp: true
23 | test_batch_size: 32
24 | train_type: prompt_cam
25 | vpt_dropout: 0.0
26 | vpt_layer: null
27 | vpt_mode: null
28 | vpt_num: 120
29 | warmup_epoch: 20
30 | warmup_lr_init: 1.0e-06
31 | wd: 0.001
32 |
--------------------------------------------------------------------------------
/experiment/config/prompt_cam/dinov2/pet/args.yaml:
--------------------------------------------------------------------------------
1 | batch_size: 16
2 | crop_size: 224
3 | data: pet
4 | data_path: ./data/images/pet
5 | debug: true
6 | drop_path_rate: 0.0
7 | early_patience: 101
8 | epoch: 100
9 | eval_freq: 10
10 | final_acc_hp: true
11 | final_run: true
12 | full: false
13 | gpu_num: 1
14 | lr: 0.005
15 | lr_min: 1.0e-06
16 | model: dinov2
17 | momentum: 0.9
18 | normalized: true
19 | optimizer: sgd
20 | pretrained_weights: vit_base_patch14_dinov2
21 | random_seed: 42
22 | store_ckp: true
23 | test_batch_size: 32
24 | train_type: prompt_cam
25 | vpt_dropout: 0.0
26 | vpt_layer: null
27 | vpt_mode: null
28 | vpt_num: 37
29 | warmup_epoch: 20
30 | warmup_lr_init: 1.0e-06
31 | wd: 0.001
32 |
--------------------------------------------------------------------------------
/experiment/run.py:
--------------------------------------------------------------------------------
1 | import os
2 | from experiment.build_model import get_model
3 | from experiment.build_loader import get_loader
4 | from utils.global_var import OUTPUT_DIR, TUNE_DIR, TUNE_DIR_TEST
5 | from engine.trainer import Trainer
6 | from timm.utils import get_outdir
7 | from utils.log_utils import logging_env_setup
8 | from utils.misc import method_name
9 | from datetime import datetime
10 | import yaml
11 | import torch
12 | from utils.misc import set_seed
13 | from collections import OrderedDict
14 | from statistics import mean
15 | import json
16 | import time
17 | import csv
18 | import numpy as np
19 | from utils.setup_logging import get_logger
20 |
21 | logger = get_logger("Prompt_CAM")
22 |
23 |
24 | def train(params, train_loader, val_loader, test_loader):
25 | model, tune_parameters, model_grad_params_no_head = get_model(params)
26 | trainer = Trainer(model, tune_parameters, params)
27 | train_metrics, best_eval_metrics, eval_metrics = trainer.train_classifier(train_loader, val_loader, test_loader)
28 | return train_metrics, best_eval_metrics, eval_metrics, model_grad_params_no_head, trainer.model
29 |
30 | def basic_run(params):
31 | if torch.cuda.is_available():
32 | torch.cuda.empty_cache()
33 | data_name = params.data.split("-")[-1]
34 | dataset_name = params.data.split("-")[0]
35 | method = method_name(params)
36 | start_time = datetime.now().strftime("%Y-%m-%d-%H:%M")
37 | output_dir = os.path.join(OUTPUT_DIR, params.pretrained_weights, dataset_name, method, data_name, start_time)
38 | params.output_dir = get_outdir(output_dir)
39 | params_text = yaml.safe_dump(params.__dict__, default_flow_style=False)
40 | with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
41 | f.write(params_text)
42 | logging_env_setup(params)
43 | logger.info(f'Start loading {data_name}')
44 | train_loader, val_loader, test_loader = get_loader(params, logger)
45 |
46 | train(params, train_loader, val_loader, test_loader)
47 |
48 |
49 | def update_output_dir(default_params, test):
50 | logger.info(f'start running {default_params.method_name}')
51 | if torch.cuda.is_available():
52 | torch.cuda.empty_cache()
53 | data_name = default_params.data.split("-")[-1]
54 | dataset_name = default_params.data.split("-")[0]
55 | method = default_params.method_name
56 | if test:
57 | output_dir = os.path.join(TUNE_DIR_TEST, default_params.experiment_name, dataset_name, data_name, method)
58 | else:
59 | output_dir = os.path.join(TUNE_DIR, default_params.experiment_name, dataset_name, data_name, method)
60 | default_params.output_dir = output_dir
61 |
62 | logging_env_setup(default_params)
63 | return output_dir, data_name
64 |
65 |
66 |
67 | def evaluate(default_params):
68 | _, _, test_loader = get_loader(default_params, logger)
69 | if 'eval' in default_params.test_data:
70 | result_name = f'{default_params.test_data.split("_")[1]}_result.json'
71 | else:
72 | result_name = f'{default_params.test_data}_result.json'
73 | if not os.path.isfile(os.path.join(default_params.output_dir, result_name)):
74 | if not os.path.isfile(os.path.join(default_params.output_dir, 'final_result.json')):
75 | logger.info('no final_result.json, the model is not fine-tuned, show model zero shot performance')
76 | best_tune = ()
77 | result_name = 'zero_shot_' + result_name
78 | else:
79 | result = json.load(open(os.path.join(default_params.output_dir, 'final_result.json')))
80 | best_tune = result['best_tune']
81 | default_params.update(best_tune)
82 |
83 | model, tune_parameters, model_grad_params_no_head = get_model(default_params)
84 | trainer = Trainer(model, tune_parameters, default_params)
85 | if not os.path.isfile(os.path.join(default_params.output_dir, 'model.pt')):
86 | assert not os.path.isfile(os.path.join(default_params.output_dir, 'final_result.json'))
87 | logger.info('no model.pt, shows zero shot performance')
88 | else:
89 | trainer.load_weight()
90 | eval_metrics = trainer.eval_classifier(test_loader, 'test')
91 | json.dump(
92 | {"avg_acc": eval_metrics['top1'], "inserted_parameters": model_grad_params_no_head,
93 | 'best_tune': best_tune},
94 | open(os.path.join(default_params.output_dir, result_name), 'w'))
95 | else:
96 | logger.info(f'finish {result_name} for {default_params.method_name}')
97 | return
98 |
99 |
100 | def result_tracker(first_col, train_metrics, eval_metrics, best_eval_metrics, filename, write_header=False, first_col_name='param_set',
101 | eval_name='val_'):
102 | rowd = OrderedDict([(first_col_name, first_col)])
103 | rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
104 | rowd.update([(eval_name + k, v) for k, v in eval_metrics.items()])
105 | rowd.update([(eval_name + "best_" + k, v) for k, v in best_eval_metrics.items()])
106 | with open(filename, mode='a') as cf:
107 | dw = csv.DictWriter(cf, fieldnames=rowd.keys())
108 | if write_header:
109 | dw.writeheader()
110 | dw.writerow(rowd)
111 |
--------------------------------------------------------------------------------
/experiment/visualize_run.py:
--------------------------------------------------------------------------------
1 | import os
2 | from experiment.build_model import get_model
3 | from experiment.build_loader import get_loader
4 | from timm.utils import get_outdir,accuracy
5 | from utils.log_utils import logging_env_setup
6 | from utils.misc import AverageMeter
7 | import torch
8 | import numpy as np
9 | from utils.setup_logging import get_logger
10 |
11 | from utils.visual_utils import combine_images,create_overlay_images
12 |
13 |
14 | logger = get_logger("Prompt_CAM")
15 |
16 |
17 | def basic_vis(params):
18 |
19 | if torch.cuda.is_available():
20 | torch.cuda.empty_cache()
21 | dataset_name = params.data.split("-")[0]
22 | top_traits = f"top_traits_{params.top_traits}"
23 | output_dir = os.path.join(params.vis_outdir, params.model, dataset_name, "class_"+str(params.vis_cls),top_traits)
24 | params.output_dir = get_outdir(output_dir)
25 | logging_env_setup(params)
26 |
27 |
28 | logger.info(f'Start loading test data: {dataset_name}')
29 | _, _, test_loader = get_loader(params, logger)
30 |
31 | logger.info(f'Start loading model: {params.model}')
32 | model, _ , _ = get_model(params)
33 | model.load_state_dict(torch.load(params.checkpoint)['model_state_dict'])
34 | logger.info (f'Model loaded from {params.checkpoint}')
35 |
36 | top1_m = AverageMeter()
37 |
38 | model.eval()
39 |
40 |
41 | params.test_batch_size= 1
42 | _, _, test_loader = get_loader(params, logger)
43 |
44 |
45 | smpl_count = 0
46 | with torch.no_grad():
47 | for batch_idx, (samples, targets) in enumerate(test_loader):
48 | # move data to device
49 | samples = samples.to(params.device, non_blocking=True) # (batchsize, channel, height, width)
50 | targets = targets.to(params.device, non_blocking=True) # (batchsize, )
51 |
52 | if targets[0].item() == params.vis_cls:
53 | smpl_count += 1
54 |
55 | outputs, attn_map = model(samples)
56 | predicted_class = torch.argmax(outputs, dim=1).item()
57 |
58 | if predicted_class == targets[0].item():
59 | logger.info(f"Predicted class: {predicted_class}, Target class: {targets[0].item()}")
60 | prune_attn_heads(model,samples,targets, predicted_class,smpl_count, params)
61 | else:
62 | attn_map = attn_map[:, :, targets[0].item(), (params.vpt_num+1):]
63 | create_overlay_images(samples,
64 | model.patch_size,
65 | attn_map,
66 | f'{params.output_dir}/img_{smpl_count}')
67 |
68 | combine_images(path=f'{params.output_dir}/img_{smpl_count}', pred_class=predicted_class)
69 |
70 |
71 | if smpl_count == params.nmbr_samples:
72 | break
73 |
74 |
75 |
76 |
77 | #TODO: ADD Later
78 | with torch.no_grad():
79 | for batch_idx, (samples, targets) in enumerate(test_loader):
80 | # move data to device
81 | samples = samples.to(params.device, non_blocking=True) # (batchsize, 2048)
82 | targets = targets.to(params.device, non_blocking=True) # (batchsize, )
83 |
84 | outputs,_ = model(samples)
85 | acc1,_= accuracy(outputs.squeeze(-1), targets, topk=(1,5))
86 | top1_m.update(acc1.item(), samples.shape[0])
87 |
88 | del outputs, acc1
89 |
90 | logger.info("Evaluate: average top1: {:.2f}".format(top1_m.avg))
91 |
92 |
93 | def prune_attn_heads(model,inputs,target, prediction,smpl_count, params):
94 | remaining_head_list = list(range(model.num_heads))
95 | pruned_head_index = None
96 | blur_head_lst = []
97 |
98 | while len(remaining_head_list) > 0 and len(remaining_head_list) > params.top_traits:
99 | highest_score=-1e8
100 | remaining_head_scores= []
101 |
102 | for head_idx in remaining_head_list:
103 | output,_ = model(inputs,
104 | blur_head_lst=blur_head_lst+[head_idx],
105 | target_cls=prediction)
106 |
107 | probabilities = torch.softmax(output.squeeze(-1), dim=-1)
108 |
109 | remaining_head_scores.append(probabilities[0,prediction].item())
110 |
111 | if remaining_head_scores[-1] > highest_score:
112 | highest_score=remaining_head_scores[-1]
113 | pruned_head_index=head_idx
114 |
115 | if pruned_head_index is not None:
116 | blur_head_lst.append(pruned_head_index)
117 | remaining_head_list.remove(pruned_head_index)
118 | print(f'best head to prune is {pruned_head_index+1} with score {highest_score}')
119 |
120 | sorted_remaining_heads = [head for _, head in sorted(zip(remaining_head_scores, remaining_head_list))]
121 |
122 | _,attn_map=model(inputs,
123 | blur_head_lst=blur_head_lst,
124 | target_cls=prediction)
125 | attn_map = attn_map[:, sorted_remaining_heads, prediction, (params.vpt_num+1):]
126 | create_overlay_images(inputs,
127 | model.patch_size,
128 | attn_map,
129 | f'{params.output_dir}/img_{smpl_count}')
130 |
131 | combine_images(path=f'{params.output_dir}/img_{smpl_count}', pred_class=prediction)
132 |
133 |
134 |
135 |
136 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from experiment.run import basic_run
3 | from utils.setup_logging import get_logger
4 | from utils.misc import set_seed,load_yaml,override_args_with_yaml
5 | import time
6 |
7 | logger = get_logger("Prompt_CAM")
8 |
9 |
10 | def main():
11 | args = setup_parser().parse_args()
12 |
13 | if args.config:
14 | yaml_config = load_yaml(args.config)
15 | if yaml_config:
16 | args = override_args_with_yaml(args, yaml_config)
17 |
18 | set_seed(args.random_seed)
19 | start = time.time()
20 | args.vis_attn= False
21 | basic_run(args)
22 | end = time.time()
23 | logger.info(f'----------- Total Run time : {(end - start) / 60} mins-----------')
24 |
25 |
26 | def setup_parser():
27 | parser = argparse.ArgumentParser(description='Prompt_CAM')
28 |
29 | ########################Pretrained Model#########################
30 | parser.add_argument('--pretrained_weights', type=str, default='vit_base_patch16_224_in21k',
31 | choices=['vit_base_patch16_224_in21k', 'vit_base_mae', 'vit_base_patch14_dinov2','vit_base_patch16_dino',
32 | 'vit_base_patch16_clip_224'],
33 | help='pretrained weights name')
34 | parser.add_argument('--drop_path_rate', default=0.1,
35 | type=float,
36 | help='Drop Path Rate (default: %(default)s)')
37 | parser.add_argument('--model', type=str, default='dinov2', choices=['vit', 'dino', 'dinov2'],
38 | help='pretrained model name')
39 |
40 | parser.add_argument('--train_type', type=str, default='vpt', choices=['vpt', 'prompt_cam', 'linear'],
41 | help='pretrained model name')
42 |
43 | ########################Optimizer Scheduler#########################
44 | parser.add_argument('--optimizer', default='sgd', choices=['sgd', 'adam', 'adamw'],
45 | help='Optimizer (default: %(default)s)')
46 | parser.add_argument('--lr', default=0.005,
47 | type=float,
48 | help='Learning rate (default: %(default)s)')
49 | parser.add_argument('--epoch', default=100,
50 | type=int,
51 | help='The number of total epochs used. (default: %(default)s)')
52 | parser.add_argument('--warmup_epoch', default=20,
53 | type=int,
54 | help='warnup epoch in scheduler. (default: %(default)s)')
55 | parser.add_argument('--lr_min', type=float, default=1e-5,
56 | help='lr_min for scheduler (default: %(default)s)')
57 | parser.add_argument('--warmup_lr_init', type=float, default=1e-6,
58 | help='warmup_lr_init for scheduler (default: %(default)s)')
59 | parser.add_argument('--batch_size', default=16,
60 | type=int,
61 | help='Batch size (default: %(default)s)')
62 | parser.add_argument('--test_batch_size', default=32,
63 | type=int,
64 | help='Test batch size (default: %(default)s)')
65 | parser.add_argument('--wd', type=float, default=0.001,
66 | help='weight_decay (default: %(default)s)')
67 | parser.add_argument('--momentum', type=float, default=0.9,
68 | help='momentum used in sgd (default: %(default)s)')
69 | parser.add_argument('--early_patience', type=int, default=101,
70 | help='early stop patience (default: %(default)s)')
71 |
72 | ########################Data#########################
73 | parser.add_argument('--data', default="processed_vtab-dtd",
74 | help='data name. (default: %(default)s)')
75 | parser.add_argument('--data_path', default="data_folder/vtab_processed",
76 | help='Path to the dataset. (default: %(default)s)')
77 | parser.add_argument('--crop_size', default=224,
78 | type=int,
79 | help='Crop size of the input image (default: %(default)s)')
80 | parser.add_argument('--final_run', action='store_false',
81 | help='If final_run is True, use train+val as train data else, use train only')
82 | parser.add_argument('--normalized', action='store_false',
83 | help='If imagees are normalized using ImageNet mean and variance ')
84 |
85 | ########################VPT#########################
86 | parser.add_argument('--vpt_mode', type=str, default=None, choices=['deep', 'shallow'],
87 | help='VPT mode, deep or shallow')
88 | parser.add_argument('--vpt_num', default=10, type=int,
89 | help='Number of prompts (default: %(default)s)')
90 | parser.add_argument('--vpt_layer', default=None, type=int,
91 | help='Number of layers to add prompt, start from the last layer (default: %(default)s)')
92 | parser.add_argument('--vpt_dropout', default=0.1, type=float,
93 | help='VPT dropout rate for deep mode. (default: %(default)s)')
94 |
95 |
96 | ########################full#########################
97 | parser.add_argument('--full', action='store_true',
98 | help='whether turn on full finetune')
99 |
100 | ########################Misc#########################
101 | parser.add_argument('--gpu_num', default=1,
102 | type=int,
103 | help='Number of GPU (default: %(default)s)')
104 | parser.add_argument('--debug', action='store_false',
105 | help='Debug mode to show more information (default: %(default)s)')
106 | parser.add_argument('--random_seed', default=42,
107 | type=int,
108 | help='Random seed (default: %(default)s)')
109 | parser.add_argument('--eval_freq', default=10,
110 | type=int,
111 | help='eval frequency(epoch) testset (default: %(default)s)')
112 | parser.add_argument('--store_ckp', action='store_true',
113 | help='whether store checkpoint')
114 | parser.add_argument('--final_acc_hp', action='store_false',
115 | help='if true, use the best acc during all epochs as criteria to select HP, if false, use the acc at final epoch as criteria to select HP ')
116 |
117 | ######################## YAML Config #########################
118 | parser.add_argument('--config', type=str, default=None, help='Path to YAML config file')
119 |
120 |
121 | return parser
122 |
123 |
124 | if __name__ == '__main__':
125 | main()
126 |
--------------------------------------------------------------------------------
/model/attention.py:
--------------------------------------------------------------------------------
1 | from torch.jit import Final
2 | import torch.nn as nn
3 | from timm.layers import use_fused_attn
4 | import torch
5 | import torch.nn.functional as F
6 | from typing import Tuple
7 |
8 |
9 | class AttentionPETL(nn.Module):
10 | fused_attn: Final[bool]
11 |
12 | def __init__(
13 | self,
14 | dim: int,
15 | num_heads: int = 8,
16 | qkv_bias: bool = False,
17 | qk_norm: bool = False,
18 | attn_drop: float = 0.,
19 | proj_drop: float = 0.,
20 | norm_layer: nn.Module = nn.LayerNorm,
21 | params=None,
22 | ) -> None:
23 | super().__init__()
24 | assert dim % num_heads == 0, 'dim should be divisible by num_heads'
25 | self.num_heads = num_heads
26 | self.head_dim = dim // num_heads
27 | self.scale = self.head_dim ** -0.5
28 | self.fused_attn = use_fused_attn()
29 |
30 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
31 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
32 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
33 | self.attn_drop = nn.Dropout(attn_drop)
34 | self.proj = nn.Linear(dim, dim)
35 | self.proj_drop = nn.Dropout(proj_drop)
36 |
37 | ############# Added module #############
38 | self.params = params
39 | ############# Added module end #############
40 |
41 | def forward(self, x: torch.Tensor, block_idx, blur_head_lst=[],target_cls=-1) -> Tuple[torch.Tensor,torch.Tensor]:
42 | B, N, C = x.shape
43 | ############# Added module #############
44 | qkv = self.qkv(x)
45 | ############# Added module end #############
46 |
47 | qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
48 | q, k, v = qkv.unbind(0)
49 |
50 | q, k = self.q_norm(q), self.k_norm(k)
51 |
52 | q = q * self.scale
53 | attn = q @ k.transpose(-2, -1)
54 |
55 | ############# Added module #############
56 | if len(blur_head_lst)!=0:
57 | attn[:, blur_head_lst, target_cls, :] = 0
58 |
59 | ############# Added module end #############
60 |
61 | attn = attn.softmax(dim=-1)
62 | attn = self.attn_drop(attn)
63 | x = attn @ v
64 |
65 | x = x.transpose(1, 2).reshape(B, N, C)
66 | proj = self.proj(x)
67 | x = self.proj_drop(proj)
68 | return x,attn
69 |
--------------------------------------------------------------------------------
/model/block.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from timm.layers import DropPath
3 | from timm.models.vision_transformer import LayerScale
4 | from timm.layers.trace_utils import _assert
5 | import torch
6 | from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union, List
7 | from model.mlp import MlpPETL
8 | from model.attention import AttentionPETL
9 |
10 | class BlockPETL(nn.Module):
11 | def __init__(
12 | self,
13 | dim: int,
14 | num_heads: int,
15 | mlp_ratio: float = 4.,
16 | qkv_bias: bool = False,
17 | qk_norm: bool = False,
18 | proj_drop: float = 0.,
19 | attn_drop: float = 0.,
20 | init_values: Optional[float] = None,
21 | drop_path: float = 0.,
22 | act_layer: nn.Module = nn.GELU,
23 | norm_layer: nn.Module = nn.LayerNorm,
24 | mlp_layer: nn.Module = MlpPETL,
25 | params=None
26 | ) -> None:
27 | super().__init__()
28 | self.norm1 = norm_layer(dim)
29 | self.attn = AttentionPETL(
30 | dim,
31 | num_heads=num_heads,
32 | qkv_bias=qkv_bias,
33 | qk_norm=qk_norm,
34 | attn_drop=attn_drop,
35 | proj_drop=proj_drop,
36 | norm_layer=norm_layer,
37 | ############# Added module #############
38 | params=params
39 | ############# Added module end #############
40 | )
41 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
42 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
43 |
44 | self.norm2 = norm_layer(dim)
45 | self.mlp = mlp_layer(
46 | in_features=dim,
47 | hidden_features=int(dim * mlp_ratio),
48 | act_layer=act_layer,
49 | drop=proj_drop
50 | )
51 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
52 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
53 |
54 | ############# Added module #############
55 | self.params = params
56 | ############# Added module end #############
57 |
58 | def forward(self, x: torch.Tensor, idx, blur_head_lst=[], target_cls=-1) -> Tuple[torch.Tensor,torch.Tensor]:
59 | output, attn_map = self.attn(self.norm1(x), idx , blur_head_lst=blur_head_lst, target_cls=target_cls)
60 | x = x + self.drop_path1(self.ls1(output))
61 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
62 | return x,attn_map
63 |
64 |
65 |
66 |
--------------------------------------------------------------------------------
/model/mlp.py:
--------------------------------------------------------------------------------
1 | from torch import nn as nn
2 | from timm.layers.helpers import to_2tuple
3 | from functools import partial
4 |
5 |
6 | class MlpPETL(nn.Module):
7 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks
8 | """
9 | def __init__(
10 | self,
11 | in_features,
12 | hidden_features=None,
13 | out_features=None,
14 | act_layer=nn.GELU,
15 | norm_layer=None,
16 | bias=True,
17 | drop=0.,
18 | use_conv=False
19 | ):
20 | super().__init__()
21 |
22 | out_features = out_features or in_features
23 | hidden_features = hidden_features or in_features
24 | bias = to_2tuple(bias)
25 | drop_probs = to_2tuple(drop)
26 | linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
27 |
28 | self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
29 | self.act = act_layer()
30 | self.drop1 = nn.Dropout(drop_probs[0])
31 | self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
32 | self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
33 | self.drop2 = nn.Dropout(drop_probs[1])
34 |
35 |
36 | def forward(self, x):
37 | h = self.fc1(x)
38 | x = self.act(h)
39 | x = self.drop1(x)
40 | x = self.norm(x)
41 | h = self.fc2(x)
42 | x = self.drop2(h)
43 | return x
--------------------------------------------------------------------------------
/model/patch_embed.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, List, Optional, Tuple, Union
2 |
3 | import torch
4 | from torch import nn as nn
5 | import torch.nn.functional as F
6 |
7 | from timm.layers.format import Format, nchw_to
8 | from timm.layers.helpers import to_2tuple
9 | from timm.layers.trace_utils import _assert
10 |
11 |
12 |
13 | class PatchEmbedPETL(nn.Module):
14 | """ 2D Image to Patch Embedding
15 | """
16 | output_fmt: Format
17 | dynamic_img_pad: torch.jit.Final[bool]
18 |
19 | def __init__(
20 | self,
21 | img_size: Optional[int] = 224,
22 | patch_size: int = 16,
23 | in_chans: int = 3,
24 | embed_dim: int = 768,
25 | norm_layer: Optional[Callable] = None,
26 | flatten: bool = True,
27 | output_fmt: Optional[str] = None,
28 | bias: bool = True,
29 | strict_img_size: bool = True,
30 | dynamic_img_pad: bool = False,
31 | params = None
32 | ):
33 | super().__init__()
34 | self.patch_size = to_2tuple(patch_size)
35 | if img_size is not None:
36 | self.img_size = to_2tuple(img_size)
37 | self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
38 | self.num_patches = self.grid_size[0] * self.grid_size[1]
39 | else:
40 | self.img_size = None
41 | self.grid_size = None
42 | self.num_patches = None
43 |
44 | if output_fmt is not None:
45 | self.flatten = False
46 | self.output_fmt = Format(output_fmt)
47 | else:
48 | # flatten spatial dim and transpose to channels last, kept for bwd compat
49 | self.flatten = flatten
50 | self.output_fmt = Format.NCHW
51 | self.strict_img_size = strict_img_size
52 | self.dynamic_img_pad = dynamic_img_pad
53 |
54 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
55 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
56 |
57 | ############# Added module #############
58 | self.params = params
59 | self.norm_layer = norm_layer
60 | ############# Added module end #############
61 |
62 | def forward(self, x):
63 | B, C, H, W = x.shape
64 | if self.img_size is not None:
65 | if self.strict_img_size:
66 | _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
67 | _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
68 | elif not self.dynamic_img_pad:
69 | _assert(
70 | H % self.patch_size[0] == 0,
71 | f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
72 | )
73 | _assert(
74 | W % self.patch_size[1] == 0,
75 | f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
76 | )
77 | if self.dynamic_img_pad:
78 | pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
79 | pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
80 | x = F.pad(x, (0, pad_w, 0, pad_h))
81 | x = self.proj(x)
82 | if self.flatten:
83 | x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
84 | elif self.output_fmt != Format.NCHW:
85 | x = nchw_to(x, self.output_fmt)
86 | x = self.norm(x)
87 | return x
88 |
--------------------------------------------------------------------------------
/model/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | def init_weight(down, up, option):
7 | with torch.no_grad():
8 | if option == 'lora_kaiming':
9 | nn.init.kaiming_uniform_(down.weight, a=math.sqrt(5))
10 | nn.init.zeros_(up.weight)
11 | nn.init.zeros_(down.bias)
12 | nn.init.zeros_(up.bias)
13 | elif option == 'lora_xavier':
14 | nn.init.xavier_uniform_(down.weight)
15 | nn.init.zeros_(up.weight)
16 | nn.init.zeros_(down.bias)
17 | nn.init.zeros_(up.bias)
18 | elif option == 'xavier':
19 | nn.init.xavier_uniform_(down.weight)
20 | nn.init.xavier_uniform_(up.weight)
21 | nn.init.normal_(down.bias, std=1e-6)
22 | nn.init.normal_(up.bias, std=1e-6)
23 | elif option == 'zero':
24 | nn.init.zeros_(down.weight)
25 | nn.init.zeros_(up.bias)
26 | nn.init.zeros_(down.weight)
27 | nn.init.zeros_(up.bias)
28 | else:
29 | raise NotImplementedError
--------------------------------------------------------------------------------
/model/vision_transformer.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union, List
3 |
4 | try:
5 | from typing import Literal
6 | except ImportError:
7 | from typing_extensions import Literal
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torch.utils.checkpoint
12 | from torch.jit import Final
13 | from timm.layers import DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \
14 | trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
15 | get_act_layer, get_norm_layer, LayerType
16 | from timm.models._builder import build_model_with_cfg
17 | from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv
18 | from timm.models._registry import generate_default_cfgs, register_model, register_model_deprecations
19 | from timm.models.vision_transformer import VisionTransformer
20 | from timm.models.vision_transformer import LayerScale, init_weights_vit_timm, get_init_weights_vit, \
21 | _load_weights, checkpoint_filter_fn
22 | ## added for petl
23 | from utils.setup_logging import get_logger
24 | from model.block import BlockPETL
25 | from model.patch_embed import PatchEmbedPETL
26 | from model.mlp import MlpPETL
27 | from model.vpt import VPT
28 |
29 | logger = get_logger("Prompt_CAM")
30 |
31 |
32 | class VisionTransformerPETL(VisionTransformer):
33 | """ Vision Transformer
34 |
35 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
36 | - https://arxiv.org/abs/2010.11929
37 | """
38 | dynamic_img_size: Final[bool]
39 |
40 | def __init__(
41 | self,
42 | img_size: Union[int, Tuple[int, int]] = 224,
43 | patch_size: Union[int, Tuple[int, int]] = 16,
44 | in_chans: int = 3,
45 | num_classes: int = 1000,
46 | global_pool: Literal['', 'avg', 'token', 'map'] = 'token',
47 | embed_dim: int = 768,
48 | depth: int = 12,
49 | num_heads: int = 12,
50 | mlp_ratio: float = 4.,
51 | qkv_bias: bool = True,
52 | qk_norm: bool = False,
53 | init_values: Optional[float] = None,
54 | class_token: bool = True,
55 | no_embed_class: bool = False,
56 | reg_tokens: int = 0,
57 | pre_norm: bool = False,
58 | fc_norm: Optional[bool] = None,
59 | dynamic_img_size: bool = False,
60 | dynamic_img_pad: bool = False,
61 | drop_rate: float = 0.,
62 | pos_drop_rate: float = 0.,
63 | patch_drop_rate: float = 0.,
64 | proj_drop_rate: float = 0.,
65 | attn_drop_rate: float = 0.,
66 | drop_path_rate: float = 0.,
67 | weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '',
68 | embed_layer: Callable = PatchEmbedPETL,
69 | norm_layer: Optional[LayerType] = None,
70 | act_layer: Optional[LayerType] = None,
71 | block_fn: Type[nn.Module] = BlockPETL,
72 | mlp_layer: Type[nn.Module] = MlpPETL,
73 | params=None
74 | ) -> None:
75 | """
76 | Args:
77 | img_size: Input image size.
78 | patch_size: Patch size.
79 | in_chans: Number of image input channels.
80 | num_classes: Mumber of classes for classification head.
81 | global_pool: Type of global pooling for final sequence (default: 'token').
82 | embed_dim: Transformer embedding dimension.
83 | depth: Depth of transformer.
84 | num_heads: Number of attention heads.
85 | mlp_ratio: Ratio of mlp hidden dim to embedding dim.
86 | qkv_bias: Enable bias for qkv projections if True.
87 | init_values: Layer-scale init values (layer-scale enabled if not None).
88 | class_token: Use class token.
89 | no_embed_class: Don't include position embeddings for class (or reg) tokens.
90 | reg_tokens: Number of register tokens.
91 | fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
92 | drop_rate: Head dropout rate.
93 | pos_drop_rate: Position embedding dropout rate.
94 | attn_drop_rate: Attention dropout rate.
95 | drop_path_rate: Stochastic depth rate.
96 | weight_init: Weight initialization scheme.
97 | embed_layer: Patch embedding layer.
98 | norm_layer: Normalization layer.
99 | act_layer: MLP activation layer.
100 | block_fn: Transformer block layer.
101 | """
102 | super().__init__()
103 | assert global_pool in ('', 'avg', 'token', 'map')
104 | assert class_token or global_pool != 'token'
105 | use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
106 | norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
107 | act_layer = get_act_layer(act_layer) or nn.GELU
108 |
109 | self.num_classes = num_classes
110 | self.global_pool = global_pool
111 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
112 | self.num_prefix_tokens = 1 if class_token else 0
113 | self.num_prefix_tokens += reg_tokens
114 | self.num_reg_tokens = reg_tokens
115 | self.has_class_token = class_token
116 | self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg)
117 | self.dynamic_img_size = dynamic_img_size
118 | self.grad_checkpointing = False
119 |
120 | embed_args = {}
121 | if dynamic_img_size:
122 | # flatten deferred until after pos embed
123 | embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
124 | self.patch_embed = embed_layer(
125 | img_size=img_size,
126 | patch_size=patch_size,
127 | in_chans=in_chans,
128 | embed_dim=embed_dim,
129 | bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
130 | dynamic_img_pad=dynamic_img_pad,
131 | params=params,
132 | **embed_args,
133 | )
134 | num_patches = self.patch_embed.num_patches
135 |
136 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
137 | self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
138 | embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
139 | self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
140 | self.pos_drop = nn.Dropout(p=pos_drop_rate)
141 | if patch_drop_rate > 0:
142 | self.patch_drop = PatchDropout(
143 | patch_drop_rate,
144 | num_prefix_tokens=self.num_prefix_tokens,
145 | )
146 | else:
147 | self.patch_drop = nn.Identity()
148 | self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
149 |
150 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
151 |
152 | ############# Added module start #############
153 | self.patch_size = patch_size
154 | self.params = params
155 | if self.params.train_type in ['vpt','prompt_cam']:
156 | self.vpt = VPT(params, depth, patch_size, embed_dim)
157 | ############# Added module end #############
158 |
159 | self.blocks = nn.Sequential(*[
160 | block_fn(
161 | dim=embed_dim,
162 | num_heads=num_heads,
163 | mlp_ratio=mlp_ratio,
164 | qkv_bias=qkv_bias,
165 | qk_norm=qk_norm,
166 | init_values=init_values,
167 | proj_drop=proj_drop_rate,
168 | attn_drop=attn_drop_rate,
169 | drop_path=dpr[i],
170 | norm_layer=norm_layer,
171 | act_layer=act_layer,
172 | mlp_layer=mlp_layer,
173 | ############# Added module start #############
174 | params=params
175 | ############# Added module end #############
176 | )
177 | for i in range(depth)])
178 | self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
179 |
180 | # Classifier Head
181 | if global_pool == 'map':
182 | self.attn_pool = AttentionPoolLatent(
183 | self.embed_dim,
184 | num_heads=num_heads,
185 | mlp_ratio=mlp_ratio,
186 | norm_layer=norm_layer,
187 | )
188 | else:
189 | self.attn_pool = None
190 | self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
191 | self.head_drop = nn.Dropout(drop_rate)
192 | self.num_heads= num_heads
193 |
194 | ############# Added module start #############
195 | if self.params.train_type == 'vpt' or self.params.train_type == 'linear':
196 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
197 | elif self.params.train_type == 'prompt_cam':
198 | self.head = nn.Linear(self.embed_dim, 1)
199 | ############# Added module end #############
200 |
201 |
202 | if weight_init != 'skip':
203 | self.init_weights(weight_init)
204 |
205 | @torch.jit.ignore()
206 | def load_pretrained(self, checkpoint_path: str, prefix: str = '') -> None:
207 | _load_weights_PETL(self, checkpoint_path, prefix)
208 |
209 | def forward_features(self, x: torch.Tensor, blur_head_lst=[], target_cls=-1) -> Tuple[torch.Tensor,torch.Tensor]:
210 | pcam_outputs = None
211 | x = self.patch_embed(x)
212 | x = self._pos_embed(x)
213 | x = self.patch_drop(x)
214 | x = self.norm_pre(x)
215 |
216 | if self.grad_checkpointing and not torch.jit.is_scripting():
217 | x = checkpoint_seq(self.blocks, x)
218 | else:
219 | ############# Added module #############
220 | for idx, block in enumerate(self.blocks):
221 | if self.params.train_type in ['vpt','prompt_cam']:
222 | prompt = self.vpt.retrieve_prompt(idx, x.shape[0])
223 | if prompt is not None:
224 | x = torch.cat([prompt, x], dim=1)
225 |
226 | # forward block
227 | if idx == len(self.blocks) - 1:
228 | x,attn_map = block(x, idx, blur_head_lst=blur_head_lst, target_cls=target_cls)
229 | else:
230 | x,_ = block(x, idx)
231 |
232 | if self.params.vpt_mode and prompt is not None:
233 | x = x[:, self.params.vpt_num:, :]
234 | elif self.params.train_type == 'prompt_cam':
235 | pcam_outputs = x
236 | x = x[:, self.params.vpt_num:, :]
237 |
238 | if self.params.train_type == 'prompt_cam':
239 | x = pcam_outputs
240 | ############# Added module end #############
241 | x = self.norm(x)
242 | return x,attn_map
243 |
244 | def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
245 | #import pdb;pdb.set_trace()
246 | if self.params.train_type == 'prompt_cam':
247 | output_feature = x[:,:self.params.vpt_num]
248 | else:
249 | if self.attn_pool is not None:
250 | output_feature = self.attn_pool(x)
251 | elif self.global_pool == 'avg':
252 | output_feature = x[:, self.num_prefix_tokens:].mean(dim=1)
253 | elif self.global_pool:
254 | output_feature = x[:, 0] # class token
255 | else:
256 | output_feature = x
257 |
258 | output_feature = self.fc_norm(output_feature)
259 | output_feature = self.head_drop(output_feature)
260 |
261 | return output_feature if pre_logits else self.head(output_feature)
262 |
263 | def forward(self, x: torch.Tensor, blur_head_lst=[], target_cls=-1) -> Tuple[torch.Tensor,torch.Tensor]:
264 | attn_maps = None
265 | if self.params.vis_attn:
266 | x,attn_maps = self.forward_features(x,
267 | blur_head_lst=blur_head_lst,
268 | target_cls=target_cls)
269 | else:
270 | x,_ = self.forward_features(x)
271 |
272 | x = self.forward_head(x)
273 | return x, attn_maps
274 |
275 | def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
276 | self.num_classes = num_classes
277 | if global_pool is not None:
278 | assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
279 | if global_pool == 'map' and self.attn_pool is None:
280 | assert False, "Cannot currently add attention pooling in reset_classifier()."
281 | elif global_pool != 'map' and self.attn_pool is not None:
282 | self.attn_pool = None # remove attention pooling
283 | self.global_pool = global_pool
284 | ############# Added module #############
285 | if self.params.train_type == 'vpt' or self.params.train_type == 'linear':
286 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
287 | else:
288 | self.head = nn.Linear(self.embed_dim, 1)
289 | ############# Added module end #############
290 | #Original
291 | # self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
292 |
293 |
294 | @torch.no_grad()
295 | def _load_weights_PETL(model: VisionTransformerPETL, checkpoint_path: str, prefix: str = ''):
296 | if checkpoint_path.endswith('.npz'):
297 | _load_weights(model, checkpoint_path, prefix)
298 | elif checkpoint_path.endswith('.pth') or checkpoint_path.endswith('.bin'):
299 | _load_weights_pth(model, checkpoint_path, checkpoint_filter_fn)
300 |
301 |
302 | def _load_weights_pth(model, checkpoint_path, filter_fn=checkpoint_filter_fn):
303 | """ Load weights from .pth checkpoints
304 | """
305 | state_dict = torch.load(checkpoint_path, map_location='cpu')
306 | if filter_fn is not None:
307 | state_dict = filter_fn(state_dict, model)
308 | if 'head.weight' in state_dict:
309 | state_dict.pop('head.weight', None)
310 | if 'head.bias' in state_dict:
311 | state_dict.pop('head.bias', None)
312 | model.load_state_dict(state_dict, strict=False)
313 |
314 |
315 | def _create_vision_transformer_petl(variant: str, pretrained: bool = False, **kwargs):
316 | if kwargs.get('features_only', None):
317 | raise RuntimeError('features_only not implemented for Vision Transformer models.')
318 |
319 | if 'flexi' in variant:
320 | # Google FlexiViT pretrained models have a strong preference for bilinear patch / embed
321 | # interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation.
322 | _filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False)
323 | else:
324 | _filter_fn = checkpoint_filter_fn
325 |
326 | # attn pool (currently only in siglip) params removed if pool disabled, is there a better soln?
327 | strict = True
328 | if 'siglip' in variant and kwargs.get('global_pool', None) != 'map':
329 | strict = False
330 |
331 | return build_model_with_cfg(
332 | VisionTransformerPETL,
333 | variant,
334 | pretrained,
335 | pretrained_filter_fn=checkpoint_filter_fn,
336 | pretrained_strict=strict,
337 | **kwargs,
338 | )
339 |
340 |
341 | @register_model
342 | def vit_base_patch14_dinov2_petl(pretrained: bool = False, **kwargs):
343 | """ ViT-B/14 for DINOv2
344 | change img_size to 224
345 | """
346 | model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, img_size=224)
347 | model = _create_vision_transformer_petl(
348 | 'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
349 | return model
350 |
351 | @register_model
352 | def vit_base_patch16_dino_petl(pretrained: bool = False, **kwargs):
353 | """ ViT-B/16 for DINO
354 | """
355 | model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
356 | model = _create_vision_transformer_petl(
357 | 'vit_base_patch16_224.dino', pretrained=pretrained, **dict(model_args, **kwargs))
358 | return model
359 |
360 | @register_model
361 | def vit_base_patch16_224_in21k_petl(pretrained=False, **kwargs):
362 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
363 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
364 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
365 | """
366 | model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
367 | model = _create_vision_transformer_petl(
368 | 'vit_base_patch16_224_in21k', pretrained=pretrained, **dict(model_args, **kwargs))
369 | return model
370 |
371 |
372 | @register_model
373 | def vit_base_patch16_clip_224_petl(pretrained: bool = False, **kwargs) -> VisionTransformer:
374 | """ ViT-B/16 CLIP image tower
375 | """
376 | model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm,
377 | act_layer='quick_gelu')
378 | model = _create_vision_transformer_petl(
379 | 'vit_base_patch16_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs))
380 | return model
381 |
--------------------------------------------------------------------------------
/model/vpt.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from functools import reduce
4 | from operator import mul
5 | import math
6 | from torch.nn.modules.utils import _pair
7 |
8 |
9 | class VPT(nn.Module):
10 |
11 | def __init__(self, params, depth, patch_size, embed_dim):
12 | super().__init__()
13 | self.params = params
14 | self.depth = depth
15 | if params.vpt_mode == 'shallow':
16 | prompt_layer = 1
17 | elif params.vpt_mode == 'deep':
18 | if params.vpt_layer:
19 | prompt_layer = params.vpt_layer
20 | else:
21 | prompt_layer = depth
22 | elif params.train_type == 'prompt_cam':
23 | prompt_layer = depth
24 | else:
25 | raise ValueError
26 | val = math.sqrt(6. / float(3 * reduce(mul, _pair(patch_size), 1) + embed_dim))
27 | self.prompt_embeddings = nn.Parameter(torch.zeros(
28 | prompt_layer, params.vpt_num, embed_dim))
29 | # xavier_uniform initialization
30 | nn.init.uniform_(self.prompt_embeddings.data, -val, val)
31 | self.prompt_dropout = nn.Dropout(params.vpt_dropout)
32 |
33 |
34 | def retrieve_prompt(self, index, batch_size):
35 | if self.params.vpt_layer:
36 | index = index - (self.depth - self.params.vpt_layer)
37 | if index < 0:
38 | return None
39 | if index < len(self.prompt_embeddings):
40 | return self.prompt_dropout(self.prompt_embeddings[index]).expand(batch_size, -1, -1)
41 | else:
42 | return None
43 |
--------------------------------------------------------------------------------
/samples/Baltimore_Oriole.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Imageomics/Prompt_CAM/2990095277bdaa7bbdc77a079693243925e8a855/samples/Baltimore_Oriole.jpg
--------------------------------------------------------------------------------
/samples/Brewer_Blackbird.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Imageomics/Prompt_CAM/2990095277bdaa7bbdc77a079693243925e8a855/samples/Brewer_Blackbird.jpg
--------------------------------------------------------------------------------
/samples/Orchard_Oriole.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Imageomics/Prompt_CAM/2990095277bdaa7bbdc77a079693243925e8a855/samples/Orchard_Oriole.jpg
--------------------------------------------------------------------------------
/samples/Scott_Oriole.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Imageomics/Prompt_CAM/2990095277bdaa7bbdc77a079693243925e8a855/samples/Scott_Oriole.jpg
--------------------------------------------------------------------------------
/samples/red_winged_blackbird.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Imageomics/Prompt_CAM/2990095277bdaa7bbdc77a079693243925e8a855/samples/red_winged_blackbird.jpg
--------------------------------------------------------------------------------
/samples/rusty_Blackbird.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Imageomics/Prompt_CAM/2990095277bdaa7bbdc77a079693243925e8a855/samples/rusty_Blackbird.jpg
--------------------------------------------------------------------------------
/samples/sample_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Imageomics/Prompt_CAM/2990095277bdaa7bbdc77a079693243925e8a855/samples/sample_image.png
--------------------------------------------------------------------------------
/samples/trait_manipulation.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Imageomics/Prompt_CAM/2990095277bdaa7bbdc77a079693243925e8a855/samples/trait_manipulation.jpg
--------------------------------------------------------------------------------
/samples/yellow_headed_blackbird.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Imageomics/Prompt_CAM/2990095277bdaa7bbdc77a079693243925e8a855/samples/yellow_headed_blackbird.jpg
--------------------------------------------------------------------------------
/utils/distributed.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | """Distributed helpers."""
4 |
5 | import torch
6 | import torch.distributed as dist
7 | _LOCAL_PROCESS_GROUP = None
8 |
9 |
10 | def get_world_size() -> int:
11 | if not dist.is_available():
12 | return 1
13 | if not dist.is_initialized():
14 | return 1
15 | return dist.get_world_size()
16 |
17 |
18 | def get_rank() -> int:
19 | if not dist.is_available():
20 | return 0
21 | if not dist.is_initialized():
22 | return 0
23 | return dist.get_rank()
24 |
25 |
26 | def is_master_process(num_gpus=8):
27 | """
28 | Determines if the current process is the master process.
29 | """
30 | if torch.distributed.is_initialized():
31 | return dist.get_rank() % num_gpus == 0
32 | else:
33 | return True
34 |
35 |
36 | def run(
37 | local_rank,
38 | num_proc,
39 | func,
40 | init_method,
41 | shard_id,
42 | num_shards,
43 | backend,
44 | cfg,
45 | args,
46 | ):
47 | """
48 | Runs a function from a child process.
49 | Args:
50 | local_rank (int): rank of the current process on the current machine.
51 | num_proc (int): number of processes per machine.
52 | func (function): function to execute on each of the process.
53 | init_method (string): method to initialize the distributed training.
54 | TCP initialization: equiring a network address reachable from all
55 | processes followed by the port.
56 | Shared file-system initialization: makes use of a file system that
57 | is shared and visible from all machines. The URL should start with
58 | file:// and contain a path to a non-existent file on a shared file
59 | system.
60 | shard_id (int): the rank of the current machine.
61 | num_shards (int): number of overall machines for the distributed
62 | training job.
63 | backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are
64 | supports, each with different capabilities. Details can be found
65 | here:
66 | https://pytorch.org/docs/stable/distributed.html
67 | cfg (CfgNode): configs. Details can be found in
68 | loco/config/defaults.py
69 | """
70 | # Initialize the process group.
71 | # shard_id = get_rank()
72 | world_size = num_proc * num_shards
73 | rank = shard_id * num_proc + local_rank
74 |
75 | try:
76 | torch.distributed.init_process_group(
77 | backend=backend,
78 | init_method=init_method,
79 | world_size=world_size,
80 | rank=rank,
81 | )
82 | except Exception as e:
83 | raise e
84 |
85 | torch.cuda.set_device(local_rank)
86 | func(cfg, args)
87 |
88 |
89 | def destroy_process_group():
90 | """Destroys the default process group."""
91 | torch.distributed.destroy_process_group()
92 |
93 |
94 | def scaled_all_reduce(cfg, tensors):
95 | """Performs the scaled all_reduce operation on the provided tensors.
96 |
97 | The input tensors are modified in-place. Currently supports only the sum
98 | reduction operator. The reduced values are scaled by the inverse size of
99 | the process group (equivalent to cfg.NUM_GPUS).
100 | """
101 | # Queue the reductions
102 | reductions = []
103 | for tensor in tensors:
104 | reduction = torch.distributed.all_reduce(tensor, async_op=True)
105 | reductions.append(reduction)
106 | # Wait for reductions to finish
107 | for reduction in reductions:
108 | reduction.wait()
109 | # Scale the results
110 | for tensor in tensors:
111 | tensor.mul_(1.0 / cfg.NUM_GPUS / cfg.NUM_SHARDS)
112 | return tensors
113 |
114 |
115 | def cat_all_gather(tensors):
116 | """Performs the concatenated all_gather operation on the provided tensors.
117 | """
118 | tensors_gather = [
119 | torch.ones_like(tensors)
120 | for _ in range(torch.distributed.get_world_size())
121 | ]
122 | torch.distributed.all_gather(tensors_gather, tensors, async_op=False)
123 |
124 | output = torch.cat(tensors_gather, dim=0)
125 | return output
126 |
127 |
128 | def local_cat_all_gather(tensors):
129 | """Performs the concatenated all_gather operation on the provided tensors.
130 | """
131 | tensors_gather = [
132 | torch.ones_like(tensors)
133 | for _ in range(get_local_size())
134 | ]
135 | torch.distributed.all_gather(
136 | tensors_gather,
137 | tensors,
138 | async_op=False,
139 | group=_LOCAL_PROCESS_GROUP,
140 | )
141 | output = torch.cat(tensors_gather, dim=0)
142 | return output
143 |
144 |
145 | def get_local_size():
146 | """
147 | Returns:
148 | The size of the per-machine process group,
149 | i.e. the number of processes per machine.
150 | """
151 | if not dist.is_available():
152 | return 1
153 | if not dist.is_initialized():
154 | return 1
155 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
156 |
157 |
158 | def get_local_rank():
159 | """
160 | Returns:
161 | The rank of the current process within the local (per-machine) process group.
162 | """
163 | if not dist.is_available():
164 | return 0
165 | if not dist.is_initialized():
166 | return 0
167 | assert _LOCAL_PROCESS_GROUP is not None
168 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
169 |
--------------------------------------------------------------------------------
/utils/file_io.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | """
4 | Project specific pathmanagers for a project as recommended by Detectron2
5 | """
6 | from iopath.common.file_io import PathManager as PathManagerBase
7 | from iopath.common.file_io import HTTPURLHandler
8 |
9 |
10 | PathManager = PathManagerBase()
11 | PathManager.register_handler(HTTPURLHandler())
12 |
--------------------------------------------------------------------------------
/utils/global_var.py:
--------------------------------------------------------------------------------
1 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
2 | OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
3 |
4 | TFDS_DATASETS = {
5 | 'caltech101': 102,
6 | 'cifar(num_classes=100)': 100,
7 | 'dtd': 47,
8 | 'oxford_flowers102': 102,
9 | 'oxford_iiit_pet': 37,
10 | 'patch_camelyon': 2,
11 | 'sun397': 397,
12 | 'svhn': 10,
13 | 'resisc45': 45,
14 | 'eurosat': 10,
15 | 'dmlab': 6,
16 | 'kitti(task="closest_vehicle_distance")': 4,
17 | 'smallnorb(predicted_attribute="label_azimuth")': 18,
18 | 'smallnorb(predicted_attribute="label_elevation")': 9,
19 | 'dsprites(predicted_attribute="label_x_position",num_classes=16)': 16,
20 | 'dsprites(predicted_attribute="label_orientation",num_classes=16)': 16,
21 | 'clevr(task="closest_object_distance")': 6,
22 | 'clevr(task="count_all")': 8,
23 | 'diabetic_retinopathy(config="btgraham-300")': 5
24 | }
25 |
26 | VTAB_DATASETS = {'caltech101': 102, 'clevr_count': 8, 'dmlab': 6, 'dsprites_ori': 16, 'eurosat': 10, 'oxford_flowers102': 102,
27 | 'patch_camelyon': 2,
28 | 'smallnorb_azi': 18, 'svhn': 10, 'cifar': 100, 'clevr_dist': 6, 'dsprites_loc': 16, 'dtd': 47,
29 | 'kitti': 4, 'oxford_iiit_pet': 37, 'resisc45': 45,
30 | 'smallnorb_ele': 9, 'sun397': 397, 'diabetic_retinopathy': 5}
31 |
32 | MEAN_STD_DICT = {
33 | 'vit_base_patch16_224_in21k': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
34 | 'vit_base_patch14_dinov2': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
35 | 'vit_base_mae': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
36 | 'vit_base_patch16_clip_224': (OPENAI_CLIP_MEAN, OPENAI_CLIP_STD)
37 | }
38 |
39 | CIFAR_VAL_SPLIT = {
40 | 'cifar100-500': 0.2,
41 | 'cifar100-1k': 0.2,
42 | 'cifar100-5k': 0.1,
43 | 'cifar100-10k': 0.1,
44 | 'cifar100-full': 0.1
45 | }
46 |
47 | RETINOPATHY_VAL_SPLIT = {
48 | 'retinopathy-500': 0.2,
49 | 'retinopathy-5k': 0.1,
50 | 'retinopathy-10k': 0.1,
51 | 'retinopathy-full': 0.1
52 | }
53 |
54 | RESISC_VAL_SPLIT = {
55 | 'resisc-225': 0.2,
56 | 'resisc-450': 0.2,
57 | 'resisc-900': 0.1,
58 | 'resisc-2250': 0.1,
59 | 'resisc-4500': 0.1,
60 | 'resisc-9000': 0.1,
61 | 'resisc-full': 0.1
62 |
63 | }
64 |
65 |
66 | OUTPUT_DIR = "./output"
67 | TUNE_DIR = "./tune_output"
68 | TUNE_DIR_TEST = "./tune_output_test"
69 |
--------------------------------------------------------------------------------
/utils/log_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pprint
3 | import sys
4 | from collections import defaultdict
5 | import torch
6 | from tabulate import tabulate
7 |
8 | from utils.distributed import get_rank, get_world_size
9 | from utils.setup_logging import setup_logging
10 |
11 |
12 | def get_env_module():
13 | var_name = "ENV_MODULE"
14 | return var_name, os.environ.get(var_name, "")
15 |
16 |
17 | def collect_torch_env() -> str:
18 | try:
19 | import torch.__config__
20 |
21 | return torch.__config__.show()
22 | except ImportError:
23 | # compatible with older versions of pytorch
24 | from torch.utils.collect_env import get_pretty_env_info
25 |
26 | return get_pretty_env_info()
27 |
28 |
29 | def collect_env_info():
30 | data = []
31 | data.append(("Python", sys.version.replace("\n", "")))
32 | data.append(get_env_module())
33 | data.append(("PyTorch", torch.__version__))
34 | data.append(("PyTorch Debug Build", torch.version.debug))
35 |
36 | has_cuda = torch.cuda.is_available()
37 | data.append(("CUDA available", has_cuda))
38 | if has_cuda:
39 | data.append(("CUDA ID", os.environ["CUDA_VISIBLE_DEVICES"]))
40 | devices = defaultdict(list)
41 | for k in range(torch.cuda.device_count()):
42 | devices[torch.cuda.get_device_name(k)].append(str(k))
43 | for name, devids in devices.items():
44 | data.append(("GPU " + ",".join(devids), name))
45 |
46 | env_str = tabulate(data) + "\n"
47 | env_str += collect_torch_env()
48 | return env_str
49 |
50 |
51 | def logging_env_setup(params) -> None:
52 | logger = setup_logging(
53 | params.gpu_num, get_world_size(), params.output_dir, name="Prompt_CAM")
54 |
55 | # Log basic information about environment, cmdline arguments, and config
56 | rank = get_rank()
57 | logger.info(
58 | f"Rank of current process: {rank}. World size: {get_world_size()}")
59 | logger.info("Environment info:\n" + collect_env_info())
60 |
61 |
62 | # Show the config
63 | logger.info("Training with config:")
64 | logger.info(pprint.pformat(params))
65 |
66 | def log_model_info(model, logger, verbose=False):
67 | """Logs model info"""
68 | if verbose:
69 | logger.info(f"Classification Model:\n{model}")
70 | model_total_params = sum(p.numel() for p in model.parameters())
71 | model_grad_params = sum(
72 | p.numel() for p in model.parameters() if p.requires_grad)
73 | model_grad_params_no_head = sum(p.numel() for n, p in model.named_parameters() if p.requires_grad and 'head' not in n)
74 | logger.info("Total Parameters: {0}\t Gradient Parameters: {1}\t Gradient Parameters No Head: {2}".format(
75 | model_total_params, model_grad_params, model_grad_params_no_head))
76 | logger.info(f"total tuned percent:{(model_grad_params/model_total_params*100):.2f} %")
77 | logger.info(f"total tuned percent no head:{(model_grad_params_no_head / model_total_params * 100):.2f} %")
78 |
79 | # Print the names of parameters that require gradient updates
80 | fine_tuned_params = [n for n, p in model.named_parameters() if p.requires_grad]
81 |
82 | logger.info("Fine-tuned Parameters:")
83 | for param_name in fine_tuned_params:
84 | logger.info(param_name)
85 |
86 | return model_grad_params_no_head
87 |
88 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import random
4 | import time
5 | import yaml
6 | from dotwiz import DotWiz
7 |
8 | class AverageMeter(object):
9 | """Computes and stores the average and current value"""
10 | def __init__(self, name=None, fmt=':f'):
11 | self.name = name
12 | self.fmt = fmt
13 | self.reset()
14 |
15 | def reset(self):
16 | self.val = 0
17 | self.avg = 0
18 | self.sum = 0
19 | self.count = 0
20 |
21 | def update(self, val, n=1):
22 | self.val = val
23 | self.sum += val * n
24 | self.count += n
25 | self.avg = self.sum / self.count
26 |
27 | def __str__(self):
28 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
29 | return fmtstr.format(**self.__dict__)
30 |
31 |
32 | def method_name(params):
33 | name = ''
34 | if params.train_type == 'prompt_cam':
35 | name += 'pcam_'
36 | name += params.train_type + '_'
37 | name += str(params.vpt_num) + '_'
38 | elif params.vpt_mode:
39 | name += 'vpt_'
40 | name += params.vpt_mode + '_'
41 | name += str(params.vpt_num) + '_'
42 | name += str(params.vpt_layer) + '_'
43 | #####if nothing, linear
44 | if name == '':
45 | name += 'linear' + '_'
46 | name += params.optimizer
47 | return name
48 |
49 |
50 | def set_seed(random_seed=42):
51 | np.random.seed(random_seed)
52 | random.seed(random_seed)
53 | torch.manual_seed(random_seed)
54 | if torch.cuda.is_available():
55 | torch.cuda.manual_seed(random_seed)
56 | torch.cuda.manual_seed_all(random_seed)
57 | torch.backends.cudnn.deterministic = True
58 | torch.backends.cudnn.benchmark = False
59 |
60 |
61 | @torch.no_grad()
62 | def throughput(model,img_size=224,bs=1):
63 | with torch.no_grad():
64 | x = torch.randn(bs, 3, img_size, img_size).cuda()
65 | batch_size=x.shape[0]
66 | # model=create_model('vit_base_patch16_224_in21k', checkpoint_path='./ViT-B_16.npz', drop_path_rate=0.1)
67 | model.eval()
68 | for i in range(50):
69 | model(x)
70 | torch.cuda.synchronize()
71 | print(f"throughput averaged with 30 times")
72 | tic1 = time.time()
73 | for i in range(30):
74 | model(x)
75 | torch.cuda.synchronize()
76 | tic2 = time.time()
77 | print(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")
78 | MB = 1024.0 * 1024.0
79 | print('memory:', torch.cuda.max_memory_allocated() / MB)
80 |
81 | def load_yaml(path):
82 | with open(path, 'r') as stream:
83 | try:
84 | return DotWiz(yaml.load(stream, Loader=yaml.FullLoader))
85 | except yaml.YAMLError as exc:
86 | print(exc)
87 |
88 | def override_args_with_yaml(args, yaml_config):
89 | """Override argparse args with values from YAML if they exist."""
90 | for key, value in yaml_config.items():
91 | if hasattr(args, key):
92 | setattr(args, key, value)
93 | return args
94 |
95 | def load_vis_args_with_yaml(args,yaml_config_path,checkpoint_path):
96 | """Create args with yaml for notebook"""
97 | yaml_config = load_yaml(yaml_config_path)
98 | for key, value in yaml_config.items():
99 | setattr(args, key, value)
100 |
101 | set_seed(args.random_seed)
102 | args.checkpoint = checkpoint_path
103 | args.test_batch_size= 1
104 | return args
105 |
106 | class EarlyStop:
107 | def __init__(self, patience=1, min_delta=0):
108 | self.patience = patience
109 | self.min_delta = min_delta
110 | self.counter = 0
111 | self.max_metrics = None
112 |
113 | def early_stop(self, eval_metrics):
114 | '''
115 |
116 | :param val_acc:
117 | :return: bool(if early stop), bool(if save model)
118 | '''
119 | if self.max_metrics is None:
120 | self.max_metrics = eval_metrics
121 | if eval_metrics['top1'] > self.max_metrics['top1']:
122 | self.max_metrics = eval_metrics
123 | self.counter = 0
124 | return False, True
125 | elif eval_metrics['top1'] < (self.max_metrics['top1'] - self.min_delta):
126 | self.counter += 1
127 | if self.counter >= self.patience:
128 | return True, False
129 | return False, False
130 |
--------------------------------------------------------------------------------
/utils/setup_logging.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | """Logging."""
4 |
5 | import builtins
6 | import functools
7 | import logging
8 | import sys
9 | import os
10 | from termcolor import colored
11 |
12 | from .distributed import is_master_process
13 | from .file_io import PathManager
14 |
15 | # Show filename and line number in logs
16 | _FORMAT = "[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s"
17 |
18 |
19 | def _suppress_print():
20 | """Suppresses printing from the current process."""
21 |
22 | def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False):
23 | pass
24 |
25 | builtins.print = print_pass
26 |
27 |
28 | # cache the opened file object, so that different calls to `setup_logger`
29 | # with the same file name can safely write to the same file.
30 | @functools.lru_cache(maxsize=None)
31 | def _cached_log_stream(filename):
32 | return PathManager.open(filename, "a")
33 |
34 |
35 | @functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers # noqa
36 | def setup_logging(
37 | num_gpu, num_shards, output="", name="Prompt_CAM", color=True):
38 | """Sets up the logging."""
39 | # Enable logging only for the master process
40 | if is_master_process(num_gpu):
41 | # Clear the root logger to prevent any existing logging config
42 | # (e.g. set by another module) from messing with our setup
43 | logging.root.handlers = []
44 | # Configure logging
45 | logging.basicConfig(
46 | level=logging.INFO, format=_FORMAT, stream=sys.stdout
47 | )
48 | else:
49 | _suppress_print()
50 |
51 | if name is None:
52 | name = __name__
53 | logger = logging.getLogger(name)
54 | # remove any lingering handler
55 | logger.handlers.clear()
56 |
57 | logger.setLevel(logging.INFO)
58 | logger.propagate = False
59 |
60 | plain_formatter = logging.Formatter(
61 | "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s",
62 | datefmt="%m/%d %H:%M:%S",
63 | )
64 | if color:
65 | formatter = _ColorfulFormatter(
66 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
67 | datefmt="%m/%d %H:%M:%S",
68 | root_name=name,
69 | abbrev_name=str(name),
70 | )
71 | else:
72 | formatter = plain_formatter
73 |
74 | if is_master_process(num_gpu):
75 | ch = logging.StreamHandler(stream=sys.stdout)
76 | ch.setLevel(logging.DEBUG)
77 | ch.setFormatter(formatter)
78 | logger.addHandler(ch)
79 |
80 | if is_master_process(num_gpu * num_shards):
81 | if len(output) > 0:
82 | if output.endswith(".txt") or output.endswith(".log"):
83 | filename = output
84 | else:
85 | filename = os.path.join(output, "logs.txt")
86 |
87 | PathManager.mkdirs(os.path.dirname(filename))
88 |
89 | fh = logging.StreamHandler(_cached_log_stream(filename))
90 | fh.setLevel(logging.DEBUG)
91 | fh.setFormatter(plain_formatter)
92 | logger.addHandler(fh)
93 | return logger
94 |
95 |
96 | def setup_single_logging(name, output=""):
97 | """Sets up the logging."""
98 | # Enable logging only for the master process
99 | # Clear the root logger to prevent any existing logging config
100 | # (e.g. set by another module) from messing with our setup
101 | logging.root.handlers = []
102 | # Configure logging
103 | logging.basicConfig(
104 | level=logging.INFO, format=_FORMAT, stream=sys.stdout
105 | )
106 |
107 | if len(name) == 0:
108 | name = __name__
109 | logger = logging.getLogger(name)
110 | logger.setLevel(logging.INFO)
111 | logger.propagate = False
112 |
113 | plain_formatter = logging.Formatter(
114 | "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s",
115 | datefmt="%m/%d %H:%M:%S",
116 | )
117 | formatter = _ColorfulFormatter(
118 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
119 | datefmt="%m/%d %H:%M:%S",
120 | root_name=name,
121 | abbrev_name=str(name),
122 | )
123 |
124 | ch = logging.StreamHandler(stream=sys.stdout)
125 | ch.setLevel(logging.DEBUG)
126 | ch.setFormatter(formatter)
127 | logger.addHandler(ch)
128 |
129 | if len(output) > 0:
130 | if output.endswith(".txt") or output.endswith(".log"):
131 | filename = output
132 | else:
133 | filename = os.path.join(output, "logs.txt")
134 |
135 | PathManager.mkdirs(os.path.dirname(filename))
136 |
137 | fh = logging.StreamHandler(_cached_log_stream(filename))
138 | fh.setLevel(logging.DEBUG)
139 | fh.setFormatter(plain_formatter)
140 | logger.addHandler(fh)
141 |
142 | return logger
143 |
144 |
145 | def get_logger(name):
146 | """Retrieves the logger."""
147 | return logging.getLogger(name)
148 |
149 |
150 |
151 | class _ColorfulFormatter(logging.Formatter):
152 | # from detectron2
153 | def __init__(self, *args, **kwargs):
154 | self._root_name = kwargs.pop("root_name") + "."
155 | self._abbrev_name = kwargs.pop("abbrev_name", "")
156 | if len(self._abbrev_name):
157 | self._abbrev_name = self._abbrev_name + "."
158 | super(_ColorfulFormatter, self).__init__(*args, **kwargs)
159 |
160 | def formatMessage(self, record: logging.LogRecord) -> str:
161 | record.name = record.name.replace(self._root_name, self._abbrev_name)
162 | log = super(_ColorfulFormatter, self).formatMessage(record)
163 | if record.levelno == logging.WARNING:
164 | prefix = colored("WARNING", "red", attrs=["blink"])
165 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
166 | prefix = colored("ERROR", "red", attrs=["blink", "underline"])
167 | else:
168 | return log
169 | return prefix + " " + log
170 |
--------------------------------------------------------------------------------
/utils/visual_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import warnings
5 |
6 | import numpy as np
7 | import random
8 | import cv2
9 | import shutil
10 |
11 |
12 | from time import sleep
13 | from random import randint
14 |
15 | from PIL import Image
16 |
17 | import torch.nn as nn
18 | import matplotlib.pyplot as plt
19 | import torch.nn.functional as F
20 |
21 | parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
22 | sys.path.append(parent_dir)
23 |
24 | from data.dataset.utils import get_transformation
25 |
26 |
27 | warnings.filterwarnings("ignore")
28 |
29 | def combine_images(path, pred_class,resize_dim=(200,200)):
30 | images = [os.path.join(path, image) for image in os.listdir(path) if image.endswith('.jpg')]
31 | images.sort(key=lambda x: int(os.path.basename(x).split('.')[0]))
32 | imgs = [Image.open(image).resize(resize_dim) for image in images]
33 |
34 | widths, heights = zip(*(img.size for img in imgs))
35 |
36 | total_width = sum(widths)
37 | max_height = max(heights)
38 | merged_image = Image.new('RGB', (total_width, max_height))
39 |
40 | x_offset = 0
41 | for img in imgs:
42 | merged_image.paste(img, (x_offset, 0))
43 | x_offset += img.width
44 | merged_image.save(path + "/" + "concatenated_prediction_"+str(pred_class)+".jpg")
45 |
46 | for image in images:
47 | #print(image)
48 | os.remove(image)
49 |
50 | def SuperImposeHeatmap(attention, input_image):
51 | alpha = 0.5
52 |
53 | attention_resized = cv2.resize(attention, (input_image.shape[1], input_image.shape[0]), interpolation=cv2.INTER_CUBIC)
54 |
55 | # Check if the attention map is already normalized
56 | min_val = attention_resized.min()
57 | max_val = attention_resized.max()
58 | attention_normalized = (attention_resized - min_val) / (max_val - min_val)
59 |
60 |
61 | # Apply Gaussian blur for smoothing
62 | attention_normalized = cv2.GaussianBlur(attention_normalized, (9, 9), 0)
63 |
64 | # Convert to heatmap
65 | heatmap = (attention_normalized * 255).astype(np.uint8)
66 | heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
67 |
68 | # Superimpose the heatmap on the original image
69 | result = (input_image * alpha + heatmap * (1 - alpha)).astype(np.uint8)
70 | return result
71 |
72 | def create_overlay_images(X,patch_size,attentions,output_folder):
73 | w_featmap = X.shape[-1] // patch_size
74 | attentions =attentions[0]
75 |
76 | nh = attentions.shape[0]
77 | if os.path.exists(output_folder):
78 | shutil.rmtree(output_folder)
79 | os.makedirs(output_folder, exist_ok=True)
80 |
81 | image = X[0].detach().cpu().numpy().transpose(1, 2, 0) # Shape (H, W, C)
82 | image = (image - image.min()) / (image.max() - image.min())
83 | image = (image * 255).astype(np.uint8)
84 |
85 | cv2.imwrite(os.path.join(output_folder, "0.jpg"), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
86 |
87 | for head in range(nh):
88 | attention_map = attentions[head].reshape(w_featmap, w_featmap).detach().cpu().numpy()
89 | result_image = SuperImposeHeatmap(attention_map, image)
90 |
91 | # Save the overlayed image
92 | save_path = os.path.join(output_folder, f"{head+1}.jpg")
93 | cv2.imwrite(save_path, result_image)
94 |
95 |
96 | def prune_and_plot_ranked_heads(model,inputs,target, params):
97 | if params.top_traits<1 and params.top_traits>model.num_heads:
98 | raise notImplementedError("top_traits must be greater than 0 and less than the number of heads")
99 |
100 | remaining_head_list = list(range(model.num_heads))
101 | pruned_head_index = None
102 | blur_head_lst = []
103 | blur_head_probs = []
104 |
105 | # Determine the ranking of heads by iteratively finding the head that, when blurred,
106 | # gives the highest probability for the target.
107 | while len(remaining_head_list) > 0:
108 | highest_score=-1e8
109 | remaining_head_scores= []
110 |
111 | for head_idx in remaining_head_list:
112 | output,_ = model(inputs,
113 | blur_head_lst=blur_head_lst+[head_idx],
114 | target_cls=target)
115 |
116 | probabilities = torch.softmax(output.squeeze(-1), dim=-1)
117 |
118 | remaining_head_scores.append(probabilities[0,target].item())
119 |
120 | if remaining_head_scores[-1] > highest_score:
121 | highest_score=remaining_head_scores[-1]
122 | pruned_head_index=head_idx
123 |
124 | if pruned_head_index is not None:
125 | blur_head_lst.append(pruned_head_index)
126 | remaining_head_list.remove(pruned_head_index)
127 | blur_head_probs.append(highest_score)
128 |
129 | #### Convert the image for overlaying the attention maps
130 | image = inputs[0].detach().cpu().numpy().transpose(1, 2, 0) # Shape (H, W, C)
131 | image = (image - image.min()) / (image.max() - image.min())
132 | image = (image * 255).astype(np.uint8)
133 | w_featmap = inputs.shape[-1] // model.patch_size
134 |
135 | ### Go through all the attention maps and overlay them on the image
136 | _,attn_maps = model(inputs)
137 | attn_maps = attn_maps[:, :, target, (params.vpt_num+1):]
138 | overlayed_attn_maps = []
139 |
140 | for head in range(model.num_heads):
141 | attention_map = attn_maps[0,head].reshape(w_featmap, w_featmap).detach().cpu().numpy()
142 | overlayed_attn_maps.append(cv2.cvtColor(SuperImposeHeatmap(attention_map, image),cv2.COLOR_BGR2RGB))
143 |
144 | #### Create captions for all the plot
145 | captions = ['input image']
146 | for head in range(model.num_heads):
147 | captions.append(f'rank {head+1}')
148 |
149 | print(f'Head # (from most important to least important):{blur_head_lst[::-1]}')
150 | #### Create plot with overlayed attention maps
151 | for bl_idx in range(1,len(blur_head_lst)+1):
152 | if bl_idx == params.top_traits+1:
153 | break
154 | current_blr_head_lst = blur_head_lst[:-bl_idx]
155 | remaining_head_list = remaining_head_list+ [blur_head_lst[-bl_idx]]
156 |
157 | current_images = [image]
158 | for r_idx in remaining_head_list:
159 | current_images+=[overlayed_attn_maps[r_idx]]
160 |
161 | # Create a grid of images
162 | grid_size = (1,len(current_images))
163 | fig, axes = plt.subplots(*grid_size, figsize=(10, 10))
164 | for i, ax in enumerate(axes.flatten()):
165 | ax.imshow(current_images[i])
166 | ax.axis('off')
167 | ax.set_title(captions[i])
168 |
169 | plt.tight_layout()
170 | plt.show()
171 |
172 |
173 | def load_image(image_path):
174 | image = Image.open(image_path)
175 | image = image.convert("RGB")
176 | transformation = get_transformation(mode='test')
177 | transformed_image = transformation(image)
178 | return torch.unsqueeze(transformed_image, 0)
--------------------------------------------------------------------------------
/visualize.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from experiment.visualize_run import basic_vis
3 | from utils.setup_logging import get_logger
4 | from utils.misc import set_seed,load_yaml
5 | import time
6 |
7 | logger = get_logger('Prompt_CAM')
8 |
9 |
10 | def main():
11 | args = setup_parser().parse_args()
12 |
13 | if args.config:
14 | yaml_config = load_yaml(args.config)
15 | for key, value in yaml_config.items():
16 | setattr(args, key, value)
17 |
18 | set_seed(args.random_seed)
19 | start = time.time()
20 | basic_vis(args)
21 | end = time.time()
22 | logger.info(f'----------- Total Run time : {(end - start) / 60} mins-----------')
23 |
24 | def setup_parser():
25 | parser = argparse.ArgumentParser(description='Prompt_CAM')
26 | ######################## YAML Config #########################
27 | parser.add_argument('--config', type=str, default=None, help='Path to YAML config file')
28 |
29 | ####################### Model #########################
30 | parser.add_argument('--checkpoint', default=None, type=str, help='Path to the model checkpoint')
31 |
32 | ####################### Visualization Configuration #########################
33 | parser.add_argument('--vis_attn', default=True, type=bool, help='visualize the attention map')
34 | parser.add_argument('--vis_cls', default=23, type=int, help='Class in the current Dataset to visualize')
35 | parser.add_argument('--nmbr_samples', default=10, type=int, help='Number of samples to visualize')
36 | parser.add_argument('--top_traits', default=4, type=int, help='Number of top traits per sample to visualize')
37 |
38 | parser.add_argument('--vis_outdir', default="./visualization", type=str, help='Output directory for visualization')
39 | ########################Misc#########################
40 | parser.add_argument('--gpu_num', default=1,
41 | type=int,
42 | help='Number of GPU (default: %(default)s)')
43 | parser.add_argument('--random_seed', default=42,
44 | type=int,
45 | help='Random seed (default: %(default)s)')
46 |
47 | return parser
48 |
49 |
50 | if __name__ == '__main__':
51 | main()
52 |
--------------------------------------------------------------------------------