├── LICENSE
├── README.md
├── assets
├── intro.png
├── ratio.png
├── results.png
└── visualize.png
├── configs
├── cub.yml
├── cub_stage2.yml
├── imagenet.yml
└── imagenet_stage2.yml
├── datasets
├── __init__.py
├── base.py
└── evaluation
│ ├── __init__.py
│ └── cam.py
├── main.py
└── models
├── __init__.py
└── attn.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Generative Prompt Model for Weakly Supervised Object Localization
4 |
5 |
6 |
7 | This is the official implementaion of paper [***Generative Prompt Model for Weakly Supervised Object Localization***](https://openaccess.thecvf.com/content/ICCV2023/html/Zhao_Generative_Prompt_Model_for_Weakly_Supervised_Object_Localization_ICCV_2023_paper.html), which is accepted in ***ICCV 2023***. This repository contains Pytorch training code, evaluation code, pre-trained models, and visualization method.
8 |
9 |
10 |
11 | [](https://arxiv.org/abs/2307.09756)
12 | 
13 | 
14 | [](LICENSE)
15 |
16 | [](https://paperswithcode.com/sota/weakly-supervised-object-localization-on-2?p=generative-prompt-model-for-weakly-supervised)
17 | [](https://paperswithcode.com/sota/weakly-supervised-object-localization-on-cub?p=generative-prompt-model-for-weakly-supervised)
18 |
19 |
20 |
21 |
22 |

23 |
24 |
25 | ## 1. Contents
26 | - Generative Prompt Model for Weakly Supervised Object Localization
27 | - [1. Contents](#1-contents)
28 | - [2. Introduction](#2-introduction)
29 | - [3. Results](#3-results)
30 | - [4. Get Start](#4-get-start)
31 | - [4.1 Installation](#41-installation)
32 | - [4.2 Dataset and Files Preparation](#42-dataset-and-files-preparation)
33 | - [4.3 Training](#43-training)
34 | - [4.4 Inference](#44-inference)
35 | - [4.5 Extra Options](#45-extra-options)
36 | - [5. Contacts](#5-contacts)
37 | - [6. Acknowledgment](#6-acknowledgment)
38 | - [7. Citation](#7-citation)
39 |
40 | ## 2. Introduction
41 |
42 | Weakly supervised object localization (WSOL) remains challenging when learning object localization models from image category labels. Conventional methods that discriminatively train activation models ignore representative yet less discriminative object parts. In this study, we propose a generative prompt model (GenPromp), defining the first generative pipeline to localize less discriminative object parts by formulating WSOL as a conditional image denoising procedure. During training, GenPromp converts image category labels to learnable prompt embeddings which are fed to a generative model to conditionally recover the input image with noise and learn representative embeddings. During inference, GenPromp combines the representative embeddings with discriminative embeddings (queried from an off-the-shelf vision-language model) for both representative and discriminative capacity. The combined embeddings are finally used to generate multi-scale high-quality attention maps, which facilitate localizing full object extent. Experiments on CUB-200-2011 and ILSVRC show that GenPromp respectively outperforms the best discriminative models, setting a solid baseline for WSOL with the generative model.
43 |
44 |
45 | ## 3. Results
46 |
47 |
48 |
49 |

50 |
51 |
52 | We re-train GenPromp with a better learning schedule on 4 x A100. The performance of GenPromp on CUB-200-2011 is further improved.
53 |
54 | | Method | Dataset | Cls Back. | Top-1 Loc | Top-5 Loc | GT-known Loc |
55 | | ------ | ------- | --------- | --------- | --------- | ------------ |
56 | | GenPromp | CUB-200-2011 | EfficientNet-B7 | 87.0 | 96.1 | 98.0 |
57 | | GenPromp (Re-train) | CUB-200-2011 | EfficientNet-B7 | 87.2 (+0.2) | 96.3 (+0.2) | 98.3 (+0.3) |
58 | | GenPromp | ImageNet | EfficientNet-B7 | 65.2 | 73.4 | 75.0 |
59 |
60 | ## 4. Get Start
61 |
62 | ### 4.1 Installation
63 |
64 | To setup the environment of GenPromp, we use `conda` to manage our dependencies. Our developers use `CUDA 11.3` to do experiments. Run the following commands to install GenPromp:
65 | ```
66 | conda create -n gpm python=3.8 -y && conda activate gpm
67 | pip install --upgrade pip
68 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
69 | pip install --upgrade diffusers[torch]==0.13.1
70 | pip install transformers==4.29.2 accelerate==0.19.0
71 | pip install matplotlib opencv-python OmegaConf tqdm
72 | ```
73 |
74 | ### 4.2 Dataset and Files Preparation
75 | To train GenPromp with pre-training weights and infer GenPromp with the given weights, download the files in the table and arrange the files according to the file tree below. (Uploading)
76 |
77 |
78 | | Dataset & Files | Download | Usage |
79 | | -------------------------------------- | ---------------------------------------------------------------------- | --------------------------------------------------------------------- |
80 | | data/ImageNet_ILSVRC2012 (146GB) | [Official Link](http://image-net.org/) | Benchmark dataset |
81 | | data/CUB_200_2011 (1.2GB) | [Official Link](http://www.vision.caltech.edu/datasets/cub_200_2011/) | Benchmark dataset |
82 | | ckpts/pretrains (5.2GB) | [Official Link](https://huggingface.co/CompVis/stable-diffusion-v1-4), [Google Drive](https://drive.google.com/drive/folders/1z4oqLBhIQsADqQOfnwOmXO-a4M2V4Ca8?usp=sharing), [Baidu Drive](https://pan.baidu.com/s/1DC8QUjocIJcmASsA-A1TYA)(o9ei) | Stable Diffusion pretrain weights |
83 | | ckpts/classifications (2.3GB) | [Google Drive](https://drive.google.com/drive/folders/1z4oqLBhIQsADqQOfnwOmXO-a4M2V4Ca8?usp=sharing), [Baidu Drive](https://pan.baidu.com/s/1DC8QUjocIJcmASsA-A1TYA)(o9ei) | Classfication results on benchmark datasets |
84 | | ckpts/imagenet750 (3.3.GB) | [Google Drive](https://drive.google.com/drive/folders/1z4oqLBhIQsADqQOfnwOmXO-a4M2V4Ca8?usp=sharing), [Baidu Drive](https://pan.baidu.com/s/1DC8QUjocIJcmASsA-A1TYA)(o9ei) | Weights that achieves 75.0% GT-Known Loc on ImageNet |
85 | | ckpts/cub983 (3.3GB) | [Google Drive](https://drive.google.com/drive/folders/1z4oqLBhIQsADqQOfnwOmXO-a4M2V4Ca8?usp=sharing), [Baidu Drive](https://pan.baidu.com/s/1DC8QUjocIJcmASsA-A1TYA)(o9ei) | Weights that achieves 98.3% GT-Known Loc on CUB |
86 |
87 | ```text
88 | |--GenPromp/
89 | |--data/
90 | |--ImageNet_ILSVRC2012/
91 | |--ILSVRC2012_list/
92 | |--train/
93 | |--val/
94 | |--CUB_200_2011
95 | |--attributes/
96 | |--images/
97 | ...
98 | |--ckpts/
99 | |--pretrains/
100 | |--stable-diffusion-v1-4/
101 | |--classifications/
102 | |--cub_efficientnetb7.json
103 | |--imagenet_efficientnet-b7_3rdparty_8xb32-aa-advprop_in1k.json
104 | |--imagenet750/
105 | |--tokens/
106 | |--49408.bin
107 | |--49409.bin
108 | ...
109 | |--unet/
110 | |--cub983/
111 | |--tokens/
112 | |--49408.bin
113 | |--49409.bin
114 | ...
115 | |--unet/
116 | |--configs/
117 | |--datasets
118 | |--models
119 | |--main.py
120 | ```
121 |
122 |
123 | ### 4.3 Training
124 |
125 | Here is a training example of GenPromp on ImageNet.
126 | ```
127 | accelerate config
128 | accelerate launch python main.py --function train_token --config configs/imagenet.yml --opt "{'train': {'save_path': 'ckpts/imagenet/'}}"
129 | accelerate launch python main.py --function train_unet --config configs/imagenet_stage2.yml --opt "{'train': {'load_token_path': 'ckpts/imagenet/tokens/', 'save_path': 'ckpts/imagenet/'}}"
130 | ```
131 | `accelerate` is used for multi-GPU training. In the first training stage, the weights of concept tokens of the representative embeddings are learned and saved to `ckpts/imagenet/`. In the second training stage, the weights of the learned concept tokens are loaded from `ckpts/imagenet/tokens/`, then the weights of the UNet are finetuned and saved to `ckpts/imagenet/`. Other configurations can be seen in the config files (i.e. `configs/imagenet.yml` and `configs/imagenet_stage2.yml`) and can be modified by `--opt` with a parameter dict (See [Extra Options](#45-extra-options) for details).
132 |
133 | Here is a training example of GenPromp on CUB_200_2011.
134 | ```
135 | accelerate config
136 | accelerate launch python main.py --function train_token --config configs/cub.yml --opt "{'train': {'save_path': 'ckpts/cub/'}}"
137 | accelerate launch python main.py --function train_unet --config configs/cub_stage2.yml --opt "{'train': {'load_token_path': 'ckpts/cub/tokens/', 'save_path': 'ckpts/cub/'}}"
138 | ```
139 |
140 | ### 4.4 Inference
141 | Here is a inference example of GenPromp on ImageNet.
142 |
143 | ```
144 | python main.py --function test --config configs/imagenet_stage2.yml --opt "{'test': {'load_token_path': 'ckpts/imagenet750/tokens/', 'load_unet_path': 'ckpts/imagenet750/unet/', 'save_log_path': 'ckpts/imagnet750/log.txt'}}"
145 | ```
146 | In the inference stage, the weights of the learned concept tokens are load from `ckpts/imagenet750/tokens/` , the weights of the finetuned UNet are load from `ckpts/imagenet750/unet/` and the log file is saved to `ckpts/imagnet750/log.txt`. Due the random noise added to the tested image, the results might fluctuate within a small range ($\pm$ 0.1).
147 |
148 | Here is a inference example of GenPromp on CUB_200_2011.
149 | ```
150 | python main.py --function test --config configs/cub_stage2.yml --opt "{'test': {'load_token_path': 'ckpts/cub983/tokens/', 'load_unet_path': 'ckpts/cub983/unet/', 'save_log_path': 'ckpts/cub983/log.txt'}}"
151 | ```
152 |
153 | ### 4.5 Extra Options
154 |
155 | There are many extra options during training and inference. The default option is configured in the `yml` file. We can use `--opt` to add or override the default option with a parameter dict. Here are some usage of the most commonly used options.
156 |
157 | | Option | Scope | Usage |
158 | | -------| ----- | ----- |
159 | | {'data': {'keep_class': [0, 9]}} | data | keep the data with category id in `[0, 1, 2, 3, ..., 9]` |
160 | | {'train': {'batch_size': 2}} | train | train with batch size `2`. |
161 | | {'train': {'num_train_epochs': 1}} | train | train the model for `1` epoch. |
162 | | {'train': {'save_steps': 200}} | train_unet | save trained UNet every `200` steps. |
163 | | {'train': {'max_train_steps': 600}} | train_unet | terminate training within `600` steps. |
164 | | {'train': {'gradient_accumulation_steps': 2}} | train | batch size `x2` when the memory of GPU is limited. |
165 | | {'train': {'learning_rate': 5.0e-08}} | train | the learning rate is `5.0e-8`. |
166 | | {'train': {'scale_lr': True}} | train | the learning rate is multiplied with batch size if `True`. |
167 | | {'train': {'load_pretrain_path': 'stable-diffusion/'}} | train | the pretrained model is load from `stable-diffusion/`. |
168 | | {'train': {'load_token_path': 'ckpt/tokens/'}} | train | the trained concept tokens are load from `ckpt/tokens/`. |
169 | | {'train': {'save_path': 'ckpt/'}} | train | save the trained weights to `ckpt/`. |
170 | | {'test': {'batch_size': 2}} | test | test with batch size `2`. |
171 | | {'test': {'cam_thr': 0.25}} | test | test with cam threshold `0.25`. |
172 | | {'test': {'combine_ratio': 0.6}} | test | combine ratio between $f_r$ and $f_d$ is `0.6`. |
173 | | {'test': {'load_class_path': 'imagenet_efficientnet.json'}} | test | load classification results from `imagenet_efficientnet.json`. |
174 | | {'test': {'load_pretrain_path': 'stable-diffusion/'}} | test | the pretrained model is load from `stable-diffusion/`. |
175 | | {'test': {'load_token_path': 'ckpt/tokens/'}} | test | the trained concept tokens are load from `ckpt/tokens/`. |
176 | | {'test': {'load_unet_path': 'ckpt/unet/'}} | test | the trained UNet is load from `ckpt/unet/`. |
177 | | {'test': {'save_vis_path': 'ckpt/vis/'}} | test | the visualized predictions are saved to `ckpt/vis/`. |
178 | | {'test': {'save_log_path': 'ckpt/log.txt'}} | test | the log file is saved to `ckpt/log.txt`. |
179 | | {'test': {'eval_mode': 'top1'}} | test | `top1` denotes evaluating the predicted top1 cls category of the test image, `top5` denotes evaluating the predicted top5 cls category of the test image, `gtk` denotes evaluating the gt category of the test image, which can be tested without the classification result. We use `top1` as the default eval mode. |
180 |
181 | These options can be combined by simplely merging the dicts. For example, if you want to evaluate GenPromp with config file `configs/imagenet_stage2.yml`, with categories `[0, 1, 2, ..., 9]`, concept tokens load from `ckpts/imagenet750/tokens/`, UNet load from `ckpts/imagenet750/unet/`, log file of the evaluated metrics saved to `ckpts/imagnet750/log0-9.txt`, combine ratio equals to `0`, visualization results saved to `ckpts/imagenet750/vis`, using the following command:
182 |
183 | ```
184 | python main.py --function test --config configs/imagenet_stage2.yml --opt "{'data': {'keep_class': [0, 9]}, 'test': {'load_token_path': 'ckpts/imagenet750/tokens/', 'load_unet_path': 'ckpts/imagenet750/unet/', 'save_log_path':'ckpts/imagnet750/log.txt', 'combine_ratio': 0, 'save_vis_path': 'ckpts/imagenet750/vis'}}"
185 | ```
186 |
187 |
188 |
189 |

190 |
191 |
192 | ## 5. Contacts
193 | If you have any question about our work or this repository, please don't hesitate to contact us by emails or open an issue under this project.
194 | - [zhaoyuzhong20@mails.ucas.ac.cn](zhaoyuzhong20@mails.ucas.ac.cn)
195 | - [wanfang@ucas.ac.cn](wanfang@ucas.ac.cn)
196 |
197 | ## 6. Acknowledgment
198 |
199 | - Part of the code is borrowed from [TS-CAM](https://github.com/vasgaowei/TS-CAM), [diffusers](https://github.com/huggingface/diffusers), and [prompt-to-prompt](https://github.com/google/prompt-to-prompt/), we sincerely thank them for their contributions to the community. Thank [MzeroMiko](https://github.com/MzeroMiko) for better code implementation.
200 |
201 |
202 | ## 7. Citation
203 |
204 | ```text
205 | @article{zhao2023generative,
206 | title={Generative Prompt Model for Weakly Supervised Object Localization},
207 | author={Zhao, Yuzhong and Ye, Qixiang and Wu, Weijia and Shen, Chunhua and Wan, Fang},
208 | journal={arXiv preprint arXiv:2307.09756},
209 | year={2023}
210 | }
211 | ```
212 |
213 | ```text
214 | @InProceedings{Zhao_2023_ICCV,
215 | author = {Zhao, Yuzhong and Ye, Qixiang and Wu, Weijia and Shen, Chunhua and Wan, Fang},
216 | title = {Generative Prompt Model for Weakly Supervised Object Localization},
217 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
218 | month = {October},
219 | year = {2023},
220 | pages = {6351-6361}
221 | }
222 | ```
223 |
--------------------------------------------------------------------------------
/assets/intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/callsys/GenPromp/c480e0a106d7158da4a4e86cc27441f181996b04/assets/intro.png
--------------------------------------------------------------------------------
/assets/ratio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/callsys/GenPromp/c480e0a106d7158da4a4e86cc27441f181996b04/assets/ratio.png
--------------------------------------------------------------------------------
/assets/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/callsys/GenPromp/c480e0a106d7158da4a4e86cc27441f181996b04/assets/results.png
--------------------------------------------------------------------------------
/assets/visualize.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/callsys/GenPromp/c480e0a106d7158da4a4e86cc27441f181996b04/assets/visualize.png
--------------------------------------------------------------------------------
/configs/cub.yml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset: "cub"
3 | root: "data/CUB_200_2011/"
4 | keep_class: None
5 | num_workers: 4
6 | resize_size: 512
7 | crop_size: 512
8 |
9 | train:
10 | batch_size: 4 #8
11 | save_steps: None
12 | num_train_epochs: 2
13 | max_train_steps: None
14 | gradient_accumulation_steps: 1
15 | learning_rate: 5.0e-05
16 | scale_lr: True
17 | scale_learning_rate: None
18 | lr_scheduler: "constant"
19 | lr_warmup_steps: 0
20 | adam_beta1: 0.9
21 | adam_beta2: 0.999
22 | adam_weight_decay: 1e-2
23 | adam_epsilon: 1e-08
24 | load_pretrain_path: "ckpts/pretrains/stable-diffusion-v1-4"
25 | load_token_path: None
26 | save_path: "ckpts/cub"
27 |
28 | test:
29 | batch_size: 2
30 | eval_mode: "top1" #["gtk", "top1", "top5"]
31 | cam_thr: 0.23
32 | combine_ratio: 0.6
33 | load_class_path: "ckpts/classification/cub_efficientnetb7.json"
34 | load_pretrain_path: "ckpts/pretrains/stable-diffusion-v1-4"
35 | load_token_path: "ckpts/cub/tokens/"
36 | load_unet_path: None
37 | save_vis_path: None
38 | save_log_path: "ckpts/cub/log.txt"
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/configs/cub_stage2.yml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset: "cub"
3 | root: "data/CUB_200_2011/"
4 | keep_class: None
5 | num_workers: 4
6 | resize_size: 512
7 | crop_size: 512
8 |
9 | train:
10 | batch_size: 4 #8
11 | save_steps: 40
12 | num_train_epochs: 100
13 | max_train_steps: 250
14 | gradient_accumulation_steps: 16
15 | learning_rate: 5.0e-06
16 | scale_lr: False
17 | scale_learning_rate: None
18 | lr_scheduler: "constant"
19 | lr_warmup_steps: 0
20 | adam_beta1: 0.9
21 | adam_beta2: 0.999
22 | adam_weight_decay: 1e-2
23 | adam_epsilon: 1e-08
24 | load_pretrain_path: "ckpts/pretrains/stable-diffusion-v1-4"
25 | load_token_path: "ckpts/cub/tokens"
26 | save_path: "ckpts/cub"
27 |
28 | test:
29 | batch_size: 2
30 | eval_mode: "top1" #["gtk", "top1", "top5"]
31 | cam_thr: 0.23
32 | combine_ratio: 0.6
33 | load_class_path: "ckpts/classification/cub_efficientnetb7.json"
34 | load_pretrain_path: "ckpts/pretrains/stable-diffusion-v1-4"
35 | load_token_path: "ckpts/cub/tokens/"
36 | load_unet_path: "ckpts/cub/unet/"
37 | save_vis_path: None
38 | save_log_path: "ckpts/cub/log.txt"
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/configs/imagenet.yml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset: "imagenet"
3 | root: "data/ImageNet_ILSVRC2012/"
4 | keep_class: None
5 | num_workers: 4
6 | resize_size: 512
7 | crop_size: 512
8 |
9 | train:
10 | batch_size: 2 #8
11 | save_steps: None
12 | num_train_epochs: 2
13 | max_train_steps: None
14 | gradient_accumulation_steps: 2
15 | learning_rate: 5.0e-05
16 | scale_lr: True
17 | scale_learning_rate: None
18 | lr_scheduler: "constant"
19 | lr_warmup_steps: 0
20 | adam_beta1: 0.9
21 | adam_beta2: 0.999
22 | adam_weight_decay: 1e-2
23 | adam_epsilon: 1e-08
24 | load_pretrain_path: "ckpts/pretrains/stable-diffusion-v1-4"
25 | load_token_path: None
26 | save_path: "ckpts/imagenet"
27 |
28 | test:
29 | batch_size: 2
30 | eval_mode: "top1" #["gtk", "top1", "top5"]
31 | cam_thr: 0.25
32 | combine_ratio: 0.6
33 | load_class_path: "ckpts/classification/imagenet_efficientnet-b7_3rdparty_8xb32-aa-advprop_in1k.json"
34 | load_pretrain_path: "ckpts/pretrains/stable-diffusion-v1-4"
35 | load_token_path: "ckpts/imagenet/tokens/"
36 | load_unet_path: None
37 | save_vis_path: None
38 | save_log_path: "ckpts/imagenet/log.txt"
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/configs/imagenet_stage2.yml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset: "imagenet"
3 | root: "data/ImageNet_ILSVRC2012/"
4 | keep_class: None
5 | num_workers: 4
6 | resize_size: 512
7 | crop_size: 512
8 |
9 | train:
10 | batch_size: 2 #8
11 | save_steps: 100
12 | num_train_epochs: 1
13 | max_train_steps: 600
14 | gradient_accumulation_steps: 32
15 | learning_rate: 5.0e-08
16 | scale_lr: True
17 | scale_learning_rate: None
18 | lr_scheduler: "constant"
19 | lr_warmup_steps: 0
20 | adam_beta1: 0.9
21 | adam_beta2: 0.999
22 | adam_weight_decay: 1e-2
23 | adam_epsilon: 1e-08
24 | load_pretrain_path: "ckpts/pretrains/stable-diffusion-v1-4"
25 | load_token_path: "ckpts/imagenet/tokens"
26 | save_path: "ckpts/imagenet"
27 |
28 | test:
29 | batch_size: 2
30 | eval_mode: "top1" #["gtk", "top1", "top5"]
31 | cam_thr: 0.25
32 | combine_ratio: 0.6
33 | load_class_path: "ckpts/classification/imagenet_efficientnet-b7_3rdparty_8xb32-aa-advprop_in1k.json"
34 | load_pretrain_path: "ckpts/pretrains/stable-diffusion-v1-4"
35 | load_token_path: "ckpts/imagenet/tokens/"
36 | load_unet_path: "ckpts/imagenet/unet/"
37 | save_vis_path: None
38 | save_log_path: "ckpts/imagenet/log.txt"
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/callsys/GenPromp/c480e0a106d7158da4a4e86cc27441f181996b04/datasets/__init__.py
--------------------------------------------------------------------------------
/datasets/base.py:
--------------------------------------------------------------------------------
1 | import os
2 | import PIL
3 | import json
4 | import pickle
5 | import random
6 | import numpy as np
7 | from PIL import Image
8 | import torch
9 | import torch.utils.checkpoint
10 | from torch.utils.data import Dataset
11 | from torchvision import transforms
12 | from packaging import version
13 | from transformers import CLIPTokenizer
14 | import matplotlib.pyplot as plt
15 |
16 | if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
17 | PIL_INTERPOLATION = {
18 | "linear": PIL.Image.Resampling.BILINEAR,
19 | "bilinear": PIL.Image.Resampling.BILINEAR,
20 | "bicubic": PIL.Image.Resampling.BICUBIC,
21 | "lanczos": PIL.Image.Resampling.LANCZOS,
22 | "nearest": PIL.Image.Resampling.NEAREST,
23 | }
24 | else:
25 | PIL_INTERPOLATION = {
26 | "linear": PIL.Image.LINEAR,
27 | "bilinear": PIL.Image.BILINEAR,
28 | "bicubic": PIL.Image.BICUBIC,
29 | "lanczos": PIL.Image.LANCZOS,
30 | "nearest": PIL.Image.NEAREST,
31 | }
32 |
33 |
34 | caption_templates = [
35 | "a photo of a {}",
36 | "a rendering of a {}",
37 | "the photo of a {}",
38 | "a photo of my {}",
39 | "a photo of the {}",
40 | "a photo of one {}",
41 | "a rendition of a {}",
42 | ]
43 |
44 | class BaseDataset(Dataset):
45 | def __init__(self,
46 | root=".",
47 | repeats=1,
48 | crop_size=512,
49 | resize_size=512,
50 | test_mode=False,
51 | keep_class=None,
52 | center_crop=False,
53 | text_encoder=None,
54 | load_class_path=None,
55 | load_token_path=None,
56 | load_pretrain_path=None,
57 | interpolation="bicubic",
58 | token_templates="",
59 | caption_templates=caption_templates,
60 | **kwargs,
61 | ):
62 | self.root = root
63 | self.data_repeat = repeats if not test_mode else 1
64 | self.test_mode = test_mode
65 | self.keep_class = keep_class
66 | self.load_class_path = load_class_path
67 | self.load_token_path = load_token_path
68 | self.load_pretrain_path = load_pretrain_path
69 | self.token_templates = token_templates
70 | self.caption_templates = caption_templates
71 |
72 | self.train_pipelines = self.init_train_pipelines(
73 | center_crop=center_crop,
74 | resize_size=resize_size,
75 | crop_size=crop_size,
76 | interpolation=interpolation,
77 | )
78 | self.test_pipelines = self.init_test_pipelines(
79 | crop_size=crop_size,
80 | interpolation=interpolation,
81 | )
82 |
83 | print(f"INFO: {self.__class__.__name__}:\t load data.", flush=True)
84 | self.load_data()
85 | print(f"INFO: {self.__class__.__name__}:\t init samples.", flush=True)
86 | self.init_samples()
87 | print(f"INFO: {self.__class__.__name__}:\t init text encoders.", flush=True)
88 | self.init_text_encoder(text_encoder)
89 |
90 | def load_data(self):
91 | class_file = os.path.join(self.root, 'ILSVRC2012_list', 'LOC_synset_mapping.txt')
92 | self.categories = []
93 | with open(class_file, 'r') as f:
94 | discriptions = f.readlines() # "n01882714 koala..."
95 | for id, line in enumerate(discriptions):
96 | tag, description = line.strip().split(' ', maxsplit=1)
97 | self.categories.append(description)
98 | self.num_classes = len(self.categories)
99 |
100 | self.names = []
101 | self.labels = []
102 | self.bboxes = []
103 | self.pred_logits = []
104 | self.image_paths = []
105 |
106 | data_file = os.path.join(self.root, 'ILSVRC2012_list', 'train.txt')
107 | image_dir = os.path.join(self.root, 'train')
108 | if self.test_mode:
109 | data_file = os.path.join(self.root, 'ILSVRC2012_list', 'val_folder_new.txt')
110 | image_dir = os.path.join(self.root, 'val')
111 |
112 | with open(data_file) as f:
113 | datamappings = f.readlines() # "n01440764/n01440764_10026.JEPG 0"
114 | for id, line in enumerate(datamappings):
115 | info = line.strip().split()
116 | self.names.append(info[0][:-5]) # "n01440764/n01440764_10026"
117 | self.labels.append(int(info[1])) # "0"
118 | if self.test_mode:
119 | self.bboxes.append(np.array(list(map(float, info[2:]))).reshape(-1, 4))
120 | if self.keep_class is not None:
121 | self.filter_classes()
122 | self.pred_logits = None
123 | if self.test_mode:
124 | with open(self.load_class_path, 'r') as f:
125 | name2result = json.load(f)
126 | self.pred_logits = [torch.Tensor(name2result[name]['pred_scores']) for name in self.names]
127 | self.image_paths = [os.path.join(image_dir, name + '.JPEG') for name in self.names]
128 | self.num_images = len(self.labels)
129 |
130 | def init_samples(self):
131 | # format tokens by category
132 | def select_meta_tokens(cats=[], tokenizer=None):
133 | for c in cats:
134 | token_ids = tokenizer.encode(c, add_special_tokens=False)
135 | if len(token_ids) == 1: # has exist token to indicate input token
136 | return c, token_ids[-1], True
137 | token_ids = tokenizer.encode(cats[0], add_special_tokens=False)
138 | token = tokenizer.decode(token_ids[-1]) # only use the final one
139 | return token, token_ids[-1], False
140 |
141 | tokenizer = CLIPTokenizer.from_pretrained(self.load_pretrain_path, subfolder="tokenizer")
142 | concept_tokens = [self.token_templates.format(id) for id in range(self.num_classes)]
143 | tokenizer.add_tokens(concept_tokens)
144 |
145 | categories = [[t.strip() for t in c.strip().split(',')] for c in self.categories]
146 | categories = [[d.split(' ')[-1] for d in c] for c in categories]
147 | meta_tokens = [select_meta_tokens(c, tokenizer)[0] for c in categories]
148 |
149 | caption_ids_meta_token = [[tokenizer(
150 | template.format(token),
151 | padding="max_length",
152 | truncation=True,
153 | max_length=tokenizer.model_max_length,
154 | return_tensors="pt",
155 | ).input_ids[0] for template in self.caption_templates] for token in meta_tokens]
156 | caption_ids_concept_token = [[tokenizer(
157 | template.format(token),
158 | padding="max_length",
159 | truncation=True,
160 | max_length=tokenizer.model_max_length,
161 | return_tensors="pt",
162 | ).input_ids[0] for template in self.caption_templates] for token in concept_tokens]
163 |
164 | cat2tokens = [dict(meta_token=a, concept_token=b, caption_ids_meta_token=c, caption_ids_concept_token=d)
165 | for a, b, c, d in
166 | zip(meta_tokens, concept_tokens, caption_ids_meta_token, caption_ids_concept_token)]
167 |
168 | # format tokens by sample
169 | load_gt = True
170 | sample2tokens = [cat2tokens[id] for id in self.labels]
171 |
172 | samples = []
173 | for idx, lt in enumerate(sample2tokens):
174 | scores = []
175 | ids_concept_token = []
176 | ids_meta_token = []
177 | if self.pred_logits is not None:
178 | tensor = torch.Tensor(self.pred_logits[idx]).topk(5)
179 | tmp = [cat2tokens[id] for id in tensor.indices.numpy().tolist()]
180 | scores.extend([float(s) for s in tensor.values.numpy().tolist()])
181 | ids_concept_token.extend([t['caption_ids_concept_token'] for t in tmp])
182 | ids_meta_token.extend([t['caption_ids_meta_token'] for t in tmp])
183 | if load_gt:
184 | scores.append(0.0)
185 | ids_concept_token.append(lt['caption_ids_concept_token'])
186 | ids_meta_token.append(lt['caption_ids_meta_token'])
187 | samples.append(dict(
188 | scores=scores,
189 | caption_ids_concept_token=ids_concept_token,
190 | caption_ids_meta_token=ids_meta_token,
191 | ))
192 |
193 | self.tokenizer = tokenizer
194 | self.cat2tokens = cat2tokens
195 | self.samples = samples
196 |
197 | def init_text_encoder(self, text_encoder):
198 | text_encoder.resize_token_embeddings(len(self.tokenizer))
199 | if self.test_mode or (self.load_token_path is not None):
200 | text_encoder = self.load_embeddings(text_encoder)
201 | elif not self.test_mode:
202 | text_encoder = self.init_embeddings(text_encoder)
203 | return text_encoder
204 |
205 | def load_embeddings(self, text_encoder):
206 | missing_token = False
207 | token_embeds = text_encoder.get_input_embeddings().weight.data.clone()
208 | for token in self.cat2tokens:
209 | concept_token = token['concept_token']
210 | concept_token_id = self.tokenizer.encode(concept_token, add_special_tokens=False)[-1]
211 | token_bin = os.path.join(self.load_token_path, f"{concept_token_id}.bin")
212 | if not os.path.exists(token_bin):
213 | missing_token = True
214 | continue
215 | token_embeds[concept_token_id] = torch.load(token_bin)[concept_token]
216 | text_encoder.get_input_embeddings().weight = torch.nn.Parameter(token_embeds)
217 | if missing_token:
218 | print(f"WARN: {self.__class__.__name__}:\t missing token.", flush=True)
219 | return text_encoder
220 |
221 | def init_embeddings(self, text_encoder):
222 | token_embeds = text_encoder.get_input_embeddings().weight.data.clone()
223 | for token in self.cat2tokens:
224 | meta_token_id = self.tokenizer.encode(token['meta_token'], add_special_tokens=False)[0]
225 | concept_token_id = self.tokenizer.encode(token['concept_token'], add_special_tokens=False)[0]
226 | token_embeds[concept_token_id] = token_embeds[meta_token_id]
227 | text_encoder.get_input_embeddings().weight = torch.nn.Parameter(token_embeds)
228 | return text_encoder
229 |
230 | def filter_classes(self):
231 | if isinstance(self.keep_class, int):
232 | mask = np.array(self.labels) == self.keep_class
233 | else:
234 | mask = np.zeros_like(self.labels)
235 | for cls in self.keep_class:
236 | mask = mask + (np.array(self.labels) == cls).astype(float)
237 | mask = mask.astype(bool)
238 |
239 | self.labels = [el for i, el in enumerate(self.labels) if mask[i]]
240 | if self.names is not None and len(self.names) != 0:
241 | self.names = [el for i, el in enumerate(self.names) if mask[i]]
242 | if self.bboxes is not None and len(self.bboxes) != 0:
243 | self.bboxes = [el for i, el in enumerate(self.bboxes) if mask[i]]
244 |
245 | def init_train_pipelines(self, **kwargs):
246 | center_crop = kwargs.pop("center_crop", False)
247 | resize_size = kwargs.pop("resize_size", 512)
248 | crop_size = kwargs.pop("crop_size", 512)
249 | brightness = kwargs.pop("brightness", 0.1)
250 | contrast = kwargs.pop("contrast", 0.1)
251 | saturation = kwargs.pop("saturation", 0.1)
252 | hue = kwargs.pop("hue", 0.1)
253 | interpolation = {
254 | "linear": PIL_INTERPOLATION["linear"],
255 | "bilinear": PIL_INTERPOLATION["bilinear"],
256 | "bicubic": PIL_INTERPOLATION["bicubic"],
257 | "lanczos": PIL_INTERPOLATION["lanczos"],
258 | }[kwargs.pop("interpolation", "bicubic")]
259 |
260 | train_transform = transforms.Compose([
261 | transforms.Resize((resize_size, resize_size)),
262 | transforms.RandomCrop((crop_size, crop_size)),
263 | transforms.RandomHorizontalFlip(),
264 | transforms.ColorJitter(brightness, contrast, saturation, hue)
265 | ])
266 |
267 | def train_pipeline(data):
268 | img = data['img']
269 | if center_crop:
270 | crop = min(img.shape[0], img.shape[1])
271 | h, w, = (
272 | img.shape[0],
273 | img.shape[1],
274 | )
275 | img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
276 | img = Image.fromarray(img)
277 | img = train_transform(img)
278 | img = img.resize((crop_size, crop_size), resample=interpolation)
279 | img = np.array(img).astype(np.uint8)
280 | img = (img / 127.5 - 1.0).astype(np.float32)
281 | data["img"] = torch.from_numpy(img).permute(2, 0, 1)
282 | return data
283 |
284 | return train_pipeline
285 |
286 | def init_test_pipelines(self, **kwargs):
287 | crop_size = kwargs.pop("crop_size", 512)
288 | interpolation = {
289 | "linear": PIL_INTERPOLATION["linear"],
290 | "bilinear": PIL_INTERPOLATION["bilinear"],
291 | "bicubic": PIL_INTERPOLATION["bicubic"],
292 | "lanczos": PIL_INTERPOLATION["lanczos"],
293 | }[kwargs.pop("interpolation", "bicubic")]
294 |
295 | def test_pipeline(data):
296 | img = data['img']
297 | bbox = data['gt_bboxes']
298 | image_width, image_height = data['ori_shape']
299 | img = Image.fromarray(img)
300 | img = img.resize((crop_size, crop_size), resample=interpolation)
301 | img = np.array(img).astype(np.uint8)
302 | img = (img / 127.5 - 1.0).astype(np.float32)
303 | data["img"] = torch.from_numpy(img).permute(2, 0, 1)
304 |
305 | [x1, y1, x2, y2] = np.split(bbox, 4, 1)
306 | resize_size = crop_size
307 | shift_size = 0
308 | left_bottom_x = np.maximum(x1 / image_width * resize_size - shift_size, 0).astype(int)
309 | left_bottom_y = np.maximum(y1 / image_height * resize_size - shift_size, 0).astype(int)
310 | right_top_x = np.minimum(x2 / image_width * resize_size - shift_size, crop_size - 1).astype(int)
311 | right_top_y = np.minimum(y2 / image_height * resize_size - shift_size, crop_size - 1).astype(int)
312 |
313 | gt_bbox = np.concatenate((left_bottom_x, left_bottom_y, right_top_x, right_top_y), axis=1).reshape(-1)
314 | gt_bbox = " ".join(list(map(str, gt_bbox)))
315 | data['gt_bboxes'] = gt_bbox
316 |
317 |
318 | return data
319 |
320 | return test_pipeline
321 |
322 | def prepare_data(self, idx):
323 | pipelines = self.train_pipelines if not self.test_mode else self.test_pipelines
324 | bboxes = [] if not self.test_mode else self.bboxes[idx]
325 |
326 | name = self.names[idx]
327 | label = self.labels[idx]
328 | image_path = self.image_paths[idx]
329 | image = Image.open(image_path).convert('RGB')
330 | image_size = list(image.size)
331 | img = np.array(image).astype(np.uint8)
332 | data = pipelines(dict(
333 | img=img, ori_shape=image_size,
334 | gt_labels=label, gt_names=name, gt_bboxes=bboxes,
335 | ))
336 |
337 | data = dict(
338 | img=data["img"],
339 | ori_shape=data["ori_shape"],
340 | gt_labels=data["gt_labels"],
341 | gt_bboxes=data["gt_bboxes"],
342 | name=data["gt_names"],
343 | )
344 |
345 | if self.test_mode:
346 | pred_scores = self.pred_logits
347 | data.update(dict(
348 | pred_logits=pred_scores[idx],
349 | pred_top5_ids=pred_scores[idx].topk(5).indices,
350 | ))
351 |
352 | return data
353 |
354 | def prepare_tokens(self, idx):
355 | sample = self.samples[idx]
356 |
357 | if not self.test_mode:
358 | choice = random.choice(range(len(self.caption_templates)))
359 | sample = dict(
360 | caption_ids_concept_token=sample['caption_ids_concept_token'][-1][choice],
361 | caption_ids_meta_token=sample['caption_ids_meta_token'][-1][choice],
362 | )
363 |
364 | return sample
365 |
366 | def __len__(self):
367 | return self.num_images * self.data_repeat
368 |
369 | def __getitem__(self, idx):
370 | data = self.prepare_data(idx % self.num_images)
371 | data.update(self.prepare_tokens(idx % self.num_images))
372 | return data
373 |
374 | class SubDataset(BaseDataset):
375 | def __init__(self, keep_class, dataset):
376 | self.dataset = dataset
377 |
378 | if isinstance(keep_class, int):
379 | keep_class = [keep_class]
380 | self.keep_class = keep_class
381 |
382 | if isinstance(self.keep_class, int):
383 | mask = np.array(self.dataset.labels) == self.keep_class
384 | else:
385 | mask = np.zeros_like(self.dataset.labels)
386 | for cls in self.keep_class:
387 | mask = mask + (np.array(self.dataset.labels) == cls).astype(float)
388 | mask = mask.astype(bool)
389 |
390 | self.names = self.dataset.names
391 | self.image_paths = self.dataset.image_paths
392 | self.data_repeat = self.dataset.data_repeat
393 | self.labels = self.dataset.labels
394 | self.samples = self.dataset.samples
395 | self.caption_templates = self.dataset.caption_templates
396 | self.test_mode = self.dataset.test_mode
397 | self.num_images = self.dataset.num_images
398 | self.train_pipelines = self.dataset.train_pipelines
399 | self.test_pipelines = self.dataset.test_pipelines
400 |
401 | self.names = [el for i, el in enumerate(self.names) if mask[i]]
402 | self.image_paths = [el for i, el in enumerate(self.image_paths) if mask[i]]
403 | self.labels = [el for i, el in enumerate(self.labels) if mask[i]]
404 | self.samples = [el for i, el in enumerate(self.samples) if mask[i]]
405 | self.num_images = len(self.image_paths)
406 |
407 | class ImagenetDataset(BaseDataset):
408 | def __init__(self,
409 | root=".",
410 | repeats=1,
411 | crop_size=512,
412 | resize_size=512,
413 | test_mode=False,
414 | keep_class=None,
415 | center_crop=False,
416 | text_encoder=None,
417 | load_class_path=None,
418 | load_token_path=None,
419 | load_pretrain_path=None,
420 | interpolation="bicubic",
421 | token_templates="",
422 | caption_templates=caption_templates,
423 | **kwargs,
424 | ):
425 | super().__init__(root=root,
426 | repeats=repeats,
427 | crop_size=crop_size,
428 | resize_size=resize_size,
429 | test_mode=test_mode,
430 | keep_class=keep_class,
431 | center_crop=center_crop,
432 | text_encoder=text_encoder,
433 | load_class_path=load_class_path,
434 | load_token_path=load_token_path,
435 | load_pretrain_path=load_pretrain_path,
436 | interpolation=interpolation,
437 | token_templates=token_templates,
438 | caption_templates=caption_templates,
439 | **kwargs)
440 |
441 | def load_data(self):
442 | """
443 | returns:
444 | self.names: []
445 | self.labels: []
446 | self.bboxes: []; (x1,y1,x2,y2)
447 | self.pred_logits: [] | None
448 | self.image_paths: []
449 | """
450 | class_file = os.path.join(self.root, 'ILSVRC2012_list', 'LOC_synset_mapping.txt')
451 | self.categories = []
452 | with open(class_file, 'r') as f:
453 | discriptions = f.readlines() # "n01882714 koala..."
454 | for id, line in enumerate(discriptions):
455 | tag, description = line.strip().split(' ', maxsplit=1)
456 | self.categories.append(description)
457 | self.num_classes = len(self.categories)
458 |
459 | self.names = []
460 | self.labels = []
461 | self.bboxes = []
462 | self.pred_logits = []
463 | self.image_paths = []
464 |
465 | data_file = os.path.join(self.root, 'ILSVRC2012_list', 'train.txt')
466 | image_dir = os.path.join(self.root, 'train')
467 | if self.test_mode:
468 | data_file = os.path.join(self.root, 'ILSVRC2012_list', 'val_folder_new.txt')
469 | image_dir = os.path.join(self.root, 'val')
470 |
471 | with open(data_file) as f:
472 | datamappings = f.readlines() # "n01440764/n01440764_10026.JEPG 0"
473 | for id, line in enumerate(datamappings):
474 | info = line.strip().split()
475 | self.names.append(info[0][:-5]) # "n01440764/n01440764_10026"
476 | self.labels.append(int(info[1])) # "0"
477 | if self.test_mode:
478 | self.bboxes.append(np.array(list(map(float, info[2:]))).reshape(-1, 4))
479 | if self.keep_class is not None:
480 | self.filter_classes()
481 | self.pred_logits = None
482 | if self.test_mode:
483 | with open(self.load_class_path, 'r') as f:
484 | name2result = json.load(f)
485 | self.pred_logits = [torch.Tensor(name2result[name]['pred_scores']) for name in self.names]
486 | self.image_paths = [os.path.join(image_dir, name + '.JPEG') for name in self.names]
487 | self.num_images = len(self.labels)
488 |
489 | class CUBDataset(BaseDataset):
490 | def __init__(self,
491 | root=".",
492 | repeats=1,
493 | crop_size=512,
494 | resize_size=512,
495 | test_mode=False,
496 | keep_class=None,
497 | center_crop=False,
498 | text_encoder=None,
499 | load_class_path=None,
500 | load_token_path=None,
501 | load_pretrain_path=None,
502 | interpolation="bicubic",
503 | token_templates="",
504 | caption_templates=caption_templates,
505 | **kwargs,
506 | ):
507 | super().__init__(root=root,
508 | repeats=repeats,
509 | crop_size=crop_size,
510 | resize_size=resize_size,
511 | test_mode=test_mode,
512 | keep_class=keep_class,
513 | center_crop=center_crop,
514 | text_encoder=text_encoder,
515 | load_class_path=load_class_path,
516 | load_token_path=load_token_path,
517 | load_pretrain_path=load_pretrain_path,
518 | interpolation=interpolation,
519 | token_templates=token_templates,
520 | caption_templates=caption_templates,
521 | **kwargs)
522 |
523 | def load_data(self):
524 | self.categories = ["bird"] * 200
525 | self.num_classes = len(self.categories)
526 |
527 | images_file = os.path.join(self.root, 'images.txt')
528 | labels_file = os.path.join(self.root, 'image_class_labels.txt')
529 | splits_file = os.path.join(self.root, 'train_test_split.txt')
530 | bboxes_file = os.path.join(self.root, 'bounding_boxes.txt')
531 |
532 | with open(images_file, 'r') as f:
533 | lines = f.readlines()
534 | image_list = [line.strip().split(' ')[1] for line in lines]
535 | with open(labels_file, 'r') as f:
536 | lines = f.readlines()
537 | labels_list = [line.strip().split(' ')[1] for line in lines]
538 | with open(splits_file, 'r') as f:
539 | lines = f.readlines()
540 | splits_list = [line.strip().split(' ')[1] for line in lines]
541 | with open(bboxes_file, 'r') as f:
542 | lines = f.readlines()
543 | bboxes_list = [line.strip().split(' ')[1:] for line in lines]
544 |
545 | train_index = [i for i, v in enumerate(splits_list) if v == '1']
546 | test_index = [i for i, v in enumerate(splits_list) if v == '0']
547 | index_list = train_index if not self.test_mode else test_index
548 | self.names = [image_list[i] for i in index_list]
549 | self.labels = [int(labels_list[i]) - 1 for i in index_list]
550 | self.bboxes = [np.array(list(map(float, bboxes_list[i]))).reshape(-1, 4)
551 | for i in index_list]
552 | for i in range(len(self.bboxes)):
553 | self.bboxes[i][:, 2:4] = self.bboxes[i][:, 0:2] + self.bboxes[i][:, 2:4]
554 |
555 | if self.keep_class is not None:
556 | self.filter_classes()
557 | self.pred_logits = None
558 | if self.test_mode:
559 | with open(self.load_class_path, 'r') as f:
560 | name2result = json.load(f)
561 | self.pred_logits = [torch.Tensor(name2result[name]['pred_scores']) for name in self.names]
562 | self.image_paths = [os.path.join(self.root, 'images', name) for name in self.names]
563 | self.num_images = len(self.labels)
564 |
565 |
566 |
--------------------------------------------------------------------------------
/datasets/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from .cam import evaluate_cls_loc, list2acc
2 |
3 | # Evaluation code from TS-CAM
4 | class Evaluator():
5 | def __init__(self, logfile='log.txt', len_dataloader=1):
6 | self.cls_top1 = []
7 | self.cls_top5 = []
8 | self.loc_top1 = []
9 | self.loc_top5 = []
10 | self.loc_gt_known = []
11 | self.top1_loc_right = []
12 | self.top1_loc_cls = []
13 | self.top1_loc_mins = []
14 | self.top1_loc_part = []
15 | self.top1_loc_more = []
16 | self.top1_loc_wrong = []
17 | self.logfile = logfile
18 | self.len_dataloader = len_dataloader
19 |
20 | def __call__(self, input, target, bbox, logits, pad_cams, image_names, cfg, step):
21 | cls_top1_b, cls_top5_b, loc_top1_b, loc_top5_b, loc_gt_known_b, top1_loc_right_b, \
22 | top1_loc_cls_b, top1_loc_mins_b, top1_loc_part_b, top1_loc_more_b, top1_loc_wrong_b = \
23 | evaluate_cls_loc(input, target, bbox, logits, pad_cams, image_names, cfg)
24 | self.cls_top1.extend(cls_top1_b)
25 | self.cls_top5.extend(cls_top5_b)
26 | self.loc_top1.extend(loc_top1_b)
27 | self.loc_top5.extend(loc_top5_b)
28 | self.top1_loc_right.extend(top1_loc_right_b)
29 | self.top1_loc_cls.extend(top1_loc_cls_b)
30 | self.top1_loc_mins.extend(top1_loc_mins_b)
31 | self.top1_loc_more.extend(top1_loc_more_b)
32 | self.top1_loc_part.extend(top1_loc_part_b)
33 | self.top1_loc_wrong.extend(top1_loc_wrong_b)
34 |
35 | self.loc_gt_known.extend(loc_gt_known_b)
36 |
37 | if step != 0 and (step % 100 == 0 or step == self.len_dataloader - 1):
38 | str1 = 'Val Epoch: [{0}][{1}/{2}]\t'.format(0, step + 1, self.len_dataloader)
39 | str2 = 'Cls@1:{0:.3f}\tCls@5:{1:.3f}\tLoc@1:{2:.3f}\tLoc@5:{3:.3f}\tLoc_gt:{4:.3f}'.format(
40 | list2acc(self.cls_top1), list2acc(self.cls_top5),list2acc(self.loc_top1), list2acc(self.loc_top5), list2acc(self.loc_gt_known))
41 | str3 = 'M-ins:{0:.3f}\tPart:{1:.3f}\tMore:{2:.3f}\tRight:{3:.3f}\tWrong:{4:.3f}\tCls:{5:.3f}'.format(
42 | list2acc(self.top1_loc_mins), list2acc(self.top1_loc_part), list2acc(self.top1_loc_more),
43 | list2acc(self.top1_loc_right), list2acc(self.top1_loc_wrong), list2acc(self.top1_loc_cls))
44 |
45 | if self.logfile is not None:
46 | with open(self.logfile, 'a') as fw:
47 | fw.write('\n'+str1+'\n')
48 | fw.write(str2+'\n')
49 | fw.write(str3+'\n')
50 |
51 | print(str1)
52 | print(str2)
53 | print(str3)
54 |
55 |
--------------------------------------------------------------------------------
/datasets/evaluation/cam.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | import torch
5 | import matplotlib.pyplot as plt
6 |
7 | def resize_cam(cam, size=(224, 224)):
8 | cam = cv2.resize(cam , (size[0], size[1]))
9 | #cam = cam - cam.min()
10 | #cam = cam / cam.max()
11 | cam_min, cam_max = cam.min(), cam.max()
12 | cam = (cam - cam_min) / (cam_max - cam_min)
13 | return cam
14 |
15 | def blend_cam(image, cam, es_box=[0,0,1,1]):
16 | I = np.zeros_like(cam)
17 | x1, y1, x2, y2 = es_box
18 | I[y1:y2, x1:x2] = 1
19 | cam = cam * I
20 | cam = (cam * 255.).astype(np.uint8)
21 | heatmap = cv2.applyColorMap(cam, cv2.COLORMAP_JET)
22 | blend = image * 0.5 + heatmap * 0.5
23 |
24 | blend = blend.astype(np.uint8)
25 |
26 | return blend, heatmap
27 |
28 | def get_bboxes(cam, cam_thr=0.2):
29 | """
30 | cam: single image with shape (h, w, 1)
31 | thr_val: float value (0~1)
32 | return estimated bounding box
33 | """
34 | cam = (cam * 255.).astype(np.uint8)
35 | map_thr = cam_thr * np.max(cam)
36 |
37 | _, thr_gray_heatmap = cv2.threshold(cam,
38 | int(map_thr), 255,
39 | cv2.THRESH_TOZERO)
40 | #thr_gray_heatmap = (thr_gray_heatmap*255.).astype(np.uint8)
41 |
42 | contours, _ = cv2.findContours(thr_gray_heatmap,
43 | cv2.RETR_TREE,
44 | cv2.CHAIN_APPROX_SIMPLE)
45 | if len(contours) != 0:
46 | c = max(contours, key=cv2.contourArea)
47 | x, y, w, h = cv2.boundingRect(c)
48 | estimated_bbox = [x, y, x + w, y + h]
49 | else:
50 | estimated_bbox = [0, 0, 1, 1]
51 |
52 | return estimated_bbox #, thr_gray_heatmap, len(contours)
53 |
54 | def count_max(x):
55 | count_dict = {}
56 | for xlist in x:
57 | for item in xlist:
58 | if item==0:
59 | continue
60 | if item not in count_dict.keys():
61 | count_dict[item] = 0
62 | count_dict[item] += 1
63 | if count_dict == {}:
64 | return -1
65 | count_dict = sorted(count_dict.items(), key=lambda d:d[1], reverse=True)
66 | return count_dict[0][0]
67 |
68 | def tensor2image(input, image_mean, image_std):
69 | image_mean = torch.reshape(torch.tensor(image_mean), (1, 3, 1, 1))
70 | image_std = torch.reshape(torch.tensor(image_std), (1, 3, 1, 1))
71 | image = input * image_std + image_mean
72 | image = image.numpy().transpose(0, 2, 3, 1)
73 | image = image[:, :, :, ::-1]
74 | return image
75 |
76 | def calculate_IOU(boxA, boxB):
77 | xA = max(boxA[0], boxB[0])
78 | yA = max(boxA[1], boxB[1])
79 | xB = min(boxA[2], boxB[2])
80 | yB = min(boxA[3], boxB[3])
81 |
82 | # compute the area of intersection rectangle
83 | interArea = (xB - xA + 1) * (yB - yA + 1)
84 |
85 | # compute the area of both the prediction and ground-truth
86 | # rectangles
87 | boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
88 | boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
89 |
90 | # compute the intersection over union by taking the intersection
91 | # area and dividing it by the sum of prediction + ground-truth
92 | # areas - the interesection area
93 | iou = interArea / float(boxAArea + boxBArea - interArea)
94 |
95 | # return the intersection over union value
96 | return iou
97 |
98 | def draw_bbox(image, iou, gt_box, pred_box, gt_score, is_top1=False):
99 |
100 | def draw_bbox(img, box1, box2, color1=(0, 0, 255), color2=(0, 255, 0)):
101 | # for i in range(len(box1)):
102 | # cv2.rectangle(img, (box1[i,0], box1[i,1]), (box1[i,2], box1[i,3]), color1, 4)
103 | cv2.rectangle(img, (box2[0], box2[1]), (box2[2], box2[3]), color2, 4)
104 | return img
105 |
106 | def mark_target(img, text='target', pos=(25, 25), size=2):
107 | cv2.putText(img, text, pos, cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), size)
108 | return img
109 |
110 | boxed_image = image.copy()
111 |
112 | # draw bbox on image
113 | boxed_image = draw_bbox(boxed_image, gt_box, pred_box)
114 |
115 | # mark the iou
116 | # mark_target(boxed_image, '%.1f' % (iou * 100), (140, 30), 2)
117 | # mark_target(boxed_image, 'IOU%.2f' % (iou), (80, 30), 2)
118 | # # mark the top1
119 | # if is_top1:
120 | # mark_target(boxed_image, 'Top1', (10, 30))
121 | # mark_target(boxed_image, 'GT_Score%.2f' % (gt_score), (10, 200), 2)
122 |
123 | return boxed_image
124 |
125 | def evaluate_cls_loc(input, cls_label, bbox_label, logits, cams, image_names, config):
126 | """
127 | :param input: input tensors of the model
128 | :param cls_label: class label
129 | :param bbox_label: bounding box label
130 | :param logits: classification scores
131 | :param cams: cam of all the classes
132 | :param image_names: names of images
133 | :param cfg: configurations
134 | :param epoch: epoch
135 | :return: evaluate results
136 | """
137 | cls_top1 = []
138 | cls_top5 = []
139 | loc_top1 = []
140 | loc_top5 = []
141 | loc_gt_known = []
142 | top1_loc_right = []
143 | top1_loc_cls = []
144 | top1_loc_mins = []
145 | top1_loc_part = []
146 | top1_loc_more = []
147 | top1_loc_wrong = []
148 |
149 | # label, top1 and top5 results
150 | cls_label = cls_label.tolist()
151 | cls_scores = logits.tolist()
152 | _, top1_idx = logits.topk(1, 1, True, True)
153 | top1_idx = top1_idx.tolist()
154 | _, top5_idx = logits.topk(5, 1, True, True)
155 | top5_idx = top5_idx.tolist()
156 |
157 | batch = cams.shape[0]
158 | image_mean = [127.5, 127.5, 127.5]
159 | image_std = [127.5, 127.5, 127.5]
160 | image = tensor2image(input.clone().detach().cpu(), image_mean, image_std).astype(np.uint8)
161 |
162 | for b in range(batch):
163 | gt_bbox = bbox_label[b].strip().split(' ')
164 | gt_bbox = list(map(float, gt_bbox))
165 | crop_size = config["data"]["test"]["dataset"]["crop_size"]
166 | top_bboxes, top_mask=get_topk_boxes(top5_idx[b], cams[b], crop_size=crop_size, threshold=config["test"]["cam_thr"])
167 | topk_cls, topk_loc, wrong_details=cls_loc_err(top_bboxes, cls_label[b], gt_bbox, topk=(1,5))
168 | cls_top1_b, cls_top5_b = topk_cls
169 | loc_top1_b, loc_top5_b = topk_loc
170 | cls_top1.append(cls_top1_b)
171 | cls_top5.append(cls_top5_b)
172 | loc_top1.append(loc_top1_b)
173 | loc_top5.append(loc_top5_b)
174 | cls_wrong, multi_instances, region_part, region_more, region_wrong = wrong_details
175 | right = 1 - (cls_wrong + multi_instances + region_part + region_more + region_wrong)
176 | top1_loc_right.append(right)
177 | top1_loc_cls.append(cls_wrong)
178 | top1_loc_mins.append(multi_instances)
179 | top1_loc_part.append(region_part)
180 | top1_loc_more.append(region_more)
181 | top1_loc_wrong.append(region_wrong)
182 | # gt_known
183 | # mean top k
184 | cam_b = cams[b, [cls_label[b]], :, :]
185 | cam_b = torch.mean(cam_b, dim=0, keepdim=True)
186 |
187 | cam_b = cam_b.detach().cpu().numpy().transpose(1, 2, 0)
188 |
189 | # Resize and Normalize CAM
190 | cam_b = resize_cam(cam_b, size=(crop_size, crop_size))
191 |
192 | # Estimate BBOX
193 | estimated_bbox = get_bboxes(cam_b, cam_thr=config["test"]["cam_thr"])
194 |
195 | # Calculate IoU
196 | gt_box_cnt = len(gt_bbox) // 4
197 | max_iou = 0
198 | for i in range(gt_box_cnt):
199 | gt_box = gt_bbox[i * 4:(i + 1) * 4]
200 | iou_i = cal_iou(estimated_bbox, gt_box)
201 | if iou_i > max_iou:
202 | max_iou = iou_i
203 |
204 | iou = max_iou
205 | # iou = calculate_IOU(bbox_label[b].numpy(), estimated_bbox)
206 |
207 | # print('cam_b shape', cam_b.shape, 'cam_b max', cam_b.max(), 'cam_b min', cam_b.min(), 'thre', cfg.MODEL.CAM_THR, 'iou ', iou)
208 | #if iou < 0.5:
209 | # pdb.set_trace()
210 | # gt known
211 | if iou >= 0.5:
212 | loc_gt_known.append(1)
213 | else:
214 | loc_gt_known.append(0)
215 |
216 | # Get blended image
217 | box = [0, 0, 512, 512]
218 | blend, heatmap = blend_cam(image[b], cam_b, estimated_bbox)
219 | # Get boxed image
220 | gt_score = cls_scores[b][top1_idx[b][0]] # score of gt class
221 | boxed_image = draw_bbox(blend, iou, np.array(gt_bbox).reshape(-1,4).astype(int), estimated_bbox, gt_score, False)
222 |
223 | # save result
224 | save_vis_path = config["test"]["save_vis_path"]
225 | if save_vis_path is not None:
226 | image_name = image_names[b]
227 | boxed_image_dir = save_vis_path
228 | if not os.path.exists(boxed_image_dir):
229 | os.mkdir(boxed_image_dir)
230 | class_dir = os.path.join(boxed_image_dir, image_name.split('/')[0])
231 | save_path = os.path.join(class_dir, image_name.split('/')[-1] + '.jpg')
232 |
233 | if not os.path.exists(class_dir):
234 | os.mkdir(class_dir)
235 | cv2.imwrite(save_path, boxed_image)
236 |
237 | return cls_top1, cls_top5, loc_top1, loc_top5, loc_gt_known, top1_loc_right, top1_loc_cls, top1_loc_mins, \
238 | top1_loc_part, top1_loc_more, top1_loc_wrong
239 |
240 | def get_topk_boxes(cls_inds, cam_map, crop_size, topk=(1, 5), threshold=0.2, ):
241 | maxk_boxes = []
242 | maxk_maps = []
243 | for cls in cls_inds:
244 | cam_map_ = cam_map[[cls], :, :]
245 | cam_map_ = cam_map_.detach().cpu().numpy().transpose(1, 2, 0)
246 | # Resize and Normalize CAM
247 | cam_map_ = resize_cam(cam_map_, size=(crop_size, crop_size))
248 | maxk_maps.append(cam_map_.copy())
249 |
250 | # Estimate BBOX
251 | estimated_bbox = get_bboxes(cam_map_, cam_thr=threshold)
252 | maxk_boxes.append([cls] + estimated_bbox)
253 |
254 | result = [maxk_boxes[:k] for k in topk]
255 |
256 | return result, maxk_maps
257 |
258 | def cls_loc_err(topk_boxes, gt_label, gt_boxes, topk=(1,), iou_th=0.5):
259 | assert len(topk_boxes) == len(topk)
260 | gt_boxes = gt_boxes
261 | gt_box_cnt = len(gt_boxes) // 4
262 | topk_loc = []
263 | topk_cls = []
264 | for topk_box in topk_boxes:
265 | loc_acc = 0
266 | cls_acc = 0
267 | for cls_box in topk_box:
268 | max_iou = 0
269 | max_gt_id = 0
270 | for i in range(gt_box_cnt):
271 | gt_box = gt_boxes[i*4:(i+1)*4]
272 | iou_i = cal_iou(cls_box[1:], gt_box)
273 | if iou_i> max_iou:
274 | max_iou = iou_i
275 | max_gt_id = i
276 | if len(topk_box) == 1:
277 | wrong_details = get_badcase_detail(cls_box, gt_boxes, gt_label, max_iou, max_gt_id)
278 | if cls_box[0] == gt_label:
279 | cls_acc = 1
280 | if cls_box[0] == gt_label and max_iou > iou_th:
281 | loc_acc = 1
282 | break
283 | topk_loc.append(float(loc_acc))
284 | topk_cls.append(float(cls_acc))
285 | return topk_cls, topk_loc, wrong_details
286 |
287 | def cal_iou(box1, box2, method='iou'):
288 | """
289 | support:
290 | 1. box1 and box2 are the same shape: [N, 4]
291 | 2.
292 | :param box1:
293 | :param box2:
294 | :return:
295 | """
296 | box1 = np.asarray(box1, dtype=float)
297 | box2 = np.asarray(box2, dtype=float)
298 | if box1.ndim == 1:
299 | box1 = box1[np.newaxis, :]
300 | if box2.ndim == 1:
301 | box2 = box2[np.newaxis, :]
302 |
303 | iw = np.minimum(box1[:, 2], box2[:, 2]) - np.maximum(box1[:, 0], box2[:, 0]) + 1
304 | ih = np.minimum(box1[:, 3], box2[:, 3]) - np.maximum(box1[:, 1], box2[:, 1]) + 1
305 |
306 | i_area = np.maximum(iw, 0.0) * np.maximum(ih, 0.0)
307 | box1_area = (box1[:, 2] - box1[:, 0] + 1) * (box1[:, 3] - box1[:, 1] + 1)
308 | box2_area = (box2[:, 2] - box2[:, 0] + 1) * (box2[:, 3] - box2[:, 1] + 1)
309 |
310 | if method == 'iog':
311 | iou_val = i_area / (box2_area)
312 | elif method == 'iob':
313 | iou_val = i_area / (box1_area)
314 | else:
315 | iou_val = i_area / (box1_area + box2_area - i_area)
316 | return iou_val
317 |
318 | def get_badcase_detail(top1_bbox, gt_bboxes, gt_label, max_iou, max_gt_id):
319 | cls_wrong = 0
320 | multi_instances = 0
321 | region_part = 0
322 | region_more = 0
323 | region_wrong = 0
324 |
325 | pred_cls = top1_bbox[0]
326 | pred_bbox = top1_bbox[1:]
327 |
328 | if not int(pred_cls) == gt_label:
329 | cls_wrong = 1
330 | return cls_wrong, multi_instances, region_part, region_more, region_wrong
331 |
332 | if max_iou > 0.5:
333 | return 0, 0, 0, 0, 0
334 |
335 | # multi_instances error
336 | gt_box_cnt = len(gt_bboxes) // 4
337 | if gt_box_cnt > 1:
338 | iogs = []
339 | for i in range(gt_box_cnt):
340 | gt_box = gt_bboxes[i * 4:(i + 1) * 4]
341 | iog = cal_iou(pred_bbox, gt_box, method='iog')
342 | iogs.append(iog)
343 | if sum(np.array(iogs) > 0.3)> 1:
344 | multi_instances = 1
345 | return cls_wrong, multi_instances, region_part, region_more, region_wrong
346 | # -region part error
347 | iog = cal_iou(pred_bbox, gt_bboxes[max_gt_id*4:(max_gt_id+1)*4], method='iog')
348 | iob = cal_iou(pred_bbox, gt_bboxes[max_gt_id*4:(max_gt_id+1)*4], method='iob')
349 | if iob >0.5:
350 | region_part = 1
351 | return cls_wrong, multi_instances, region_part, region_more, region_wrong
352 | if iog >= 0.7:
353 | region_more = 1
354 | return cls_wrong, multi_instances, region_part, region_more, region_wrong
355 | region_wrong = 1
356 | return cls_wrong, multi_instances, region_part, region_more, region_wrong
357 |
358 | def accuracy(output, target, topk=(1,)):
359 | """ Computes the precision@k for the specified values of k
360 | :param output: tensor of shape B x K, predicted logits of image from model
361 | :param target: tensor of shape B X 1, ground-truth logits of image
362 | :param topk: top predictions
363 | :return: list of precision@k
364 | """
365 | maxk = max(topk)
366 | batch_size = target.size(0)
367 |
368 | _, pred = output.topk(maxk, 1, True, True)
369 | pred = pred.t()
370 | correct = pred.eq(target.view(1, -1).expand_as(pred))
371 |
372 | res = []
373 | for k in topk:
374 | correct_k = correct[:k].reshape(-1).float().sum(0)
375 | res.append(correct_k.mul_(100.0 / batch_size))
376 | return res
377 |
378 | def list2acc(results_list):
379 | """
380 | :param results_list: list contains 0 and 1
381 | :return: accuarcy
382 | """
383 | accuarcy = results_list.count(1)/len(results_list)
384 | return accuarcy
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import math
4 | import yaml
5 | import argparse
6 | import tqdm
7 | import copy
8 | import itertools
9 | import numpy as np
10 | import matplotlib.pyplot as plt
11 |
12 | import torch
13 | import torch.nn.functional as F
14 | import torch.utils.checkpoint
15 | from torch.utils.data import Dataset
16 | from accelerate import Accelerator
17 | from accelerate.utils import set_seed
18 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
19 | from diffusers.optimization import get_scheduler
20 | from transformers import CLIPTextModel, CLIPTokenizer
21 |
22 | from models.attn import AttentionStore
23 | from datasets.base import CUBDataset, ImagenetDataset, SubDataset
24 | from datasets.evaluation import Evaluator
25 |
26 | sys.path.append(os.path.dirname(os.path.abspath(__file__)))
27 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
28 |
29 | DATASETS = dict(cub=CUBDataset, imagenet=ImagenetDataset)
30 | OPTIMIZER = dict(AdamW=torch.optim.AdamW)
31 |
32 | def set_env(benchmark=True):
33 | # get config
34 | parser = argparse.ArgumentParser(description=globals()["__doc__"])
35 | parser.add_argument("--function", type=str, default="test", required=True)
36 | parser.add_argument("--config", type=str, default="configs/cub.yml", help="Config file", required=True)
37 | parser.add_argument("--opt", type=str, default="dict()", help="Override options.")
38 | parser.add_argument("--seed", type=int, default=1234, help="Random seed")
39 | args = parser.parse_args()
40 |
41 | def load_yaml_conf(config):
42 | def dict_fix(d):
43 | for k, v in d.items():
44 | if isinstance(v, dict):
45 | v = dict_fix(v)
46 | elif v == "None":
47 | v = None
48 | elif v == "False":
49 | v = False
50 | elif v == "True":
51 | v = True
52 | elif isinstance(v, str):
53 | v = float(v) if v.isdigit() else v
54 | d[k] = v
55 | return d
56 |
57 | assert os.path.exists(config), "ERROR: no config file found."
58 | with open(config, "r") as f:
59 | config = yaml.safe_load(f)
60 | config = dict_fix(config)
61 | return config
62 |
63 | def format_conf(config):
64 | if config["train"]["max_train_steps"] is None:
65 | config["train"]["max_train_steps"] = 0
66 |
67 | keep_class = config["data"]["keep_class"]
68 | if keep_class is not None:
69 | if isinstance(keep_class, int):
70 | keep_class = [keep_class]
71 | elif isinstance(keep_class, list) and len(keep_class)==2:
72 | keep_class = list(range(keep_class[0], keep_class[1]+1))
73 | else:
74 | assert isinstance(keep_class, list)
75 |
76 | data = dict(
77 | train=dict(
78 | batch_size=config["train"]["batch_size"],
79 | shuffle=True,
80 | dataset=dict(type=config["data"]["dataset"],
81 | root=config["data"]["root"],
82 | keep_class=keep_class,
83 | crop_size=config["data"]["crop_size"],
84 | resize_size=config["data"]["resize_size"],
85 | load_pretrain_path=config["train"]["load_pretrain_path"],
86 | load_token_path=config["train"]["load_token_path"],
87 | save_path=config["train"]["save_path"],
88 | ),
89 | ),
90 | test=dict(
91 | batch_size=config["test"]["batch_size"],
92 | shuffle=False,
93 | dataset=dict(type=config["data"]["dataset"],
94 | root=config["data"]["root"],
95 | keep_class=keep_class,
96 | crop_size=config["data"]["crop_size"],
97 | resize_size=config["data"]["resize_size"],
98 | load_pretrain_path=config["test"]["load_pretrain_path"],
99 | load_class_path=config["test"]["load_class_path"],
100 | load_token_path=config["test"]["load_token_path"],
101 | ),
102 | ),
103 | )
104 |
105 | optimizer = dict(
106 | type="AdamW",
107 | lr=config["train"]["learning_rate"],
108 | betas=(config["train"]["adam_beta1"], config["train"]["adam_beta2"]),
109 | weight_decay=eval(config["train"]["adam_weight_decay"]),
110 | eps=eval(config["train"]["adam_epsilon"]),
111 | )
112 |
113 | lr_scheduler = dict(
114 | type=config["train"]["lr_scheduler"],
115 | num_warmup_steps=config["train"]["lr_warmup_steps"] * config["train"]["gradient_accumulation_steps"],
116 | num_training_steps=config["train"]["max_train_steps"] * config["train"]["gradient_accumulation_steps"],
117 | )
118 |
119 | accelerator = dict(
120 | # logging_dir=os.path.join(config["train"]["save_path"], "logs"),
121 | gradient_accumulation_steps=config["train"]["gradient_accumulation_steps"],
122 | mixed_precision="no",
123 | log_with=None
124 | )
125 |
126 | model = dict(
127 |
128 | )
129 |
130 | train = dict(
131 | epochs=config["train"]["num_train_epochs"],
132 | scale_lr=config["train"]["scale_lr"],
133 | push_to_hub=False,
134 | save_path=config["train"]["save_path"],
135 | save_step=config["train"]["save_steps"],
136 | load_pretrain_path=config["train"]["load_pretrain_path"],
137 | load_token_path=config["train"]["load_token_path"],
138 | )
139 |
140 | test = dict(
141 | cam_thr=config["test"]["cam_thr"],
142 | eval_mode=config["test"]["eval_mode"],
143 | combine_ratio=config["test"]["combine_ratio"],
144 | load_pretrain_path=config["test"]["load_pretrain_path"],
145 | load_token_path=config["test"]["load_token_path"],
146 | load_unet_path=config["test"]["load_unet_path"] if config["test"]["load_unet_path"] is not None else os.path.join(config["test"]["load_pretrain_path"], "unet"),
147 | save_vis_path=config["test"]["save_vis_path"],
148 | save_log_path=config["test"]["save_log_path"]
149 | )
150 |
151 | config = dict(
152 | model=model,
153 | data=data,
154 | optimizer=optimizer,
155 | lr_scheduler=lr_scheduler,
156 | accelerator=accelerator,
157 | train=train,
158 | test=test,
159 | )
160 |
161 | return config
162 |
163 | def merge_conf(config, extra_config):
164 | for key, value in extra_config.items():
165 | if not isinstance(value, dict):
166 | config[key] = value
167 | else:
168 | merge_value = merge_conf(config.get(key, dict()), value)
169 | config[key] = merge_value
170 | return config
171 |
172 | # parse config file
173 | config = load_yaml_conf(args.config)
174 |
175 | # override options
176 | extra_config = eval(args.opt)
177 | merge_conf(config, extra_config)
178 |
179 | config = format_conf(config)
180 |
181 | # set random seed
182 | if args.seed is not None:
183 | set_seed(args.seed, device_specific=False)
184 |
185 | # set benchmark
186 | torch.backends.cudnn.benchmark = benchmark
187 |
188 | return args, config
189 |
190 | def test(config):
191 | split = "test"
192 | device = 'cuda'
193 | torch_dtype = torch.float16
194 |
195 | eval_mode = config["test"]["eval_mode"]
196 | load_pretrain_path = config["test"]["load_pretrain_path"]
197 | keep_class = config["data"][split]["dataset"]["keep_class"]
198 | combine_ratio = config["test"]["combine_ratio"]
199 | save_log_path = config["test"]["save_log_path"]
200 | load_unet_path = config["test"]["load_unet_path"]
201 | batch_size = config["data"][split]["batch_size"]
202 |
203 | text_encoder = CLIPTextModel.from_pretrained(load_pretrain_path, subfolder="text_encoder").to(device)
204 | vae = AutoencoderKL.from_pretrained(load_pretrain_path, subfolder="vae", torch_dtype=torch_dtype).to(device)
205 | unet = UNet2DConditionModel.from_pretrained(load_unet_path, torch_dtype=torch_dtype).to(device)
206 | noise_scheduler = DDPMScheduler.from_pretrained(load_pretrain_path, subfolder="scheduler")
207 |
208 | data_configs = config["data"][split].copy()
209 | dataset_config = data_configs.pop("dataset", None)
210 | dataset_type = dataset_config.pop("type", "imagenet")
211 | dataset_config.update(dict(
212 | test_mode=(split == "val" or split == "test"),
213 | text_encoder=text_encoder))
214 | dataset = DATASETS[dataset_type](**dataset_config)
215 | dataloader = torch.utils.data.DataLoader(dataset, **data_configs)
216 |
217 | vae.eval()
218 | unet.eval()
219 | text_encoder.eval()
220 |
221 | evaluator = Evaluator(logfile=save_log_path, len_dataloader=len(dataloader))
222 | controller = AttentionStore(batch_size=batch_size)
223 | AttentionStore.register_attention_control(controller, unet)
224 |
225 | if keep_class is None:
226 | keep_class = list(range(dataset.num_classes))
227 | print(f"INFO: Test Save:\t [log: {str(config['test']['save_log_path'])}] [vis: {str(config['test']['save_vis_path'])}]", flush=True)
228 | print(f"INFO: Test CheckPoint:\t [token: {str(config['test']['load_token_path'])}] [unet: {str(config['test']['load_unet_path'])}]", flush=True)
229 | print(f"INFO: Test Class [{keep_class[0]}-{keep_class[-1]}]:\t [dataset: {dataset_type}] [eval mode: {eval_mode}] "
230 | f"[cam thr: {config['test']['cam_thr']}] [combine ratio: {combine_ratio}]", flush=True)
231 |
232 | for step, data in enumerate(tqdm.tqdm(dataloader)):
233 | if eval_mode == "gtk":
234 | image = data["img"].to(torch_dtype).to(device)
235 | latents = vae.encode(image).latent_dist.sample().detach() * 0.18215
236 | noise = torch.randn(latents.shape).to(latents.device)
237 | timesteps = torch.randint(
238 | 0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=latents.device
239 | ).long()
240 |
241 | representative_embeddings = [text_encoder(ids.to(device))[0] for ids in data["caption_ids_concept_token"][-1]]
242 | representative_embeddings = sum(representative_embeddings)/len(data["caption_ids_concept_token"][-1])
243 |
244 | discriminative_embeddings = [text_encoder(ids.to(device))[0] for ids in data["caption_ids_meta_token"][-1]]
245 | discriminative_embeddings = sum(discriminative_embeddings) / len(data["caption_ids_meta_token"][-1])
246 | combine_embeddings = combine_ratio * representative_embeddings + (1-combine_ratio) * discriminative_embeddings
247 | combine_embeddings = combine_embeddings.to(torch_dtype)
248 |
249 | for t in [0, 99]:
250 | timesteps = torch.ones_like(timesteps) * t
251 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps).to(torch_dtype)
252 | noise_pred = unet(noisy_latents, timesteps, combine_embeddings).sample
253 |
254 | cams = controller.diffusion_cam(idx=5)
255 |
256 | controller.reset()
257 |
258 | cams_tensor = torch.from_numpy(cams).to(device).unsqueeze(dim=1)
259 | pad_cams = cams_tensor.repeat(1, dataset.num_classes, 1, 1)
260 |
261 | evaluator(data["img"], data["gt_labels"], data['gt_bboxes'], data["pred_logits"], pad_cams, data["name"], config, step)
262 | elif eval_mode == "top1":
263 | cams_all = []
264 |
265 | image = data["img"].to(torch_dtype).to(device)
266 | latents = vae.encode(image).latent_dist.sample().detach() * 0.18215
267 | noise = torch.randn(latents.shape).to(latents.device)
268 | timesteps = torch.randint(
269 | 0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=latents.device
270 | ).long()
271 |
272 | for top_idx in [0, 5]:
273 |
274 | # save inference cost
275 | if top_idx == 5:
276 | top1_idx = data["pred_top5_ids"][:, 0]
277 | gtk_idx = torch.LongTensor(data["gt_labels"]).to(torch.int64)
278 | if torch.all(top1_idx == gtk_idx):
279 | cams_all = cams_all*2
280 | break
281 |
282 | representative_embeddings = [text_encoder(ids.to(device))[0] for ids in
283 | data["caption_ids_concept_token"][top_idx]]
284 | representative_embeddings = sum(representative_embeddings) / len(data["caption_ids_concept_token"][top_idx])
285 |
286 | discriminative_embeddings = [text_encoder(ids.to(device))[0] for ids in
287 | data["caption_ids_meta_token"][top_idx]]
288 | discriminative_embeddings = sum(discriminative_embeddings) / len(data["caption_ids_meta_token"][top_idx])
289 | combine_embeddings = combine_ratio * representative_embeddings + (
290 | 1 - combine_ratio) * discriminative_embeddings
291 | combine_embeddings = combine_embeddings.to(torch_dtype)
292 |
293 | for t in [0, 99]:
294 | timesteps = torch.ones_like(timesteps) * t
295 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps).to(torch_dtype)
296 | noise_pred = unet(noisy_latents, timesteps, combine_embeddings).sample
297 |
298 | controller.batch_size = len(noise_pred)
299 | cams = controller.diffusion_cam(idx=5)
300 | controller.reset()
301 | cams_all.append(cams)
302 |
303 | cams_tensor = torch.from_numpy(cams_all[0]).to(device).unsqueeze(dim=1)
304 | pad_cams = cams_tensor.repeat(1, dataset.num_classes, 1, 1)
305 | for i, pad_cam in enumerate(pad_cams):
306 | pad_cam[data["gt_labels"][i]] = torch.from_numpy(cams_all[-1])[i]
307 |
308 | evaluator(data["img"], data["gt_labels"], data['gt_bboxes'], data["pred_logits"], pad_cams, data["name"], config, step)
309 | elif eval_mode == "top5":
310 | cams_all = []
311 |
312 | image = data["img"].to(torch_dtype).to(device)
313 | latents = vae.encode(image).latent_dist.sample().detach() * 0.18215
314 | noise = torch.randn(latents.shape).to(latents.device)
315 | timesteps = torch.randint(
316 | 0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=latents.device
317 | ).long()
318 |
319 | for top_idx in range(6):
320 | representative_embeddings = [text_encoder(ids.to(device))[0] for ids in
321 | data["caption_ids_concept_token"][top_idx]]
322 | representative_embeddings = sum(representative_embeddings) / len(data["caption_ids_concept_token"][top_idx])
323 |
324 | discriminative_embeddings = [text_encoder(ids.to(device))[0] for ids in
325 | data["caption_ids_meta_token"][top_idx]]
326 | discriminative_embeddings = sum(discriminative_embeddings) / len(data["caption_ids_meta_token"][top_idx])
327 | combine_embeddings = combine_ratio * representative_embeddings + (
328 | 1 - combine_ratio) * discriminative_embeddings
329 | combine_embeddings = combine_embeddings.to(torch_dtype)
330 |
331 | for t in [0, 99]:
332 | timesteps = torch.ones_like(timesteps) * t
333 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps).to(torch_dtype)
334 | noise_pred = unet(noisy_latents, timesteps, combine_embeddings).sample
335 |
336 | cams = controller.diffusion_cam(idx=5)
337 | controller.reset()
338 | cams_all.append(cams)
339 |
340 | cams = torch.from_numpy(np.stack([cam for cam in cams_all[:-1]], 1))
341 | pad_cams = torch.zeros((batch_size, dataset.num_classes, *cams.shape[-2:]))
342 | for i, pad_cam in enumerate(pad_cams):
343 | pad_cam[data["pred_top5_ids"][i]] = cams[i]
344 | pad_cam[data["gt_labels"][i]] = torch.from_numpy(cams_all[-1])[i]
345 |
346 | evaluator(data["img"], data["gt_labels"], data['gt_bboxes'], data["pred_logits"], pad_cams, data["name"], config, step)
347 | else:
348 | raise ValueError("select eval_mode in [gtk, top1, top5].")
349 |
350 | def train_token(config):
351 | split = "train"
352 | device = 'cuda'
353 | torch_dtype = torch.float32
354 |
355 | load_pretrain_path = config["train"]["load_pretrain_path"]
356 | keep_class = config["data"][split]["dataset"]["keep_class"]
357 |
358 | text_encoder = CLIPTextModel.from_pretrained(load_pretrain_path, subfolder="text_encoder", torch_dtype=torch_dtype)
359 | vae = AutoencoderKL.from_pretrained(load_pretrain_path, subfolder="vae", torch_dtype=torch_dtype)
360 | unet = UNet2DConditionModel.from_pretrained(load_pretrain_path, subfolder="unet", torch_dtype=torch_dtype)
361 | noise_scheduler = DDPMScheduler.from_pretrained(load_pretrain_path, subfolder="scheduler")
362 |
363 | def freeze_params(params):
364 | for param in params:
365 | param.requires_grad = False
366 |
367 | freeze_params(itertools.chain(
368 | vae.parameters(),
369 | unet.parameters(),
370 | text_encoder.text_model.encoder.parameters(),
371 | text_encoder.text_model.final_layer_norm.parameters(),
372 | text_encoder.text_model.embeddings.position_embedding.parameters(),
373 | ))
374 |
375 | data_configs = config["data"][split].copy()
376 | dataset_config = data_configs.pop("dataset", None)
377 | dataset_type = dataset_config.pop("type", "imagenet")
378 | dataset_config.update(dict(
379 | test_mode=(split == "val" or split == "test"),
380 | text_encoder=text_encoder))
381 | dataset = DATASETS[dataset_type](**dataset_config)
382 |
383 | def train_loop(config, class_id, text_encoder=None, unet=None, vae=None, dataset=None):
384 |
385 | def get_grads_to_zero(class_id, dataset):
386 | tokenizer = dataset.tokenizer
387 | index_grads_to_zero = torch.ones((len(tokenizer))).bool()
388 | concept_token = dataset.cat2tokens[class_id]["concept_token"]
389 | token_id = tokenizer.encode(concept_token, add_special_tokens=False)[0]
390 | index_grads_to_zero[token_id] = False
391 | return index_grads_to_zero, token_id
392 |
393 | index_grads_to_zero, token_id = get_grads_to_zero(class_id, dataset)
394 | config = copy.deepcopy(config)
395 |
396 | subdataset = SubDataset(keep_class=[class_id], dataset=dataset)
397 | data_configs = config["data"][split].copy()
398 | data_configs.pop("dataset", None)
399 | dataloader = torch.utils.data.DataLoader(subdataset, **data_configs)
400 |
401 | accelerator_config = config.get("accelerator", None)
402 | accelerator = Accelerator(**accelerator_config)
403 | if accelerator.is_main_process:
404 | accelerator.init_trackers("wsol", config=config)
405 |
406 | save_path = config['train']['save_path']
407 | batch_size = config['data']['train']['batch_size']
408 | gradient_accumulation_steps = config['accelerator']['gradient_accumulation_steps']
409 | num_train_epochs = config['train']['epochs']
410 | max_train_steps = config['lr_scheduler']['num_training_steps'] // gradient_accumulation_steps
411 | total_batch_size = batch_size * accelerator.num_processes * gradient_accumulation_steps
412 |
413 | if config['train']['scale_lr']:
414 | config['optimizer']['lr'] = config['optimizer']['lr'] * total_batch_size
415 | if (max_train_steps is None) or (max_train_steps == 0):
416 | num_update_steps_per_epoch = math.ceil(len(dataloader) / gradient_accumulation_steps)
417 | max_train_steps = num_train_epochs * num_update_steps_per_epoch // accelerator.num_processes
418 |
419 | config['lr_scheduler']['num_training_steps'] = max_train_steps * gradient_accumulation_steps
420 | optimizer_config = config.get("optimizer", None)
421 | optimizer_type = optimizer_config.pop("type", "AdamW")
422 | lr_scheduler_config = config.get("lr_scheduler", None)
423 | lr_scheduler_type = lr_scheduler_config.pop("type", "constant")
424 |
425 | optimizer = OPTIMIZER[optimizer_type](text_encoder.get_input_embeddings().parameters(), **optimizer_config)
426 | lr_scheduler = get_scheduler(name=lr_scheduler_type, optimizer=optimizer, **lr_scheduler_config)
427 |
428 | if accelerator.is_main_process:
429 | print(f"INFO: Train Save:\t [ckpt: {save_path}]", flush=True)
430 | print(f"INFO: Train Class [{class_id}]:\t [num samples: {len(dataloader)}] "
431 | f"[num epochs: {num_train_epochs}] [batch size: {total_batch_size}] "
432 | f"[total steps: {max_train_steps}]", flush=True)
433 |
434 | vae, unet, text_encoder, optimizer, lr_scheduler, dataloader = accelerator.prepare(vae, unet, text_encoder,
435 | optimizer, lr_scheduler,
436 | dataloader)
437 | vae.eval()
438 | unet.eval()
439 |
440 | global_step = 0
441 | progress_bar = tqdm.tqdm(range(max_train_steps), disable=(not accelerator.is_local_main_process))
442 | for epoch in range(num_train_epochs):
443 | text_encoder.train()
444 | progress_bar.set_description(f"Epoch[{epoch+1}/{num_train_epochs}] ")
445 | for step, data in enumerate(dataloader):
446 | with accelerator.accumulate(text_encoder):
447 | combine_embeddings = text_encoder(data["caption_ids_concept_token"])[0]
448 |
449 | image = data["img"].to(torch_dtype) # use torch.float16 rather than float32
450 | latents = vae.encode(image).latent_dist.sample().detach() * 0.18215
451 |
452 | noise = torch.randn(latents.shape, device=latents.device, dtype=torch_dtype)
453 | timesteps = torch.randint(low=0, high=noise_scheduler.config.num_train_timesteps,
454 | size=(latents.shape[0],), device=latents.device).long()
455 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
456 |
457 | noise_pred = unet(noisy_latents, timesteps, combine_embeddings).sample
458 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
459 |
460 | accelerator.backward(loss)
461 |
462 | if accelerator.num_processes > 1:
463 | grads = text_encoder.module.get_input_embeddings().weight.grad
464 | else:
465 | grads = text_encoder.get_input_embeddings().weight.grad
466 | grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
467 |
468 | optimizer.step()
469 | lr_scheduler.step()
470 | optimizer.zero_grad()
471 |
472 | # Checks if the accelerator has performed an optimization step behind the scenes
473 | if accelerator.sync_gradients:
474 | progress_bar.update(1)
475 | global_step += 1
476 |
477 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
478 | progress_bar.set_postfix(refresh=False, **logs)
479 | accelerator.log(logs, step=global_step)
480 |
481 | if global_step >= max_train_steps:
482 | break
483 |
484 | if global_step >= max_train_steps:
485 | break
486 |
487 | # save concept token embeddings per epoch
488 | accelerator.wait_for_everyone()
489 | if accelerator.is_main_process:
490 | out_dir = os.path.join(save_path, "tokens")
491 | if epoch != num_train_epochs - 1:
492 | out_dir = os.path.join(save_path, f"tokens_e{epoch}")
493 | os.makedirs(out_dir, exist_ok=True)
494 | unwrap_text_encoder = accelerator.unwrap_model(text_encoder)
495 | concept_token = dataset.tokenizer.decode(token_id)
496 | concept_token_embeddings = unwrap_text_encoder.get_input_embeddings().weight[token_id]
497 | dct = {concept_token: concept_token_embeddings.detach().cpu()}
498 | torch.save(dct, os.path.join(out_dir, f"{token_id}.bin"))
499 |
500 | accelerator.end_training()
501 |
502 | if keep_class is None:
503 | keep_class = list(range(dataset.num_classes))
504 | for class_id in keep_class:
505 | train_loop(config, class_id, text_encoder, unet, vae, dataset)
506 |
507 | def train_unet(config):
508 | split = "train"
509 | device = 'cuda'
510 | torch_dtype = torch.float32
511 |
512 | load_pretrain_path = config["train"]["load_pretrain_path"]
513 |
514 | text_encoder = CLIPTextModel.from_pretrained(load_pretrain_path, subfolder="text_encoder", torch_dtype=torch_dtype)
515 | vae = AutoencoderKL.from_pretrained(load_pretrain_path, subfolder="vae", torch_dtype=torch_dtype)
516 | unet = UNet2DConditionModel.from_pretrained(load_pretrain_path, subfolder="unet", torch_dtype=torch_dtype)
517 | noise_scheduler = DDPMScheduler.from_pretrained(load_pretrain_path, subfolder="scheduler")
518 |
519 | def freeze_params(params):
520 | for param in params:
521 | param.requires_grad = False
522 |
523 | freeze_params(itertools.chain(
524 | vae.parameters(),
525 | text_encoder.parameters(),
526 | ))
527 |
528 | data_configs = config["data"][split].copy()
529 | dataset_config = data_configs.pop("dataset", None)
530 | dataset_type = dataset_config.pop("type", "imagenet")
531 | dataset_config.update(dict(
532 | test_mode=(split == "val" or split == "test"),
533 | text_encoder=text_encoder))
534 | dataset = DATASETS[dataset_type](**dataset_config)
535 | dataloader = torch.utils.data.DataLoader(dataset, **data_configs)
536 |
537 | def train_loop(config, text_encoder=None, unet=None, vae=None, dataloader=None):
538 | config = copy.deepcopy(config)
539 |
540 | accelerator_config = config.get("accelerator", None)
541 | accelerator = Accelerator(**accelerator_config)
542 | if accelerator.is_main_process:
543 | accelerator.init_trackers("wsol", config=config)
544 |
545 | save_path = config['train']['save_path']
546 | save_step = config['train']['save_step']
547 | batch_size = config['data']['train']['batch_size']
548 | gradient_accumulation_steps = config['accelerator']['gradient_accumulation_steps']
549 | num_train_epochs = config['train']['epochs']
550 | max_train_steps = config['lr_scheduler']['num_training_steps'] // gradient_accumulation_steps
551 | total_batch_size = batch_size * accelerator.num_processes * gradient_accumulation_steps
552 |
553 | if config['train']['scale_lr']:
554 | config['optimizer']['lr'] = config['optimizer']['lr'] * total_batch_size
555 | if (max_train_steps is None) or (max_train_steps == 0):
556 | num_update_steps_per_epoch = math.ceil(len(dataloader) / gradient_accumulation_steps)
557 | max_train_steps = num_train_epochs * num_update_steps_per_epoch // accelerator.num_processes
558 |
559 | config['lr_scheduler']['num_training_steps'] = max_train_steps * gradient_accumulation_steps
560 | optimizer_config = config.get("optimizer", None)
561 | optimizer_type = optimizer_config.pop("type", "AdamW")
562 | lr_scheduler_config = config.get("lr_scheduler", None)
563 | lr_scheduler_type = lr_scheduler_config.pop("type", "constant")
564 |
565 | optimizer = OPTIMIZER[optimizer_type](accelerator.unwrap_model(unet).parameters(), **optimizer_config)
566 | lr_scheduler = get_scheduler(name=lr_scheduler_type, optimizer=optimizer, **lr_scheduler_config)
567 |
568 | if accelerator.is_main_process:
569 | print(f"INFO: Train Save:\t [ckpt: {save_path}]", flush=True)
570 | print(f"INFO: Train UNet:\t [num samples: {len(dataloader)}] "
571 | f"[num epochs: {num_train_epochs}] [batch size: {total_batch_size}] "
572 | f"[total steps: {max_train_steps}] [save step: {save_step}]", flush=True)
573 |
574 | vae, unet, text_encoder, optimizer, lr_scheduler, dataloader = accelerator.prepare(vae, unet, text_encoder,
575 | optimizer, lr_scheduler,
576 | dataloader)
577 | vae.eval()
578 | text_encoder.eval()
579 |
580 | global_step = 0
581 | progress_bar = tqdm.tqdm(range(max_train_steps), disable=(not accelerator.is_local_main_process))
582 | for epoch in range(num_train_epochs):
583 | unet.train()
584 | progress_bar.set_description(f"Epoch[{epoch + 1}/{num_train_epochs}] ")
585 | for step, data in enumerate(dataloader):
586 | with accelerator.accumulate(unet):
587 | representative_embeddings = text_encoder(data["caption_ids_concept_token"])[0]
588 | discriminative_embeddings = text_encoder(data["caption_ids_meta_token"])[0]
589 | combine_embeddings = 0.5 * representative_embeddings + 0.5 * discriminative_embeddings
590 |
591 | image = data["img"].to(torch_dtype)
592 | latents = vae.encode(image).latent_dist.sample().detach() * 0.18215
593 |
594 | noise = torch.randn(latents.shape, device=latents.device, dtype=torch_dtype)
595 | timesteps = torch.randint(low=0, high=noise_scheduler.config.num_train_timesteps,
596 | size=(latents.shape[0],), device=latents.device).long()
597 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
598 |
599 | noise_pred = unet(noisy_latents, timesteps, combine_embeddings).sample
600 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
601 |
602 | accelerator.backward(loss)
603 |
604 | optimizer.step()
605 | lr_scheduler.step()
606 | optimizer.zero_grad()
607 |
608 | # Checks if the accelerator has performed an optimization step behind the scenes
609 | if accelerator.sync_gradients:
610 | progress_bar.update(1)
611 | global_step += 1
612 |
613 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
614 | progress_bar.set_postfix(refresh=False, **logs)
615 | accelerator.log(logs, step=global_step)
616 |
617 | if global_step >= max_train_steps:
618 | break
619 | elif (global_step + 1) % save_step == 0:
620 | if accelerator.sync_gradients:
621 | accelerator.wait_for_everyone()
622 | if accelerator.is_main_process:
623 | out_dir = os.path.join(save_path, f"unet_s{global_step}")
624 | os.makedirs(out_dir, exist_ok=True)
625 | try:
626 | unet.module.save_pretrained(save_directory=out_dir)
627 | except:
628 | unet.save_pretrained(save_directory=out_dir)
629 |
630 |
631 | if global_step >= max_train_steps:
632 | break
633 |
634 | accelerator.end_training()
635 |
636 | train_loop(config, text_encoder, unet, vae, dataloader)
637 |
638 | if __name__ == "__main__":
639 | args, config = set_env(benchmark=True)
640 | print(args)
641 | eval(args.function)(config)
642 |
643 |
644 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/callsys/GenPromp/c480e0a106d7158da4a4e86cc27441f181996b04/models/__init__.py
--------------------------------------------------------------------------------
/models/attn.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | import numpy as np
4 | import torch.utils.checkpoint
5 | import matplotlib.pyplot as plt
6 |
7 | class AttentionStore():
8 | def __init__(self, batch_size=2):
9 | self.cur_step = 0
10 | self.num_att_layers = -1
11 | self.cur_att_layer = 0
12 | self.step_store = self.get_empty_store()
13 | self.attention_store = {}
14 | self.active = True
15 | self.batch_size = batch_size
16 |
17 | def reset(self):
18 | self.cur_step = 0
19 | self.cur_att_layer = 0
20 | self.step_store = self.get_empty_store()
21 | self.attention_store = {}
22 |
23 | @property
24 | def num_uncond_att_layers(self):
25 | return 0
26 |
27 | def step_callback(self, x_t):
28 | return x_t
29 |
30 | @staticmethod
31 | def get_empty_store():
32 | return {"down_cross": [], "mid_cross": [], "up_cross": [],
33 | "down_self": [], "mid_self": [], "up_self": []}
34 |
35 | def forward(self, attn, is_cross: bool, place_in_unet: str):
36 | if self.active:
37 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
38 | self.step_store[key].append(attn)
39 | return attn
40 |
41 | def between_steps(self):
42 | if self.active:
43 | if len(self.attention_store) == 0:
44 | self.attention_store = self.step_store
45 | else:
46 | for key in self.attention_store:
47 | for i in range(len(self.attention_store[key])):
48 | self.attention_store[key][i] += self.step_store[key][i]
49 | self.step_store = self.get_empty_store()
50 |
51 | def get_average_attention(self):
52 | average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in
53 | self.attention_store}
54 | return average_attention
55 |
56 | def __call__(self, attn, is_cross: bool, place_in_unet: str):
57 | if self.cur_att_layer >= self.num_uncond_att_layers:
58 | attn = self.forward(attn, is_cross, place_in_unet)
59 | self.cur_att_layer += 1
60 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
61 | self.cur_att_layer = 0
62 | self.cur_step += 1
63 | self.between_steps()
64 | return attn
65 |
66 | @staticmethod
67 | def register_attention_control(controller, model):
68 |
69 | def ca_attention(self, place_in_unet):
70 |
71 | def get_attention_scores(query, key, attention_mask=None):
72 | dtype = query.dtype
73 |
74 | if self.upcast_attention:
75 | query = query.float()
76 | key = key.float()
77 |
78 | if attention_mask is None:
79 | baddbmm_input = torch.empty(
80 | query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
81 | )
82 | beta = 0
83 | else:
84 | baddbmm_input = attention_mask
85 | beta = 1
86 |
87 | attention_scores = torch.baddbmm(
88 | baddbmm_input,
89 | query,
90 | key.transpose(-1, -2),
91 | beta=beta,
92 | alpha=self.scale,
93 | )
94 |
95 | if self.upcast_softmax:
96 | attention_scores = attention_scores.float()
97 |
98 | attention_probs = attention_scores.softmax(dim=-1)
99 | attention_probs = attention_probs.to(dtype)
100 |
101 | if query.shape == key.shape:
102 | is_cross = False
103 | else:
104 | is_cross = True
105 |
106 | attention_probs = controller(attention_probs, is_cross, place_in_unet)
107 |
108 | return attention_probs
109 |
110 | return get_attention_scores
111 |
112 | def register_recr(net_, count, place_in_unet):
113 | if net_.__class__.__name__ == 'CrossAttention':
114 | # net_._attention = ca_attention(net_, place_in_unet)
115 | net_.get_attention_scores = ca_attention(net_, place_in_unet)
116 | return count + 1
117 | elif hasattr(net_, 'children'):
118 | for net__ in net_.children():
119 | count = register_recr(net__, count, place_in_unet)
120 | return count
121 |
122 | cross_att_count = 0
123 | sub_nets = model.named_children()
124 | for net in sub_nets:
125 | if "down" in net[0]:
126 | cross_att_count += register_recr(net[1], 0, "down")
127 | elif "up" in net[0]:
128 | cross_att_count += register_recr(net[1], 0, "up")
129 | elif "mid" in net[0]:
130 | cross_att_count += register_recr(net[1], 0, "mid")
131 | controller.num_att_layers = cross_att_count
132 |
133 | def aggregate_attention(self, res, from_where, is_cross, bz):
134 | out = []
135 | attention_maps = self.get_average_attention()
136 | num_pixels = res ** 2
137 | for location in from_where:
138 | for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
139 | if item.shape[1] == num_pixels:
140 | cross_maps = item.reshape(bz, -1, res, res, item.shape[-1])
141 | out.append(cross_maps)
142 | out = torch.cat(out, dim=1)
143 | return out.cpu()
144 |
145 | def self_attention_map(self, res, from_where, bz, max_com=10, out_res=64):
146 | attention_maps = self.aggregate_attention(res, from_where, False, bz)
147 | maps = []
148 | for b in range(bz):
149 | attention_map = attention_maps[b].detach().numpy().astype(np.float32).mean(0).reshape((res**2, res**2))
150 | u, s, vh = np.linalg.svd(attention_map - np.mean(attention_map, axis=1, keepdims=True))
151 | images = []
152 | for i in range(max_com):
153 | image = vh[i].reshape(res, res)
154 | # image = image/image.max()
155 | # image = (image - image.min()) / (image.max() - image.min())
156 | image = cv2.resize(image, (out_res, out_res), interpolation=cv2.INTER_CUBIC)
157 | images.append(image)
158 | map = np.stack(images, 0).max(0)
159 | maps.append(map)
160 | return np.stack(maps, 0)
161 |
162 | def cross_attention_map(self, res, from_where, bz, out_res=64, idx=5):
163 | attention_maps = self.aggregate_attention(res, from_where, True, bz)
164 | attention_maps = attention_maps[..., idx]
165 | attention_maps = attention_maps.sum(1) / attention_maps.shape[1]
166 |
167 | maps = []
168 | for b in range(bz):
169 | map = attention_maps[b, :, :]
170 | map = cv2.resize(map.detach().numpy().astype(np.float32), (out_res, out_res),
171 | interpolation=cv2.INTER_CUBIC)
172 | # map = map / map.max()
173 | maps.append(map)
174 | return np.stack(maps, 0)
175 |
176 | def diffusion_cam(self, idx=5):
177 | bz = self.batch_size
178 | attention_maps_8_ca = self.cross_attention_map(8, ("up", "mid", "down"), bz, idx=idx)
179 | attention_maps_16_up_ca = self.cross_attention_map(16, ("up",), bz, idx=idx)
180 | attention_maps_16_down_ca = self.cross_attention_map(16, ("down",), bz, idx=idx)
181 | attention_maps_ca = (attention_maps_8_ca + attention_maps_16_up_ca + attention_maps_16_down_ca) / 3
182 | cams = attention_maps_ca
183 | cams = cams / cams.max((1,2))[:, None, None]
184 | return cams
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
--------------------------------------------------------------------------------