├── LICENSE ├── README.md ├── assets ├── intro.png ├── ratio.png ├── results.png └── visualize.png ├── configs ├── cub.yml ├── cub_stage2.yml ├── imagenet.yml └── imagenet_stage2.yml ├── datasets ├── __init__.py ├── base.py └── evaluation │ ├── __init__.py │ └── cam.py ├── main.py └── models ├── __init__.py └── attn.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Generative Prompt Model for Weakly Supervised Object Localization 4 |
5 | 6 | 7 | This is the official implementaion of paper [***Generative Prompt Model for Weakly Supervised Object Localization***](https://openaccess.thecvf.com/content/ICCV2023/html/Zhao_Generative_Prompt_Model_for_Weakly_Supervised_Object_Localization_ICCV_2023_paper.html), which is accepted in ***ICCV 2023***. This repository contains Pytorch training code, evaluation code, pre-trained models, and visualization method. 8 | 9 |
10 | 11 | [![arXiv preprint](http://img.shields.io/badge/arXiv-2307.09756-b31b1b)](https://arxiv.org/abs/2307.09756) 12 | ![Python 3.8](https://img.shields.io/badge/Python-3.8-green.svg?style=plastic) 13 | ![PyTorch 1.11](https://img.shields.io/badge/PyTorch-1.11-EE4C2C.svg?style=plastic) 14 | [![LICENSE](https://img.shields.io/github/license/vasgaowei/ts-cam.svg)](LICENSE) 15 | 16 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/generative-prompt-model-for-weakly-supervised/weakly-supervised-object-localization-on-2)](https://paperswithcode.com/sota/weakly-supervised-object-localization-on-2?p=generative-prompt-model-for-weakly-supervised) 17 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/generative-prompt-model-for-weakly-supervised/weakly-supervised-object-localization-on-cub)](https://paperswithcode.com/sota/weakly-supervised-object-localization-on-cub?p=generative-prompt-model-for-weakly-supervised) 18 |
19 | 20 |
21 | 22 | 23 |
24 | 25 | ## 1. Contents 26 | - Generative Prompt Model for Weakly Supervised Object Localization 27 | - [1. Contents](#1-contents) 28 | - [2. Introduction](#2-introduction) 29 | - [3. Results](#3-results) 30 | - [4. Get Start](#4-get-start) 31 | - [4.1 Installation](#41-installation) 32 | - [4.2 Dataset and Files Preparation](#42-dataset-and-files-preparation) 33 | - [4.3 Training](#43-training) 34 | - [4.4 Inference](#44-inference) 35 | - [4.5 Extra Options](#45-extra-options) 36 | - [5. Contacts](#5-contacts) 37 | - [6. Acknowledgment](#6-acknowledgment) 38 | - [7. Citation](#7-citation) 39 | 40 | ## 2. Introduction 41 | 42 | Weakly supervised object localization (WSOL) remains challenging when learning object localization models from image category labels. Conventional methods that discriminatively train activation models ignore representative yet less discriminative object parts. In this study, we propose a generative prompt model (GenPromp), defining the first generative pipeline to localize less discriminative object parts by formulating WSOL as a conditional image denoising procedure. During training, GenPromp converts image category labels to learnable prompt embeddings which are fed to a generative model to conditionally recover the input image with noise and learn representative embeddings. During inference, GenPromp combines the representative embeddings with discriminative embeddings (queried from an off-the-shelf vision-language model) for both representative and discriminative capacity. The combined embeddings are finally used to generate multi-scale high-quality attention maps, which facilitate localizing full object extent. Experiments on CUB-200-2011 and ILSVRC show that GenPromp respectively outperforms the best discriminative models, setting a solid baseline for WSOL with the generative model. 43 | 44 | 45 | ## 3. Results 46 | 47 |
48 | 49 | 50 |
51 | 52 | We re-train GenPromp with a better learning schedule on 4 x A100. The performance of GenPromp on CUB-200-2011 is further improved. 53 | 54 | | Method | Dataset | Cls Back. | Top-1 Loc | Top-5 Loc | GT-known Loc | 55 | | ------ | ------- | --------- | --------- | --------- | ------------ | 56 | | GenPromp | CUB-200-2011 | EfficientNet-B7 | 87.0 | 96.1 | 98.0 | 57 | | GenPromp (Re-train) | CUB-200-2011 | EfficientNet-B7 | 87.2 (+0.2) | 96.3 (+0.2) | 98.3 (+0.3) | 58 | | GenPromp | ImageNet | EfficientNet-B7 | 65.2 | 73.4 | 75.0 | 59 | 60 | ## 4. Get Start 61 | 62 | ### 4.1 Installation 63 | 64 | To setup the environment of GenPromp, we use `conda` to manage our dependencies. Our developers use `CUDA 11.3` to do experiments. Run the following commands to install GenPromp: 65 | ``` 66 | conda create -n gpm python=3.8 -y && conda activate gpm 67 | pip install --upgrade pip 68 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113 69 | pip install --upgrade diffusers[torch]==0.13.1 70 | pip install transformers==4.29.2 accelerate==0.19.0 71 | pip install matplotlib opencv-python OmegaConf tqdm 72 | ``` 73 | 74 | ### 4.2 Dataset and Files Preparation 75 | To train GenPromp with pre-training weights and infer GenPromp with the given weights, download the files in the table and arrange the files according to the file tree below. (Uploading) 76 | 77 | 78 | | Dataset & Files | Download | Usage | 79 | | -------------------------------------- | ---------------------------------------------------------------------- | --------------------------------------------------------------------- | 80 | | data/ImageNet_ILSVRC2012 (146GB) | [Official Link](http://image-net.org/) | Benchmark dataset | 81 | | data/CUB_200_2011 (1.2GB) | [Official Link](http://www.vision.caltech.edu/datasets/cub_200_2011/) | Benchmark dataset | 82 | | ckpts/pretrains (5.2GB) | [Official Link](https://huggingface.co/CompVis/stable-diffusion-v1-4), [Google Drive](https://drive.google.com/drive/folders/1z4oqLBhIQsADqQOfnwOmXO-a4M2V4Ca8?usp=sharing), [Baidu Drive](https://pan.baidu.com/s/1DC8QUjocIJcmASsA-A1TYA)(o9ei) | Stable Diffusion pretrain weights | 83 | | ckpts/classifications (2.3GB) | [Google Drive](https://drive.google.com/drive/folders/1z4oqLBhIQsADqQOfnwOmXO-a4M2V4Ca8?usp=sharing), [Baidu Drive](https://pan.baidu.com/s/1DC8QUjocIJcmASsA-A1TYA)(o9ei) | Classfication results on benchmark datasets | 84 | | ckpts/imagenet750 (3.3.GB) | [Google Drive](https://drive.google.com/drive/folders/1z4oqLBhIQsADqQOfnwOmXO-a4M2V4Ca8?usp=sharing), [Baidu Drive](https://pan.baidu.com/s/1DC8QUjocIJcmASsA-A1TYA)(o9ei) | Weights that achieves 75.0% GT-Known Loc on ImageNet | 85 | | ckpts/cub983 (3.3GB) | [Google Drive](https://drive.google.com/drive/folders/1z4oqLBhIQsADqQOfnwOmXO-a4M2V4Ca8?usp=sharing), [Baidu Drive](https://pan.baidu.com/s/1DC8QUjocIJcmASsA-A1TYA)(o9ei) | Weights that achieves 98.3% GT-Known Loc on CUB | 86 | 87 | ```text 88 | |--GenPromp/ 89 | |--data/ 90 | |--ImageNet_ILSVRC2012/ 91 | |--ILSVRC2012_list/ 92 | |--train/ 93 | |--val/ 94 | |--CUB_200_2011 95 | |--attributes/ 96 | |--images/ 97 | ... 98 | |--ckpts/ 99 | |--pretrains/ 100 | |--stable-diffusion-v1-4/ 101 | |--classifications/ 102 | |--cub_efficientnetb7.json 103 | |--imagenet_efficientnet-b7_3rdparty_8xb32-aa-advprop_in1k.json 104 | |--imagenet750/ 105 | |--tokens/ 106 | |--49408.bin 107 | |--49409.bin 108 | ... 109 | |--unet/ 110 | |--cub983/ 111 | |--tokens/ 112 | |--49408.bin 113 | |--49409.bin 114 | ... 115 | |--unet/ 116 | |--configs/ 117 | |--datasets 118 | |--models 119 | |--main.py 120 | ``` 121 | 122 | 123 | ### 4.3 Training 124 | 125 | Here is a training example of GenPromp on ImageNet. 126 | ``` 127 | accelerate config 128 | accelerate launch python main.py --function train_token --config configs/imagenet.yml --opt "{'train': {'save_path': 'ckpts/imagenet/'}}" 129 | accelerate launch python main.py --function train_unet --config configs/imagenet_stage2.yml --opt "{'train': {'load_token_path': 'ckpts/imagenet/tokens/', 'save_path': 'ckpts/imagenet/'}}" 130 | ``` 131 | `accelerate` is used for multi-GPU training. In the first training stage, the weights of concept tokens of the representative embeddings are learned and saved to `ckpts/imagenet/`. In the second training stage, the weights of the learned concept tokens are loaded from `ckpts/imagenet/tokens/`, then the weights of the UNet are finetuned and saved to `ckpts/imagenet/`. Other configurations can be seen in the config files (i.e. `configs/imagenet.yml` and `configs/imagenet_stage2.yml`) and can be modified by `--opt` with a parameter dict (See [Extra Options](#45-extra-options) for details). 132 | 133 | Here is a training example of GenPromp on CUB_200_2011. 134 | ``` 135 | accelerate config 136 | accelerate launch python main.py --function train_token --config configs/cub.yml --opt "{'train': {'save_path': 'ckpts/cub/'}}" 137 | accelerate launch python main.py --function train_unet --config configs/cub_stage2.yml --opt "{'train': {'load_token_path': 'ckpts/cub/tokens/', 'save_path': 'ckpts/cub/'}}" 138 | ``` 139 | 140 | ### 4.4 Inference 141 | Here is a inference example of GenPromp on ImageNet. 142 | 143 | ``` 144 | python main.py --function test --config configs/imagenet_stage2.yml --opt "{'test': {'load_token_path': 'ckpts/imagenet750/tokens/', 'load_unet_path': 'ckpts/imagenet750/unet/', 'save_log_path': 'ckpts/imagnet750/log.txt'}}" 145 | ``` 146 | In the inference stage, the weights of the learned concept tokens are load from `ckpts/imagenet750/tokens/` , the weights of the finetuned UNet are load from `ckpts/imagenet750/unet/` and the log file is saved to `ckpts/imagnet750/log.txt`. Due the random noise added to the tested image, the results might fluctuate within a small range ($\pm$ 0.1). 147 | 148 | Here is a inference example of GenPromp on CUB_200_2011. 149 | ``` 150 | python main.py --function test --config configs/cub_stage2.yml --opt "{'test': {'load_token_path': 'ckpts/cub983/tokens/', 'load_unet_path': 'ckpts/cub983/unet/', 'save_log_path': 'ckpts/cub983/log.txt'}}" 151 | ``` 152 | 153 | ### 4.5 Extra Options 154 | 155 | There are many extra options during training and inference. The default option is configured in the `yml` file. We can use `--opt` to add or override the default option with a parameter dict. Here are some usage of the most commonly used options. 156 | 157 | | Option | Scope | Usage | 158 | | -------| ----- | ----- | 159 | | {'data': {'keep_class': [0, 9]}} | data | keep the data with category id in `[0, 1, 2, 3, ..., 9]` | 160 | | {'train': {'batch_size': 2}} | train | train with batch size `2`. | 161 | | {'train': {'num_train_epochs': 1}} | train | train the model for `1` epoch. | 162 | | {'train': {'save_steps': 200}} | train_unet | save trained UNet every `200` steps. | 163 | | {'train': {'max_train_steps': 600}} | train_unet | terminate training within `600` steps. | 164 | | {'train': {'gradient_accumulation_steps': 2}} | train | batch size `x2` when the memory of GPU is limited. | 165 | | {'train': {'learning_rate': 5.0e-08}} | train | the learning rate is `5.0e-8`. | 166 | | {'train': {'scale_lr': True}} | train | the learning rate is multiplied with batch size if `True`. | 167 | | {'train': {'load_pretrain_path': 'stable-diffusion/'}} | train | the pretrained model is load from `stable-diffusion/`. | 168 | | {'train': {'load_token_path': 'ckpt/tokens/'}} | train | the trained concept tokens are load from `ckpt/tokens/`. | 169 | | {'train': {'save_path': 'ckpt/'}} | train | save the trained weights to `ckpt/`. | 170 | | {'test': {'batch_size': 2}} | test | test with batch size `2`. | 171 | | {'test': {'cam_thr': 0.25}} | test | test with cam threshold `0.25`. | 172 | | {'test': {'combine_ratio': 0.6}} | test | combine ratio between $f_r$ and $f_d$ is `0.6`. | 173 | | {'test': {'load_class_path': 'imagenet_efficientnet.json'}} | test | load classification results from `imagenet_efficientnet.json`. | 174 | | {'test': {'load_pretrain_path': 'stable-diffusion/'}} | test | the pretrained model is load from `stable-diffusion/`. | 175 | | {'test': {'load_token_path': 'ckpt/tokens/'}} | test | the trained concept tokens are load from `ckpt/tokens/`. | 176 | | {'test': {'load_unet_path': 'ckpt/unet/'}} | test | the trained UNet is load from `ckpt/unet/`. | 177 | | {'test': {'save_vis_path': 'ckpt/vis/'}} | test | the visualized predictions are saved to `ckpt/vis/`. | 178 | | {'test': {'save_log_path': 'ckpt/log.txt'}} | test | the log file is saved to `ckpt/log.txt`. | 179 | | {'test': {'eval_mode': 'top1'}} | test | `top1` denotes evaluating the predicted top1 cls category of the test image, `top5` denotes evaluating the predicted top5 cls category of the test image, `gtk` denotes evaluating the gt category of the test image, which can be tested without the classification result. We use `top1` as the default eval mode. | 180 | 181 | These options can be combined by simplely merging the dicts. For example, if you want to evaluate GenPromp with config file `configs/imagenet_stage2.yml`, with categories `[0, 1, 2, ..., 9]`, concept tokens load from `ckpts/imagenet750/tokens/`, UNet load from `ckpts/imagenet750/unet/`, log file of the evaluated metrics saved to `ckpts/imagnet750/log0-9.txt`, combine ratio equals to `0`, visualization results saved to `ckpts/imagenet750/vis`, using the following command: 182 | 183 | ``` 184 | python main.py --function test --config configs/imagenet_stage2.yml --opt "{'data': {'keep_class': [0, 9]}, 'test': {'load_token_path': 'ckpts/imagenet750/tokens/', 'load_unet_path': 'ckpts/imagenet750/unet/', 'save_log_path':'ckpts/imagnet750/log.txt', 'combine_ratio': 0, 'save_vis_path': 'ckpts/imagenet750/vis'}}" 185 | ``` 186 | 187 |
188 | 189 | 190 |
191 | 192 | ## 5. Contacts 193 | If you have any question about our work or this repository, please don't hesitate to contact us by emails or open an issue under this project. 194 | - [zhaoyuzhong20@mails.ucas.ac.cn](zhaoyuzhong20@mails.ucas.ac.cn) 195 | - [wanfang@ucas.ac.cn](wanfang@ucas.ac.cn) 196 | 197 | ## 6. Acknowledgment 198 | 199 | - Part of the code is borrowed from [TS-CAM](https://github.com/vasgaowei/TS-CAM), [diffusers](https://github.com/huggingface/diffusers), and [prompt-to-prompt](https://github.com/google/prompt-to-prompt/), we sincerely thank them for their contributions to the community. Thank [MzeroMiko](https://github.com/MzeroMiko) for better code implementation. 200 | 201 | 202 | ## 7. Citation 203 | 204 | ```text 205 | @article{zhao2023generative, 206 | title={Generative Prompt Model for Weakly Supervised Object Localization}, 207 | author={Zhao, Yuzhong and Ye, Qixiang and Wu, Weijia and Shen, Chunhua and Wan, Fang}, 208 | journal={arXiv preprint arXiv:2307.09756}, 209 | year={2023} 210 | } 211 | ``` 212 | 213 | ```text 214 | @InProceedings{Zhao_2023_ICCV, 215 | author = {Zhao, Yuzhong and Ye, Qixiang and Wu, Weijia and Shen, Chunhua and Wan, Fang}, 216 | title = {Generative Prompt Model for Weakly Supervised Object Localization}, 217 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 218 | month = {October}, 219 | year = {2023}, 220 | pages = {6351-6361} 221 | } 222 | ``` 223 | -------------------------------------------------------------------------------- /assets/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/callsys/GenPromp/c480e0a106d7158da4a4e86cc27441f181996b04/assets/intro.png -------------------------------------------------------------------------------- /assets/ratio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/callsys/GenPromp/c480e0a106d7158da4a4e86cc27441f181996b04/assets/ratio.png -------------------------------------------------------------------------------- /assets/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/callsys/GenPromp/c480e0a106d7158da4a4e86cc27441f181996b04/assets/results.png -------------------------------------------------------------------------------- /assets/visualize.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/callsys/GenPromp/c480e0a106d7158da4a4e86cc27441f181996b04/assets/visualize.png -------------------------------------------------------------------------------- /configs/cub.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "cub" 3 | root: "data/CUB_200_2011/" 4 | keep_class: None 5 | num_workers: 4 6 | resize_size: 512 7 | crop_size: 512 8 | 9 | train: 10 | batch_size: 4 #8 11 | save_steps: None 12 | num_train_epochs: 2 13 | max_train_steps: None 14 | gradient_accumulation_steps: 1 15 | learning_rate: 5.0e-05 16 | scale_lr: True 17 | scale_learning_rate: None 18 | lr_scheduler: "constant" 19 | lr_warmup_steps: 0 20 | adam_beta1: 0.9 21 | adam_beta2: 0.999 22 | adam_weight_decay: 1e-2 23 | adam_epsilon: 1e-08 24 | load_pretrain_path: "ckpts/pretrains/stable-diffusion-v1-4" 25 | load_token_path: None 26 | save_path: "ckpts/cub" 27 | 28 | test: 29 | batch_size: 2 30 | eval_mode: "top1" #["gtk", "top1", "top5"] 31 | cam_thr: 0.23 32 | combine_ratio: 0.6 33 | load_class_path: "ckpts/classification/cub_efficientnetb7.json" 34 | load_pretrain_path: "ckpts/pretrains/stable-diffusion-v1-4" 35 | load_token_path: "ckpts/cub/tokens/" 36 | load_unet_path: None 37 | save_vis_path: None 38 | save_log_path: "ckpts/cub/log.txt" 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /configs/cub_stage2.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "cub" 3 | root: "data/CUB_200_2011/" 4 | keep_class: None 5 | num_workers: 4 6 | resize_size: 512 7 | crop_size: 512 8 | 9 | train: 10 | batch_size: 4 #8 11 | save_steps: 40 12 | num_train_epochs: 100 13 | max_train_steps: 250 14 | gradient_accumulation_steps: 16 15 | learning_rate: 5.0e-06 16 | scale_lr: False 17 | scale_learning_rate: None 18 | lr_scheduler: "constant" 19 | lr_warmup_steps: 0 20 | adam_beta1: 0.9 21 | adam_beta2: 0.999 22 | adam_weight_decay: 1e-2 23 | adam_epsilon: 1e-08 24 | load_pretrain_path: "ckpts/pretrains/stable-diffusion-v1-4" 25 | load_token_path: "ckpts/cub/tokens" 26 | save_path: "ckpts/cub" 27 | 28 | test: 29 | batch_size: 2 30 | eval_mode: "top1" #["gtk", "top1", "top5"] 31 | cam_thr: 0.23 32 | combine_ratio: 0.6 33 | load_class_path: "ckpts/classification/cub_efficientnetb7.json" 34 | load_pretrain_path: "ckpts/pretrains/stable-diffusion-v1-4" 35 | load_token_path: "ckpts/cub/tokens/" 36 | load_unet_path: "ckpts/cub/unet/" 37 | save_vis_path: None 38 | save_log_path: "ckpts/cub/log.txt" 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /configs/imagenet.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "imagenet" 3 | root: "data/ImageNet_ILSVRC2012/" 4 | keep_class: None 5 | num_workers: 4 6 | resize_size: 512 7 | crop_size: 512 8 | 9 | train: 10 | batch_size: 2 #8 11 | save_steps: None 12 | num_train_epochs: 2 13 | max_train_steps: None 14 | gradient_accumulation_steps: 2 15 | learning_rate: 5.0e-05 16 | scale_lr: True 17 | scale_learning_rate: None 18 | lr_scheduler: "constant" 19 | lr_warmup_steps: 0 20 | adam_beta1: 0.9 21 | adam_beta2: 0.999 22 | adam_weight_decay: 1e-2 23 | adam_epsilon: 1e-08 24 | load_pretrain_path: "ckpts/pretrains/stable-diffusion-v1-4" 25 | load_token_path: None 26 | save_path: "ckpts/imagenet" 27 | 28 | test: 29 | batch_size: 2 30 | eval_mode: "top1" #["gtk", "top1", "top5"] 31 | cam_thr: 0.25 32 | combine_ratio: 0.6 33 | load_class_path: "ckpts/classification/imagenet_efficientnet-b7_3rdparty_8xb32-aa-advprop_in1k.json" 34 | load_pretrain_path: "ckpts/pretrains/stable-diffusion-v1-4" 35 | load_token_path: "ckpts/imagenet/tokens/" 36 | load_unet_path: None 37 | save_vis_path: None 38 | save_log_path: "ckpts/imagenet/log.txt" 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /configs/imagenet_stage2.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "imagenet" 3 | root: "data/ImageNet_ILSVRC2012/" 4 | keep_class: None 5 | num_workers: 4 6 | resize_size: 512 7 | crop_size: 512 8 | 9 | train: 10 | batch_size: 2 #8 11 | save_steps: 100 12 | num_train_epochs: 1 13 | max_train_steps: 600 14 | gradient_accumulation_steps: 32 15 | learning_rate: 5.0e-08 16 | scale_lr: True 17 | scale_learning_rate: None 18 | lr_scheduler: "constant" 19 | lr_warmup_steps: 0 20 | adam_beta1: 0.9 21 | adam_beta2: 0.999 22 | adam_weight_decay: 1e-2 23 | adam_epsilon: 1e-08 24 | load_pretrain_path: "ckpts/pretrains/stable-diffusion-v1-4" 25 | load_token_path: "ckpts/imagenet/tokens" 26 | save_path: "ckpts/imagenet" 27 | 28 | test: 29 | batch_size: 2 30 | eval_mode: "top1" #["gtk", "top1", "top5"] 31 | cam_thr: 0.25 32 | combine_ratio: 0.6 33 | load_class_path: "ckpts/classification/imagenet_efficientnet-b7_3rdparty_8xb32-aa-advprop_in1k.json" 34 | load_pretrain_path: "ckpts/pretrains/stable-diffusion-v1-4" 35 | load_token_path: "ckpts/imagenet/tokens/" 36 | load_unet_path: "ckpts/imagenet/unet/" 37 | save_vis_path: None 38 | save_log_path: "ckpts/imagenet/log.txt" 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/callsys/GenPromp/c480e0a106d7158da4a4e86cc27441f181996b04/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import PIL 3 | import json 4 | import pickle 5 | import random 6 | import numpy as np 7 | from PIL import Image 8 | import torch 9 | import torch.utils.checkpoint 10 | from torch.utils.data import Dataset 11 | from torchvision import transforms 12 | from packaging import version 13 | from transformers import CLIPTokenizer 14 | import matplotlib.pyplot as plt 15 | 16 | if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): 17 | PIL_INTERPOLATION = { 18 | "linear": PIL.Image.Resampling.BILINEAR, 19 | "bilinear": PIL.Image.Resampling.BILINEAR, 20 | "bicubic": PIL.Image.Resampling.BICUBIC, 21 | "lanczos": PIL.Image.Resampling.LANCZOS, 22 | "nearest": PIL.Image.Resampling.NEAREST, 23 | } 24 | else: 25 | PIL_INTERPOLATION = { 26 | "linear": PIL.Image.LINEAR, 27 | "bilinear": PIL.Image.BILINEAR, 28 | "bicubic": PIL.Image.BICUBIC, 29 | "lanczos": PIL.Image.LANCZOS, 30 | "nearest": PIL.Image.NEAREST, 31 | } 32 | 33 | 34 | caption_templates = [ 35 | "a photo of a {}", 36 | "a rendering of a {}", 37 | "the photo of a {}", 38 | "a photo of my {}", 39 | "a photo of the {}", 40 | "a photo of one {}", 41 | "a rendition of a {}", 42 | ] 43 | 44 | class BaseDataset(Dataset): 45 | def __init__(self, 46 | root=".", 47 | repeats=1, 48 | crop_size=512, 49 | resize_size=512, 50 | test_mode=False, 51 | keep_class=None, 52 | center_crop=False, 53 | text_encoder=None, 54 | load_class_path=None, 55 | load_token_path=None, 56 | load_pretrain_path=None, 57 | interpolation="bicubic", 58 | token_templates="", 59 | caption_templates=caption_templates, 60 | **kwargs, 61 | ): 62 | self.root = root 63 | self.data_repeat = repeats if not test_mode else 1 64 | self.test_mode = test_mode 65 | self.keep_class = keep_class 66 | self.load_class_path = load_class_path 67 | self.load_token_path = load_token_path 68 | self.load_pretrain_path = load_pretrain_path 69 | self.token_templates = token_templates 70 | self.caption_templates = caption_templates 71 | 72 | self.train_pipelines = self.init_train_pipelines( 73 | center_crop=center_crop, 74 | resize_size=resize_size, 75 | crop_size=crop_size, 76 | interpolation=interpolation, 77 | ) 78 | self.test_pipelines = self.init_test_pipelines( 79 | crop_size=crop_size, 80 | interpolation=interpolation, 81 | ) 82 | 83 | print(f"INFO: {self.__class__.__name__}:\t load data.", flush=True) 84 | self.load_data() 85 | print(f"INFO: {self.__class__.__name__}:\t init samples.", flush=True) 86 | self.init_samples() 87 | print(f"INFO: {self.__class__.__name__}:\t init text encoders.", flush=True) 88 | self.init_text_encoder(text_encoder) 89 | 90 | def load_data(self): 91 | class_file = os.path.join(self.root, 'ILSVRC2012_list', 'LOC_synset_mapping.txt') 92 | self.categories = [] 93 | with open(class_file, 'r') as f: 94 | discriptions = f.readlines() # "n01882714 koala..." 95 | for id, line in enumerate(discriptions): 96 | tag, description = line.strip().split(' ', maxsplit=1) 97 | self.categories.append(description) 98 | self.num_classes = len(self.categories) 99 | 100 | self.names = [] 101 | self.labels = [] 102 | self.bboxes = [] 103 | self.pred_logits = [] 104 | self.image_paths = [] 105 | 106 | data_file = os.path.join(self.root, 'ILSVRC2012_list', 'train.txt') 107 | image_dir = os.path.join(self.root, 'train') 108 | if self.test_mode: 109 | data_file = os.path.join(self.root, 'ILSVRC2012_list', 'val_folder_new.txt') 110 | image_dir = os.path.join(self.root, 'val') 111 | 112 | with open(data_file) as f: 113 | datamappings = f.readlines() # "n01440764/n01440764_10026.JEPG 0" 114 | for id, line in enumerate(datamappings): 115 | info = line.strip().split() 116 | self.names.append(info[0][:-5]) # "n01440764/n01440764_10026" 117 | self.labels.append(int(info[1])) # "0" 118 | if self.test_mode: 119 | self.bboxes.append(np.array(list(map(float, info[2:]))).reshape(-1, 4)) 120 | if self.keep_class is not None: 121 | self.filter_classes() 122 | self.pred_logits = None 123 | if self.test_mode: 124 | with open(self.load_class_path, 'r') as f: 125 | name2result = json.load(f) 126 | self.pred_logits = [torch.Tensor(name2result[name]['pred_scores']) for name in self.names] 127 | self.image_paths = [os.path.join(image_dir, name + '.JPEG') for name in self.names] 128 | self.num_images = len(self.labels) 129 | 130 | def init_samples(self): 131 | # format tokens by category 132 | def select_meta_tokens(cats=[], tokenizer=None): 133 | for c in cats: 134 | token_ids = tokenizer.encode(c, add_special_tokens=False) 135 | if len(token_ids) == 1: # has exist token to indicate input token 136 | return c, token_ids[-1], True 137 | token_ids = tokenizer.encode(cats[0], add_special_tokens=False) 138 | token = tokenizer.decode(token_ids[-1]) # only use the final one 139 | return token, token_ids[-1], False 140 | 141 | tokenizer = CLIPTokenizer.from_pretrained(self.load_pretrain_path, subfolder="tokenizer") 142 | concept_tokens = [self.token_templates.format(id) for id in range(self.num_classes)] 143 | tokenizer.add_tokens(concept_tokens) 144 | 145 | categories = [[t.strip() for t in c.strip().split(',')] for c in self.categories] 146 | categories = [[d.split(' ')[-1] for d in c] for c in categories] 147 | meta_tokens = [select_meta_tokens(c, tokenizer)[0] for c in categories] 148 | 149 | caption_ids_meta_token = [[tokenizer( 150 | template.format(token), 151 | padding="max_length", 152 | truncation=True, 153 | max_length=tokenizer.model_max_length, 154 | return_tensors="pt", 155 | ).input_ids[0] for template in self.caption_templates] for token in meta_tokens] 156 | caption_ids_concept_token = [[tokenizer( 157 | template.format(token), 158 | padding="max_length", 159 | truncation=True, 160 | max_length=tokenizer.model_max_length, 161 | return_tensors="pt", 162 | ).input_ids[0] for template in self.caption_templates] for token in concept_tokens] 163 | 164 | cat2tokens = [dict(meta_token=a, concept_token=b, caption_ids_meta_token=c, caption_ids_concept_token=d) 165 | for a, b, c, d in 166 | zip(meta_tokens, concept_tokens, caption_ids_meta_token, caption_ids_concept_token)] 167 | 168 | # format tokens by sample 169 | load_gt = True 170 | sample2tokens = [cat2tokens[id] for id in self.labels] 171 | 172 | samples = [] 173 | for idx, lt in enumerate(sample2tokens): 174 | scores = [] 175 | ids_concept_token = [] 176 | ids_meta_token = [] 177 | if self.pred_logits is not None: 178 | tensor = torch.Tensor(self.pred_logits[idx]).topk(5) 179 | tmp = [cat2tokens[id] for id in tensor.indices.numpy().tolist()] 180 | scores.extend([float(s) for s in tensor.values.numpy().tolist()]) 181 | ids_concept_token.extend([t['caption_ids_concept_token'] for t in tmp]) 182 | ids_meta_token.extend([t['caption_ids_meta_token'] for t in tmp]) 183 | if load_gt: 184 | scores.append(0.0) 185 | ids_concept_token.append(lt['caption_ids_concept_token']) 186 | ids_meta_token.append(lt['caption_ids_meta_token']) 187 | samples.append(dict( 188 | scores=scores, 189 | caption_ids_concept_token=ids_concept_token, 190 | caption_ids_meta_token=ids_meta_token, 191 | )) 192 | 193 | self.tokenizer = tokenizer 194 | self.cat2tokens = cat2tokens 195 | self.samples = samples 196 | 197 | def init_text_encoder(self, text_encoder): 198 | text_encoder.resize_token_embeddings(len(self.tokenizer)) 199 | if self.test_mode or (self.load_token_path is not None): 200 | text_encoder = self.load_embeddings(text_encoder) 201 | elif not self.test_mode: 202 | text_encoder = self.init_embeddings(text_encoder) 203 | return text_encoder 204 | 205 | def load_embeddings(self, text_encoder): 206 | missing_token = False 207 | token_embeds = text_encoder.get_input_embeddings().weight.data.clone() 208 | for token in self.cat2tokens: 209 | concept_token = token['concept_token'] 210 | concept_token_id = self.tokenizer.encode(concept_token, add_special_tokens=False)[-1] 211 | token_bin = os.path.join(self.load_token_path, f"{concept_token_id}.bin") 212 | if not os.path.exists(token_bin): 213 | missing_token = True 214 | continue 215 | token_embeds[concept_token_id] = torch.load(token_bin)[concept_token] 216 | text_encoder.get_input_embeddings().weight = torch.nn.Parameter(token_embeds) 217 | if missing_token: 218 | print(f"WARN: {self.__class__.__name__}:\t missing token.", flush=True) 219 | return text_encoder 220 | 221 | def init_embeddings(self, text_encoder): 222 | token_embeds = text_encoder.get_input_embeddings().weight.data.clone() 223 | for token in self.cat2tokens: 224 | meta_token_id = self.tokenizer.encode(token['meta_token'], add_special_tokens=False)[0] 225 | concept_token_id = self.tokenizer.encode(token['concept_token'], add_special_tokens=False)[0] 226 | token_embeds[concept_token_id] = token_embeds[meta_token_id] 227 | text_encoder.get_input_embeddings().weight = torch.nn.Parameter(token_embeds) 228 | return text_encoder 229 | 230 | def filter_classes(self): 231 | if isinstance(self.keep_class, int): 232 | mask = np.array(self.labels) == self.keep_class 233 | else: 234 | mask = np.zeros_like(self.labels) 235 | for cls in self.keep_class: 236 | mask = mask + (np.array(self.labels) == cls).astype(float) 237 | mask = mask.astype(bool) 238 | 239 | self.labels = [el for i, el in enumerate(self.labels) if mask[i]] 240 | if self.names is not None and len(self.names) != 0: 241 | self.names = [el for i, el in enumerate(self.names) if mask[i]] 242 | if self.bboxes is not None and len(self.bboxes) != 0: 243 | self.bboxes = [el for i, el in enumerate(self.bboxes) if mask[i]] 244 | 245 | def init_train_pipelines(self, **kwargs): 246 | center_crop = kwargs.pop("center_crop", False) 247 | resize_size = kwargs.pop("resize_size", 512) 248 | crop_size = kwargs.pop("crop_size", 512) 249 | brightness = kwargs.pop("brightness", 0.1) 250 | contrast = kwargs.pop("contrast", 0.1) 251 | saturation = kwargs.pop("saturation", 0.1) 252 | hue = kwargs.pop("hue", 0.1) 253 | interpolation = { 254 | "linear": PIL_INTERPOLATION["linear"], 255 | "bilinear": PIL_INTERPOLATION["bilinear"], 256 | "bicubic": PIL_INTERPOLATION["bicubic"], 257 | "lanczos": PIL_INTERPOLATION["lanczos"], 258 | }[kwargs.pop("interpolation", "bicubic")] 259 | 260 | train_transform = transforms.Compose([ 261 | transforms.Resize((resize_size, resize_size)), 262 | transforms.RandomCrop((crop_size, crop_size)), 263 | transforms.RandomHorizontalFlip(), 264 | transforms.ColorJitter(brightness, contrast, saturation, hue) 265 | ]) 266 | 267 | def train_pipeline(data): 268 | img = data['img'] 269 | if center_crop: 270 | crop = min(img.shape[0], img.shape[1]) 271 | h, w, = ( 272 | img.shape[0], 273 | img.shape[1], 274 | ) 275 | img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] 276 | img = Image.fromarray(img) 277 | img = train_transform(img) 278 | img = img.resize((crop_size, crop_size), resample=interpolation) 279 | img = np.array(img).astype(np.uint8) 280 | img = (img / 127.5 - 1.0).astype(np.float32) 281 | data["img"] = torch.from_numpy(img).permute(2, 0, 1) 282 | return data 283 | 284 | return train_pipeline 285 | 286 | def init_test_pipelines(self, **kwargs): 287 | crop_size = kwargs.pop("crop_size", 512) 288 | interpolation = { 289 | "linear": PIL_INTERPOLATION["linear"], 290 | "bilinear": PIL_INTERPOLATION["bilinear"], 291 | "bicubic": PIL_INTERPOLATION["bicubic"], 292 | "lanczos": PIL_INTERPOLATION["lanczos"], 293 | }[kwargs.pop("interpolation", "bicubic")] 294 | 295 | def test_pipeline(data): 296 | img = data['img'] 297 | bbox = data['gt_bboxes'] 298 | image_width, image_height = data['ori_shape'] 299 | img = Image.fromarray(img) 300 | img = img.resize((crop_size, crop_size), resample=interpolation) 301 | img = np.array(img).astype(np.uint8) 302 | img = (img / 127.5 - 1.0).astype(np.float32) 303 | data["img"] = torch.from_numpy(img).permute(2, 0, 1) 304 | 305 | [x1, y1, x2, y2] = np.split(bbox, 4, 1) 306 | resize_size = crop_size 307 | shift_size = 0 308 | left_bottom_x = np.maximum(x1 / image_width * resize_size - shift_size, 0).astype(int) 309 | left_bottom_y = np.maximum(y1 / image_height * resize_size - shift_size, 0).astype(int) 310 | right_top_x = np.minimum(x2 / image_width * resize_size - shift_size, crop_size - 1).astype(int) 311 | right_top_y = np.minimum(y2 / image_height * resize_size - shift_size, crop_size - 1).astype(int) 312 | 313 | gt_bbox = np.concatenate((left_bottom_x, left_bottom_y, right_top_x, right_top_y), axis=1).reshape(-1) 314 | gt_bbox = " ".join(list(map(str, gt_bbox))) 315 | data['gt_bboxes'] = gt_bbox 316 | 317 | 318 | return data 319 | 320 | return test_pipeline 321 | 322 | def prepare_data(self, idx): 323 | pipelines = self.train_pipelines if not self.test_mode else self.test_pipelines 324 | bboxes = [] if not self.test_mode else self.bboxes[idx] 325 | 326 | name = self.names[idx] 327 | label = self.labels[idx] 328 | image_path = self.image_paths[idx] 329 | image = Image.open(image_path).convert('RGB') 330 | image_size = list(image.size) 331 | img = np.array(image).astype(np.uint8) 332 | data = pipelines(dict( 333 | img=img, ori_shape=image_size, 334 | gt_labels=label, gt_names=name, gt_bboxes=bboxes, 335 | )) 336 | 337 | data = dict( 338 | img=data["img"], 339 | ori_shape=data["ori_shape"], 340 | gt_labels=data["gt_labels"], 341 | gt_bboxes=data["gt_bboxes"], 342 | name=data["gt_names"], 343 | ) 344 | 345 | if self.test_mode: 346 | pred_scores = self.pred_logits 347 | data.update(dict( 348 | pred_logits=pred_scores[idx], 349 | pred_top5_ids=pred_scores[idx].topk(5).indices, 350 | )) 351 | 352 | return data 353 | 354 | def prepare_tokens(self, idx): 355 | sample = self.samples[idx] 356 | 357 | if not self.test_mode: 358 | choice = random.choice(range(len(self.caption_templates))) 359 | sample = dict( 360 | caption_ids_concept_token=sample['caption_ids_concept_token'][-1][choice], 361 | caption_ids_meta_token=sample['caption_ids_meta_token'][-1][choice], 362 | ) 363 | 364 | return sample 365 | 366 | def __len__(self): 367 | return self.num_images * self.data_repeat 368 | 369 | def __getitem__(self, idx): 370 | data = self.prepare_data(idx % self.num_images) 371 | data.update(self.prepare_tokens(idx % self.num_images)) 372 | return data 373 | 374 | class SubDataset(BaseDataset): 375 | def __init__(self, keep_class, dataset): 376 | self.dataset = dataset 377 | 378 | if isinstance(keep_class, int): 379 | keep_class = [keep_class] 380 | self.keep_class = keep_class 381 | 382 | if isinstance(self.keep_class, int): 383 | mask = np.array(self.dataset.labels) == self.keep_class 384 | else: 385 | mask = np.zeros_like(self.dataset.labels) 386 | for cls in self.keep_class: 387 | mask = mask + (np.array(self.dataset.labels) == cls).astype(float) 388 | mask = mask.astype(bool) 389 | 390 | self.names = self.dataset.names 391 | self.image_paths = self.dataset.image_paths 392 | self.data_repeat = self.dataset.data_repeat 393 | self.labels = self.dataset.labels 394 | self.samples = self.dataset.samples 395 | self.caption_templates = self.dataset.caption_templates 396 | self.test_mode = self.dataset.test_mode 397 | self.num_images = self.dataset.num_images 398 | self.train_pipelines = self.dataset.train_pipelines 399 | self.test_pipelines = self.dataset.test_pipelines 400 | 401 | self.names = [el for i, el in enumerate(self.names) if mask[i]] 402 | self.image_paths = [el for i, el in enumerate(self.image_paths) if mask[i]] 403 | self.labels = [el for i, el in enumerate(self.labels) if mask[i]] 404 | self.samples = [el for i, el in enumerate(self.samples) if mask[i]] 405 | self.num_images = len(self.image_paths) 406 | 407 | class ImagenetDataset(BaseDataset): 408 | def __init__(self, 409 | root=".", 410 | repeats=1, 411 | crop_size=512, 412 | resize_size=512, 413 | test_mode=False, 414 | keep_class=None, 415 | center_crop=False, 416 | text_encoder=None, 417 | load_class_path=None, 418 | load_token_path=None, 419 | load_pretrain_path=None, 420 | interpolation="bicubic", 421 | token_templates="", 422 | caption_templates=caption_templates, 423 | **kwargs, 424 | ): 425 | super().__init__(root=root, 426 | repeats=repeats, 427 | crop_size=crop_size, 428 | resize_size=resize_size, 429 | test_mode=test_mode, 430 | keep_class=keep_class, 431 | center_crop=center_crop, 432 | text_encoder=text_encoder, 433 | load_class_path=load_class_path, 434 | load_token_path=load_token_path, 435 | load_pretrain_path=load_pretrain_path, 436 | interpolation=interpolation, 437 | token_templates=token_templates, 438 | caption_templates=caption_templates, 439 | **kwargs) 440 | 441 | def load_data(self): 442 | """ 443 | returns: 444 | self.names: [] 445 | self.labels: [] 446 | self.bboxes: []; (x1,y1,x2,y2) 447 | self.pred_logits: [] | None 448 | self.image_paths: [] 449 | """ 450 | class_file = os.path.join(self.root, 'ILSVRC2012_list', 'LOC_synset_mapping.txt') 451 | self.categories = [] 452 | with open(class_file, 'r') as f: 453 | discriptions = f.readlines() # "n01882714 koala..." 454 | for id, line in enumerate(discriptions): 455 | tag, description = line.strip().split(' ', maxsplit=1) 456 | self.categories.append(description) 457 | self.num_classes = len(self.categories) 458 | 459 | self.names = [] 460 | self.labels = [] 461 | self.bboxes = [] 462 | self.pred_logits = [] 463 | self.image_paths = [] 464 | 465 | data_file = os.path.join(self.root, 'ILSVRC2012_list', 'train.txt') 466 | image_dir = os.path.join(self.root, 'train') 467 | if self.test_mode: 468 | data_file = os.path.join(self.root, 'ILSVRC2012_list', 'val_folder_new.txt') 469 | image_dir = os.path.join(self.root, 'val') 470 | 471 | with open(data_file) as f: 472 | datamappings = f.readlines() # "n01440764/n01440764_10026.JEPG 0" 473 | for id, line in enumerate(datamappings): 474 | info = line.strip().split() 475 | self.names.append(info[0][:-5]) # "n01440764/n01440764_10026" 476 | self.labels.append(int(info[1])) # "0" 477 | if self.test_mode: 478 | self.bboxes.append(np.array(list(map(float, info[2:]))).reshape(-1, 4)) 479 | if self.keep_class is not None: 480 | self.filter_classes() 481 | self.pred_logits = None 482 | if self.test_mode: 483 | with open(self.load_class_path, 'r') as f: 484 | name2result = json.load(f) 485 | self.pred_logits = [torch.Tensor(name2result[name]['pred_scores']) for name in self.names] 486 | self.image_paths = [os.path.join(image_dir, name + '.JPEG') for name in self.names] 487 | self.num_images = len(self.labels) 488 | 489 | class CUBDataset(BaseDataset): 490 | def __init__(self, 491 | root=".", 492 | repeats=1, 493 | crop_size=512, 494 | resize_size=512, 495 | test_mode=False, 496 | keep_class=None, 497 | center_crop=False, 498 | text_encoder=None, 499 | load_class_path=None, 500 | load_token_path=None, 501 | load_pretrain_path=None, 502 | interpolation="bicubic", 503 | token_templates="", 504 | caption_templates=caption_templates, 505 | **kwargs, 506 | ): 507 | super().__init__(root=root, 508 | repeats=repeats, 509 | crop_size=crop_size, 510 | resize_size=resize_size, 511 | test_mode=test_mode, 512 | keep_class=keep_class, 513 | center_crop=center_crop, 514 | text_encoder=text_encoder, 515 | load_class_path=load_class_path, 516 | load_token_path=load_token_path, 517 | load_pretrain_path=load_pretrain_path, 518 | interpolation=interpolation, 519 | token_templates=token_templates, 520 | caption_templates=caption_templates, 521 | **kwargs) 522 | 523 | def load_data(self): 524 | self.categories = ["bird"] * 200 525 | self.num_classes = len(self.categories) 526 | 527 | images_file = os.path.join(self.root, 'images.txt') 528 | labels_file = os.path.join(self.root, 'image_class_labels.txt') 529 | splits_file = os.path.join(self.root, 'train_test_split.txt') 530 | bboxes_file = os.path.join(self.root, 'bounding_boxes.txt') 531 | 532 | with open(images_file, 'r') as f: 533 | lines = f.readlines() 534 | image_list = [line.strip().split(' ')[1] for line in lines] 535 | with open(labels_file, 'r') as f: 536 | lines = f.readlines() 537 | labels_list = [line.strip().split(' ')[1] for line in lines] 538 | with open(splits_file, 'r') as f: 539 | lines = f.readlines() 540 | splits_list = [line.strip().split(' ')[1] for line in lines] 541 | with open(bboxes_file, 'r') as f: 542 | lines = f.readlines() 543 | bboxes_list = [line.strip().split(' ')[1:] for line in lines] 544 | 545 | train_index = [i for i, v in enumerate(splits_list) if v == '1'] 546 | test_index = [i for i, v in enumerate(splits_list) if v == '0'] 547 | index_list = train_index if not self.test_mode else test_index 548 | self.names = [image_list[i] for i in index_list] 549 | self.labels = [int(labels_list[i]) - 1 for i in index_list] 550 | self.bboxes = [np.array(list(map(float, bboxes_list[i]))).reshape(-1, 4) 551 | for i in index_list] 552 | for i in range(len(self.bboxes)): 553 | self.bboxes[i][:, 2:4] = self.bboxes[i][:, 0:2] + self.bboxes[i][:, 2:4] 554 | 555 | if self.keep_class is not None: 556 | self.filter_classes() 557 | self.pred_logits = None 558 | if self.test_mode: 559 | with open(self.load_class_path, 'r') as f: 560 | name2result = json.load(f) 561 | self.pred_logits = [torch.Tensor(name2result[name]['pred_scores']) for name in self.names] 562 | self.image_paths = [os.path.join(self.root, 'images', name) for name in self.names] 563 | self.num_images = len(self.labels) 564 | 565 | 566 | -------------------------------------------------------------------------------- /datasets/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .cam import evaluate_cls_loc, list2acc 2 | 3 | # Evaluation code from TS-CAM 4 | class Evaluator(): 5 | def __init__(self, logfile='log.txt', len_dataloader=1): 6 | self.cls_top1 = [] 7 | self.cls_top5 = [] 8 | self.loc_top1 = [] 9 | self.loc_top5 = [] 10 | self.loc_gt_known = [] 11 | self.top1_loc_right = [] 12 | self.top1_loc_cls = [] 13 | self.top1_loc_mins = [] 14 | self.top1_loc_part = [] 15 | self.top1_loc_more = [] 16 | self.top1_loc_wrong = [] 17 | self.logfile = logfile 18 | self.len_dataloader = len_dataloader 19 | 20 | def __call__(self, input, target, bbox, logits, pad_cams, image_names, cfg, step): 21 | cls_top1_b, cls_top5_b, loc_top1_b, loc_top5_b, loc_gt_known_b, top1_loc_right_b, \ 22 | top1_loc_cls_b, top1_loc_mins_b, top1_loc_part_b, top1_loc_more_b, top1_loc_wrong_b = \ 23 | evaluate_cls_loc(input, target, bbox, logits, pad_cams, image_names, cfg) 24 | self.cls_top1.extend(cls_top1_b) 25 | self.cls_top5.extend(cls_top5_b) 26 | self.loc_top1.extend(loc_top1_b) 27 | self.loc_top5.extend(loc_top5_b) 28 | self.top1_loc_right.extend(top1_loc_right_b) 29 | self.top1_loc_cls.extend(top1_loc_cls_b) 30 | self.top1_loc_mins.extend(top1_loc_mins_b) 31 | self.top1_loc_more.extend(top1_loc_more_b) 32 | self.top1_loc_part.extend(top1_loc_part_b) 33 | self.top1_loc_wrong.extend(top1_loc_wrong_b) 34 | 35 | self.loc_gt_known.extend(loc_gt_known_b) 36 | 37 | if step != 0 and (step % 100 == 0 or step == self.len_dataloader - 1): 38 | str1 = 'Val Epoch: [{0}][{1}/{2}]\t'.format(0, step + 1, self.len_dataloader) 39 | str2 = 'Cls@1:{0:.3f}\tCls@5:{1:.3f}\tLoc@1:{2:.3f}\tLoc@5:{3:.3f}\tLoc_gt:{4:.3f}'.format( 40 | list2acc(self.cls_top1), list2acc(self.cls_top5),list2acc(self.loc_top1), list2acc(self.loc_top5), list2acc(self.loc_gt_known)) 41 | str3 = 'M-ins:{0:.3f}\tPart:{1:.3f}\tMore:{2:.3f}\tRight:{3:.3f}\tWrong:{4:.3f}\tCls:{5:.3f}'.format( 42 | list2acc(self.top1_loc_mins), list2acc(self.top1_loc_part), list2acc(self.top1_loc_more), 43 | list2acc(self.top1_loc_right), list2acc(self.top1_loc_wrong), list2acc(self.top1_loc_cls)) 44 | 45 | if self.logfile is not None: 46 | with open(self.logfile, 'a') as fw: 47 | fw.write('\n'+str1+'\n') 48 | fw.write(str2+'\n') 49 | fw.write(str3+'\n') 50 | 51 | print(str1) 52 | print(str2) 53 | print(str3) 54 | 55 | -------------------------------------------------------------------------------- /datasets/evaluation/cam.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import torch 5 | import matplotlib.pyplot as plt 6 | 7 | def resize_cam(cam, size=(224, 224)): 8 | cam = cv2.resize(cam , (size[0], size[1])) 9 | #cam = cam - cam.min() 10 | #cam = cam / cam.max() 11 | cam_min, cam_max = cam.min(), cam.max() 12 | cam = (cam - cam_min) / (cam_max - cam_min) 13 | return cam 14 | 15 | def blend_cam(image, cam, es_box=[0,0,1,1]): 16 | I = np.zeros_like(cam) 17 | x1, y1, x2, y2 = es_box 18 | I[y1:y2, x1:x2] = 1 19 | cam = cam * I 20 | cam = (cam * 255.).astype(np.uint8) 21 | heatmap = cv2.applyColorMap(cam, cv2.COLORMAP_JET) 22 | blend = image * 0.5 + heatmap * 0.5 23 | 24 | blend = blend.astype(np.uint8) 25 | 26 | return blend, heatmap 27 | 28 | def get_bboxes(cam, cam_thr=0.2): 29 | """ 30 | cam: single image with shape (h, w, 1) 31 | thr_val: float value (0~1) 32 | return estimated bounding box 33 | """ 34 | cam = (cam * 255.).astype(np.uint8) 35 | map_thr = cam_thr * np.max(cam) 36 | 37 | _, thr_gray_heatmap = cv2.threshold(cam, 38 | int(map_thr), 255, 39 | cv2.THRESH_TOZERO) 40 | #thr_gray_heatmap = (thr_gray_heatmap*255.).astype(np.uint8) 41 | 42 | contours, _ = cv2.findContours(thr_gray_heatmap, 43 | cv2.RETR_TREE, 44 | cv2.CHAIN_APPROX_SIMPLE) 45 | if len(contours) != 0: 46 | c = max(contours, key=cv2.contourArea) 47 | x, y, w, h = cv2.boundingRect(c) 48 | estimated_bbox = [x, y, x + w, y + h] 49 | else: 50 | estimated_bbox = [0, 0, 1, 1] 51 | 52 | return estimated_bbox #, thr_gray_heatmap, len(contours) 53 | 54 | def count_max(x): 55 | count_dict = {} 56 | for xlist in x: 57 | for item in xlist: 58 | if item==0: 59 | continue 60 | if item not in count_dict.keys(): 61 | count_dict[item] = 0 62 | count_dict[item] += 1 63 | if count_dict == {}: 64 | return -1 65 | count_dict = sorted(count_dict.items(), key=lambda d:d[1], reverse=True) 66 | return count_dict[0][0] 67 | 68 | def tensor2image(input, image_mean, image_std): 69 | image_mean = torch.reshape(torch.tensor(image_mean), (1, 3, 1, 1)) 70 | image_std = torch.reshape(torch.tensor(image_std), (1, 3, 1, 1)) 71 | image = input * image_std + image_mean 72 | image = image.numpy().transpose(0, 2, 3, 1) 73 | image = image[:, :, :, ::-1] 74 | return image 75 | 76 | def calculate_IOU(boxA, boxB): 77 | xA = max(boxA[0], boxB[0]) 78 | yA = max(boxA[1], boxB[1]) 79 | xB = min(boxA[2], boxB[2]) 80 | yB = min(boxA[3], boxB[3]) 81 | 82 | # compute the area of intersection rectangle 83 | interArea = (xB - xA + 1) * (yB - yA + 1) 84 | 85 | # compute the area of both the prediction and ground-truth 86 | # rectangles 87 | boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1) 88 | boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1) 89 | 90 | # compute the intersection over union by taking the intersection 91 | # area and dividing it by the sum of prediction + ground-truth 92 | # areas - the interesection area 93 | iou = interArea / float(boxAArea + boxBArea - interArea) 94 | 95 | # return the intersection over union value 96 | return iou 97 | 98 | def draw_bbox(image, iou, gt_box, pred_box, gt_score, is_top1=False): 99 | 100 | def draw_bbox(img, box1, box2, color1=(0, 0, 255), color2=(0, 255, 0)): 101 | # for i in range(len(box1)): 102 | # cv2.rectangle(img, (box1[i,0], box1[i,1]), (box1[i,2], box1[i,3]), color1, 4) 103 | cv2.rectangle(img, (box2[0], box2[1]), (box2[2], box2[3]), color2, 4) 104 | return img 105 | 106 | def mark_target(img, text='target', pos=(25, 25), size=2): 107 | cv2.putText(img, text, pos, cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), size) 108 | return img 109 | 110 | boxed_image = image.copy() 111 | 112 | # draw bbox on image 113 | boxed_image = draw_bbox(boxed_image, gt_box, pred_box) 114 | 115 | # mark the iou 116 | # mark_target(boxed_image, '%.1f' % (iou * 100), (140, 30), 2) 117 | # mark_target(boxed_image, 'IOU%.2f' % (iou), (80, 30), 2) 118 | # # mark the top1 119 | # if is_top1: 120 | # mark_target(boxed_image, 'Top1', (10, 30)) 121 | # mark_target(boxed_image, 'GT_Score%.2f' % (gt_score), (10, 200), 2) 122 | 123 | return boxed_image 124 | 125 | def evaluate_cls_loc(input, cls_label, bbox_label, logits, cams, image_names, config): 126 | """ 127 | :param input: input tensors of the model 128 | :param cls_label: class label 129 | :param bbox_label: bounding box label 130 | :param logits: classification scores 131 | :param cams: cam of all the classes 132 | :param image_names: names of images 133 | :param cfg: configurations 134 | :param epoch: epoch 135 | :return: evaluate results 136 | """ 137 | cls_top1 = [] 138 | cls_top5 = [] 139 | loc_top1 = [] 140 | loc_top5 = [] 141 | loc_gt_known = [] 142 | top1_loc_right = [] 143 | top1_loc_cls = [] 144 | top1_loc_mins = [] 145 | top1_loc_part = [] 146 | top1_loc_more = [] 147 | top1_loc_wrong = [] 148 | 149 | # label, top1 and top5 results 150 | cls_label = cls_label.tolist() 151 | cls_scores = logits.tolist() 152 | _, top1_idx = logits.topk(1, 1, True, True) 153 | top1_idx = top1_idx.tolist() 154 | _, top5_idx = logits.topk(5, 1, True, True) 155 | top5_idx = top5_idx.tolist() 156 | 157 | batch = cams.shape[0] 158 | image_mean = [127.5, 127.5, 127.5] 159 | image_std = [127.5, 127.5, 127.5] 160 | image = tensor2image(input.clone().detach().cpu(), image_mean, image_std).astype(np.uint8) 161 | 162 | for b in range(batch): 163 | gt_bbox = bbox_label[b].strip().split(' ') 164 | gt_bbox = list(map(float, gt_bbox)) 165 | crop_size = config["data"]["test"]["dataset"]["crop_size"] 166 | top_bboxes, top_mask=get_topk_boxes(top5_idx[b], cams[b], crop_size=crop_size, threshold=config["test"]["cam_thr"]) 167 | topk_cls, topk_loc, wrong_details=cls_loc_err(top_bboxes, cls_label[b], gt_bbox, topk=(1,5)) 168 | cls_top1_b, cls_top5_b = topk_cls 169 | loc_top1_b, loc_top5_b = topk_loc 170 | cls_top1.append(cls_top1_b) 171 | cls_top5.append(cls_top5_b) 172 | loc_top1.append(loc_top1_b) 173 | loc_top5.append(loc_top5_b) 174 | cls_wrong, multi_instances, region_part, region_more, region_wrong = wrong_details 175 | right = 1 - (cls_wrong + multi_instances + region_part + region_more + region_wrong) 176 | top1_loc_right.append(right) 177 | top1_loc_cls.append(cls_wrong) 178 | top1_loc_mins.append(multi_instances) 179 | top1_loc_part.append(region_part) 180 | top1_loc_more.append(region_more) 181 | top1_loc_wrong.append(region_wrong) 182 | # gt_known 183 | # mean top k 184 | cam_b = cams[b, [cls_label[b]], :, :] 185 | cam_b = torch.mean(cam_b, dim=0, keepdim=True) 186 | 187 | cam_b = cam_b.detach().cpu().numpy().transpose(1, 2, 0) 188 | 189 | # Resize and Normalize CAM 190 | cam_b = resize_cam(cam_b, size=(crop_size, crop_size)) 191 | 192 | # Estimate BBOX 193 | estimated_bbox = get_bboxes(cam_b, cam_thr=config["test"]["cam_thr"]) 194 | 195 | # Calculate IoU 196 | gt_box_cnt = len(gt_bbox) // 4 197 | max_iou = 0 198 | for i in range(gt_box_cnt): 199 | gt_box = gt_bbox[i * 4:(i + 1) * 4] 200 | iou_i = cal_iou(estimated_bbox, gt_box) 201 | if iou_i > max_iou: 202 | max_iou = iou_i 203 | 204 | iou = max_iou 205 | # iou = calculate_IOU(bbox_label[b].numpy(), estimated_bbox) 206 | 207 | # print('cam_b shape', cam_b.shape, 'cam_b max', cam_b.max(), 'cam_b min', cam_b.min(), 'thre', cfg.MODEL.CAM_THR, 'iou ', iou) 208 | #if iou < 0.5: 209 | # pdb.set_trace() 210 | # gt known 211 | if iou >= 0.5: 212 | loc_gt_known.append(1) 213 | else: 214 | loc_gt_known.append(0) 215 | 216 | # Get blended image 217 | box = [0, 0, 512, 512] 218 | blend, heatmap = blend_cam(image[b], cam_b, estimated_bbox) 219 | # Get boxed image 220 | gt_score = cls_scores[b][top1_idx[b][0]] # score of gt class 221 | boxed_image = draw_bbox(blend, iou, np.array(gt_bbox).reshape(-1,4).astype(int), estimated_bbox, gt_score, False) 222 | 223 | # save result 224 | save_vis_path = config["test"]["save_vis_path"] 225 | if save_vis_path is not None: 226 | image_name = image_names[b] 227 | boxed_image_dir = save_vis_path 228 | if not os.path.exists(boxed_image_dir): 229 | os.mkdir(boxed_image_dir) 230 | class_dir = os.path.join(boxed_image_dir, image_name.split('/')[0]) 231 | save_path = os.path.join(class_dir, image_name.split('/')[-1] + '.jpg') 232 | 233 | if not os.path.exists(class_dir): 234 | os.mkdir(class_dir) 235 | cv2.imwrite(save_path, boxed_image) 236 | 237 | return cls_top1, cls_top5, loc_top1, loc_top5, loc_gt_known, top1_loc_right, top1_loc_cls, top1_loc_mins, \ 238 | top1_loc_part, top1_loc_more, top1_loc_wrong 239 | 240 | def get_topk_boxes(cls_inds, cam_map, crop_size, topk=(1, 5), threshold=0.2, ): 241 | maxk_boxes = [] 242 | maxk_maps = [] 243 | for cls in cls_inds: 244 | cam_map_ = cam_map[[cls], :, :] 245 | cam_map_ = cam_map_.detach().cpu().numpy().transpose(1, 2, 0) 246 | # Resize and Normalize CAM 247 | cam_map_ = resize_cam(cam_map_, size=(crop_size, crop_size)) 248 | maxk_maps.append(cam_map_.copy()) 249 | 250 | # Estimate BBOX 251 | estimated_bbox = get_bboxes(cam_map_, cam_thr=threshold) 252 | maxk_boxes.append([cls] + estimated_bbox) 253 | 254 | result = [maxk_boxes[:k] for k in topk] 255 | 256 | return result, maxk_maps 257 | 258 | def cls_loc_err(topk_boxes, gt_label, gt_boxes, topk=(1,), iou_th=0.5): 259 | assert len(topk_boxes) == len(topk) 260 | gt_boxes = gt_boxes 261 | gt_box_cnt = len(gt_boxes) // 4 262 | topk_loc = [] 263 | topk_cls = [] 264 | for topk_box in topk_boxes: 265 | loc_acc = 0 266 | cls_acc = 0 267 | for cls_box in topk_box: 268 | max_iou = 0 269 | max_gt_id = 0 270 | for i in range(gt_box_cnt): 271 | gt_box = gt_boxes[i*4:(i+1)*4] 272 | iou_i = cal_iou(cls_box[1:], gt_box) 273 | if iou_i> max_iou: 274 | max_iou = iou_i 275 | max_gt_id = i 276 | if len(topk_box) == 1: 277 | wrong_details = get_badcase_detail(cls_box, gt_boxes, gt_label, max_iou, max_gt_id) 278 | if cls_box[0] == gt_label: 279 | cls_acc = 1 280 | if cls_box[0] == gt_label and max_iou > iou_th: 281 | loc_acc = 1 282 | break 283 | topk_loc.append(float(loc_acc)) 284 | topk_cls.append(float(cls_acc)) 285 | return topk_cls, topk_loc, wrong_details 286 | 287 | def cal_iou(box1, box2, method='iou'): 288 | """ 289 | support: 290 | 1. box1 and box2 are the same shape: [N, 4] 291 | 2. 292 | :param box1: 293 | :param box2: 294 | :return: 295 | """ 296 | box1 = np.asarray(box1, dtype=float) 297 | box2 = np.asarray(box2, dtype=float) 298 | if box1.ndim == 1: 299 | box1 = box1[np.newaxis, :] 300 | if box2.ndim == 1: 301 | box2 = box2[np.newaxis, :] 302 | 303 | iw = np.minimum(box1[:, 2], box2[:, 2]) - np.maximum(box1[:, 0], box2[:, 0]) + 1 304 | ih = np.minimum(box1[:, 3], box2[:, 3]) - np.maximum(box1[:, 1], box2[:, 1]) + 1 305 | 306 | i_area = np.maximum(iw, 0.0) * np.maximum(ih, 0.0) 307 | box1_area = (box1[:, 2] - box1[:, 0] + 1) * (box1[:, 3] - box1[:, 1] + 1) 308 | box2_area = (box2[:, 2] - box2[:, 0] + 1) * (box2[:, 3] - box2[:, 1] + 1) 309 | 310 | if method == 'iog': 311 | iou_val = i_area / (box2_area) 312 | elif method == 'iob': 313 | iou_val = i_area / (box1_area) 314 | else: 315 | iou_val = i_area / (box1_area + box2_area - i_area) 316 | return iou_val 317 | 318 | def get_badcase_detail(top1_bbox, gt_bboxes, gt_label, max_iou, max_gt_id): 319 | cls_wrong = 0 320 | multi_instances = 0 321 | region_part = 0 322 | region_more = 0 323 | region_wrong = 0 324 | 325 | pred_cls = top1_bbox[0] 326 | pred_bbox = top1_bbox[1:] 327 | 328 | if not int(pred_cls) == gt_label: 329 | cls_wrong = 1 330 | return cls_wrong, multi_instances, region_part, region_more, region_wrong 331 | 332 | if max_iou > 0.5: 333 | return 0, 0, 0, 0, 0 334 | 335 | # multi_instances error 336 | gt_box_cnt = len(gt_bboxes) // 4 337 | if gt_box_cnt > 1: 338 | iogs = [] 339 | for i in range(gt_box_cnt): 340 | gt_box = gt_bboxes[i * 4:(i + 1) * 4] 341 | iog = cal_iou(pred_bbox, gt_box, method='iog') 342 | iogs.append(iog) 343 | if sum(np.array(iogs) > 0.3)> 1: 344 | multi_instances = 1 345 | return cls_wrong, multi_instances, region_part, region_more, region_wrong 346 | # -region part error 347 | iog = cal_iou(pred_bbox, gt_bboxes[max_gt_id*4:(max_gt_id+1)*4], method='iog') 348 | iob = cal_iou(pred_bbox, gt_bboxes[max_gt_id*4:(max_gt_id+1)*4], method='iob') 349 | if iob >0.5: 350 | region_part = 1 351 | return cls_wrong, multi_instances, region_part, region_more, region_wrong 352 | if iog >= 0.7: 353 | region_more = 1 354 | return cls_wrong, multi_instances, region_part, region_more, region_wrong 355 | region_wrong = 1 356 | return cls_wrong, multi_instances, region_part, region_more, region_wrong 357 | 358 | def accuracy(output, target, topk=(1,)): 359 | """ Computes the precision@k for the specified values of k 360 | :param output: tensor of shape B x K, predicted logits of image from model 361 | :param target: tensor of shape B X 1, ground-truth logits of image 362 | :param topk: top predictions 363 | :return: list of precision@k 364 | """ 365 | maxk = max(topk) 366 | batch_size = target.size(0) 367 | 368 | _, pred = output.topk(maxk, 1, True, True) 369 | pred = pred.t() 370 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 371 | 372 | res = [] 373 | for k in topk: 374 | correct_k = correct[:k].reshape(-1).float().sum(0) 375 | res.append(correct_k.mul_(100.0 / batch_size)) 376 | return res 377 | 378 | def list2acc(results_list): 379 | """ 380 | :param results_list: list contains 0 and 1 381 | :return: accuarcy 382 | """ 383 | accuarcy = results_list.count(1)/len(results_list) 384 | return accuarcy -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import yaml 5 | import argparse 6 | import tqdm 7 | import copy 8 | import itertools 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | import torch.utils.checkpoint 15 | from torch.utils.data import Dataset 16 | from accelerate import Accelerator 17 | from accelerate.utils import set_seed 18 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 19 | from diffusers.optimization import get_scheduler 20 | from transformers import CLIPTextModel, CLIPTokenizer 21 | 22 | from models.attn import AttentionStore 23 | from datasets.base import CUBDataset, ImagenetDataset, SubDataset 24 | from datasets.evaluation import Evaluator 25 | 26 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 27 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 28 | 29 | DATASETS = dict(cub=CUBDataset, imagenet=ImagenetDataset) 30 | OPTIMIZER = dict(AdamW=torch.optim.AdamW) 31 | 32 | def set_env(benchmark=True): 33 | # get config 34 | parser = argparse.ArgumentParser(description=globals()["__doc__"]) 35 | parser.add_argument("--function", type=str, default="test", required=True) 36 | parser.add_argument("--config", type=str, default="configs/cub.yml", help="Config file", required=True) 37 | parser.add_argument("--opt", type=str, default="dict()", help="Override options.") 38 | parser.add_argument("--seed", type=int, default=1234, help="Random seed") 39 | args = parser.parse_args() 40 | 41 | def load_yaml_conf(config): 42 | def dict_fix(d): 43 | for k, v in d.items(): 44 | if isinstance(v, dict): 45 | v = dict_fix(v) 46 | elif v == "None": 47 | v = None 48 | elif v == "False": 49 | v = False 50 | elif v == "True": 51 | v = True 52 | elif isinstance(v, str): 53 | v = float(v) if v.isdigit() else v 54 | d[k] = v 55 | return d 56 | 57 | assert os.path.exists(config), "ERROR: no config file found." 58 | with open(config, "r") as f: 59 | config = yaml.safe_load(f) 60 | config = dict_fix(config) 61 | return config 62 | 63 | def format_conf(config): 64 | if config["train"]["max_train_steps"] is None: 65 | config["train"]["max_train_steps"] = 0 66 | 67 | keep_class = config["data"]["keep_class"] 68 | if keep_class is not None: 69 | if isinstance(keep_class, int): 70 | keep_class = [keep_class] 71 | elif isinstance(keep_class, list) and len(keep_class)==2: 72 | keep_class = list(range(keep_class[0], keep_class[1]+1)) 73 | else: 74 | assert isinstance(keep_class, list) 75 | 76 | data = dict( 77 | train=dict( 78 | batch_size=config["train"]["batch_size"], 79 | shuffle=True, 80 | dataset=dict(type=config["data"]["dataset"], 81 | root=config["data"]["root"], 82 | keep_class=keep_class, 83 | crop_size=config["data"]["crop_size"], 84 | resize_size=config["data"]["resize_size"], 85 | load_pretrain_path=config["train"]["load_pretrain_path"], 86 | load_token_path=config["train"]["load_token_path"], 87 | save_path=config["train"]["save_path"], 88 | ), 89 | ), 90 | test=dict( 91 | batch_size=config["test"]["batch_size"], 92 | shuffle=False, 93 | dataset=dict(type=config["data"]["dataset"], 94 | root=config["data"]["root"], 95 | keep_class=keep_class, 96 | crop_size=config["data"]["crop_size"], 97 | resize_size=config["data"]["resize_size"], 98 | load_pretrain_path=config["test"]["load_pretrain_path"], 99 | load_class_path=config["test"]["load_class_path"], 100 | load_token_path=config["test"]["load_token_path"], 101 | ), 102 | ), 103 | ) 104 | 105 | optimizer = dict( 106 | type="AdamW", 107 | lr=config["train"]["learning_rate"], 108 | betas=(config["train"]["adam_beta1"], config["train"]["adam_beta2"]), 109 | weight_decay=eval(config["train"]["adam_weight_decay"]), 110 | eps=eval(config["train"]["adam_epsilon"]), 111 | ) 112 | 113 | lr_scheduler = dict( 114 | type=config["train"]["lr_scheduler"], 115 | num_warmup_steps=config["train"]["lr_warmup_steps"] * config["train"]["gradient_accumulation_steps"], 116 | num_training_steps=config["train"]["max_train_steps"] * config["train"]["gradient_accumulation_steps"], 117 | ) 118 | 119 | accelerator = dict( 120 | # logging_dir=os.path.join(config["train"]["save_path"], "logs"), 121 | gradient_accumulation_steps=config["train"]["gradient_accumulation_steps"], 122 | mixed_precision="no", 123 | log_with=None 124 | ) 125 | 126 | model = dict( 127 | 128 | ) 129 | 130 | train = dict( 131 | epochs=config["train"]["num_train_epochs"], 132 | scale_lr=config["train"]["scale_lr"], 133 | push_to_hub=False, 134 | save_path=config["train"]["save_path"], 135 | save_step=config["train"]["save_steps"], 136 | load_pretrain_path=config["train"]["load_pretrain_path"], 137 | load_token_path=config["train"]["load_token_path"], 138 | ) 139 | 140 | test = dict( 141 | cam_thr=config["test"]["cam_thr"], 142 | eval_mode=config["test"]["eval_mode"], 143 | combine_ratio=config["test"]["combine_ratio"], 144 | load_pretrain_path=config["test"]["load_pretrain_path"], 145 | load_token_path=config["test"]["load_token_path"], 146 | load_unet_path=config["test"]["load_unet_path"] if config["test"]["load_unet_path"] is not None else os.path.join(config["test"]["load_pretrain_path"], "unet"), 147 | save_vis_path=config["test"]["save_vis_path"], 148 | save_log_path=config["test"]["save_log_path"] 149 | ) 150 | 151 | config = dict( 152 | model=model, 153 | data=data, 154 | optimizer=optimizer, 155 | lr_scheduler=lr_scheduler, 156 | accelerator=accelerator, 157 | train=train, 158 | test=test, 159 | ) 160 | 161 | return config 162 | 163 | def merge_conf(config, extra_config): 164 | for key, value in extra_config.items(): 165 | if not isinstance(value, dict): 166 | config[key] = value 167 | else: 168 | merge_value = merge_conf(config.get(key, dict()), value) 169 | config[key] = merge_value 170 | return config 171 | 172 | # parse config file 173 | config = load_yaml_conf(args.config) 174 | 175 | # override options 176 | extra_config = eval(args.opt) 177 | merge_conf(config, extra_config) 178 | 179 | config = format_conf(config) 180 | 181 | # set random seed 182 | if args.seed is not None: 183 | set_seed(args.seed, device_specific=False) 184 | 185 | # set benchmark 186 | torch.backends.cudnn.benchmark = benchmark 187 | 188 | return args, config 189 | 190 | def test(config): 191 | split = "test" 192 | device = 'cuda' 193 | torch_dtype = torch.float16 194 | 195 | eval_mode = config["test"]["eval_mode"] 196 | load_pretrain_path = config["test"]["load_pretrain_path"] 197 | keep_class = config["data"][split]["dataset"]["keep_class"] 198 | combine_ratio = config["test"]["combine_ratio"] 199 | save_log_path = config["test"]["save_log_path"] 200 | load_unet_path = config["test"]["load_unet_path"] 201 | batch_size = config["data"][split]["batch_size"] 202 | 203 | text_encoder = CLIPTextModel.from_pretrained(load_pretrain_path, subfolder="text_encoder").to(device) 204 | vae = AutoencoderKL.from_pretrained(load_pretrain_path, subfolder="vae", torch_dtype=torch_dtype).to(device) 205 | unet = UNet2DConditionModel.from_pretrained(load_unet_path, torch_dtype=torch_dtype).to(device) 206 | noise_scheduler = DDPMScheduler.from_pretrained(load_pretrain_path, subfolder="scheduler") 207 | 208 | data_configs = config["data"][split].copy() 209 | dataset_config = data_configs.pop("dataset", None) 210 | dataset_type = dataset_config.pop("type", "imagenet") 211 | dataset_config.update(dict( 212 | test_mode=(split == "val" or split == "test"), 213 | text_encoder=text_encoder)) 214 | dataset = DATASETS[dataset_type](**dataset_config) 215 | dataloader = torch.utils.data.DataLoader(dataset, **data_configs) 216 | 217 | vae.eval() 218 | unet.eval() 219 | text_encoder.eval() 220 | 221 | evaluator = Evaluator(logfile=save_log_path, len_dataloader=len(dataloader)) 222 | controller = AttentionStore(batch_size=batch_size) 223 | AttentionStore.register_attention_control(controller, unet) 224 | 225 | if keep_class is None: 226 | keep_class = list(range(dataset.num_classes)) 227 | print(f"INFO: Test Save:\t [log: {str(config['test']['save_log_path'])}] [vis: {str(config['test']['save_vis_path'])}]", flush=True) 228 | print(f"INFO: Test CheckPoint:\t [token: {str(config['test']['load_token_path'])}] [unet: {str(config['test']['load_unet_path'])}]", flush=True) 229 | print(f"INFO: Test Class [{keep_class[0]}-{keep_class[-1]}]:\t [dataset: {dataset_type}] [eval mode: {eval_mode}] " 230 | f"[cam thr: {config['test']['cam_thr']}] [combine ratio: {combine_ratio}]", flush=True) 231 | 232 | for step, data in enumerate(tqdm.tqdm(dataloader)): 233 | if eval_mode == "gtk": 234 | image = data["img"].to(torch_dtype).to(device) 235 | latents = vae.encode(image).latent_dist.sample().detach() * 0.18215 236 | noise = torch.randn(latents.shape).to(latents.device) 237 | timesteps = torch.randint( 238 | 0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=latents.device 239 | ).long() 240 | 241 | representative_embeddings = [text_encoder(ids.to(device))[0] for ids in data["caption_ids_concept_token"][-1]] 242 | representative_embeddings = sum(representative_embeddings)/len(data["caption_ids_concept_token"][-1]) 243 | 244 | discriminative_embeddings = [text_encoder(ids.to(device))[0] for ids in data["caption_ids_meta_token"][-1]] 245 | discriminative_embeddings = sum(discriminative_embeddings) / len(data["caption_ids_meta_token"][-1]) 246 | combine_embeddings = combine_ratio * representative_embeddings + (1-combine_ratio) * discriminative_embeddings 247 | combine_embeddings = combine_embeddings.to(torch_dtype) 248 | 249 | for t in [0, 99]: 250 | timesteps = torch.ones_like(timesteps) * t 251 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps).to(torch_dtype) 252 | noise_pred = unet(noisy_latents, timesteps, combine_embeddings).sample 253 | 254 | cams = controller.diffusion_cam(idx=5) 255 | 256 | controller.reset() 257 | 258 | cams_tensor = torch.from_numpy(cams).to(device).unsqueeze(dim=1) 259 | pad_cams = cams_tensor.repeat(1, dataset.num_classes, 1, 1) 260 | 261 | evaluator(data["img"], data["gt_labels"], data['gt_bboxes'], data["pred_logits"], pad_cams, data["name"], config, step) 262 | elif eval_mode == "top1": 263 | cams_all = [] 264 | 265 | image = data["img"].to(torch_dtype).to(device) 266 | latents = vae.encode(image).latent_dist.sample().detach() * 0.18215 267 | noise = torch.randn(latents.shape).to(latents.device) 268 | timesteps = torch.randint( 269 | 0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=latents.device 270 | ).long() 271 | 272 | for top_idx in [0, 5]: 273 | 274 | # save inference cost 275 | if top_idx == 5: 276 | top1_idx = data["pred_top5_ids"][:, 0] 277 | gtk_idx = torch.LongTensor(data["gt_labels"]).to(torch.int64) 278 | if torch.all(top1_idx == gtk_idx): 279 | cams_all = cams_all*2 280 | break 281 | 282 | representative_embeddings = [text_encoder(ids.to(device))[0] for ids in 283 | data["caption_ids_concept_token"][top_idx]] 284 | representative_embeddings = sum(representative_embeddings) / len(data["caption_ids_concept_token"][top_idx]) 285 | 286 | discriminative_embeddings = [text_encoder(ids.to(device))[0] for ids in 287 | data["caption_ids_meta_token"][top_idx]] 288 | discriminative_embeddings = sum(discriminative_embeddings) / len(data["caption_ids_meta_token"][top_idx]) 289 | combine_embeddings = combine_ratio * representative_embeddings + ( 290 | 1 - combine_ratio) * discriminative_embeddings 291 | combine_embeddings = combine_embeddings.to(torch_dtype) 292 | 293 | for t in [0, 99]: 294 | timesteps = torch.ones_like(timesteps) * t 295 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps).to(torch_dtype) 296 | noise_pred = unet(noisy_latents, timesteps, combine_embeddings).sample 297 | 298 | controller.batch_size = len(noise_pred) 299 | cams = controller.diffusion_cam(idx=5) 300 | controller.reset() 301 | cams_all.append(cams) 302 | 303 | cams_tensor = torch.from_numpy(cams_all[0]).to(device).unsqueeze(dim=1) 304 | pad_cams = cams_tensor.repeat(1, dataset.num_classes, 1, 1) 305 | for i, pad_cam in enumerate(pad_cams): 306 | pad_cam[data["gt_labels"][i]] = torch.from_numpy(cams_all[-1])[i] 307 | 308 | evaluator(data["img"], data["gt_labels"], data['gt_bboxes'], data["pred_logits"], pad_cams, data["name"], config, step) 309 | elif eval_mode == "top5": 310 | cams_all = [] 311 | 312 | image = data["img"].to(torch_dtype).to(device) 313 | latents = vae.encode(image).latent_dist.sample().detach() * 0.18215 314 | noise = torch.randn(latents.shape).to(latents.device) 315 | timesteps = torch.randint( 316 | 0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=latents.device 317 | ).long() 318 | 319 | for top_idx in range(6): 320 | representative_embeddings = [text_encoder(ids.to(device))[0] for ids in 321 | data["caption_ids_concept_token"][top_idx]] 322 | representative_embeddings = sum(representative_embeddings) / len(data["caption_ids_concept_token"][top_idx]) 323 | 324 | discriminative_embeddings = [text_encoder(ids.to(device))[0] for ids in 325 | data["caption_ids_meta_token"][top_idx]] 326 | discriminative_embeddings = sum(discriminative_embeddings) / len(data["caption_ids_meta_token"][top_idx]) 327 | combine_embeddings = combine_ratio * representative_embeddings + ( 328 | 1 - combine_ratio) * discriminative_embeddings 329 | combine_embeddings = combine_embeddings.to(torch_dtype) 330 | 331 | for t in [0, 99]: 332 | timesteps = torch.ones_like(timesteps) * t 333 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps).to(torch_dtype) 334 | noise_pred = unet(noisy_latents, timesteps, combine_embeddings).sample 335 | 336 | cams = controller.diffusion_cam(idx=5) 337 | controller.reset() 338 | cams_all.append(cams) 339 | 340 | cams = torch.from_numpy(np.stack([cam for cam in cams_all[:-1]], 1)) 341 | pad_cams = torch.zeros((batch_size, dataset.num_classes, *cams.shape[-2:])) 342 | for i, pad_cam in enumerate(pad_cams): 343 | pad_cam[data["pred_top5_ids"][i]] = cams[i] 344 | pad_cam[data["gt_labels"][i]] = torch.from_numpy(cams_all[-1])[i] 345 | 346 | evaluator(data["img"], data["gt_labels"], data['gt_bboxes'], data["pred_logits"], pad_cams, data["name"], config, step) 347 | else: 348 | raise ValueError("select eval_mode in [gtk, top1, top5].") 349 | 350 | def train_token(config): 351 | split = "train" 352 | device = 'cuda' 353 | torch_dtype = torch.float32 354 | 355 | load_pretrain_path = config["train"]["load_pretrain_path"] 356 | keep_class = config["data"][split]["dataset"]["keep_class"] 357 | 358 | text_encoder = CLIPTextModel.from_pretrained(load_pretrain_path, subfolder="text_encoder", torch_dtype=torch_dtype) 359 | vae = AutoencoderKL.from_pretrained(load_pretrain_path, subfolder="vae", torch_dtype=torch_dtype) 360 | unet = UNet2DConditionModel.from_pretrained(load_pretrain_path, subfolder="unet", torch_dtype=torch_dtype) 361 | noise_scheduler = DDPMScheduler.from_pretrained(load_pretrain_path, subfolder="scheduler") 362 | 363 | def freeze_params(params): 364 | for param in params: 365 | param.requires_grad = False 366 | 367 | freeze_params(itertools.chain( 368 | vae.parameters(), 369 | unet.parameters(), 370 | text_encoder.text_model.encoder.parameters(), 371 | text_encoder.text_model.final_layer_norm.parameters(), 372 | text_encoder.text_model.embeddings.position_embedding.parameters(), 373 | )) 374 | 375 | data_configs = config["data"][split].copy() 376 | dataset_config = data_configs.pop("dataset", None) 377 | dataset_type = dataset_config.pop("type", "imagenet") 378 | dataset_config.update(dict( 379 | test_mode=(split == "val" or split == "test"), 380 | text_encoder=text_encoder)) 381 | dataset = DATASETS[dataset_type](**dataset_config) 382 | 383 | def train_loop(config, class_id, text_encoder=None, unet=None, vae=None, dataset=None): 384 | 385 | def get_grads_to_zero(class_id, dataset): 386 | tokenizer = dataset.tokenizer 387 | index_grads_to_zero = torch.ones((len(tokenizer))).bool() 388 | concept_token = dataset.cat2tokens[class_id]["concept_token"] 389 | token_id = tokenizer.encode(concept_token, add_special_tokens=False)[0] 390 | index_grads_to_zero[token_id] = False 391 | return index_grads_to_zero, token_id 392 | 393 | index_grads_to_zero, token_id = get_grads_to_zero(class_id, dataset) 394 | config = copy.deepcopy(config) 395 | 396 | subdataset = SubDataset(keep_class=[class_id], dataset=dataset) 397 | data_configs = config["data"][split].copy() 398 | data_configs.pop("dataset", None) 399 | dataloader = torch.utils.data.DataLoader(subdataset, **data_configs) 400 | 401 | accelerator_config = config.get("accelerator", None) 402 | accelerator = Accelerator(**accelerator_config) 403 | if accelerator.is_main_process: 404 | accelerator.init_trackers("wsol", config=config) 405 | 406 | save_path = config['train']['save_path'] 407 | batch_size = config['data']['train']['batch_size'] 408 | gradient_accumulation_steps = config['accelerator']['gradient_accumulation_steps'] 409 | num_train_epochs = config['train']['epochs'] 410 | max_train_steps = config['lr_scheduler']['num_training_steps'] // gradient_accumulation_steps 411 | total_batch_size = batch_size * accelerator.num_processes * gradient_accumulation_steps 412 | 413 | if config['train']['scale_lr']: 414 | config['optimizer']['lr'] = config['optimizer']['lr'] * total_batch_size 415 | if (max_train_steps is None) or (max_train_steps == 0): 416 | num_update_steps_per_epoch = math.ceil(len(dataloader) / gradient_accumulation_steps) 417 | max_train_steps = num_train_epochs * num_update_steps_per_epoch // accelerator.num_processes 418 | 419 | config['lr_scheduler']['num_training_steps'] = max_train_steps * gradient_accumulation_steps 420 | optimizer_config = config.get("optimizer", None) 421 | optimizer_type = optimizer_config.pop("type", "AdamW") 422 | lr_scheduler_config = config.get("lr_scheduler", None) 423 | lr_scheduler_type = lr_scheduler_config.pop("type", "constant") 424 | 425 | optimizer = OPTIMIZER[optimizer_type](text_encoder.get_input_embeddings().parameters(), **optimizer_config) 426 | lr_scheduler = get_scheduler(name=lr_scheduler_type, optimizer=optimizer, **lr_scheduler_config) 427 | 428 | if accelerator.is_main_process: 429 | print(f"INFO: Train Save:\t [ckpt: {save_path}]", flush=True) 430 | print(f"INFO: Train Class [{class_id}]:\t [num samples: {len(dataloader)}] " 431 | f"[num epochs: {num_train_epochs}] [batch size: {total_batch_size}] " 432 | f"[total steps: {max_train_steps}]", flush=True) 433 | 434 | vae, unet, text_encoder, optimizer, lr_scheduler, dataloader = accelerator.prepare(vae, unet, text_encoder, 435 | optimizer, lr_scheduler, 436 | dataloader) 437 | vae.eval() 438 | unet.eval() 439 | 440 | global_step = 0 441 | progress_bar = tqdm.tqdm(range(max_train_steps), disable=(not accelerator.is_local_main_process)) 442 | for epoch in range(num_train_epochs): 443 | text_encoder.train() 444 | progress_bar.set_description(f"Epoch[{epoch+1}/{num_train_epochs}] ") 445 | for step, data in enumerate(dataloader): 446 | with accelerator.accumulate(text_encoder): 447 | combine_embeddings = text_encoder(data["caption_ids_concept_token"])[0] 448 | 449 | image = data["img"].to(torch_dtype) # use torch.float16 rather than float32 450 | latents = vae.encode(image).latent_dist.sample().detach() * 0.18215 451 | 452 | noise = torch.randn(latents.shape, device=latents.device, dtype=torch_dtype) 453 | timesteps = torch.randint(low=0, high=noise_scheduler.config.num_train_timesteps, 454 | size=(latents.shape[0],), device=latents.device).long() 455 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 456 | 457 | noise_pred = unet(noisy_latents, timesteps, combine_embeddings).sample 458 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 459 | 460 | accelerator.backward(loss) 461 | 462 | if accelerator.num_processes > 1: 463 | grads = text_encoder.module.get_input_embeddings().weight.grad 464 | else: 465 | grads = text_encoder.get_input_embeddings().weight.grad 466 | grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0) 467 | 468 | optimizer.step() 469 | lr_scheduler.step() 470 | optimizer.zero_grad() 471 | 472 | # Checks if the accelerator has performed an optimization step behind the scenes 473 | if accelerator.sync_gradients: 474 | progress_bar.update(1) 475 | global_step += 1 476 | 477 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 478 | progress_bar.set_postfix(refresh=False, **logs) 479 | accelerator.log(logs, step=global_step) 480 | 481 | if global_step >= max_train_steps: 482 | break 483 | 484 | if global_step >= max_train_steps: 485 | break 486 | 487 | # save concept token embeddings per epoch 488 | accelerator.wait_for_everyone() 489 | if accelerator.is_main_process: 490 | out_dir = os.path.join(save_path, "tokens") 491 | if epoch != num_train_epochs - 1: 492 | out_dir = os.path.join(save_path, f"tokens_e{epoch}") 493 | os.makedirs(out_dir, exist_ok=True) 494 | unwrap_text_encoder = accelerator.unwrap_model(text_encoder) 495 | concept_token = dataset.tokenizer.decode(token_id) 496 | concept_token_embeddings = unwrap_text_encoder.get_input_embeddings().weight[token_id] 497 | dct = {concept_token: concept_token_embeddings.detach().cpu()} 498 | torch.save(dct, os.path.join(out_dir, f"{token_id}.bin")) 499 | 500 | accelerator.end_training() 501 | 502 | if keep_class is None: 503 | keep_class = list(range(dataset.num_classes)) 504 | for class_id in keep_class: 505 | train_loop(config, class_id, text_encoder, unet, vae, dataset) 506 | 507 | def train_unet(config): 508 | split = "train" 509 | device = 'cuda' 510 | torch_dtype = torch.float32 511 | 512 | load_pretrain_path = config["train"]["load_pretrain_path"] 513 | 514 | text_encoder = CLIPTextModel.from_pretrained(load_pretrain_path, subfolder="text_encoder", torch_dtype=torch_dtype) 515 | vae = AutoencoderKL.from_pretrained(load_pretrain_path, subfolder="vae", torch_dtype=torch_dtype) 516 | unet = UNet2DConditionModel.from_pretrained(load_pretrain_path, subfolder="unet", torch_dtype=torch_dtype) 517 | noise_scheduler = DDPMScheduler.from_pretrained(load_pretrain_path, subfolder="scheduler") 518 | 519 | def freeze_params(params): 520 | for param in params: 521 | param.requires_grad = False 522 | 523 | freeze_params(itertools.chain( 524 | vae.parameters(), 525 | text_encoder.parameters(), 526 | )) 527 | 528 | data_configs = config["data"][split].copy() 529 | dataset_config = data_configs.pop("dataset", None) 530 | dataset_type = dataset_config.pop("type", "imagenet") 531 | dataset_config.update(dict( 532 | test_mode=(split == "val" or split == "test"), 533 | text_encoder=text_encoder)) 534 | dataset = DATASETS[dataset_type](**dataset_config) 535 | dataloader = torch.utils.data.DataLoader(dataset, **data_configs) 536 | 537 | def train_loop(config, text_encoder=None, unet=None, vae=None, dataloader=None): 538 | config = copy.deepcopy(config) 539 | 540 | accelerator_config = config.get("accelerator", None) 541 | accelerator = Accelerator(**accelerator_config) 542 | if accelerator.is_main_process: 543 | accelerator.init_trackers("wsol", config=config) 544 | 545 | save_path = config['train']['save_path'] 546 | save_step = config['train']['save_step'] 547 | batch_size = config['data']['train']['batch_size'] 548 | gradient_accumulation_steps = config['accelerator']['gradient_accumulation_steps'] 549 | num_train_epochs = config['train']['epochs'] 550 | max_train_steps = config['lr_scheduler']['num_training_steps'] // gradient_accumulation_steps 551 | total_batch_size = batch_size * accelerator.num_processes * gradient_accumulation_steps 552 | 553 | if config['train']['scale_lr']: 554 | config['optimizer']['lr'] = config['optimizer']['lr'] * total_batch_size 555 | if (max_train_steps is None) or (max_train_steps == 0): 556 | num_update_steps_per_epoch = math.ceil(len(dataloader) / gradient_accumulation_steps) 557 | max_train_steps = num_train_epochs * num_update_steps_per_epoch // accelerator.num_processes 558 | 559 | config['lr_scheduler']['num_training_steps'] = max_train_steps * gradient_accumulation_steps 560 | optimizer_config = config.get("optimizer", None) 561 | optimizer_type = optimizer_config.pop("type", "AdamW") 562 | lr_scheduler_config = config.get("lr_scheduler", None) 563 | lr_scheduler_type = lr_scheduler_config.pop("type", "constant") 564 | 565 | optimizer = OPTIMIZER[optimizer_type](accelerator.unwrap_model(unet).parameters(), **optimizer_config) 566 | lr_scheduler = get_scheduler(name=lr_scheduler_type, optimizer=optimizer, **lr_scheduler_config) 567 | 568 | if accelerator.is_main_process: 569 | print(f"INFO: Train Save:\t [ckpt: {save_path}]", flush=True) 570 | print(f"INFO: Train UNet:\t [num samples: {len(dataloader)}] " 571 | f"[num epochs: {num_train_epochs}] [batch size: {total_batch_size}] " 572 | f"[total steps: {max_train_steps}] [save step: {save_step}]", flush=True) 573 | 574 | vae, unet, text_encoder, optimizer, lr_scheduler, dataloader = accelerator.prepare(vae, unet, text_encoder, 575 | optimizer, lr_scheduler, 576 | dataloader) 577 | vae.eval() 578 | text_encoder.eval() 579 | 580 | global_step = 0 581 | progress_bar = tqdm.tqdm(range(max_train_steps), disable=(not accelerator.is_local_main_process)) 582 | for epoch in range(num_train_epochs): 583 | unet.train() 584 | progress_bar.set_description(f"Epoch[{epoch + 1}/{num_train_epochs}] ") 585 | for step, data in enumerate(dataloader): 586 | with accelerator.accumulate(unet): 587 | representative_embeddings = text_encoder(data["caption_ids_concept_token"])[0] 588 | discriminative_embeddings = text_encoder(data["caption_ids_meta_token"])[0] 589 | combine_embeddings = 0.5 * representative_embeddings + 0.5 * discriminative_embeddings 590 | 591 | image = data["img"].to(torch_dtype) 592 | latents = vae.encode(image).latent_dist.sample().detach() * 0.18215 593 | 594 | noise = torch.randn(latents.shape, device=latents.device, dtype=torch_dtype) 595 | timesteps = torch.randint(low=0, high=noise_scheduler.config.num_train_timesteps, 596 | size=(latents.shape[0],), device=latents.device).long() 597 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 598 | 599 | noise_pred = unet(noisy_latents, timesteps, combine_embeddings).sample 600 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 601 | 602 | accelerator.backward(loss) 603 | 604 | optimizer.step() 605 | lr_scheduler.step() 606 | optimizer.zero_grad() 607 | 608 | # Checks if the accelerator has performed an optimization step behind the scenes 609 | if accelerator.sync_gradients: 610 | progress_bar.update(1) 611 | global_step += 1 612 | 613 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 614 | progress_bar.set_postfix(refresh=False, **logs) 615 | accelerator.log(logs, step=global_step) 616 | 617 | if global_step >= max_train_steps: 618 | break 619 | elif (global_step + 1) % save_step == 0: 620 | if accelerator.sync_gradients: 621 | accelerator.wait_for_everyone() 622 | if accelerator.is_main_process: 623 | out_dir = os.path.join(save_path, f"unet_s{global_step}") 624 | os.makedirs(out_dir, exist_ok=True) 625 | try: 626 | unet.module.save_pretrained(save_directory=out_dir) 627 | except: 628 | unet.save_pretrained(save_directory=out_dir) 629 | 630 | 631 | if global_step >= max_train_steps: 632 | break 633 | 634 | accelerator.end_training() 635 | 636 | train_loop(config, text_encoder, unet, vae, dataloader) 637 | 638 | if __name__ == "__main__": 639 | args, config = set_env(benchmark=True) 640 | print(args) 641 | eval(args.function)(config) 642 | 643 | 644 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/callsys/GenPromp/c480e0a106d7158da4a4e86cc27441f181996b04/models/__init__.py -------------------------------------------------------------------------------- /models/attn.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | import torch.utils.checkpoint 5 | import matplotlib.pyplot as plt 6 | 7 | class AttentionStore(): 8 | def __init__(self, batch_size=2): 9 | self.cur_step = 0 10 | self.num_att_layers = -1 11 | self.cur_att_layer = 0 12 | self.step_store = self.get_empty_store() 13 | self.attention_store = {} 14 | self.active = True 15 | self.batch_size = batch_size 16 | 17 | def reset(self): 18 | self.cur_step = 0 19 | self.cur_att_layer = 0 20 | self.step_store = self.get_empty_store() 21 | self.attention_store = {} 22 | 23 | @property 24 | def num_uncond_att_layers(self): 25 | return 0 26 | 27 | def step_callback(self, x_t): 28 | return x_t 29 | 30 | @staticmethod 31 | def get_empty_store(): 32 | return {"down_cross": [], "mid_cross": [], "up_cross": [], 33 | "down_self": [], "mid_self": [], "up_self": []} 34 | 35 | def forward(self, attn, is_cross: bool, place_in_unet: str): 36 | if self.active: 37 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" 38 | self.step_store[key].append(attn) 39 | return attn 40 | 41 | def between_steps(self): 42 | if self.active: 43 | if len(self.attention_store) == 0: 44 | self.attention_store = self.step_store 45 | else: 46 | for key in self.attention_store: 47 | for i in range(len(self.attention_store[key])): 48 | self.attention_store[key][i] += self.step_store[key][i] 49 | self.step_store = self.get_empty_store() 50 | 51 | def get_average_attention(self): 52 | average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in 53 | self.attention_store} 54 | return average_attention 55 | 56 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 57 | if self.cur_att_layer >= self.num_uncond_att_layers: 58 | attn = self.forward(attn, is_cross, place_in_unet) 59 | self.cur_att_layer += 1 60 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: 61 | self.cur_att_layer = 0 62 | self.cur_step += 1 63 | self.between_steps() 64 | return attn 65 | 66 | @staticmethod 67 | def register_attention_control(controller, model): 68 | 69 | def ca_attention(self, place_in_unet): 70 | 71 | def get_attention_scores(query, key, attention_mask=None): 72 | dtype = query.dtype 73 | 74 | if self.upcast_attention: 75 | query = query.float() 76 | key = key.float() 77 | 78 | if attention_mask is None: 79 | baddbmm_input = torch.empty( 80 | query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device 81 | ) 82 | beta = 0 83 | else: 84 | baddbmm_input = attention_mask 85 | beta = 1 86 | 87 | attention_scores = torch.baddbmm( 88 | baddbmm_input, 89 | query, 90 | key.transpose(-1, -2), 91 | beta=beta, 92 | alpha=self.scale, 93 | ) 94 | 95 | if self.upcast_softmax: 96 | attention_scores = attention_scores.float() 97 | 98 | attention_probs = attention_scores.softmax(dim=-1) 99 | attention_probs = attention_probs.to(dtype) 100 | 101 | if query.shape == key.shape: 102 | is_cross = False 103 | else: 104 | is_cross = True 105 | 106 | attention_probs = controller(attention_probs, is_cross, place_in_unet) 107 | 108 | return attention_probs 109 | 110 | return get_attention_scores 111 | 112 | def register_recr(net_, count, place_in_unet): 113 | if net_.__class__.__name__ == 'CrossAttention': 114 | # net_._attention = ca_attention(net_, place_in_unet) 115 | net_.get_attention_scores = ca_attention(net_, place_in_unet) 116 | return count + 1 117 | elif hasattr(net_, 'children'): 118 | for net__ in net_.children(): 119 | count = register_recr(net__, count, place_in_unet) 120 | return count 121 | 122 | cross_att_count = 0 123 | sub_nets = model.named_children() 124 | for net in sub_nets: 125 | if "down" in net[0]: 126 | cross_att_count += register_recr(net[1], 0, "down") 127 | elif "up" in net[0]: 128 | cross_att_count += register_recr(net[1], 0, "up") 129 | elif "mid" in net[0]: 130 | cross_att_count += register_recr(net[1], 0, "mid") 131 | controller.num_att_layers = cross_att_count 132 | 133 | def aggregate_attention(self, res, from_where, is_cross, bz): 134 | out = [] 135 | attention_maps = self.get_average_attention() 136 | num_pixels = res ** 2 137 | for location in from_where: 138 | for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: 139 | if item.shape[1] == num_pixels: 140 | cross_maps = item.reshape(bz, -1, res, res, item.shape[-1]) 141 | out.append(cross_maps) 142 | out = torch.cat(out, dim=1) 143 | return out.cpu() 144 | 145 | def self_attention_map(self, res, from_where, bz, max_com=10, out_res=64): 146 | attention_maps = self.aggregate_attention(res, from_where, False, bz) 147 | maps = [] 148 | for b in range(bz): 149 | attention_map = attention_maps[b].detach().numpy().astype(np.float32).mean(0).reshape((res**2, res**2)) 150 | u, s, vh = np.linalg.svd(attention_map - np.mean(attention_map, axis=1, keepdims=True)) 151 | images = [] 152 | for i in range(max_com): 153 | image = vh[i].reshape(res, res) 154 | # image = image/image.max() 155 | # image = (image - image.min()) / (image.max() - image.min()) 156 | image = cv2.resize(image, (out_res, out_res), interpolation=cv2.INTER_CUBIC) 157 | images.append(image) 158 | map = np.stack(images, 0).max(0) 159 | maps.append(map) 160 | return np.stack(maps, 0) 161 | 162 | def cross_attention_map(self, res, from_where, bz, out_res=64, idx=5): 163 | attention_maps = self.aggregate_attention(res, from_where, True, bz) 164 | attention_maps = attention_maps[..., idx] 165 | attention_maps = attention_maps.sum(1) / attention_maps.shape[1] 166 | 167 | maps = [] 168 | for b in range(bz): 169 | map = attention_maps[b, :, :] 170 | map = cv2.resize(map.detach().numpy().astype(np.float32), (out_res, out_res), 171 | interpolation=cv2.INTER_CUBIC) 172 | # map = map / map.max() 173 | maps.append(map) 174 | return np.stack(maps, 0) 175 | 176 | def diffusion_cam(self, idx=5): 177 | bz = self.batch_size 178 | attention_maps_8_ca = self.cross_attention_map(8, ("up", "mid", "down"), bz, idx=idx) 179 | attention_maps_16_up_ca = self.cross_attention_map(16, ("up",), bz, idx=idx) 180 | attention_maps_16_down_ca = self.cross_attention_map(16, ("down",), bz, idx=idx) 181 | attention_maps_ca = (attention_maps_8_ca + attention_maps_16_up_ca + attention_maps_16_down_ca) / 3 182 | cams = attention_maps_ca 183 | cams = cams / cams.max((1,2))[:, None, None] 184 | return cams 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | --------------------------------------------------------------------------------