├── .gitignore ├── LICENSE ├── README.md ├── checkpoints ├── .gitkeep └── README.md ├── configs ├── CUB_resnet18.json ├── CUB_resnet34.json ├── CUB_resnet50.json ├── CUB_resnet50v2.json └── ImageNet_resnet50.json ├── data ├── CUB.py ├── RDD.py ├── RDD_4.py ├── RDD_bbox.py ├── __init__.py ├── imagenet.py └── sampler.py ├── datasets ├── .gitkeep └── README.md ├── evaluate.py ├── exp.sh ├── metrics ├── __init__.py ├── accuracy.py ├── base.py └── patch_insdel.py ├── models ├── __init__.py ├── attention_branch.py ├── lrp.py └── rise.py ├── oneshot.py ├── optim ├── __init__.py └── sam.py ├── outputs └── .gitkeep ├── poetry.lock ├── pyproject.toml ├── qual └── original │ ├── Arabian_camel.png │ ├── Brandt_Cormorant.png │ ├── Geococcyx.png │ ├── Rock_Wren.png │ ├── Savannah_Sparrow.png │ ├── bee.png │ ├── bubble.png │ ├── bustard.png │ ├── drumstick.png │ ├── oboe.png │ ├── ram.png │ ├── sock.png │ ├── solar_dish.png │ ├── water_ouzel.png │ └── wombat.png ├── scripts ├── calc_dataset_info.py └── visualize_transforms.py ├── src ├── __init__.py ├── data.py ├── data_processing.py ├── lrp.py ├── lrp_filter.py ├── lrp_layers.py ├── scorecam.py └── utils.py ├── train.py ├── utils ├── loss.py ├── utils.py └── visualize.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .ruff_cache 3 | /wandb 4 | /checkpoints 5 | /datasets 6 | /outputs 7 | /oneshot_images 8 | /qual* 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | cover/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | .pybuilder/ 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | # For a library or package, you might want to ignore these files since the code is 96 | # intended to run in multiple environments; otherwise, check them in: 97 | # .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # poetry 107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 108 | # This is especially recommended for binary packages to ensure reproducibility, and is more 109 | # commonly ignored for libraries. 110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 111 | #poetry.lock 112 | 113 | # pdm 114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 115 | #pdm.lock 116 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 117 | # in version control. 118 | # https://pdm.fming.dev/#use-with-ide 119 | .pdm.toml 120 | 121 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 122 | __pypackages__/ 123 | 124 | # Celery stuff 125 | celerybeat-schedule 126 | celerybeat.pid 127 | 128 | # SageMath parsed files 129 | *.sage.py 130 | 131 | # Environments 132 | .env 133 | .venv 134 | env/ 135 | venv/ 136 | ENV/ 137 | env.bak/ 138 | venv.bak/ 139 | 140 | # Spyder project settings 141 | .spyderproject 142 | .spyproject 143 | 144 | # Rope project settings 145 | .ropeproject 146 | 147 | # mkdocs documentation 148 | /site 149 | 150 | # mypy 151 | .mypy_cache/ 152 | .dmypy.json 153 | dmypy.json 154 | 155 | # Pyre type checker 156 | .pyre/ 157 | 158 | # pytype static type analyzer 159 | .pytype/ 160 | 161 | # Cython debug symbols 162 | cython_debug/ 163 | 164 | # PyCharm 165 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 166 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 167 | # and can be added to the global gitignore or merged into this file. For a more nuclear 168 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 169 | #.idea/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The Clear BSD License 2 | 3 | Copyright (c) 2024 the authors of "Layer-Wise Relevance Propagation with Conservation Property for ResNet" (Seitaro Otsuki, Tsumugi Iida, Félix Doublet, 4 | Tsubasa Hirakawa, Takayoshi Yamashita, Hironobu Fujiyoshi and Komei Sugiura) 5 | 6 | All rights reserved. 7 | 8 | Redistribution and use in source and binary forms, with or without 9 | modification, are permitted (subject to the limitations in the disclaimer 10 | below) provided that the following conditions are met: 11 | 12 | * Redistributions of source code must retain the above copyright notice, 13 | this list of conditions and the following disclaimer. 14 | 15 | * Redistributions in binary form must reproduce the above copyright 16 | notice, this list of conditions and the following disclaimer in the 17 | documentation and/or other materials provided with the distribution. 18 | 19 | * Neither the name of the copyright holder nor the names of its 20 | contributors may be used to endorse or promote products derived from this 21 | software without specific prior written permission. 22 | 23 | NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 24 | THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 25 | CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 26 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 27 | PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 28 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 29 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 30 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR 31 | BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER 32 | IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 33 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 34 | POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [ECCV24] Layer-Wise Relevance Propagation with Conservation Property for ResNet 2 | 3 | - Accepted at ECCV 2024 4 | - [Project page](https://5ei74r0.github.io/lrp-for-resnet.page/) 5 | - [ArXiv](https://arxiv.org/abs/2407.09115) 6 | 7 | 8 | The transparent formulation of explanation methods is essential for elucidating the predictions of neural networks, which are typically black-box models. Layer-wise Relevance Propagation (LRP) is a well-established method that transparently traces the flow of a model's prediction backward through its architecture by backpropagating relevance scores. However, the conventional LRP does not fully consider the existence of skip connections, and thus its application to the widely used ResNet architecture has not been thoroughly explored. In this study, we extend LRP to ResNet models by introducing Relevance Splitting at points where the output from a skip connection converges with that from a residual block. Our formulation guarantees the conservation property throughout the process, thereby preserving the integrity of the generated explanations. To evaluate the effectiveness of our approach, we conduct experiments on ImageNet and the Caltech-UCSD Birds-200-2011 dataset. Our method achieves superior performance to that of baseline methods on standard evaluation metrics such as the Insertion-Deletion score while maintaining its conservation property. We will release our code for further research at this https URL 9 | 10 | 11 | 12 | ## Getting started 13 | Clone this repository and get in it. Then run `poetry install --no-root`. 14 | 15 | We used the following env. 16 | - Python 3.9.15 17 | - Poetry 1.7.1 18 | - cuda 11.7 19 | 20 | See [pyproject.toml](pyproject.toml) to check python dependencies. 21 | 22 | ### Datasets 23 | Follow the instructions [here](datasets/README.md). 24 | 25 | ### Get models 26 | If you want to test the method on the CUB, follow the instructions [here](checkpoints/README.md). 27 | You do not have to prepare models for ImageNet. 28 | 29 | 30 | ## Quantitative Experiments 31 | E.g.: Run ours on ImageNet. 32 | ```bash 33 | poetry run python visualize.py -c configs/ImageNet_resnet50.json --method "lrp" --heat-quantization --skip-connection-prop-type "flows_skip" --notes "imagenet--type:flows_skip--viz:norm+positive" --all_class --seed 42 --normalize --sign "positive" 34 | ``` 35 | 36 | 37 | ## Visualize attribution maps for specific images 38 | E.g.: Visualize attribution maps for water ouzel in ImageNet by our method. 39 | ```bash 40 | poetry run python oneshot.py -c configs/ImageNet_resnet50.json --method "lrp" --skip-connection-prop-type "flows_skip" --heat-quantization --image-path ./qual/original/water_ouzel.png --label 20 --save-path ./qual/ours/water_ouzel.png --normalize --sign "positive" 41 | ``` 42 | See `exp.sh` for more examples 43 | 44 | 45 | ## Bibtex 46 | 47 | ``` 48 | @article{otsuki2024layer, 49 | title={{Layer-Wise Relevance Propagation with Conservation Property for ResNet}}, 50 | author={Seitaro Otsuki, Tsumugi Iida, F\'elix Doublet, Tsubasa Hirakawa, Takayoshi Yamashita, Hironobu Fujiyoshi, Komei Sugiura}, 51 | journal={arXiv preprint arXiv:2407.09115}, 52 | year={2024}, 53 | } 54 | ``` 55 | 56 | ## License 57 | This work is licensed under the BSD-3-Clause-Clear license. 58 | To view a copy of this license, see [LICENSE](LICENSE). 59 | -------------------------------------------------------------------------------- /checkpoints/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/checkpoints/.gitkeep -------------------------------------------------------------------------------- /checkpoints/README.md: -------------------------------------------------------------------------------- 1 | We use ResNet50 trained on CUB for CUB experiments. Download them from [here](https://github.com/keio-smilab24/LRP-for-ResNet/releases) and put them into this (checkpoints/) directory. 2 | 3 | ### Structure 4 | Download models & place them as follows: 5 | ``` 6 | checkpoints 7 | ├── .gitkeep 8 | ├── CUB_resnet50_Seed40 9 | │ ├── best.pt 10 | │ ├── checkpoint.pt 11 | │ └── config.json 12 | ├── CUB_resnet50_Seed41 13 | │ ├── best.pt 14 | │ ├── checkpoint.pt 15 | │ └── config.json 16 | ├── CUB_resnet50_Seed42 17 | │ ├── best.pt 18 | │ ├── checkpoint.pt 19 | │ └── config.json 20 | ├── CUB_resnet50_Seed43 21 | │ ├── best.pt 22 | │ ├── checkpoint.pt 23 | │ └── config.json 24 | ├── CUB_resnet50_Seed44 25 | │ ├── best.pt 26 | │ ├── checkpoint.pt 27 | │ └── config.json 28 | ``` -------------------------------------------------------------------------------- /configs/CUB_resnet18.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 574, 3 | "model": "resnet18", 4 | "add_attention_branch": false, 5 | "base_pretrained": "", 6 | "trainable_module": -1, 7 | "dataset": "CUB", 8 | "image_size": 224, 9 | "batch_size": 32, 10 | "train_ratio": 0.9, 11 | "epochs": 300, 12 | "optimizer": "SGD", 13 | "lr": 1e-3, 14 | "lr_linear": 1e-3, 15 | "min_lr": 1e-6, 16 | "weight_decay": 1e-4, 17 | "factor": 0.333, 18 | "scheduler_patience": 1, 19 | "early_stopping_patience": 6 20 | } -------------------------------------------------------------------------------- /configs/CUB_resnet34.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 574, 3 | "model": "resnet34", 4 | "add_attention_branch": false, 5 | "base_pretrained": "", 6 | "trainable_module": -1, 7 | "dataset": "CUB", 8 | "image_size": 224, 9 | "batch_size": 32, 10 | "train_ratio": 0.9, 11 | "epochs": 300, 12 | "optimizer": "SGD", 13 | "lr": 1e-3, 14 | "lr_linear": 1e-3, 15 | "min_lr": 1e-6, 16 | "weight_decay": 1e-4, 17 | "factor": 0.333, 18 | "scheduler_patience": 1, 19 | "early_stopping_patience": 6 20 | } -------------------------------------------------------------------------------- /configs/CUB_resnet50.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 574, 3 | "model": "resnet50", 4 | "add_attention_branch": false, 5 | "base_pretrained": "", 6 | "trainable_module": -1, 7 | "dataset": "CUB", 8 | "image_size": 224, 9 | "batch_size": 32, 10 | "train_ratio": 0.9, 11 | "epochs": 300, 12 | "optimizer": "SGD", 13 | "lr": 1e-3, 14 | "lr_linear": 1e-3, 15 | "min_lr": 1e-6, 16 | "weight_decay": 1e-4, 17 | "factor": 0.333, 18 | "scheduler_patience": 1, 19 | "early_stopping_patience": 6 20 | } -------------------------------------------------------------------------------- /configs/CUB_resnet50v2.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 574, 3 | "model": "resnet50-v2", 4 | "add_attention_branch": false, 5 | "base_pretrained": "", 6 | "trainable_module": -1, 7 | "dataset": "CUB", 8 | "image_size": 224, 9 | "batch_size": 32, 10 | "train_ratio": 0.9, 11 | "epochs": 300, 12 | "optimizer": "SGD", 13 | "lr": 1e-3, 14 | "lr_linear": 1e-3, 15 | "min_lr": 1e-6, 16 | "weight_decay": 1e-4, 17 | "factor": 0.333, 18 | "scheduler_patience": 1, 19 | "early_stopping_patience": 6 20 | } -------------------------------------------------------------------------------- /configs/ImageNet_resnet50.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 574, 3 | "model": "resnet50", 4 | "add_attention_branch": false, 5 | "base_pretrained": "", 6 | "trainable_module": -1, 7 | "dataset": "ImageNet", 8 | "image_size": 224, 9 | "batch_size": 32, 10 | "train_ratio": 0.9, 11 | "epochs": 300, 12 | "optimizer": "SGD", 13 | "lr": 1e-3, 14 | "lr_linear": 1e-3, 15 | "min_lr": 1e-6, 16 | "weight_decay": 1e-4, 17 | "factor": 0.333, 18 | "scheduler_patience": 1, 19 | "early_stopping_patience": 6 20 | } -------------------------------------------------------------------------------- /data/CUB.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Callable, Optional, Tuple 3 | 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class CUBDataset(Dataset): 9 | 10 | def __init__( 11 | self, root: str, image_set: str = "train", transform: Optional[Callable] = None 12 | ) -> None: 13 | super().__init__() 14 | 15 | self.root = root 16 | self.image_set = image_set 17 | self.transform = transform 18 | 19 | self.base_dir = os.path.join(self.root, "CUB_200_2011") 20 | 21 | # Images and their corresponding paths 22 | image_list_path = os.path.join(self.base_dir, "images.txt") 23 | with open(image_list_path, "r") as f: 24 | lines = f.readlines() 25 | img_ids = [int(line.split()[0]) for line in lines] 26 | img_paths = [line.split()[1].strip() for line in lines] 27 | 28 | # Splitting the dataset into train/test based on train_test_split.txt 29 | train_test_split_path = os.path.join(self.base_dir, "train_test_split.txt") 30 | with open(train_test_split_path, "r") as f: 31 | lines = f.readlines() 32 | splits = {int(line.split()[0]): int(line.split()[1]) for line in lines} 33 | 34 | # Filtering images based on the desired set (train/test) 35 | if image_set == "train": 36 | self.images = [img_paths[i] for i, img_id in enumerate(img_ids) if splits[img_id] == 1] 37 | else: 38 | self.images = [img_paths[i] for i, img_id in enumerate(img_ids) if splits[img_id] == 0] 39 | 40 | # Targets and segments 41 | self.targets = [int(img_path.split('/')[0].split('.')[0]) - 1 for img_path in self.images] 42 | self.segments = [img_path.replace("images", "segmentations").replace(".jpg", ".png") for img_path in self.images] 43 | 44 | def __getitem__(self, index) -> Tuple[Any, Any]: 45 | image_path = os.path.join(self.base_dir, "images", self.images[index]) 46 | image = Image.open(image_path).convert("RGB") 47 | target = self.targets[index] 48 | 49 | if self.transform: 50 | image = self.transform(image) 51 | 52 | return image, target 53 | 54 | def __len__(self) -> int: 55 | return len(self.images) 56 | 57 | 58 | CUB_CLASSES = [ 59 | 'Black_footed_Albatross', 60 | 'Laysan_Albatross', 61 | 'Sooty_Albatross', 62 | 'Groove_billed_Ani', 63 | 'Crested_Auklet', 64 | 'Least_Auklet', 65 | 'Parakeet_Auklet', 66 | 'Rhinoceros_Auklet', 67 | 'Brewer_Blackbird', 68 | 'Red_winged_Blackbird', 69 | 'Rusty_Blackbird', 70 | 'Yellow_headed_Blackbird', 71 | 'Bobolink', 72 | 'Indigo_Bunting', 73 | 'Lazuli_Bunting', 74 | 'Painted_Bunting', 75 | 'Cardinal', 76 | 'Spotted_Catbird', 77 | 'Gray_Catbird', 78 | 'Yellow_breasted_Chat', 79 | 'Eastern_Towhee', 80 | 'Chuck_will_Widow', 81 | 'Brandt_Cormorant', 82 | 'Red_faced_Cormorant', 83 | 'Pelagic_Cormorant', 84 | 'Bronzed_Cowbird', 85 | 'Shiny_Cowbird', 86 | 'Brown_Creeper', 87 | 'American_Crow', 88 | 'Fish_Crow', 89 | 'Black_billed_Cuckoo', 90 | 'Mangrove_Cuckoo', 91 | 'Yellow_billed_Cuckoo', 92 | 'Gray_crowned_Rosy_Finch', 93 | 'Purple_Finch', 94 | 'Northern_Flicker', 95 | 'Acadian_Flycatcher', 96 | 'Great_Crested_Flycatcher', 97 | 'Least_Flycatcher', 98 | 'Olive_sided_Flycatcher', 99 | 'Scissor_tailed_Flycatcher', 100 | 'Vermilion_Flycatcher', 101 | 'Yellow_bellied_Flycatcher', 102 | 'Frigatebird', 103 | 'Northern_Fulmar', 104 | 'Gadwall', 105 | 'American_Goldfinch', 106 | 'European_Goldfinch', 107 | 'Boat_tailed_Grackle', 108 | 'Eared_Grebe', 109 | 'Horned_Grebe', 110 | 'Pied_billed_Grebe', 111 | 'Western_Grebe', 112 | 'Blue_Grosbeak', 113 | 'Evening_Grosbeak', 114 | 'Pine_Grosbeak', 115 | 'Rose_breasted_Grosbeak', 116 | 'Pigeon_Guillemot', 117 | 'California_Gull', 118 | 'Glaucous_winged_Gull', 119 | 'Heermann_Gull', 120 | 'Herring_Gull', 121 | 'Ivory_Gull', 122 | 'Ring_billed_Gull', 123 | 'Slaty_backed_Gull', 124 | 'Western_Gull', 125 | 'Anna_Hummingbird', 126 | 'Ruby_throated_Hummingbird', 127 | 'Rufous_Hummingbird', 128 | 'Green_Violetear', 129 | 'Long_tailed_Jaeger', 130 | 'Pomarine_Jaeger', 131 | 'Blue_Jay', 132 | 'Florida_Jay', 133 | 'Green_Jay', 134 | 'Dark_eyed_Junco', 135 | 'Tropical_Kingbird', 136 | 'Gray_Kingbird', 137 | 'Belted_Kingfisher', 138 | 'Green_Kingfisher', 139 | 'Pied_Kingfisher', 140 | 'Ringed_Kingfisher', 141 | 'White_breasted_Kingfisher', 142 | 'Red_legged_Kittiwake', 143 | 'Horned_Lark', 144 | 'Pacific_Loon', 145 | 'Mallard', 146 | 'Western_Meadowlark', 147 | 'Hooded_Merganser', 148 | 'Red_breasted_Merganser', 149 | 'Mockingbird', 150 | 'Nighthawk', 151 | 'Clark_Nutcracker', 152 | 'White_breasted_Nuthatch', 153 | 'Baltimore_Oriole', 154 | 'Hooded_Oriole', 155 | 'Orchard_Oriole', 156 | 'Scott_Oriole', 157 | 'Ovenbird', 158 | 'Brown_Pelican', 159 | 'White_Pelican', 160 | 'Western_Wood_Pewee', 161 | 'Sayornis', 162 | 'American_Pipit', 163 | 'Whip_poor_Will', 164 | 'Horned_Puffin', 165 | 'Common_Raven', 166 | 'White_necked_Raven', 167 | 'American_Redstart', 168 | 'Geococcyx', 169 | 'Loggerhead_Shrike', 170 | 'Great_Grey_Shrike', 171 | 'Baird_Sparrow', 172 | 'Black_throated_Sparrow', 173 | 'Brewer_Sparrow', 174 | 'Chipping_Sparrow', 175 | 'Clay_colored_Sparrow', 176 | 'House_Sparrow', 177 | 'Field_Sparrow', 178 | 'Fox_Sparrow', 179 | 'Grasshopper_Sparrow', 180 | 'Harris_Sparrow', 181 | 'Henslow_Sparrow', 182 | 'Le_Conte_Sparrow', 183 | 'Lincoln_Sparrow', 184 | 'Nelson_Sharp_tailed_Sparrow', 185 | 'Savannah_Sparrow', 186 | 'Seaside_Sparrow', 187 | 'Song_Sparrow', 188 | 'Tree_Sparrow', 189 | 'Vesper_Sparrow', 190 | 'White_crowned_Sparrow', 191 | 'White_throated_Sparrow', 192 | 'Cape_Glossy_Starling', 193 | 'Bank_Swallow', 194 | 'Barn_Swallow', 195 | 'Cliff_Swallow', 196 | 'Tree_Swallow', 197 | 'Scarlet_Tanager', 198 | 'Summer_Tanager', 199 | 'Artic_Tern', 200 | 'Black_Tern', 201 | 'Caspian_Tern', 202 | 'Common_Tern', 203 | 'Elegant_Tern', 204 | 'Forsters_Tern', 205 | 'Least_Tern', 206 | 'Green_tailed_Towhee', 207 | 'Brown_Thrasher', 208 | 'Sage_Thrasher', 209 | 'Black_capped_Vireo', 210 | 'Blue_headed_Vireo', 211 | 'Philadelphia_Vireo', 212 | 'Red_eyed_Vireo', 213 | 'Warbling_Vireo', 214 | 'White_eyed_Vireo', 215 | 'Yellow_throated_Vireo', 216 | 'Bay_breasted_Warbler', 217 | 'Black_and_white_Warbler', 218 | 'Black_throated_Blue_Warbler', 219 | 'Blue_winged_Warbler', 220 | 'Canada_Warbler', 221 | 'Cape_May_Warbler', 222 | 'Cerulean_Warbler', 223 | 'Chestnut_sided_Warbler', 224 | 'Golden_winged_Warbler', 225 | 'Hooded_Warbler', 226 | 'Kentucky_Warbler', 227 | 'Magnolia_Warbler', 228 | 'Mourning_Warbler', 229 | 'Myrtle_Warbler', 230 | 'Nashville_Warbler', 231 | 'Orange_crowned_Warbler', 232 | 'Palm_Warbler', 233 | 'Pine_Warbler', 234 | 'Prairie_Warbler', 235 | 'Prothonotary_Warbler', 236 | 'Swainson_Warbler', 237 | 'Tennessee_Warbler', 238 | 'Wilson_Warbler', 239 | 'Worm_eating_Warbler', 240 | 'Yellow_Warbler', 241 | 'Northern_Waterthrush', 242 | 'Louisiana_Waterthrush', 243 | 'Bohemian_Waxwing', 244 | 'Cedar_Waxwing', 245 | 'American_Three_toed_Woodpecker', 246 | 'Pileated_Woodpecker', 247 | 'Red_bellied_Woodpecker', 248 | 'Red_cockaded_Woodpecker', 249 | 'Red_headed_Woodpecker', 250 | 'Downy_Woodpecker', 251 | 'Bewick_Wren', 252 | 'Cactus_Wren', 253 | 'Carolina_Wren', 254 | 'House_Wren', 255 | 'Marsh_Wren', 256 | 'Rock_Wren', 257 | 'Winter_Wren', 258 | 'Common_Yellowthroat', 259 | ] -------------------------------------------------------------------------------- /data/RDD.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | from typing import Any, Callable, Optional, Tuple 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | from torch.utils.data import Dataset 10 | 11 | 12 | class RDDDataset(Dataset): 13 | 14 | def __init__( 15 | self, root: str, image_set: str = "train", transform: Optional[Callable] = None 16 | ) -> None: 17 | super().__init__() 18 | 19 | self.root = root 20 | self.image_set = image_set 21 | self.transform = transform 22 | 23 | self.base_dir = os.path.join(self.root, "RDD") 24 | 25 | image_dir = "JPEGImages" 26 | annot_dir = "Annotations" 27 | sam_dir = "SAM-mask" 28 | annotation_file = f"{image_set}.csv" 29 | 30 | self.image_dir = os.path.join(self.base_dir, image_dir) 31 | self.sam_dir = os.path.join(self.base_dir, sam_dir) 32 | self.annotation_file = os.path.join(self.base_dir, annot_dir, annotation_file) 33 | 34 | with open(self.annotation_file) as f: 35 | reader = csv.reader(f) 36 | # skip header 37 | reader.__next__() 38 | annotations = [row for row in reader] 39 | 40 | # Transpose [Image_fname, Label] 41 | annotations = [list(x) for x in zip(*annotations)] 42 | 43 | image_fnames = annotations[0] 44 | self.images = list( 45 | map(lambda x: os.path.join(self.image_dir, x + ".jpg"), image_fnames) 46 | ) 47 | self.sam_segments = list( 48 | map(lambda x: os.path.join(self.sam_dir, x + ".png"), image_fnames) 49 | ) 50 | self.targets = self.targets = list( 51 | map(lambda x: int(1 <= int(x)), annotations[1]) 52 | ) 53 | 54 | def __getitem__(self, index) -> Tuple[Any, Any, Any, Any]: 55 | orig_image = Image.open(self.images[index]).convert("RGB") 56 | 57 | sam = cv2.imread(self.sam_segments[index], cv2.IMREAD_GRAYSCALE) 58 | sam_orig = np.expand_dims(cv2.resize(sam, (orig_image.size[:2])), -1) 59 | sam_mask = 1 * (sam_orig > 0) 60 | if sam_mask.sum() < sam_mask.size * 0.1: 61 | sam_mask = np.ones_like(sam_mask) 62 | 63 | # Calculate the bounding box coordinates of the mask 64 | mask_coords = np.argwhere(sam_mask.squeeze()) 65 | min_y, min_x = np.min(mask_coords, axis=0) 66 | max_y, max_x = np.max(mask_coords, axis=0) 67 | 68 | # Crop the image using the bounding box coordinates 69 | image = orig_image.crop((min_x, min_y, max_x + 1, max_y + 1)) 70 | 71 | # image = np.array(orig_image) * sam_mask 72 | # image = Image.fromarray(image.astype(np.uint8)) 73 | 74 | target = self.targets[index] 75 | 76 | if self.transform is not None: 77 | image = self.transform(image) 78 | orig_image = self.transform(orig_image) 79 | 80 | sam_resized = np.expand_dims(cv2.resize(sam, (orig_image.shape[1:])), -1) 81 | sam_mask = 1 * (sam_resized > 0) 82 | sam_mask = sam_mask.reshape(1, orig_image.shape[1], orig_image.shape[2]) 83 | sam_mask = torch.from_numpy(sam_mask.astype(np.uint8)).clone() 84 | 85 | return image, target, sam_mask, orig_image 86 | 87 | def __len__(self) -> int: 88 | return len(self.images) 89 | -------------------------------------------------------------------------------- /data/RDD_4.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | from typing import Any, Callable, Optional, Tuple 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import torchvision 9 | from PIL import Image 10 | from torch.utils.data import Dataset 11 | 12 | 13 | class RDDDataset(Dataset): 14 | 15 | def __init__( 16 | self, root: str, image_set: str = "train", transform: Optional[Callable] = None 17 | ) -> None: 18 | super().__init__() 19 | 20 | self.root = root 21 | self.image_set = image_set 22 | self.transform = transform 23 | 24 | self.base_dir = os.path.join(self.root, "RDD") 25 | 26 | image_dir = "JPEGImages" 27 | annot_dir = "Annotations" 28 | sam_dir = "SAM-all" 29 | annotation_file = f"{image_set}.csv" 30 | 31 | self.image_dir = os.path.join(self.base_dir, image_dir) 32 | self.sam_dir = os.path.join(self.base_dir, sam_dir) 33 | self.annotation_file = os.path.join(self.base_dir, annot_dir, annotation_file) 34 | 35 | with open(self.annotation_file) as f: 36 | reader = csv.reader(f) 37 | # skip header 38 | reader.__next__() 39 | annotations = [row for row in reader] 40 | 41 | # Transpose [Image_fname, Label] 42 | annotations = [list(x) for x in zip(*annotations)] 43 | 44 | image_fnames = annotations[0] 45 | self.images = list( 46 | map(lambda x: os.path.join(self.image_dir, x + ".jpg"), image_fnames) 47 | ) 48 | self.sam_segments = list( 49 | map(lambda x: os.path.join(self.sam_dir, x + ".png"), image_fnames) 50 | ) 51 | self.targets = self.targets = list( 52 | map(lambda x: int(1 <= int(x)), annotations[1]) 53 | ) 54 | 55 | def __getitem__(self, index) -> Tuple[Any, Any]: 56 | image = Image.open(self.images[index]).convert("RGB") 57 | 58 | target = self.targets[index] 59 | 60 | if self.transform is not None: 61 | image = self.transform(image) 62 | 63 | sam = cv2.imread(self.sam_segments[index], cv2.IMREAD_GRAYSCALE) 64 | # sam = np.expand_dims(cv2.resize(sam, (image.size[:2])), -1) 65 | sam = cv2.resize(sam, (image.shape[1:])) 66 | sam_mask = 1 * (sam > 0) 67 | sam_mask_pil = Image.fromarray(sam_mask.astype(np.uint8), mode="L") 68 | 69 | mask = torchvision.transforms.functional.to_tensor(sam_mask_pil) 70 | # print(image.shape, mask.shape) 71 | image = torch.cat((image, mask), 0) 72 | 73 | return image, target 74 | 75 | def __len__(self) -> int: 76 | return len(self.images) 77 | -------------------------------------------------------------------------------- /data/RDD_bbox.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | from typing import Any, Callable, Optional, Tuple 4 | 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class RDDBboxDataset(Dataset): 10 | 11 | def __init__( 12 | self, root: str, image_set: str = "train", transform: Optional[Callable] = None 13 | ) -> None: 14 | super().__init__() 15 | 16 | self.root = root 17 | self.image_set = image_set 18 | self.transform = transform 19 | 20 | self.base_dir = os.path.join(self.root, "RDD_bbox") 21 | annot_dir = "Annotations" 22 | annotation_file = f"{image_set}.csv" 23 | self.annotation_file = os.path.join(self.base_dir, annot_dir, annotation_file) 24 | 25 | with open(self.annotation_file) as f: 26 | reader = csv.reader(f) 27 | # skip header 28 | reader.__next__() 29 | annotations = [row for row in reader] 30 | 31 | # Transpose [Image_fname, Label] 32 | annotations = [list(x) for x in zip(*annotations)] 33 | 34 | self.images = list( 35 | map(lambda x: os.path.join(self.base_dir, x), annotations[0]) 36 | ) 37 | self.targets = list(map(int, annotations[1])) 38 | 39 | def __getitem__(self, index) -> Tuple[Any, Any]: 40 | image = Image.open(self.images[index]).convert("RGB") 41 | target = self.targets[index] 42 | 43 | if self.transform is not None: 44 | image = self.transform(image) 45 | 46 | return image, target 47 | 48 | def __len__(self) -> int: 49 | return len(self.images) 50 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Any, Callable, Dict, Optional 3 | 4 | import numpy 5 | import torch 6 | import torch.utils.data as data 7 | import torchvision.transforms as transforms 8 | from torch import nn 9 | 10 | from data.CUB import CUB_CLASSES, CUBDataset 11 | from data.imagenet import IMAGENET_CLASSES, ImageNetDataset 12 | from data.RDD import RDDDataset 13 | from data.RDD_bbox import RDDBboxDataset 14 | from data.sampler import BalancedBatchSampler 15 | from metrics.accuracy import Accuracy, MultiClassAccuracy 16 | 17 | ALL_DATASETS = ["RDD", "RDD_bbox", "CUB", "ImageNet"] 18 | 19 | 20 | def seed_worker(worker_id): 21 | worker_seed = torch.initial_seed() % 2**32 22 | numpy.random.seed(worker_seed) 23 | random.seed(worker_seed) 24 | 25 | 26 | def get_generator(seed=0): 27 | g = torch.Generator() 28 | g.manual_seed(seed) 29 | return g 30 | 31 | 32 | def create_dataloader_dict( 33 | dataset_name: str, 34 | batch_size: int, 35 | image_size: int = 224, 36 | only_test: bool = False, 37 | train_ratio: float = 0.9, 38 | shuffle_val: bool = False, 39 | dataloader_seed: int = 0, 40 | ) -> Dict[str, data.DataLoader]: 41 | """ 42 | Create dataloader dictionary 43 | 44 | Args: 45 | dataset_name(str) : Dataset name 46 | batch_size (int) : Batch size 47 | image_size (int) : Image size 48 | only_test (bool): Create only test dataset 49 | train_ratio(float): Train / val split ratio when there is no val 50 | 51 | Returns: 52 | dataloader_dict : Dataloader dictionary 53 | """ 54 | 55 | test_dataset = create_dataset(dataset_name, "test", image_size) 56 | test_dataloader = data.DataLoader( 57 | test_dataset, batch_size=batch_size, shuffle=shuffle_val, worker_init_fn=seed_worker, generator=get_generator(dataloader_seed), 58 | ) 59 | 60 | if only_test: 61 | return {"Test": test_dataloader} 62 | 63 | train_dataset = create_dataset( 64 | dataset_name, 65 | "train", 66 | image_size, 67 | ) 68 | 69 | dataset_params = get_parameter_depend_in_data_set(dataset_name) 70 | 71 | # Create val or split 72 | if dataset_params["has_val"]: 73 | val_dataset = create_dataset( 74 | dataset_name, 75 | "val", 76 | image_size, 77 | ) 78 | else: 79 | train_size = int(train_ratio * len(train_dataset)) 80 | val_size = len(train_dataset) - train_size 81 | 82 | train_dataset, val_dataset = data.random_split( 83 | train_dataset, [train_size, val_size] 84 | ) 85 | val_dataset.transform = test_dataset.transform 86 | 87 | if dataset_params["sampler"]: 88 | train_dataloader = data.DataLoader( 89 | train_dataset, 90 | batch_sampler=BalancedBatchSampler(train_dataset, 2, batch_size // 2), 91 | worker_init_fn=seed_worker, 92 | generator=get_generator(dataloader_seed), 93 | ) 94 | else: 95 | train_dataloader = data.DataLoader( 96 | train_dataset, batch_size=batch_size, shuffle=True, worker_init_fn=seed_worker, generator=get_generator(dataloader_seed) 97 | ) 98 | val_dataloader = data.DataLoader( 99 | val_dataset, batch_size=batch_size, shuffle=shuffle_val, worker_init_fn=seed_worker, generator=get_generator(dataloader_seed) 100 | ) 101 | 102 | dataloader_dict = { 103 | "Train": train_dataloader, 104 | "Val": val_dataloader, 105 | "Test": test_dataloader, 106 | } 107 | 108 | print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}") 109 | return dataloader_dict 110 | 111 | 112 | def create_dataset( 113 | dataset_name: str, 114 | image_set: str = "train", 115 | image_size: int = 224, 116 | transform: Optional[Callable] = None, 117 | ) -> data.Dataset: 118 | """ 119 | Create dataset 120 | Normalization parameters are created for each dataset 121 | 122 | Args: 123 | dataset_name(str) : Dataset name 124 | image_set(str) : Choose from train / val / test 125 | image_size(int) : Image size 126 | transform(Callable): transform 127 | 128 | Returns: 129 | data.Dataset : PyTorch dataset 130 | """ 131 | assert dataset_name in ALL_DATASETS 132 | params = get_parameter_depend_in_data_set(dataset_name) 133 | 134 | if transform is None: 135 | if image_set == "train": 136 | transform = transforms.Compose( 137 | [ 138 | transforms.Resize(int(image_size/0.875)), 139 | transforms.RandomResizedCrop(image_size), 140 | transforms.RandomHorizontalFlip(), 141 | transforms.ToTensor(), 142 | transforms.Normalize(params["mean"], params["std"]), 143 | ] 144 | ) 145 | else: 146 | transform = transforms.Compose( 147 | [ 148 | transforms.Resize(int(image_size/0.875)), 149 | transforms.CenterCrop(image_size), 150 | transforms.ToTensor(), 151 | transforms.Normalize(params["mean"], params["std"]), 152 | ] 153 | ) 154 | 155 | if params["has_params"]: 156 | dataset = params["dataset"]( 157 | root="./datasets", 158 | image_set=image_set, 159 | params=params, 160 | transform=transform, 161 | ) 162 | else: 163 | dataset = params["dataset"]( 164 | root="./datasets", image_set=image_set, transform=transform 165 | ) 166 | 167 | return dataset 168 | 169 | 170 | def get_parameter_depend_in_data_set( 171 | dataset_name: str, 172 | pos_weight: torch.Tensor = torch.Tensor([1]), 173 | dataset_root: str = "./datasets", 174 | ) -> Dict[str, Any]: 175 | """ 176 | Get parameters of the dataset 177 | 178 | Args: 179 | dataset_name(str): Dataset name 180 | 181 | Returns: 182 | dict[str, Any]: Parameters such as mean, variance, class name, etc. 183 | """ 184 | params = dict() 185 | params["name"] = dataset_name 186 | params["root"] = dataset_root 187 | # Whether to pass params to the dataset class 188 | params["has_params"] = False 189 | params["num_channel"] = 3 190 | params["sampler"] = True 191 | # ImageNet 192 | params["mean"] = (0.485, 0.456, 0.406) 193 | params["std"] = (0.229, 0.224, 0.225) 194 | 195 | if dataset_name == "RDD": 196 | params["dataset"] = RDDDataset 197 | params["num_channel"] = 3 198 | params["mean"] = (0.4770, 0.5026, 0.5094) 199 | params["std"] = (0.2619, 0.2684, 0.3001) 200 | params["classes"] = ("no crack", "crack") 201 | params["has_val"] = False 202 | params["has_params"] = False 203 | params["sampler"] = False 204 | 205 | params["metric"] = Accuracy() 206 | params["criterion"] = nn.CrossEntropyLoss() 207 | elif dataset_name == "RDD_bbox": 208 | params["dataset"] = RDDBboxDataset 209 | params["num_channel"] = 3 210 | params["mean"] = (0.4401, 0.4347, 0.4137) 211 | params["std"] = (0.2016, 0.1871, 0.1787) 212 | params["classes"] = ("no crack", "crack") 213 | params["has_val"] = False 214 | params["has_params"] = False 215 | params["sampler"] = False 216 | params["metric"] = Accuracy() 217 | params["criterion"] = nn.CrossEntropyLoss() 218 | elif dataset_name == "CUB": 219 | params["dataset"] = CUBDataset 220 | params["num_channel"] = 3 221 | # params["mean"] = (0.4859, 0.4996, 0.4318) 222 | # params["std"] = (0.2266, 0.2218, 0.2609) 223 | params["mean"] = (0.485, 0.456, 0.406) # trace ImageNet 224 | params["std"] = (0.229, 0.224, 0.225) # trace ImageNet 225 | params["classes"] = CUB_CLASSES 226 | params["has_val"] = False 227 | params["has_params"] = False 228 | params["sampler"] = False 229 | params["metric"] = MultiClassAccuracy() 230 | params["criterion"] = nn.CrossEntropyLoss() 231 | elif dataset_name == "ImageNet": 232 | params["dataset"] = ImageNetDataset 233 | params["num_channel"] = 3 234 | params["mean"] = (0.485, 0.456, 0.406) 235 | params["std"] = (0.229, 0.224, 0.225) 236 | params["classes"] = IMAGENET_CLASSES 237 | params["has_val"] = False # Use val. set as the test set 238 | params["has_params"] = False 239 | params["sampler"] = False 240 | params["metric"] = MultiClassAccuracy() 241 | params["criterion"] = nn.CrossEntropyLoss() 242 | 243 | return params 244 | -------------------------------------------------------------------------------- /data/sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from torch.utils.data.sampler import BatchSampler 5 | 6 | 7 | class BalancedBatchSampler(BatchSampler): 8 | """ 9 | BatchSampler - from a MNIST-like dataset, samples n_classes and within these classes samples n_samples. 10 | Returns batches of size n_classes * n_samples 11 | """ 12 | 13 | def __init__(self, dataset, n_classes, n_samples): 14 | loader = DataLoader(dataset) 15 | self.labels_list = [] 16 | for _, label in loader: 17 | self.labels_list.append(label) 18 | self.labels = torch.LongTensor(self.labels_list) 19 | self.labels_set = list(set(self.labels.numpy())) 20 | self.label_to_indices = { 21 | label: np.where(self.labels.numpy() == label)[0] 22 | for label in self.labels_set 23 | } 24 | for l in self.labels_set: 25 | np.random.shuffle(self.label_to_indices[l]) 26 | self.used_label_indices_count = {label: 0 for label in self.labels_set} 27 | self.count = 0 28 | self.n_classes = n_classes 29 | self.n_samples = n_samples 30 | self.dataset = dataset 31 | self.batch_size = self.n_samples * self.n_classes 32 | 33 | def __iter__(self): 34 | self.count = 0 35 | while self.count + self.batch_size < len(self.dataset): 36 | classes = np.random.choice(self.labels_set, self.n_classes, replace=False) 37 | indices = [] 38 | for class_ in classes: 39 | indices.extend( 40 | self.label_to_indices[class_][ 41 | self.used_label_indices_count[ 42 | class_ 43 | ] : self.used_label_indices_count[class_] 44 | + self.n_samples 45 | ] 46 | ) 47 | self.used_label_indices_count[class_] += self.n_samples 48 | if self.used_label_indices_count[class_] + self.n_samples > len( 49 | self.label_to_indices[class_] 50 | ): 51 | np.random.shuffle(self.label_to_indices[class_]) 52 | self.used_label_indices_count[class_] = 0 53 | yield indices 54 | self.count += self.n_classes * self.n_samples 55 | 56 | def __len__(self): 57 | return len(self.dataset) // self.batch_size 58 | -------------------------------------------------------------------------------- /datasets/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/datasets/.gitkeep -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | ### Structure 2 | Download datasets (ImageNet and Caltech-UCSD Birds-200-2011) and place them as follows: 3 | ``` 4 | . 5 | ├── .gitkeep 6 | ├── CUB_200_2011 7 | │ ├── attributes 8 | │ ├── bounding_boxes.txt 9 | │ ├── classes.txt 10 | │ ├── image_class_labels.txt 11 | │ ├── images 12 | │ ├── images.txt 13 | │ ├── parts 14 | │ ├── README 15 | │ └── train_test_split.txt 16 | ├── imagenet 17 | │ ├── ILSVRC2012_devkit_t12.tar.gz 18 | │ ├── ILSVRC2012_img_val.tar 19 | │ ├── meta.bin 20 | │ └── val 21 | └── README.md 22 | ``` -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from typing import Dict, Optional, Tuple 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.utils.data as data 8 | from torchinfo import summary 9 | from tqdm import tqdm 10 | 11 | from data import (ALL_DATASETS, create_dataloader_dict, 12 | get_parameter_depend_in_data_set) 13 | from metrics.base import Metric 14 | from models import ALL_MODELS, create_model 15 | from utils.loss import calculate_loss 16 | from utils.utils import fix_seed, parse_with_config 17 | 18 | 19 | @torch.no_grad() 20 | def test( 21 | dataloader: data.DataLoader, 22 | model: nn.Module, 23 | criterion: nn.modules.loss._Loss, 24 | metrics: Metric, 25 | device: torch.device, 26 | phase: str = "Test", 27 | lambdas: Optional[Dict[str, float]] = None, 28 | ) -> Tuple[float, Metric]: 29 | total = 0 30 | total_loss: float = 0 31 | 32 | model.eval() 33 | for data in tqdm(dataloader, desc=f"{phase}: ", dynamic_ncols=True): 34 | inputs, labels = data[0].to(device), data[1].to(device) 35 | outputs = model(inputs) 36 | 37 | total_loss += calculate_loss(criterion, outputs, labels, model, lambdas).item() 38 | metrics.evaluate(outputs, labels) 39 | total += labels.size(0) 40 | 41 | test_loss = total_loss / total 42 | return test_loss, metrics 43 | 44 | 45 | import torchvision.transforms as transforms 46 | from PIL import Image 47 | 48 | from utils.visualize import save_image 49 | 50 | 51 | @torch.no_grad() 52 | def test_with_save( 53 | dataloader: data.DataLoader, 54 | model: nn.Module, 55 | criterion: nn.modules.loss._Loss, 56 | metrics: Metric, 57 | device: torch.device, 58 | phase: str = "Test", 59 | lambdas: Optional[Dict[str, float]] = None, 60 | ) -> Tuple[float, Metric]: 61 | total = 0 62 | total_loss: float = 0 63 | transform = transforms.ToPILImage() 64 | 65 | model.eval() 66 | for data in tqdm(dataloader, desc=f"{phase}: ", dynamic_ncols=True): 67 | inputs, labels = data[0].to(device), data[1].to(device) 68 | outputs = model(inputs) 69 | outputs = torch.softmax(outputs, dim=1) 70 | pred = torch.argmax(outputs, dim=1) 71 | misclassified = (pred != labels).nonzero().squeeze() 72 | if misclassified.numel() == 1: 73 | misclassified = [misclassified.item()] 74 | 75 | for index in misclassified: 76 | image = inputs[index, :3].cpu() 77 | # image = transform(image) 78 | 79 | true_label = labels[index].item() 80 | # predicted_label = pred[index].item() 81 | params = get_parameter_depend_in_data_set("RDD") 82 | save_image( 83 | image.detach().cpu().numpy(), 84 | f"outputs/error_anal/{true_label}/{outputs[index, true_label]:.4f}.png", 85 | params["mean"], 86 | params["std"], 87 | ) 88 | 89 | # image.save( 90 | # f"outputs/error_anal/{true_label}/{outputs[index, true_label]}.png" 91 | # ) 92 | 93 | total_loss += calculate_loss(criterion, outputs, labels, model, lambdas).item() 94 | metrics.evaluate(outputs, labels) 95 | total += labels.size(0) 96 | 97 | test_loss = total_loss / total 98 | return test_loss, metrics 99 | 100 | 101 | def main(args: argparse.Namespace) -> None: 102 | fix_seed(args.seed, args.no_deterministic) 103 | 104 | # データセットの作成 105 | dataloader_dict = create_dataloader_dict( 106 | args.dataset, args.batch_size, args.image_size, only_test=True 107 | ) 108 | dataloader = dataloader_dict["Test"] 109 | assert isinstance(dataloader, data.DataLoader) 110 | 111 | params = get_parameter_depend_in_data_set(args.dataset) 112 | 113 | # モデルの作成 114 | model = create_model( 115 | args.model, 116 | num_classes=len(params["classes"]), 117 | num_channel=params["num_channel"], 118 | base_pretrained=args.base_pretrained, 119 | base_pretrained2=args.base_pretrained2, 120 | pretrained_path=args.pretrained, 121 | attention_branch=args.add_attention_branch, 122 | division_layer=args.div, 123 | theta_attention=args.theta_att, 124 | ) 125 | assert model is not None, "Model name is invalid" 126 | 127 | model.load_state_dict(torch.load(args.pretrained)) 128 | print(f"pretrained {args.pretrained} loaded.") 129 | 130 | criterion = params["criterion"] 131 | metric = params["metric"] 132 | 133 | # run_nameをpretrained pathから取得 134 | # checkpoints/run_name/checkpoint.pt -> run_name 135 | run_name = args.pretrained.split(os.sep)[-2] 136 | save_dir = os.path.join("outputs", run_name) 137 | if not os.path.isdir(save_dir): 138 | os.makedirs(save_dir) 139 | 140 | summary( 141 | model, 142 | (args.batch_size, params["num_channel"], args.image_size, args.image_size), 143 | ) 144 | 145 | model.to(device) 146 | 147 | lambdas = {"att": args.lambda_att, "var": args.lambda_var} 148 | loss, metrics = test_with_save( 149 | dataloader, model, criterion, metric, device, lambdas=lambdas 150 | ) 151 | metric_log = metrics.log() 152 | print(f"Test\t| {metric_log} Loss: {loss:.5f}") 153 | 154 | 155 | def parse_args() -> argparse.Namespace: 156 | parser = argparse.ArgumentParser() 157 | 158 | parser.add_argument("-c", "--config", type=str, help="path to config file (json)") 159 | 160 | parser.add_argument("--seed", type=int, default=42) 161 | parser.add_argument("--no_deterministic", action="store_false") 162 | 163 | # Model 164 | parser.add_argument("-m", "--model", choices=ALL_MODELS, help="model name") 165 | parser.add_argument( 166 | "-add_ab", 167 | "--add_attention_branch", 168 | action="store_true", 169 | help="add Attention Branch", 170 | ) 171 | parser.add_argument( 172 | "-div", 173 | "--division_layer", 174 | type=str, 175 | choices=["layer1", "layer2", "layer3"], 176 | default="layer2", 177 | help="place to attention branch", 178 | ) 179 | parser.add_argument("--base_pretrained", type=str, help="path to base pretrained") 180 | parser.add_argument("--pretrained", type=str, help="path to pretrained") 181 | 182 | # Dataset 183 | parser.add_argument("--dataset", type=str, default="IDRiD", choices=ALL_DATASETS) 184 | parser.add_argument("--image_size", type=int, default=224) 185 | parser.add_argument("--batch_size", type=int, default=32) 186 | parser.add_argument( 187 | "--loss_weights", 188 | type=float, 189 | nargs="*", 190 | default=[1.0, 1.0], 191 | help="weights for label by class", 192 | ) 193 | 194 | parser.add_argument( 195 | "--lambda_att", type=float, default=0.1, help="weights for attention loss" 196 | ) 197 | parser.add_argument( 198 | "--lambda_var", type=float, default=0, help="weights for variance loss" 199 | ) 200 | 201 | parser.add_argument("--root_dir", type=str, default="./outputs/") 202 | 203 | return parse_with_config(parser) 204 | 205 | 206 | if __name__ == "__main__": 207 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 208 | 209 | main(parse_args()) 210 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/metrics/__init__.py -------------------------------------------------------------------------------- /metrics/accuracy.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple 2 | 3 | import torch 4 | 5 | from metrics.base import Metric 6 | 7 | 8 | def num_correct_topk( 9 | output: torch.Tensor, target: torch.Tensor, topk: Tuple[int] = (1,) 10 | ) -> List[int]: 11 | """ 12 | Calculate top-k Accuracy 13 | 14 | Args: 15 | output(Tensor) : Model output 16 | target(Tensor) : Label 17 | topk(Tuple[int]): How many top ranks should be correct 18 | 19 | Returns: 20 | List[int] : top-k Accuracy 21 | """ 22 | maxk = max(topk) 23 | 24 | _, pred = output.topk(maxk, dim=1) 25 | pred = pred.t() 26 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 27 | 28 | # [[False, False, True], [F, F, F], [T, F, F]] 29 | # -> [0, 0, 1, 0, 0, 0, 1, 0, 0] -> 2 30 | result = [] 31 | for k in topk: 32 | correct_k = correct[:k].reshape(-1).float().sum(0) 33 | result.append(correct_k) 34 | return result 35 | 36 | 37 | class Accuracy(Metric): 38 | def __init__(self) -> None: 39 | self.total = 0 40 | self.correct = 0 41 | self.tp = 0 42 | self.tn = 0 43 | self.fp = 0 44 | self.fn = 0 45 | 46 | def evaluate(self, preds: torch.Tensor, labels: torch.Tensor) -> None: 47 | self.total += labels.size(0) 48 | self.correct += num_correct_topk(preds, labels)[0] 49 | 50 | preds = torch.max(preds, dim=-1)[1] 51 | tp, fp, tn, fn = confusion(preds, labels) 52 | self.tp += tp 53 | self.fp += fp 54 | self.tn += tn 55 | self.fn += fn 56 | 57 | def score(self) -> Dict[str, float]: 58 | return { 59 | "Acc": self.acc(), 60 | "TP": int(self.tp), 61 | "FP": int(self.fp), 62 | "FN": int(self.fn), 63 | "TN": int(self.tn), 64 | } 65 | 66 | def acc(self) -> float: 67 | return self.correct / self.total 68 | 69 | def clear(self) -> None: 70 | self.total = 0 71 | self.correct = 0 72 | self.tp = 0 73 | self.fp = 0 74 | self.fn = 0 75 | self.tn = 0 76 | 77 | 78 | def confusion(output, target) -> Tuple[int, int, int, int]: 79 | """ 80 | Calculate Confusion Matrix 81 | 82 | Args: 83 | output(Tensor) : Model output 84 | target(Tensor) : Label 85 | 86 | Returns: 87 | true_positive(int) : Number of TP 88 | false_positive(int): Number of FP 89 | true_negative(int) : Number of TN 90 | false_negative(int): Number of FN 91 | """ 92 | 93 | # TP: 1/1 = 1, FP: 1/0 -> inf, TN: 0/0 -> nan, FN: 0/1 -> 0 94 | confusion_vector = output / target 95 | 96 | true_positives = torch.sum(confusion_vector == 1).item() 97 | false_positives = torch.sum(confusion_vector == float("inf")).item() 98 | true_negatives = torch.sum(torch.isnan(confusion_vector)).item() 99 | false_negatives = torch.sum(confusion_vector == 0).item() 100 | 101 | return ( 102 | int(true_positives), 103 | int(false_positives), 104 | int(true_negatives), 105 | int(false_negatives), 106 | ) 107 | 108 | 109 | class MultiClassAccuracy(Metric): 110 | def __init__(self) -> None: 111 | self.total = 0 112 | self.top1_correct = 0 113 | self.top5_correct = 0 114 | 115 | def evaluate(self, preds: torch.Tensor, labels: torch.Tensor) -> None: 116 | self.total += labels.size(0) 117 | correct_top1, correct_top5 = num_correct_topk(preds, labels, (1, 5)) 118 | self.top1_correct += correct_top1 119 | self.top5_correct += correct_top5 120 | 121 | def score(self) -> Dict[str, float]: 122 | return { 123 | "Top-1 Acc": self.acc(), 124 | "Top-5 Acc": self.top5_acc(), 125 | } 126 | 127 | def acc(self) -> float: 128 | return self.top1_correct / self.total 129 | 130 | def top5_acc(self) -> float: 131 | return self.top5_correct / self.total 132 | 133 | def clear(self) -> None: 134 | self.total = 0 135 | self.top1_correct = 0 136 | self.top5_correct = 0 -------------------------------------------------------------------------------- /metrics/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from typing import Dict 3 | 4 | import torch 5 | 6 | 7 | class Metric(metaclass=ABCMeta): 8 | @abstractmethod 9 | def evaluate(self, preds: torch.Tensor, labels: torch.Tensor) -> None: 10 | pass 11 | 12 | @abstractmethod 13 | def score(self) -> Dict[str, float]: 14 | pass 15 | 16 | @abstractmethod 17 | def clear(self) -> None: 18 | pass 19 | 20 | def log(self) -> str: 21 | result = "" 22 | scores = self.score() 23 | for name, score in scores.items(): 24 | if isinstance(score, int): 25 | result += f"{name}: {score} " 26 | else: 27 | result += f"{name}: {score:.3f} " 28 | 29 | return result[:-1] 30 | -------------------------------------------------------------------------------- /metrics/patch_insdel.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from typing import Dict, Union 4 | 5 | import cv2 6 | import numpy as np 7 | import skimage.measure 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from data import get_parameter_depend_in_data_set 13 | from metrics.base import Metric 14 | from utils.utils import reverse_normalize 15 | from utils.visualize import save_data_as_plot, save_image 16 | 17 | 18 | class PatchInsertionDeletion(Metric): 19 | def __init__( 20 | self, 21 | model: nn.Module, 22 | batch_size: int, 23 | patch_size: int, 24 | step: int, 25 | dataset: str, 26 | device: torch.device, 27 | ) -> None: 28 | self.total = 0 29 | self.total_insertion = 0 30 | self.total_deletion = 0 31 | self.class_insertion: Dict[int, float] = {} 32 | self.num_by_classes: Dict[int, int] = {} 33 | self.class_deletion: Dict[int, float] = {} 34 | 35 | self.model = model 36 | self.batch_size = batch_size 37 | self.step = step 38 | self.device = device 39 | self.patch_size = patch_size 40 | self.dataset = dataset 41 | 42 | def evaluate( 43 | self, 44 | image: np.ndarray, 45 | attention: np.ndarray, 46 | label: Union[np.ndarray, torch.Tensor], 47 | ) -> None: 48 | self.image = image.copy() 49 | self.label = int(label.item()) 50 | 51 | # image (C, W, H), attention (1, W', H') -> attention (W, H) 52 | self.attention = attention 53 | if not (self.image.shape[1:] == attention.shape): 54 | self.attention = cv2.resize( 55 | attention[0], dsize=(self.image.shape[1], self.image.shape[2]) 56 | ) 57 | 58 | # Divide attention map into patches and calculate the order of patches 59 | self.divide_attention_map_into_patch() 60 | self.calculate_attention_order() 61 | 62 | # Create input for insertion and inference 63 | self.generate_insdel_images(mode="insertion") 64 | self.ins_preds = self.inference() # for plot 65 | self.ins_auc = auc(self.ins_preds) 66 | self.total_insertion += self.ins_auc 67 | del self.input 68 | 69 | # deletion 70 | self.generate_insdel_images(mode="deletion") 71 | self.del_preds = self.inference() 72 | self.del_auc = auc(self.del_preds) 73 | self.total_deletion += self.del_auc 74 | del self.input 75 | 76 | self.total += 1 77 | 78 | def divide_attention_map_into_patch(self): 79 | assert self.attention is not None 80 | 81 | self.patch_attention = skimage.measure.block_reduce( 82 | self.attention, (self.patch_size, self.patch_size), np.max 83 | ) 84 | 85 | def calculate_attention_order(self): 86 | attention_flat = np.ravel(self.patch_attention) 87 | # Sort in descending order 88 | order = np.argsort(-attention_flat) 89 | 90 | W, H = self.attention.shape 91 | patch_w, _ = W // self.patch_size, H // self.patch_size 92 | self.order = np.apply_along_axis( 93 | lambda x: map_2d_indices(x, patch_w), axis=0, arr=order 94 | ) 95 | 96 | def generate_insdel_images(self, mode: str): 97 | C, W, H = self.image.shape 98 | patch_w, patch_h = W // self.patch_size, H // self.patch_size 99 | num_insertion = math.ceil(patch_w * patch_h / self.step) 100 | 101 | params = get_parameter_depend_in_data_set(self.dataset) 102 | self.input = np.zeros((num_insertion, C, W, H)) 103 | mean, std = params["mean"], params["std"] 104 | image = reverse_normalize(self.image.copy(), mean, std) 105 | 106 | for i in range(num_insertion): 107 | step_index = min(self.step * (i + 1), self.order.shape[1] - 1) 108 | w_indices = self.order[0, step_index] 109 | h_indices = self.order[1, step_index] 110 | threthold = self.patch_attention[w_indices, h_indices] 111 | 112 | if mode == "insertion": 113 | mask = np.where(threthold <= self.patch_attention, 1, 0) 114 | elif mode == "deletion": 115 | mask = np.where(threthold <= self.patch_attention, 0, 1) 116 | 117 | mask = cv2.resize(mask, dsize=(W, H), interpolation=cv2.INTER_NEAREST) 118 | 119 | for c in range(min(3, C)): 120 | self.input[i, c] = (image[c] * mask - mean[c]) / std[c] 121 | 122 | def inference(self): 123 | inputs = torch.Tensor(self.input) 124 | 125 | num_iter = math.ceil(inputs.size(0) / self.batch_size) 126 | result = torch.zeros(0) 127 | 128 | for iter in range(num_iter): 129 | start = self.batch_size * iter 130 | batch_inputs = inputs[start : start + self.batch_size].to(self.device) 131 | 132 | outputs = self.model(batch_inputs) 133 | 134 | outputs = F.softmax(outputs, 1) 135 | outputs = outputs[:, self.label] 136 | result = torch.cat([result, outputs.cpu().detach()], dim=0) 137 | 138 | return np.nan_to_num(result) 139 | 140 | def save_images(self): 141 | params = get_parameter_depend_in_data_set(self.dataset) 142 | for i, image in enumerate(self.input): 143 | save_image(image, f"tmp/{self.total}_{i}", params["mean"], params["std"]) 144 | 145 | def score(self) -> Dict[str, float]: 146 | result = { 147 | "Insertion": self.insertion(), 148 | "Deletion": self.deletion(), 149 | "PID": self.insertion() - self.deletion(), 150 | } 151 | 152 | for class_idx in self.class_insertion.keys(): 153 | self.class_insertion_score(class_idx) 154 | self.class_deletion_score(class_idx) 155 | 156 | return result 157 | 158 | def log(self) -> str: 159 | result = "Class\tPID\tIns\tDel\n" 160 | 161 | scores = self.score() 162 | result += f"All\t{scores['PID']:.3f}\t{scores['Insertion']:.3f}\t{scores['Deletion']:.3f}\n" 163 | 164 | for class_idx in self.class_insertion.keys(): 165 | pid = scores[f"PID_{class_idx}"] 166 | insertion = scores[f"Insertion_{class_idx}"] 167 | deletion = scores[f"Deletion_{class_idx}"] 168 | result += f"{class_idx}\t{pid:.3f}\t{insertion:.3f}\t{deletion:.3f}\n" 169 | 170 | return result 171 | 172 | def insertion(self) -> float: 173 | return self.total_insertion / self.total 174 | 175 | def deletion(self) -> float: 176 | return self.total_deletion / self.total 177 | 178 | def class_insertion_score(self, class_idx: int) -> float: 179 | num_samples = self.num_by_classes[class_idx] 180 | inserton_score = self.class_insertion[class_idx] 181 | 182 | return inserton_score / num_samples 183 | 184 | def class_deletion_score(self, class_idx: int) -> float: 185 | num_samples = self.num_by_classes[class_idx] 186 | deletion_score = self.class_deletion[class_idx] 187 | 188 | return deletion_score / num_samples 189 | 190 | def clear(self) -> None: 191 | self.total = 0 192 | self.ins_preds = None 193 | self.del_preds = None 194 | 195 | def save_roc_curve(self, save_dir: str) -> None: 196 | ins_fname = os.path.join(save_dir, f"{self.total}_insertion.png") 197 | save_data_as_plot(self.ins_preds, ins_fname, label=f"AUC = {self.ins_auc:.4f}") 198 | 199 | del_fname = os.path.join(save_dir, f"{self.total}_deletion.png") 200 | save_data_as_plot(self.del_preds, del_fname, label=f"AUC = {self.del_auc:.4f}") 201 | 202 | 203 | def map_2d_indices(indices_1d: int, width: int): 204 | """ 205 | Convert 1D index to 2D index 206 | 1D index is converted to 2D index 207 | 208 | Args: 209 | indices_1d(array): index 210 | width(int) : width 211 | 212 | Examples: 213 | [[0, 1, 2], [3, 4, 5]] 214 | -> [0, 1, 2, 3, 4, 5] 215 | 216 | map_2d_indices(1, 3) 217 | >>> [0, 1] 218 | map_ed_indices(5, 3) 219 | >>> [1, 2] 220 | 221 | Return the index of the array before flattening 222 | """ 223 | return [indices_1d // width, indices_1d % width] 224 | 225 | 226 | def auc(arr): 227 | """Returns normalized Area Under Curve of the array.""" 228 | return (arr.sum() - arr[0] / 2 - arr[-1] / 2) / (arr.shape[0] - 1) 229 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torchvision 7 | from torchvision.models.resnet import ( 8 | ResNet, 9 | ResNet50_Weights, 10 | resnet18, 11 | resnet34, 12 | resnet50, 13 | ) 14 | from torchvision.models.vgg import VGG 15 | 16 | from models.attention_branch import add_attention_branch 17 | from models.lrp import replace_resnet_modules 18 | 19 | ALL_MODELS = [ 20 | "resnet18", 21 | "resnet34", 22 | "resnet", 23 | "resnet50", 24 | "resnet50-legacy", 25 | "EfficientNet", 26 | "wide", 27 | "vgg", 28 | "vgg19", 29 | "vgg19_bn", 30 | "vgg19_skip", 31 | "swin", 32 | ] 33 | 34 | 35 | class OneWayResNet(nn.Module): 36 | def __init__(self, model: ResNet) -> None: 37 | super().__init__() 38 | self.features = nn.Sequential( 39 | model.conv1, 40 | model.bn1, 41 | model.relu, 42 | model.maxpool, 43 | model.layer1, 44 | model.layer2, 45 | model.layer3, 46 | model.layer4, 47 | ) 48 | self.avgpool = model.avgpool 49 | self.classifier = nn.Sequential(model.fc) 50 | 51 | def forward(self, x: torch.Tensor) -> torch.Tensor: 52 | x = self.features(x) 53 | x = self.avgpool(x) 54 | x = torch.flatten(x, 1) 55 | x = self.classifier(x) 56 | return x 57 | 58 | 59 | class _vgg19_skip_forward: 60 | def __init__(self, model_ref: VGG) -> None: 61 | self.model_ref = model_ref 62 | 63 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 64 | # [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"] 65 | x = self.model_ref.features[0:2](x) # conv 3->64, relu 66 | x = self.model_ref.features[2:4](x) + x # conv 64->64, relu 67 | x = self.model_ref.features[4:5](x) # maxpool 68 | x = self.model_ref.features[5:7](x) # conv 64->128, relu 69 | x = self.model_ref.features[7:9](x) + x # conv 128->128, relu 70 | x = self.model_ref.features[9:10](x) # maxpool 71 | x = self.model_ref.features[10:12](x) # conv 128->256, relu 72 | x = self.model_ref.features[12:14](x) + x # conv 256->256, relu 73 | x = self.model_ref.features[14:16](x) + x # conv 256->256, relu 74 | x = self.model_ref.features[16:18](x) + x # conv 256->256, relu 75 | x = self.model_ref.features[18:19](x) # maxpool 76 | x = self.model_ref.features[19:21](x) # conv 256->512, relu 77 | x = self.model_ref.features[21:23](x) + x # conv 512->512, relu 78 | x = self.model_ref.features[23:25](x) + x # conv 512->512, relu 79 | x = self.model_ref.features[25:27](x) + x # conv 512->512, relu 80 | x = self.model_ref.features[27:28](x) # maxpool 81 | x = self.model_ref.features[28:30](x) + x # conv 512->512, relu 82 | x = self.model_ref.features[30:32](x) + x # conv 512->512, relu 83 | x = self.model_ref.features[32:34](x) + x # conv 512->512, relu 84 | x = self.model_ref.features[34:36](x) + x # conv 512->512, relu 85 | x = self.model_ref.features[36:37](x) # maxpool 86 | # x = self.model_ref.features(x) 87 | x = self.model_ref.avgpool(x) 88 | x = torch.flatten(x, 1) 89 | x = self.model_ref.classifier(x) 90 | return x 91 | 92 | 93 | def create_model( 94 | base_model: str, 95 | num_classes: int = 1000, 96 | num_channel: int = 3, 97 | base_pretrained: Optional[str] = None, 98 | base_pretrained2: Optional[str] = None, 99 | pretrained_path: Optional[str] = None, 100 | attention_branch: bool = False, 101 | division_layer: Optional[str] = None, 102 | theta_attention: float = 0, 103 | init_classifier: bool = True, 104 | ) -> Optional[nn.Module]: 105 | """ 106 | Create model from model name and parameters 107 | 108 | Args: 109 | base_model(str) : Base model name (resnet, etc.) 110 | num_classes(int) : Number of classes 111 | base_pretrained(str) : Base model pretrain path 112 | base_pretrained2(str) : Pretrain path after changing the final layer 113 | pretrained_path(str) : Pretrain path of the final model 114 | (After attention branch, etc.) 115 | attention_branch(bool): Whether to attention branch 116 | division_layer(str) : Which layer to introduce attention branch 117 | theta_attention(float): Threshold when entering Attention Branch 118 | Set pixels with lower attention than this value to 0 and input 119 | 120 | Returns: 121 | nn.Module: Created model 122 | """ 123 | # Create base model 124 | if base_model == "resnet": 125 | model = resnet18(pretrained=(base_pretrained is not None)) 126 | layer_index = {"layer1": -6, "layer2": -5, "layer3": -4} 127 | if init_classifier: 128 | model.fc = nn.Linear(model.fc.in_features, num_classes) 129 | elif base_model == "resnet18": 130 | model = resnet18(pretrained=True) 131 | if init_classifier: 132 | model.fc = nn.Linear(model.fc.in_features, num_classes) 133 | elif base_model == "resnet34": 134 | model = resnet34(pretrained=True) 135 | if init_classifier: 136 | model.fc = nn.Linear(model.fc.in_features, num_classes) 137 | elif base_model == "resnet50": 138 | model = resnet50(pretrained=True) 139 | if init_classifier: 140 | model.fc = nn.Linear(model.fc.in_features, num_classes) 141 | elif base_model == "resnet50-v2": 142 | model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) 143 | if init_classifier: 144 | model.fc = nn.Linear(model.fc.in_features, num_classes) 145 | elif base_model == "vgg": 146 | model = torchvision.models.vgg11(pretrained=True) 147 | model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, num_classes) 148 | elif base_model == "vgg19": 149 | model = torchvision.models.vgg19(pretrained=True) 150 | model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, num_classes) 151 | elif base_model == "vgg19_bn": 152 | model = torchvision.models.vgg19_bn(pretrained=True) 153 | model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, num_classes) 154 | elif base_model == "vgg19_skip": 155 | model = torchvision.models.vgg19(pretrained=True) 156 | model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, num_classes) 157 | model.forward = _vgg19_skip_forward(model) 158 | else: 159 | return None 160 | 161 | # Load if base_pretrained is path 162 | if base_pretrained is not None and os.path.isfile(base_pretrained): 163 | model.load_state_dict(torch.load(base_pretrained)) 164 | print(f"base pretrained {base_pretrained} loaded.") 165 | 166 | if attention_branch: 167 | assert division_layer is not None 168 | model = add_attention_branch( 169 | model, 170 | layer_index[division_layer], 171 | num_classes, 172 | theta_attention, 173 | ) 174 | model.attention_branch = replace_resnet_modules(model.attention_branch) 175 | 176 | # Load if pretrained is path 177 | if pretrained_path is not None and os.path.isfile(pretrained_path): 178 | model.load_state_dict(torch.load(pretrained_path)) 179 | print(f"pretrained {pretrained_path} loaded.") 180 | 181 | return model 182 | -------------------------------------------------------------------------------- /models/attention_branch.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Type 2 | 3 | import torch 4 | import torch.nn as nn 5 | from timm.models.byobnet import ByobNet 6 | 7 | from utils.utils import module_generator 8 | 9 | 10 | class BatchNorm2dWithActivation(nn.BatchNorm2d): 11 | def __init__(self, module): 12 | super(BatchNorm2dWithActivation, self).__init__( 13 | module.num_features, 14 | eps=module.eps, 15 | momentum=module.momentum, 16 | affine=module.affine, 17 | track_running_stats=module.track_running_stats, 18 | ) 19 | 20 | def forward(self, x): 21 | self.activations = x.detach().clone() 22 | return super().forward(x) 23 | 24 | 25 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 26 | """3x3 convolution with padding""" 27 | return nn.Conv2d( 28 | in_planes, 29 | out_planes, 30 | kernel_size=3, 31 | stride=stride, 32 | padding=dilation, 33 | groups=groups, 34 | bias=False, 35 | dilation=dilation, 36 | ) 37 | 38 | 39 | def conv1x1(in_planes, out_planes, stride=1): 40 | """1x1 convolution""" 41 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 42 | 43 | 44 | class BasicBlock(nn.Module): 45 | expansion = 1 46 | 47 | def __init__(self, inplanes, planes, stride=1, downsample=None): 48 | super(BasicBlock, self).__init__() 49 | self.inplanes = inplanes 50 | self.conv1 = conv3x3(inplanes, planes, stride) 51 | self.bn1 = nn.BatchNorm2d(planes) 52 | self.relu = nn.ReLU(inplace=True) 53 | self.conv2 = conv3x3(planes, planes) 54 | self.bn2 = nn.BatchNorm2d(planes) 55 | self.downsample = downsample 56 | self.stride = stride 57 | 58 | def forward(self, x): 59 | residual = x 60 | 61 | out = self.conv1(x) 62 | out = self.bn1(out) 63 | out = self.relu(out) 64 | 65 | out = self.conv2(out) 66 | out = self.bn2(out) 67 | 68 | if self.downsample is not None: 69 | residual = self.downsample(x) 70 | 71 | out += residual 72 | out = self.relu(out) 73 | 74 | return out 75 | 76 | 77 | class Bottleneck(nn.Module): 78 | expansion = 4 79 | 80 | def __init__(self, inplanes, planes, stride=1, downsample=None): 81 | super(Bottleneck, self).__init__() 82 | self.inplanes = inplanes 83 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 84 | self.bn1 = nn.BatchNorm2d(planes) 85 | self.conv2 = nn.Conv2d( 86 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 87 | ) 88 | self.bn2 = nn.BatchNorm2d(planes) 89 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 90 | self.bn3 = nn.BatchNorm2d(planes * 4) 91 | self.relu = nn.ReLU(inplace=True) 92 | self.downsample = downsample 93 | self.stride = stride 94 | 95 | def forward(self, x): 96 | residual = x 97 | 98 | out = self.conv1(x) 99 | out = self.bn1(out) 100 | out = self.relu(out) 101 | 102 | out = self.conv2(out) 103 | out = self.bn2(out) 104 | out = self.relu(out) 105 | 106 | out = self.conv3(out) 107 | out = self.bn3(out) 108 | 109 | if self.downsample is not None: 110 | residual = self.downsample(x) 111 | 112 | out += residual 113 | out = self.relu(out) 114 | 115 | return out 116 | 117 | 118 | class AttentionBranch(nn.Module): 119 | def __init__( 120 | self, 121 | block: Bottleneck, 122 | num_layer: int, 123 | num_classes: int = 1000, 124 | inplanes: int = 64, 125 | multi_task: bool = False, 126 | num_tasks: Optional[List[int]] = None, 127 | ) -> None: 128 | super().__init__() 129 | 130 | self.inplanes = inplanes 131 | 132 | self.layer1 = self._make_layer( 133 | block, self.inplanes, num_layer, stride=1, down_size=False 134 | ) 135 | 136 | hidden_channel = 10 137 | 138 | self.bn1 = nn.BatchNorm2d(self.inplanes * block.expansion) 139 | # self.bn1 = BatchNorm2dWithActivation(bn1) 140 | self.conv1 = conv1x1(self.inplanes * block.expansion, hidden_channel) 141 | self.relu = nn.ReLU(inplace=True) 142 | 143 | if multi_task: 144 | if num_tasks is None: 145 | num_tasks = [1 for _ in range(num_classes)] 146 | assert num_classes == sum(num_tasks) 147 | self.conv2 = conv3x3(num_classes, len(num_tasks)) 148 | self.bn2 = nn.BatchNorm2d(len(num_tasks)) 149 | else: 150 | self.conv2 = conv1x1(hidden_channel, hidden_channel) 151 | self.bn2 = nn.BatchNorm2d(hidden_channel) 152 | self.conv3 = conv1x1(hidden_channel, 1) 153 | self.bn3 = nn.BatchNorm2d(1) 154 | 155 | self.sigmoid = nn.Sigmoid() 156 | 157 | self.conv4 = conv1x1(hidden_channel, num_classes) 158 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 159 | self.flatten = nn.Flatten(1) 160 | 161 | def _make_layer( 162 | self, 163 | block: Bottleneck, 164 | planes: int, 165 | blocks: int, 166 | stride: int = 1, 167 | down_size: bool = True, 168 | ) -> nn.Sequential: 169 | downsample = None 170 | if stride != 1 or self.inplanes != planes * block.expansion: 171 | downsample = nn.Sequential( 172 | nn.Conv2d( 173 | self.inplanes, 174 | planes * block.expansion, 175 | kernel_size=1, 176 | stride=stride, 177 | bias=False, 178 | ), 179 | nn.BatchNorm2d(planes * block.expansion), 180 | ) 181 | 182 | layers = [] 183 | layers.append(block(self.inplanes, planes, stride, downsample)) 184 | if down_size: 185 | self.inplanes = planes * block.expansion 186 | for _ in range(1, blocks): 187 | layers.append(block(self.inplanes, planes)) 188 | 189 | return nn.Sequential(*layers) 190 | else: 191 | inplanes = planes * block.expansion 192 | for _ in range(1, blocks): 193 | layers.append(block(inplanes, planes)) 194 | 195 | return nn.Sequential(*layers) 196 | 197 | def forward(self, x): 198 | x = self.layer1(x) 199 | 200 | self.bn1_activation = x 201 | x = self.bn1(x) 202 | x = self.conv1(x) 203 | x = self.relu(x) 204 | 205 | self.class_attention = self.sigmoid(x) 206 | 207 | attention = self.conv2(x) 208 | attention = self.bn2(attention) 209 | attention = self.relu(attention) 210 | 211 | attention = self.conv3(attention) 212 | attention = self.bn3(attention) 213 | 214 | self.attention_order = attention 215 | self.attention = self.sigmoid(attention) 216 | 217 | x = self.conv4(x) 218 | x = self.avgpool(x) 219 | x = self.flatten(x) 220 | 221 | return x 222 | 223 | 224 | class AttentionBranchModel(nn.Module): 225 | def __init__( 226 | self, 227 | feature_extractor: nn.Module, 228 | perception_branch: nn.Module, 229 | block: Type = Bottleneck, 230 | num_layer: int = 2, 231 | num_classes: int = 1000, 232 | inplanes: int = 64, 233 | theta_attention: float = 0, 234 | ) -> None: 235 | super().__init__() 236 | 237 | self.feature_extractor = feature_extractor 238 | 239 | self.attention_branch = AttentionBranch(block, num_layer, num_classes, inplanes) 240 | # self.attention_branch = replace_resnet_modules(self.attention_branch) 241 | 242 | self.perception_branch = perception_branch 243 | self.theta_attention = theta_attention 244 | 245 | def forward(self, x): 246 | x = self.feature_extractor(x) 247 | 248 | # For Attention Loss 249 | self.attention_pred = self.attention_branch(x) 250 | 251 | attention = self.attention_branch.attention 252 | attention = torch.where( 253 | self.theta_attention < attention, attention.double(), 0.0 254 | ) 255 | 256 | x = x * attention 257 | x = x.float() 258 | # x = attention_x + x 259 | 260 | return self.perception_branch(x) 261 | 262 | 263 | def add_attention_branch( 264 | base_model: nn.Module, 265 | division_index: int, 266 | num_classes: int, 267 | theta_attention: float = 0, 268 | ) -> nn.Module: 269 | modules = list(base_model.children()) 270 | # [conv1, bn1, relu, maxpool, layer1, layer2, layer3, layer4, avgpool, fc] 271 | 272 | pre_model = nn.Sequential(*modules[:division_index]) 273 | post_model = nn.Sequential(*modules[division_index:-1], nn.Flatten(1), modules[-1]) 274 | 275 | if isinstance(base_model, ByobNet): 276 | return AttentionBranchModel( 277 | pre_model, 278 | post_model, 279 | Bottleneck, 280 | 2, 281 | num_classes, 282 | inplanes=64, 283 | theta_attention=theta_attention, 284 | ) 285 | 286 | final_layer = module_generator(modules[division_index], reverse=True) 287 | for module in final_layer: 288 | if isinstance(module, nn.modules.batchnorm._NormBase): 289 | inplanes = module.num_features // 2 290 | break 291 | elif isinstance(module, nn.modules.conv._ConvNd): 292 | inplanes = module.out_channels 293 | break 294 | 295 | return AttentionBranchModel( 296 | pre_model, 297 | post_model, 298 | Bottleneck, 299 | 2, 300 | num_classes, 301 | inplanes=inplanes, 302 | theta_attention=theta_attention, 303 | ) 304 | -------------------------------------------------------------------------------- /models/lrp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision.models.resnet import BasicBlock as OriginalBasicBlock 5 | from torchvision.models.resnet import Bottleneck as OriginalBottleneck 6 | 7 | from models.attention_branch import BasicBlock as ABNBasicBlock 8 | from models.attention_branch import Bottleneck as ABNBottleneck 9 | 10 | 11 | class LinearWithActivation(nn.Linear): 12 | def __init__(self, module, out_features=None): 13 | in_features = module.in_features 14 | if out_features is None: 15 | out_features = module.out_features 16 | bias = module.bias is not None 17 | super(LinearWithActivation, self).__init__(in_features, out_features, bias) 18 | self.activations = None 19 | 20 | def forward(self, x): 21 | if len(x.shape) > 2: 22 | bs = x.shape[0] 23 | x = x.view(bs, -1) 24 | self.activations = x.detach().clone() 25 | return super(LinearWithActivation, self).forward(x) 26 | 27 | 28 | class Conv2dWithActivation(nn.Conv2d): 29 | def __init__(self, module, num_channel=None): 30 | in_channels = module.in_channels if num_channel is None else num_channel 31 | super(Conv2dWithActivation, self).__init__( 32 | in_channels, 33 | module.out_channels, 34 | module.kernel_size, 35 | stride=module.stride, 36 | padding=module.padding, 37 | dilation=module.dilation, 38 | groups=module.groups, 39 | padding_mode=module.padding_mode, 40 | bias=module.bias is not None, 41 | ) 42 | self.activations = None 43 | 44 | def forward(self, x): 45 | self.activations = x.detach().clone() 46 | return super().forward(x) 47 | 48 | 49 | class BatchNorm2dWithActivation(nn.BatchNorm2d): 50 | def __init__(self, module): 51 | super(BatchNorm2dWithActivation, self).__init__( 52 | module.num_features, 53 | eps=module.eps, 54 | momentum=module.momentum, 55 | affine=module.affine, 56 | track_running_stats=module.track_running_stats, 57 | ) 58 | self.activations = None 59 | 60 | def forward(self, x): 61 | self.activations = x.detach().clone() 62 | return super().forward(x) 63 | 64 | 65 | class ReLUWithActivation(nn.ReLU): 66 | def __init__(self, *args, **kwargs): 67 | super(ReLUWithActivation, self).__init__(*args, **kwargs) 68 | self.activations = None 69 | 70 | def forward(self, x): 71 | self.activations = x.detach().clone() 72 | return super(ReLUWithActivation, self).forward(x) 73 | 74 | 75 | class MaxPool2dWithActivation(nn.MaxPool2d): 76 | def __init__(self, module): 77 | super(MaxPool2dWithActivation, self).__init__( 78 | module.kernel_size, module.stride, module.padding, module.dilation 79 | ) 80 | self.activations = None 81 | 82 | def forward(self, x): 83 | self.activations = x.detach().clone() 84 | return super(MaxPool2dWithActivation, self).forward(x) 85 | 86 | 87 | class AdaptiveAvgPool2dWithActivation(nn.AdaptiveAvgPool2d): 88 | def __init__(self, module: nn.AdaptiveAvgPool2d): 89 | super(AdaptiveAvgPool2dWithActivation, self).__init__(module.output_size) 90 | self.activations = None 91 | 92 | def forward(self, x): 93 | self.activations = x.detach().clone() 94 | return super(AdaptiveAvgPool2dWithActivation, self).forward(x) 95 | 96 | 97 | def copy_weights(target_module, source_module): 98 | if isinstance( 99 | target_module, (nn.AdaptiveAvgPool2d, nn.MaxPool2d, nn.ReLU) 100 | ) and isinstance(source_module, (nn.AdaptiveAvgPool2d, nn.MaxPool2d, nn.ReLU)): 101 | # Do nothing for layers without weights 102 | return 103 | if isinstance(target_module, nn.Linear) and isinstance(source_module, nn.Linear): 104 | target_module.weight.data.copy_(source_module.weight.data) 105 | target_module.bias.data.copy_(source_module.bias.data) 106 | elif isinstance(target_module, nn.Conv2d) and isinstance(source_module, nn.Conv2d): 107 | target_module.weight.data.copy_(source_module.weight.data) 108 | if source_module.bias is not None: 109 | target_module.bias = source_module.bias 110 | elif isinstance(target_module, nn.BatchNorm2d) and isinstance( 111 | source_module, nn.BatchNorm2d 112 | ): 113 | target_module.weight.data.copy_(source_module.weight.data) 114 | target_module.bias.data.copy_(source_module.bias.data) 115 | target_module.running_mean.data.copy_(source_module.running_mean.data) 116 | target_module.running_var.data.copy_(source_module.running_var.data) 117 | else: 118 | raise ValueError( 119 | f"Unsupported module types for copy_weights source: {source_module} and target: {target_module}" 120 | ) 121 | 122 | 123 | def replace_modules(model, wrapper): 124 | for name, module in model.named_children(): 125 | if isinstance(module, nn.Sequential): 126 | setattr(model, name, replace_modules(module, wrapper)) 127 | else: 128 | wrapped_module = wrapper(module) 129 | wrapped_module = copy_weights(module, wrapped_module) 130 | setattr(model, name, wrapped_module) 131 | return model 132 | 133 | 134 | class BasicBlockWithActivation(OriginalBasicBlock): 135 | def __init__(self, block: OriginalBasicBlock): 136 | inplanes = block.conv1.in_channels 137 | planes = block.conv1.out_channels 138 | 139 | super(BasicBlockWithActivation, self).__init__( 140 | inplanes=inplanes, 141 | planes=planes, 142 | stride=block.stride, 143 | downsample=block.downsample, 144 | ) 145 | 146 | def forward(self, x): 147 | self.activations = x.detach().clone() 148 | return super().forward(x) 149 | 150 | 151 | class BottleneckWithActivation(OriginalBottleneck): 152 | def __init__(self, block: OriginalBottleneck): 153 | inplanes = block.conv1.in_channels 154 | planes = block.conv3.out_channels // block.expansion 155 | groups = block.conv2.groups 156 | dilation = block.conv2.dilation 157 | width = block.conv1.out_channels 158 | base_width = 64 * width // (groups * planes) 159 | 160 | super(BottleneckWithActivation, self).__init__( 161 | inplanes=inplanes, 162 | planes=planes, 163 | stride=block.stride, 164 | downsample=block.downsample, 165 | groups=groups, 166 | base_width=base_width, 167 | dilation=dilation, 168 | ) 169 | 170 | def forward(self, x): 171 | self.activations = x.detach().clone() 172 | return super().forward(x) 173 | 174 | 175 | def replace_resnet_modules(model): 176 | for name, module in list(model.named_children()): 177 | if isinstance(module, nn.Sequential): 178 | for i, block in enumerate(module): 179 | if isinstance(block, OriginalBasicBlock) or isinstance( 180 | block, ABNBasicBlock 181 | ): 182 | module[i] = BasicBlockWithActivation(block) 183 | elif isinstance(block, OriginalBottleneck) or isinstance( 184 | block, ABNBottleneck 185 | ): 186 | module[i] = BottleneckWithActivation(block) 187 | setattr(model, name, module) 188 | elif isinstance(module, nn.AdaptiveAvgPool2d): 189 | setattr(model, name, AdaptiveAvgPool2dWithActivation(module)) 190 | return model 191 | -------------------------------------------------------------------------------- /models/rise.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class RISE(nn.Module): 8 | def __init__( 9 | self, 10 | model, 11 | n_masks=10000, 12 | p1=0.1, 13 | input_size=(224, 224), 14 | initial_mask_size=(7, 7), 15 | n_batch=128, 16 | mask_path=None, 17 | ): 18 | super().__init__() 19 | self.model = model 20 | self.n_masks = n_masks 21 | self.p1 = p1 22 | self.input_size = input_size 23 | self.initial_mask_size = initial_mask_size 24 | self.n_batch = n_batch 25 | 26 | if mask_path is not None: 27 | self.masks = self.load_masks(mask_path) 28 | else: 29 | self.masks = self.generate_masks() 30 | 31 | def generate_masks(self): 32 | # cell size in the upsampled mask 33 | Ch = np.ceil(self.input_size[0] / self.initial_mask_size[0]) 34 | Cw = np.ceil(self.input_size[1] / self.initial_mask_size[1]) 35 | 36 | resize_h = int((self.initial_mask_size[0] + 1) * Ch) 37 | resize_w = int((self.initial_mask_size[1] + 1) * Cw) 38 | 39 | masks = [] 40 | 41 | for _ in range(self.n_masks): 42 | # generate binary mask 43 | binary_mask = torch.randn( 44 | 1, 1, self.initial_mask_size[0], self.initial_mask_size[1] 45 | ) 46 | binary_mask = (binary_mask < self.p1).float() 47 | 48 | # upsampling mask 49 | mask = F.interpolate( 50 | binary_mask, (resize_h, resize_w), mode="bilinear", align_corners=False 51 | ) 52 | 53 | # random cropping 54 | i = np.random.randint(0, Ch) 55 | j = np.random.randint(0, Cw) 56 | mask = mask[:, :, i : i + self.input_size[0], j : j + self.input_size[1]] 57 | 58 | masks.append(mask) 59 | 60 | masks = torch.cat(masks, dim=0) # (N_masks, 1, H, W) 61 | 62 | return masks 63 | 64 | def load_masks(self, filepath): 65 | masks = torch.load(filepath) 66 | return masks 67 | 68 | def save_masks(self, filepath): 69 | torch.save(self.masks, filepath) 70 | 71 | def forward(self, x): 72 | # x: input image. (1, 3, H, W) 73 | device = x.device 74 | 75 | # keep probabilities of each class 76 | probs = [] 77 | # shape (n_masks, 3, H, W) 78 | masked_x = torch.mul(self.masks, x.to("cpu").data) 79 | 80 | for i in range(0, self.n_masks, self.n_batch): 81 | input = masked_x[i : min(i + self.n_batch, self.n_masks)].to(device) 82 | out = self.model(input) 83 | probs.append(torch.softmax(out, dim=1).to("cpu").data) 84 | 85 | probs = torch.cat(probs) # shape => (n_masks, n_classes) 86 | n_classes = probs.shape[1] 87 | 88 | # caluculate saliency map using probability scores as weights 89 | saliency = torch.matmul( 90 | probs.data.transpose(0, 1), self.masks.view(self.n_masks, -1) 91 | ) 92 | saliency = saliency.view((n_classes, self.input_size[0], self.input_size[1])) 93 | saliency = saliency / (self.n_masks * self.p1) 94 | 95 | # normalize 96 | m, _ = torch.min(saliency.view(n_classes, -1), dim=1) 97 | saliency -= m.view(n_classes, 1, 1) 98 | M, _ = torch.max(saliency.view(n_classes, -1), dim=1) 99 | saliency /= M.view(n_classes, 1, 1) 100 | return saliency.data 101 | -------------------------------------------------------------------------------- /oneshot.py: -------------------------------------------------------------------------------- 1 | """ Visualize attention map of a model with a given image. 2 | 3 | Example: 4 | ```sh 5 | poetry run python oneshot.py -c checkpoints/CUB_resnet50_Seed42/config.json --method "scorecam" \ 6 | --image-path ./qual/original/painted_bunting.png --label 15 --save-path ./qual/scorecam/painted_bunting.png 7 | 8 | poetry run python oneshot.py -c checkpoints/CUB_resnet50_Seed42/config.json --method "lrp" \ 9 | --skip-connection-prop-type "flows_skip" --heat-quantization \ 10 | --image-path ./qual/original/painted_bunting.png --label 15 --save-path ./qual/ours/painted_bunting.png 11 | 12 | poetry run python oneshot.py -c configs/ImageNet_resnet50.json --method "lrp" \ 13 | --skip-connection-prop-type "flows_skip" --heat-quantization \ 14 | --image-path ./qual/original/bee.png --label 309 --save-path ./qual/ours/bee.png 15 | ``` 16 | 17 | """ 18 | 19 | import argparse 20 | import os 21 | from typing import Any, Dict, Optional, Union 22 | 23 | import torch 24 | import torch.nn as nn 25 | import torchvision.transforms as transforms 26 | from PIL import Image 27 | from skimage.transform import resize 28 | 29 | from data import ALL_DATASETS, get_parameter_depend_in_data_set 30 | from metrics.base import Metric 31 | from models import ALL_MODELS, create_model 32 | from models.lrp import * 33 | from src.utils import SkipConnectionPropType 34 | from utils.utils import fix_seed, parse_with_config 35 | from utils.visualize import ( 36 | save_image_with_attention_map, 37 | ) 38 | from visualize import ( # TODO: move these functions to src directory 39 | apply_heat_quantization, 40 | calculate_attention, 41 | remove_other_components, 42 | ) 43 | 44 | 45 | def load_image(image_path: str, image_size: int = 224) -> torch.Tensor: 46 | imagenet_mean = (0.485, 0.456, 0.406) 47 | imagenet_std = (0.229, 0.224, 0.225) 48 | transform = transforms.Compose( 49 | [ 50 | transforms.ToTensor(), 51 | transforms.Normalize(imagenet_mean, imagenet_std), 52 | ] 53 | ) 54 | image = Image.open(image_path).convert("RGB") 55 | image = transform(image) 56 | return image 57 | 58 | 59 | def visualize( 60 | image_path: str, 61 | label: int, 62 | model: nn.Module, 63 | method: str, 64 | save_path: str, 65 | params: Dict[str, Any], 66 | device: torch.device, 67 | attention_dir: Optional[str] = None, 68 | use_c1c: bool = False, 69 | heat_quantization: bool = False, 70 | hq_level: int = 8, 71 | skip_connection_prop_type: SkipConnectionPropType = "latest", 72 | normalize: bool = False, 73 | sign: str = "all", 74 | ) -> Union[None, Metric]: 75 | model.eval() 76 | torch.cuda.memory_summary(device=device) 77 | inputs = load_image(image_path, params["input_size"][0]).unsqueeze(0).to(device) 78 | image = inputs[0].cpu().numpy() 79 | label = torch.tensor(label).to(device) 80 | 81 | attention, _ = calculate_attention( 82 | model, inputs, label, method, params, None, skip_connection_prop_type 83 | ) 84 | if use_c1c: 85 | attention = resize(attention, (28, 28)) 86 | attention = remove_other_components(attention, threshold=attention.mean()) 87 | 88 | if heat_quantization: 89 | attention = apply_heat_quantization(attention, hq_level) 90 | 91 | save_path = save_path + ".png" if save_path[-4:] != ".png" else save_path 92 | save_image_with_attention_map( 93 | image, attention, save_path, params["mean"], params["std"], only_img=True, normalize=normalize, sign=sign 94 | ) 95 | 96 | 97 | def main(args: argparse.Namespace) -> None: 98 | fix_seed(args.seed, args.no_deterministic) 99 | params = get_parameter_depend_in_data_set(args.dataset) 100 | 101 | mask_path = os.path.join(args.root_dir, "masks.npy") 102 | if not os.path.isfile(mask_path): 103 | mask_path = None 104 | rise_params = { 105 | "n_masks": args.num_masks, 106 | "p1": args.p1, 107 | "input_size": (args.image_size, args.image_size), 108 | "initial_mask_size": (args.rise_scale, args.rise_scale), 109 | "n_batch": args.batch_size, 110 | "mask_path": mask_path, 111 | } 112 | params.update(rise_params) 113 | 114 | # Create model 115 | model = create_model( 116 | args.model, 117 | num_classes=len(params["classes"]), 118 | num_channel=params["num_channel"], 119 | base_pretrained=args.base_pretrained, 120 | base_pretrained2=args.base_pretrained2, 121 | pretrained_path=args.pretrained, 122 | attention_branch=args.add_attention_branch, 123 | division_layer=args.div, 124 | theta_attention=args.theta_att, 125 | init_classifier=args.dataset != "ImageNet", # Use pretrained classifier in ImageNet 126 | ) 127 | assert model is not None, "Model name is invalid" 128 | 129 | model.to(device) 130 | visualize( 131 | args.image_path, 132 | args.label, 133 | model, 134 | args.method, 135 | args.save_path, 136 | params, 137 | device, 138 | attention_dir=args.attention_dir, 139 | use_c1c=args.use_c1c, 140 | heat_quantization=args.heat_quantization, 141 | hq_level=args.hq_level, 142 | skip_connection_prop_type=args.skip_connection_prop_type, 143 | normalize=args.normalize, 144 | sign=args.sign, 145 | ) 146 | 147 | 148 | def parse_args() -> argparse.Namespace: 149 | parser = argparse.ArgumentParser() 150 | 151 | parser.add_argument("-c", "--config", type=str, help="path to config file (json)") 152 | 153 | parser.add_argument("--seed", type=int, default=42) 154 | parser.add_argument("--no_deterministic", action="store_false") 155 | 156 | # Model 157 | parser.add_argument("-m", "--model", choices=ALL_MODELS, help="model name") 158 | parser.add_argument( 159 | "-add_ab", 160 | "--add_attention_branch", 161 | action="store_true", 162 | help="add Attention Branch", 163 | ) 164 | parser.add_argument( 165 | "--div", 166 | type=str, 167 | choices=["layer1", "layer2", "layer3"], 168 | default="layer2", 169 | help="place to attention branch", 170 | ) 171 | parser.add_argument("--base_pretrained", type=str, help="path to base pretrained") 172 | parser.add_argument( 173 | "--base_pretrained2", 174 | type=str, 175 | help="path to base pretrained2 ( after change_num_classes() )", 176 | ) 177 | parser.add_argument("--pretrained", type=str, help="path to pretrained") 178 | parser.add_argument( 179 | "--orig_model", 180 | action="store_true", 181 | help="calc insdel score by using original model", 182 | ) 183 | parser.add_argument( 184 | "--theta_att", type=float, default=0, help="threthold of attention branch" 185 | ) 186 | 187 | # Target Image 188 | parser.add_argument("--image-path", type=str, help="path to target image") 189 | parser.add_argument("--label", type=int, help="label of target image") 190 | 191 | # Visualize option 192 | parser.add_argument("--normalize", action="store_true", help="normalize attribution") 193 | parser.add_argument("--sign", type=str, default="all", help="sign of attribution to show") 194 | 195 | # Save 196 | parser.add_argument("--save-path", type=str, help="path to save image") 197 | 198 | # Dataset 199 | parser.add_argument("--dataset", type=str, default="IDRiD", choices=ALL_DATASETS) 200 | parser.add_argument("--image_size", type=int, default=224) 201 | parser.add_argument("--batch_size", type=int, default=16) 202 | parser.add_argument( 203 | "--loss_weights", 204 | type=float, 205 | nargs="*", 206 | default=[1.0, 1.0], 207 | help="weights for label by class", 208 | ) 209 | 210 | parser.add_argument("--root_dir", type=str, default="./outputs/") 211 | parser.add_argument("--visualize_only", action="store_false") 212 | parser.add_argument("--all_class", action="store_true") 213 | # recommend (step, size) in 512x512 = (1, 10000), (2, 2500), (4, 500), (8, 100), (16, 20), (32, 10), (64, 5), (128, 1) 214 | # recommend (step, size) in 224x224 = (1, 500), (2, 100), (4, 20), (8, 10), (16, 5), (32, 1) 215 | parser.add_argument("--insdel_step", type=int, default=500) 216 | parser.add_argument("--block_size", type=int, default=1) 217 | 218 | parser.add_argument( 219 | "--method", 220 | type=str, 221 | default="gradcam", 222 | ) 223 | 224 | parser.add_argument("--num_masks", type=int, default=5000) 225 | parser.add_argument("--rise_scale", type=int, default=9) 226 | parser.add_argument( 227 | "--p1", type=float, default=0.3, help="percentage of mask [pixel = (0, 0, 0)]" 228 | ) 229 | 230 | parser.add_argument("--attention_dir", type=str, help="path to attention npy file") 231 | 232 | parser.add_argument("--use-c1c", action="store_true", help="use C1C technique") 233 | 234 | parser.add_argument("--heat-quantization", action="store_true", help="use heat quantization technique") 235 | parser.add_argument("--hq-level", type=int, default=8, help="number of quantization level") 236 | 237 | parser.add_argument("--skip-connection-prop-type", type=str, default="latest", help="type of skip connection propagation") 238 | 239 | return parse_with_config(parser) 240 | 241 | 242 | if __name__ == "__main__": 243 | # import pdb; pdb.set_trace() 244 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 245 | 246 | main(parse_args()) 247 | -------------------------------------------------------------------------------- /optim/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | 3 | import torch.optim as optim 4 | 5 | 6 | from optim.sam import SAM 7 | 8 | ALL_OPTIM = ["SGD", "Adam", "AdamW", "SAM"] 9 | 10 | 11 | def create_optimizer( 12 | optim_name: str, 13 | params: Iterable, 14 | lr: float, 15 | weight_decay: float = 0.9, 16 | momentum: float = 0.9, 17 | ) -> optim.Optimizer: 18 | """ 19 | Create an optimizer 20 | 21 | Args: 22 | optim_name(str) : Name of the optimizer 23 | params(Iterable) : params 24 | lr(float) : Learning rate 25 | weight_decay(float): weight_decay 26 | momentum(float) : momentum 27 | 28 | Returns: 29 | Optimizer 30 | """ 31 | assert optim_name in ALL_OPTIM 32 | 33 | if optim_name == "SGD": 34 | optimizer = optim.SGD( 35 | params, lr=lr, momentum=momentum, weight_decay=weight_decay 36 | ) 37 | elif optim_name == "Adam": 38 | optimizer = optim.Adam(params, lr=lr, weight_decay=weight_decay) 39 | elif optim_name == "AdamW": 40 | optimizer = optim.AdamW(params, lr=lr, weight_decay=weight_decay) 41 | elif optim_name == "SAM": 42 | base_optimizer = optim.SGD 43 | optimizer = SAM( 44 | params, base_optimizer, lr=lr, weight_decay=weight_decay, momentum=momentum 45 | ) 46 | else: 47 | raise ValueError 48 | 49 | return optimizer 50 | -------------------------------------------------------------------------------- /optim/sam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class SAM(torch.optim.Optimizer): 5 | def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs): 6 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" 7 | 8 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs) 9 | super(SAM, self).__init__(params, defaults) 10 | 11 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs) 12 | self.param_groups = self.base_optimizer.param_groups 13 | 14 | @torch.no_grad() 15 | def first_step(self, zero_grad=False): 16 | grad_norm = self._grad_norm() 17 | for group in self.param_groups: 18 | scale = group["rho"] / (grad_norm + 1e-12) 19 | 20 | for p in group["params"]: 21 | if p.grad is None: 22 | continue 23 | self.state[p]["old_p"] = p.data.clone() 24 | e_w = ( 25 | (torch.pow(p, 2) if group["adaptive"] else 1.0) 26 | * p.grad 27 | * scale.to(p) 28 | ) 29 | p.add_(e_w) # climb to the local maximum "w + e(w)" 30 | 31 | if zero_grad: 32 | self.zero_grad() 33 | 34 | @torch.no_grad() 35 | def second_step(self, zero_grad=False): 36 | for group in self.param_groups: 37 | for p in group["params"]: 38 | if p.grad is None: 39 | continue 40 | p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)" 41 | 42 | self.base_optimizer.step() # do the actual "sharpness-aware" update 43 | 44 | if zero_grad: 45 | self.zero_grad() 46 | 47 | @torch.no_grad() 48 | def step(self, closure=None): 49 | assert ( 50 | closure is not None 51 | ), "Sharpness Aware Minimization requires closure, but it was not provided" 52 | closure = torch.enable_grad()( 53 | closure 54 | ) # the closure should do a full forward-backward pass 55 | 56 | self.first_step(zero_grad=True) 57 | closure() 58 | self.second_step() 59 | 60 | def _grad_norm(self): 61 | shared_device = self.param_groups[0]["params"][ 62 | 0 63 | ].device # put everything on the same device, in case of model parallelism 64 | norm = torch.norm( 65 | torch.stack( 66 | [ 67 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad) 68 | .norm(p=2) 69 | .to(shared_device) 70 | for group in self.param_groups 71 | for p in group["params"] 72 | if p.grad is not None 73 | ] 74 | ), 75 | p=2, 76 | ) 77 | return norm 78 | 79 | def load_state_dict(self, state_dict): 80 | super().load_state_dict(state_dict) 81 | self.base_optimizer.param_groups = self.param_groups 82 | -------------------------------------------------------------------------------- /outputs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/outputs/.gitkeep -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | # line-length = 119 3 | target-version = ['py39'] 4 | 5 | [tool.mypy] 6 | python_version = "3.9" 7 | 8 | [tool.ruff] 9 | # Never enforce `E501` (line length violations). 10 | ignore = ["C901", "E501", "E741", "F405", "F403", "W605"] 11 | # select = ["C", "E", "F", "I", "W"] 12 | 13 | # # Ignore import violations in all `__init__.py` files. 14 | # [tool.ruff.per-file-ignores] 15 | # "__init__.py" = ["E402", "F401", "F403", "F811"] 16 | 17 | # [tool.ruff.isort] 18 | # lines-after-imports = 2 19 | 20 | [tool.poetry] 21 | name = "lrp-for-resnet" 22 | version = "0.1.0" 23 | description = "" 24 | authors = ["Foo "] 25 | readme = "README.md" 26 | 27 | [tool.poetry.dependencies] 28 | python = ">=3.9,<3.12" 29 | torch = "2.0.0" 30 | torchvision = "0.15.1" 31 | torchaudio = "2.0.1" 32 | tqdm = "^4.65.0" 33 | matplotlib = "^3.7.1" 34 | numpy = "^1.24.3" 35 | opencv-python = "^4.7.0.72" 36 | pillow = "^9.5.0" 37 | scikit-image = "^0.20.0" 38 | scikit-learn = "^1.2.2" 39 | torchinfo = "^1.8.0" 40 | wandb = "^0.15.3" 41 | timm = "^0.9.2" 42 | grad-cam = "^1.4.6" 43 | captum = "^0.6.0" 44 | plotly = "^5.19.0" 45 | kaleido = "0.2.1" 46 | 47 | [tool.poetry.group.dev.dependencies] 48 | ruff = "^0.0.269" 49 | black = "^23.3.0" 50 | mypy = "^1.3.0" 51 | 52 | [build-system] 53 | requires = ["poetry-core"] 54 | build-backend = "poetry.core.masonry.api" 55 | -------------------------------------------------------------------------------- /qual/original/Arabian_camel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/Arabian_camel.png -------------------------------------------------------------------------------- /qual/original/Brandt_Cormorant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/Brandt_Cormorant.png -------------------------------------------------------------------------------- /qual/original/Geococcyx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/Geococcyx.png -------------------------------------------------------------------------------- /qual/original/Rock_Wren.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/Rock_Wren.png -------------------------------------------------------------------------------- /qual/original/Savannah_Sparrow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/Savannah_Sparrow.png -------------------------------------------------------------------------------- /qual/original/bee.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/bee.png -------------------------------------------------------------------------------- /qual/original/bubble.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/bubble.png -------------------------------------------------------------------------------- /qual/original/bustard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/bustard.png -------------------------------------------------------------------------------- /qual/original/drumstick.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/drumstick.png -------------------------------------------------------------------------------- /qual/original/oboe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/oboe.png -------------------------------------------------------------------------------- /qual/original/ram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/ram.png -------------------------------------------------------------------------------- /qual/original/sock.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/sock.png -------------------------------------------------------------------------------- /qual/original/solar_dish.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/solar_dish.png -------------------------------------------------------------------------------- /qual/original/water_ouzel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/water_ouzel.png -------------------------------------------------------------------------------- /qual/original/wombat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/wombat.png -------------------------------------------------------------------------------- /scripts/calc_dataset_info.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Tuple 3 | 4 | import torch 5 | import torch.utils.data as data 6 | from torchvision import transforms 7 | from tqdm import tqdm 8 | 9 | from data import ALL_DATASETS, create_dataset, get_generator, seed_worker 10 | from utils.utils import fix_seed 11 | 12 | 13 | def calc_mean_std( 14 | dataloader: data.DataLoader, image_size 15 | ) -> Tuple[torch.Tensor, torch.Tensor]: 16 | """ 17 | Calculate the mean and variance of the image dataset 18 | 19 | Args: 20 | dataloader(DataLoader): DataLoader 21 | image_size(int) : Image size 22 | 23 | Returns: 24 | Mean and variance of each channel 25 | Tuple[torch.Tensor, torch.Tensor] 26 | """ 27 | 28 | sum = torch.tensor([0.0, 0.0, 0.0]) 29 | sum_square = torch.tensor([0.0, 0.0, 0.0]) 30 | total = 0 31 | 32 | for inputs, _ in tqdm(dataloader, dynamic_ncols=True): 33 | inputs.to(device) 34 | sum += inputs.sum(axis=[0, 2, 3]) 35 | sum_square += (inputs**2).sum(axis=[0, 2, 3]) 36 | total += inputs.size(0) 37 | 38 | count = total * image_size * image_size 39 | 40 | total_mean = sum / count 41 | total_var = (sum_square / count) - (total_mean**2) 42 | total_std = torch.sqrt(total_var) 43 | 44 | return total_mean, total_std 45 | 46 | 47 | def main(): 48 | fix_seed(args.seed, True) 49 | 50 | transform = transforms.Compose( 51 | [ 52 | transforms.Resize((args.image_size, args.image_size)), 53 | transforms.ToTensor(), 54 | ] 55 | ) 56 | 57 | dataset = create_dataset(args.dataset, "train", args.image_size, transform) 58 | dataloader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, worker_init_fn=seed_worker, generator=get_generator()) 59 | 60 | mean, std = calc_mean_std(dataloader, args.image_size) 61 | print(f"mean: {mean}") 62 | print(f"std: {std}") 63 | 64 | 65 | if __name__ == "__main__": 66 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 67 | 68 | parser = argparse.ArgumentParser() 69 | 70 | parser.add_argument("--seed", default=42, type=int) 71 | 72 | parser.add_argument("--dataset", type=str, choices=ALL_DATASETS) 73 | parser.add_argument("--image_size", type=int, default=224) 74 | parser.add_argument("--batch_size", default=32, type=int) 75 | 76 | args = parser.parse_args() 77 | 78 | main() 79 | -------------------------------------------------------------------------------- /scripts/visualize_transforms.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from typing import Tuple 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | from data import ALL_DATASETS, create_dataloader_dict, get_parameter_depend_in_data_set 9 | from utils.utils import fix_seed, reverse_normalize 10 | 11 | 12 | def save_normalized_image( 13 | image: np.ndarray, 14 | fname: str, 15 | mean: Tuple[float, float, float], 16 | std: Tuple[float, float, float], 17 | ) -> None: 18 | image = reverse_normalize(image.copy(), mean, std) 19 | image = np.transpose(image, (1, 2, 0)) 20 | 21 | fig, ax = plt.subplots() 22 | ax.imshow(image) 23 | 24 | plt.savefig(fname) 25 | plt.clf() 26 | plt.close() 27 | 28 | 29 | def main(): 30 | fix_seed(args.seed, True) 31 | 32 | dataloader = create_dataloader_dict(args.dataset, 1, args.image_size) 33 | params = get_parameter_depend_in_data_set(args.dataset) 34 | 35 | save_dir = os.path.join(args.root_dir, "transforms") 36 | if not os.path.isdir(save_dir): 37 | os.makedirs(save_dir) 38 | 39 | for phase, inputs in dataloader.items(): 40 | if not phase == "Train": 41 | continue 42 | 43 | for i, data in enumerate(inputs): 44 | image = data[0].cpu().numpy()[0] 45 | save_fname = os.path.join(save_dir, f"{i}.png") 46 | save_normalized_image(image, save_fname, params["mean"], params["std"]) 47 | 48 | 49 | if __name__ == "__main__": 50 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 51 | 52 | parser = argparse.ArgumentParser() 53 | 54 | parser.add_argument("--seed", default=42, type=int) 55 | 56 | parser.add_argument("--dataset", type=str, default="IDRiD", choices=ALL_DATASETS) 57 | parser.add_argument("--image_size", type=int, default=224) 58 | 59 | parser.add_argument("--root_dir", type=str, default="./outputs/") 60 | 61 | args = parser.parse_args() 62 | 63 | main() 64 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/src/__init__.py -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | """Script holds data loader methods. 2 | """ 3 | import argparse 4 | 5 | import torch 6 | import torchvision 7 | from torch.utils import data 8 | 9 | from data import get_generator, seed_worker 10 | 11 | 12 | def get_data_loader(config: argparse.Namespace) -> torch.utils.data.DataLoader: 13 | """Creates dataloader for networks from PyTorch's Model Zoo. 14 | 15 | Data loader uses mean and standard deviation for ImageNet. 16 | 17 | Args: 18 | config: Argparse namespace object. 19 | 20 | Returns: 21 | Data loader object. 22 | 23 | """ 24 | input_dir = config.input_dir 25 | batch_size = config.batch_size 26 | 27 | mean = [0.485, 0.456, 0.406] 28 | std = [0.229, 0.224, 0.225] 29 | 30 | transforms = [] 31 | 32 | if config.resize: 33 | transforms += [ 34 | torchvision.transforms.Resize(size=int(1.1 * config.resize)), 35 | torchvision.transforms.CenterCrop(size=config.resize), 36 | ] 37 | 38 | transforms += [ 39 | torchvision.transforms.ToTensor(), 40 | torchvision.transforms.Normalize(mean, std), 41 | ] 42 | 43 | transform = torchvision.transforms.Compose(transforms=transforms) 44 | dataset = torchvision.datasets.ImageFolder(root=input_dir, transform=transform) 45 | 46 | data_loader = data.DataLoader(dataset, batch_size=batch_size, worker_init_fn=seed_worker, generator=get_generator()) 47 | 48 | return data_loader 49 | -------------------------------------------------------------------------------- /src/data_processing.py: -------------------------------------------------------------------------------- 1 | """Script with method for pre- and post-processing.""" 2 | import argparse 3 | 4 | import cv2 5 | import numpy 6 | import torch 7 | import torchvision.transforms 8 | 9 | 10 | class DataProcessing: 11 | def __init__(self, config: argparse.Namespace, device: torch.device) -> None: 12 | """Initializes data processing class.""" 13 | 14 | mean = [0.485, 0.456, 0.406] 15 | std = [0.229, 0.224, 0.225] 16 | 17 | transforms = [ 18 | torchvision.transforms.ToPILImage(), 19 | ] 20 | 21 | if config.resize: 22 | transforms += [ 23 | torchvision.transforms.Resize(size=int(1.1 * config.resize)), 24 | torchvision.transforms.CenterCrop(size=config.resize), 25 | ] 26 | 27 | transforms += [ 28 | torchvision.transforms.ToTensor(), 29 | torchvision.transforms.Normalize(mean, std), 30 | ] 31 | 32 | self.transform = torchvision.transforms.Compose(transforms=transforms) 33 | self.device = device 34 | 35 | def preprocess(self, frame: numpy.ndarray) -> torch.Tensor: 36 | """Preprocesses frame captured by webcam.""" 37 | return self.transform(frame).to(self.device)[None, ...] 38 | 39 | def postprocess(self, relevance_scores: torch.Tensor): 40 | """Normalizes relevance scores and applies colormap.""" 41 | relevance_scores = relevance_scores.numpy() 42 | relevance_scores = cv2.normalize( 43 | relevance_scores, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8UC1 44 | ) 45 | relevance_scores = cv2.applyColorMap(relevance_scores, cv2.COLORMAP_HOT) 46 | return relevance_scores 47 | -------------------------------------------------------------------------------- /src/lrp.py: -------------------------------------------------------------------------------- 1 | """Class for layer-wise relevance propagation. 2 | 3 | Layer-wise relevance propagation for VGG-like networks from PyTorch's Model Zoo. 4 | Implementation can be adapted to work with other architectures as well by adding the corresponding operations. 5 | 6 | Typical usage example: 7 | 8 | model = torchvision.models.vgg16(pretrained=True) 9 | lrp_model = LRPModel(model) 10 | r = lrp_model.forward(x) 11 | 12 | """ 13 | from copy import deepcopy 14 | from typing import Union 15 | 16 | import torch 17 | from torch import nn 18 | from torchvision.models.resnet import BasicBlock, Bottleneck 19 | from torchvision.models.vgg import VGG 20 | 21 | from models import OneWayResNet 22 | from src.utils import SkipConnectionPropType, layers_lookup 23 | 24 | 25 | class LRPModel(nn.Module): 26 | """Class wraps PyTorch model to perform layer-wise relevance propagation.""" 27 | 28 | def __init__(self, model: torch.nn.Module, rel_pass_ratio: float = 0.0, skip_connection_prop="latest") -> None: 29 | super().__init__() 30 | self.model: Union[VGG, OneWayResNet] = model 31 | self.rel_pass_ratio = rel_pass_ratio 32 | self.skip_connection_prop = skip_connection_prop 33 | 34 | self.model.eval() # self.model.train() activates dropout / batch normalization etc.! 35 | 36 | # Parse network (architecture must be based on VGG...) 37 | self.layers = self._get_layer_operations() 38 | 39 | # Create LRP network 40 | self.lrp_layers = self._create_lrp_model() 41 | 42 | def _create_lrp_model(self) -> torch.nn.ModuleList: 43 | """Method builds the model for layer-wise relevance propagation. 44 | 45 | Returns: 46 | LRP-model as module list. 47 | 48 | """ 49 | # Clone layers from original model. This is necessary as we might modify the weights. 50 | layers = deepcopy(self.layers) 51 | lookup_table = layers_lookup(self.skip_connection_prop) 52 | 53 | # Run backwards through layers 54 | for i, layer in enumerate(layers[::-1]): 55 | try: 56 | layers[i] = lookup_table[layer.__class__](layer=layer, top_k=self.rel_pass_ratio) 57 | except KeyError: 58 | message = ( 59 | f"Layer-wise relevance propagation not implemented for " 60 | f"{layer.__class__.__name__} layer." 61 | ) 62 | raise NotImplementedError(message) 63 | 64 | return layers 65 | 66 | def _get_layer_operations(self) -> torch.nn.ModuleList: 67 | """Get all network operations and store them in a list. 68 | 69 | This method is adapted to VGG networks from PyTorch's Model Zoo. 70 | Modify this method to work also for other networks. 71 | 72 | Returns: 73 | Layers of original model stored in module list. 74 | 75 | """ 76 | layers = torch.nn.ModuleList() 77 | 78 | # Parse VGG, OneWayResNet 79 | for layer in self.model.features: 80 | is_resnet_tower = isinstance(layer, nn.Sequential) and (isinstance(layer[0], BasicBlock) or isinstance(layer[0], Bottleneck)) 81 | if is_resnet_tower: 82 | for sub_layer in layer: 83 | assert isinstance(sub_layer, BasicBlock) or isinstance(sub_layer, Bottleneck) 84 | layers.append(sub_layer) 85 | else: 86 | layers.append(layer) 87 | 88 | layers.append(self.model.avgpool) 89 | layers.append(torch.nn.Flatten(start_dim=1)) 90 | 91 | for layer in self.model.classifier: 92 | layers.append(layer) 93 | 94 | return layers 95 | 96 | def forward(self, x: torch.tensor, topk=-1) -> torch.tensor: 97 | """Forward method that first performs standard inference followed by layer-wise relevance propagation. 98 | 99 | Args: 100 | x: Input tensor representing an image / images (N, C, H, W). 101 | 102 | Returns: 103 | Tensor holding relevance scores with dimensions (N, 1, H, W). 104 | 105 | """ 106 | activations = list() 107 | 108 | # Run inference and collect activations. 109 | with torch.no_grad(): 110 | # Replace image with ones avoids using image information for relevance computation. 111 | activations.append(torch.ones_like(x)) 112 | for layer in self.layers: 113 | x = layer.forward(x) 114 | activations.append(x) 115 | 116 | # Reverse order of activations to run backwards through model 117 | activations = activations[::-1] 118 | activations = [a.data.requires_grad_(True) for a in activations] 119 | 120 | # Initial relevance scores are the network's output activations 121 | relevance = torch.softmax(activations.pop(0), dim=-1) # Unsupervised 122 | if topk != -1: 123 | relevance_zero = torch.zeros_like(relevance) 124 | top_k_indices = torch.topk(relevance, topk).indices 125 | for index in top_k_indices: 126 | relevance_zero[..., index] = relevance[..., index] 127 | relevance = relevance_zero 128 | 129 | # Perform relevance propagation 130 | for i, layer in enumerate(self.lrp_layers): 131 | a = activations.pop(0) 132 | try: 133 | relevance = layer.forward(a, relevance) 134 | except RuntimeError: 135 | print(f"RuntimeError at layer {i}.\n" 136 | f"Layer: {layer.__class__.__name__}\n" 137 | f"Relevance shape: {relevance.shape}\n" 138 | f"Activation shape: {activations[0].shape}\n") 139 | exit(1) 140 | 141 | # # disturb 142 | # disturbance = (torch.rand(relevance.shape).to(relevance.device) - 0.5) 143 | # relevance = relevance + disturbance * 5e-4 144 | 145 | ### patch for debug and/or analysis ### 146 | # Escape to see relevance scores at critical layers 147 | # Uncomment the following lines to see relevance scores at critical layers 148 | # 149 | # if i == 3: # stop at 3, 7, 11, 15, 18 (1st, 5th, 9th, 13th, 16th last bottleneck block) 150 | # break 151 | # 152 | ### value of i is determined manually by checking each layer with the following code 153 | # ```python 154 | # print(len(self.lrp_layers)) # 23 for ResNet50 155 | # for i, layer in enumerate(self.lrp_layers): 156 | # print(f"{i}:\n {layer.__class__.__name__}") 157 | # ``` 158 | # RelevancePropagationBottleneckFlowsPureSkip x16 (i = 3-18) 159 | 160 | return relevance.permute(0, 2, 3, 1).sum(dim=-1).squeeze().detach().cpu() 161 | 162 | 163 | # legacy code 164 | class LRPModules(nn.Module): 165 | """Class wraps PyTorch model to perform layer-wise relevance propagation.""" 166 | 167 | def __init__( 168 | self, layers: nn.ModuleList, out_relevance: torch.Tensor, top_k: float = 0.0 169 | ) -> None: 170 | super().__init__() 171 | self.top_k = top_k 172 | 173 | # Parse network 174 | self.layers = layers 175 | self.out_relevance = out_relevance 176 | 177 | # Create LRP network 178 | self.lrp_layers = self._create_lrp_model() 179 | 180 | def _create_lrp_model(self) -> nn.ModuleList: 181 | """Method builds the model for layer-wise relevance propagation. 182 | 183 | Returns: 184 | LRP-model as module list. 185 | 186 | """ 187 | # Clone layers from original model. This is necessary as we might modify the weights. 188 | layers = deepcopy(self.layers) 189 | lookup_table = layers_lookup() 190 | 191 | # Run backwards through layers 192 | for i, layer in enumerate(layers[::-1]): 193 | try: 194 | layers[i] = lookup_table[layer.__class__](layer=layer, top_k=self.top_k) 195 | except KeyError: 196 | message = ( 197 | f"Layer-wise relevance propagation not implemented for " 198 | f"{layer.__class__.__name__} layer." 199 | ) 200 | raise NotImplementedError(message) 201 | 202 | return layers 203 | 204 | def forward(self, x: torch.tensor) -> torch.tensor: 205 | """Forward method that first performs standard inference followed by layer-wise relevance propagation. 206 | 207 | Args: 208 | x: Input tensor representing an image / images (N, C, H, W). 209 | 210 | Returns: 211 | Tensor holding relevance scores with dimensions (N, 1, H, W). 212 | 213 | """ 214 | activations = list() 215 | 216 | # Run inference and collect activations. 217 | with torch.no_grad(): 218 | # Replace image with ones avoids using image information for relevance computation. 219 | activations.append(torch.ones_like(x)) 220 | for layer in self.layers: 221 | x = layer.forward(x) 222 | activations.append(x) 223 | 224 | # Reverse order of activations to run backwards through model 225 | activations = activations[::-1] 226 | activations = [a.data.requires_grad_(True) for a in activations] 227 | 228 | # Initial relevance scores are the network's output activations 229 | relevance = torch.softmax(activations.pop(0), dim=-1) # Unsupervised 230 | if self.out_relevance is not None: 231 | relevance = self.out_relevance.to(relevance.device) 232 | 233 | # Perform relevance propagation 234 | for i, layer in enumerate(self.lrp_layers): 235 | relevance = layer.forward(activations.pop(0), relevance) 236 | 237 | return relevance.permute(0, 2, 3, 1).sum(dim=-1).squeeze().detach().cpu() 238 | 239 | 240 | def basic_lrp( 241 | model, image, rel_pass_ratio=1.0, topk=1, skip_connection_prop: SkipConnectionPropType = "latest" 242 | ): 243 | lrp_model = LRPModel(model, rel_pass_ratio=rel_pass_ratio, skip_connection_prop=skip_connection_prop) 244 | R = lrp_model.forward(image, topk) 245 | return R 246 | 247 | 248 | # Legacy code ----------------------- 249 | def resnet_lrp(model, image, topk=0.2): 250 | output = model(image) 251 | score, class_index = torch.max(output, 1) 252 | R = torch.zeros_like(output) 253 | R[0, class_index] = score 254 | 255 | post_modules = divide_module_by_name(model, "avgpool") 256 | new_post = post_modules[:-1] 257 | new_post.append(torch.nn.Flatten(start_dim=1)) 258 | new_post.append(post_modules[-1]) 259 | post_modules = new_post 260 | 261 | post_lrp = LRPModules(post_modules, R, top_k=topk) 262 | R = post_lrp.forward(post_modules[0].activations) 263 | 264 | R = resnet_layer_lrp(model.layer4, R, top_k=topk) 265 | R = resnet_layer_lrp(model.layer3, R, top_k=topk) 266 | R = resnet_layer_lrp(model.layer2, R, top_k=topk) 267 | R = resnet_layer_lrp(model.layer1, R, top_k=topk) 268 | 269 | pre_modules = divide_module_by_name(model, "layer1", before_module=True) 270 | pre_lrp = LRPModules(pre_modules, R, top_k=topk) 271 | R = pre_lrp.forward(image) 272 | 273 | return R 274 | 275 | 276 | def abn_lrp(model, image, topk=0.2): 277 | output = model(image) 278 | score, class_index = torch.max(output, 1) 279 | R = torch.zeros_like(output) 280 | R[0, class_index] = score 281 | 282 | ######################### 283 | ### Perception Branch ### 284 | ######################### 285 | post_modules = nn.ModuleList( 286 | [ 287 | model.perception_branch[2], 288 | model.perception_branch[3], 289 | model.perception_branch[4], 290 | ] 291 | ) 292 | new_post = post_modules[:-1] 293 | new_post.append(torch.nn.Flatten(start_dim=1)) 294 | new_post.append(post_modules[-1]) 295 | post_modules = new_post 296 | 297 | post_lrp = LRPModules(post_modules, R, top_k=topk) 298 | R_pb = post_lrp.forward(post_modules[0].activations) 299 | 300 | for sequential_blocks in model.perception_branch[:2][::-1]: 301 | R_pb = resnet_layer_lrp(sequential_blocks, R_pb, topk) 302 | 303 | ######################### 304 | ### Attention Branch ### 305 | ######################### 306 | # h -> layer1, bn1, conv1, relu, conv4, avgpool, flatten 307 | ab_modules = nn.ModuleList( 308 | [ 309 | model.attention_branch.bn1, 310 | model.attention_branch.conv1, 311 | model.attention_branch.relu, 312 | model.attention_branch.conv4, 313 | model.attention_branch.avgpool, 314 | model.attention_branch.flatten, 315 | ] 316 | ) 317 | ab_lrp = LRPModules(ab_modules, R, top_k=topk) 318 | R_ab = ab_lrp.forward(model.attention_branch.bn1_activation) 319 | R_ab = resnet_layer_lrp(model.attention_branch.layer1, R_ab, topk) 320 | 321 | ######################### 322 | ### Feature Extractor ### 323 | ######################### 324 | R_fe_out = R_pb + R_ab 325 | R = resnet_layer_lrp(model.feature_extractor[-1], R_fe_out, topk) 326 | R = resnet_layer_lrp(model.feature_extractor[-2], R, topk) 327 | 328 | pre_modules = nn.ModuleList( 329 | [ 330 | model.feature_extractor[0], 331 | model.feature_extractor[1], 332 | model.feature_extractor[2], 333 | model.feature_extractor[3], 334 | ] 335 | ) 336 | pre_lrp = LRPModules(pre_modules, R, top_k=topk) 337 | R = pre_lrp.forward(image) 338 | 339 | return R 340 | 341 | 342 | def resnet_layer_lrp( 343 | layer: nn.Sequential, out_relevance: torch.Tensor, top_k: float = 0.0 344 | ): 345 | for res_block in layer[::-1]: 346 | inputs = res_block.activations 347 | 348 | identify = out_relevance 349 | if res_block.downsample is not None: 350 | downsample = nn.ModuleList( 351 | [res_block.downsample[0], res_block.downsample[1]] 352 | ) 353 | skip_lrp = LRPModules(downsample, identify, top_k=top_k) 354 | skip_relevance = skip_lrp.forward(inputs) 355 | else: 356 | skip_relevance = identify 357 | 358 | main_modules = nn.ModuleList() 359 | for name, module in res_block._modules.items(): 360 | if name == "downsample": 361 | continue 362 | main_modules.append(module) 363 | main_lrp = LRPModules(main_modules, identify, top_k=top_k) 364 | main_relevance = main_lrp.forward(inputs) 365 | 366 | gamma = 0.5 367 | out_relevance = gamma * main_relevance + (1 - gamma) * skip_relevance 368 | return out_relevance 369 | 370 | 371 | def divide_module_by_name(model, module_name: str, before_module: bool = False): 372 | use_module = before_module 373 | modules = nn.ModuleList() 374 | for name, module in model._modules.items(): 375 | if name == module_name: 376 | use_module = not use_module 377 | if not use_module: 378 | continue 379 | modules.append(module) 380 | 381 | return modules 382 | -------------------------------------------------------------------------------- /src/lrp_filter.py: -------------------------------------------------------------------------------- 1 | """Implements filter method for relevance scores. 2 | """ 3 | import torch 4 | 5 | 6 | def relevance_filter(r: torch.tensor, top_k_percent: float = 1.0) -> torch.tensor: 7 | """Filter that allows largest k percent values to pass for each batch dimension. 8 | 9 | Filter keeps k% of the largest tensor elements. Other tensor elements are set to 10 | zero. Here, k = 1 means that all relevance scores are passed on to the next layer. 11 | 12 | Args: 13 | r: Tensor holding relevance scores of current layer. 14 | top_k_percent: Proportion of top k values that is passed on. 15 | 16 | Returns: 17 | Tensor of same shape as input tensor. 18 | 19 | """ 20 | assert 0.0 < top_k_percent <= 1.0 21 | 22 | if top_k_percent < 1.0: 23 | size = r.size() 24 | r = r.flatten(start_dim=1) 25 | num_elements = r.size(-1) 26 | k = max(1, int(top_k_percent * num_elements)) 27 | top_k = torch.topk(input=r, k=k, dim=-1) 28 | r = torch.zeros_like(r) 29 | r.scatter_(dim=1, index=top_k.indices, src=top_k.values) 30 | return r.view(size) 31 | else: 32 | return r 33 | -------------------------------------------------------------------------------- /src/lrp_layers.py: -------------------------------------------------------------------------------- 1 | """Layers for layer-wise relevance propagation. 2 | 3 | Layers for layer-wise relevance propagation can be modified. 4 | 5 | """ 6 | import torch 7 | from torch import nn 8 | from torchvision.models.resnet import BasicBlock, Bottleneck 9 | 10 | from src.lrp_filter import relevance_filter 11 | 12 | 13 | class RelevancePropagationBasicBlock(nn.Module): 14 | def __init__( 15 | self, 16 | layer: BasicBlock, 17 | eps: float = 1.0e-05, 18 | top_k: float = 0.0, 19 | ) -> None: 20 | super().__init__() 21 | self.layers = [ 22 | layer.conv1, 23 | layer.bn1, 24 | layer.relu, 25 | layer.conv2, 26 | layer.bn2, 27 | ] 28 | self.downsample = layer.downsample 29 | self.relu = layer.relu 30 | self.eps = eps 31 | self.top_k = top_k 32 | 33 | def _calc_mainstream_flow_ratio(self, input_: torch.Tensor) -> torch.Tensor: 34 | if self.downsample is None: 35 | return torch.ones_like(input_) 36 | mainstream = input_ 37 | shortcut = input_ 38 | for layer in self.layers: 39 | mainstream = layer(mainstream) 40 | if self.downsample is not None: 41 | shortcut = self.downsample(shortcut) 42 | assert mainstream.shape == shortcut.shape 43 | return mainstream.abs() / (shortcut.abs() + mainstream.abs()) 44 | 45 | def mainstream_backward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor: 46 | with torch.no_grad(): 47 | activations = [a] 48 | for layer in self.layers: 49 | activations.append(layer.forward(activations[-1])) 50 | 51 | activations.pop() # ignore output of this basic block 52 | activations = [a.data.requires_grad_(True) for a in activations] 53 | 54 | # NOW, IGNORES DOWN-SAMPLING & SKIP CONNECTION 55 | r_out = r 56 | for layer in self.layers[::-1]: 57 | a = activations.pop() 58 | if self.top_k: 59 | r_out = relevance_filter(r_out, top_k_percent=self.top_k) 60 | 61 | if isinstance(layer, nn.Conv2d): 62 | r_in = RelevancePropagationConv2d(layer, eps=self.eps, top_k=self.top_k)( 63 | a, r_out 64 | ) 65 | elif isinstance(layer, nn.BatchNorm2d): 66 | r_in = RelevancePropagationBatchNorm2d(layer, top_k=self.top_k)(a, r_out) 67 | elif isinstance(layer, nn.ReLU): 68 | r_in = RelevancePropagationReLU(layer, top_k=self.top_k)(a, r_out) 69 | else: 70 | raise NotImplementedError 71 | r_out = r_in 72 | return r_in 73 | 74 | def shortcut_backward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor: 75 | if self.downsample is None: 76 | return r 77 | a = a.data.requires_grad_(True) 78 | assert isinstance(self.downsample[0], nn.Conv2d) 79 | return RelevancePropagationConv2d(self.downsample[0], eps=self.eps, top_k=self.top_k)(a, r) 80 | 81 | def forward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor: 82 | ratio = self._calc_mainstream_flow_ratio(a) 83 | assert r.shape == ratio.shape 84 | r_mainstream = ratio * r 85 | r_shortcut = (1 - ratio) * r 86 | r_mainstream = self.mainstream_backward(a, r_mainstream) 87 | r_shortcut = self.shortcut_backward(a, r_shortcut) 88 | return r_mainstream + r_shortcut 89 | 90 | 91 | class RelevancePropagationBasicBlockSimple(RelevancePropagationBasicBlock): 92 | """ Relevance propagation for BasicBlock Proto A 93 | Divide relevance score for plain shortcuts 94 | """ 95 | def __init__(self, layer: BasicBlock, eps: float = 1.0e-05, top_k: float = 0.0) -> None: 96 | super().__init__(layer=layer, eps=eps, top_k=top_k) 97 | 98 | def _calc_mainstream_flow_ratio(self, input_: torch.Tensor) -> torch.Tensor: 99 | if self.downsample is None: 100 | return torch.ones_like(input_) 101 | return torch.full_like(self.downsample(input_), 0.5) 102 | 103 | 104 | class RelevancePropagationBasicBlockFlowsPureSkip(RelevancePropagationBasicBlock): 105 | """ Relevance propagation for BasicBlock Proto A 106 | Divide relevance score for plain shortcuts 107 | """ 108 | def __init__(self, layer: BasicBlock, eps: float = 1.0e-05, top_k: float = 0.0) -> None: 109 | super().__init__(layer=layer, eps=eps, top_k=top_k) 110 | 111 | def _calc_mainstream_flow_ratio(self, input_: torch.Tensor) -> torch.Tensor: 112 | mainstream = input_ 113 | shortcut = input_ 114 | for layer in self.layers: 115 | mainstream = layer(mainstream) 116 | if self.downsample is not None: 117 | shortcut = self.downsample(shortcut) 118 | assert mainstream.shape == shortcut.shape 119 | return mainstream.abs() / (shortcut.abs() + mainstream.abs()) 120 | 121 | 122 | class RelevancePropagationBasicBlockSimpleFlowsPureSkip(RelevancePropagationBasicBlock): 123 | """ Relevance propagation for BasicBlock Proto A 124 | Divide relevance score for plain shortcuts 125 | """ 126 | def __init__(self, layer: BasicBlock, eps: float = 1.0e-05, top_k: float = 0.0) -> None: 127 | super().__init__(layer=layer, eps=eps, top_k=top_k) 128 | 129 | def _calc_mainstream_flow_ratio(self, input_: torch.Tensor) -> torch.Tensor: 130 | if self.downsample is None: 131 | return torch.full_like(input_, 0.5) 132 | return torch.full_like(self.downsample(input_), 0.5) 133 | 134 | 135 | class RelevancePropagationBottleneck(nn.Module): 136 | def __init__( 137 | self, 138 | layer: Bottleneck, 139 | eps: float = 1.0e-05, 140 | top_k: float = 0.0, 141 | ) -> None: 142 | super().__init__() 143 | self.layers = [ 144 | layer.conv1, 145 | layer.bn1, 146 | layer.relu, 147 | layer.conv2, 148 | layer.bn2, 149 | layer.relu, 150 | layer.conv3, 151 | layer.bn3, 152 | ] 153 | self.downsample = layer.downsample 154 | self.relu = layer.relu 155 | self.eps = eps 156 | self.top_k = top_k 157 | 158 | def _calc_mainstream_flow_ratio(self, input_: torch.Tensor) -> torch.Tensor: 159 | if self.downsample is None: 160 | return torch.ones_like(input_) 161 | mainstream = input_ 162 | shortcut = input_ 163 | for layer in self.layers: 164 | mainstream = layer(mainstream) 165 | if self.downsample is not None: 166 | shortcut = self.downsample(shortcut) 167 | assert mainstream.shape == shortcut.shape 168 | return mainstream.abs() / (shortcut.abs() + mainstream.abs()) 169 | 170 | def mainstream_backward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor: 171 | with torch.no_grad(): 172 | activations = [a] 173 | for layer in self.layers: 174 | activations.append(layer.forward(activations[-1])) 175 | 176 | activations.pop() # ignore output of this bottleneck block 177 | activations = [a.data.requires_grad_(True) for a in activations] 178 | 179 | # NOW, IGNORES DOWN-SAMPLING & SKIP CONNECTION 180 | r_out = r 181 | for layer in self.layers[::-1]: 182 | a = activations.pop() 183 | if self.top_k: 184 | r_out = relevance_filter(r_out, top_k_percent=self.top_k) 185 | 186 | if isinstance(layer, nn.Conv2d): 187 | r_in = RelevancePropagationConv2d(layer, eps=self.eps, top_k=self.top_k)( 188 | a, r_out 189 | ) 190 | elif isinstance(layer, nn.BatchNorm2d): 191 | r_in = RelevancePropagationBatchNorm2d(layer, top_k=self.top_k)(a, r_out) 192 | elif isinstance(layer, nn.ReLU): 193 | r_in = RelevancePropagationReLU(layer, top_k=self.top_k)(a, r_out) 194 | else: 195 | raise NotImplementedError 196 | r_out = r_in 197 | return r_in 198 | 199 | def shortcut_backward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor: 200 | if self.downsample is None: 201 | return r 202 | a = a.data.requires_grad_(True) 203 | assert isinstance(self.downsample[0], nn.Conv2d) 204 | return RelevancePropagationConv2d(self.downsample[0], eps=self.eps, top_k=self.top_k)(a, r) 205 | 206 | def forward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor: 207 | ratio = self._calc_mainstream_flow_ratio(a) 208 | assert r.shape == ratio.shape 209 | r_mainstream = ratio * r 210 | r_shortcut = (1 - ratio) * r 211 | r_mainstream = self.mainstream_backward(a, r_mainstream) 212 | r_shortcut = self.shortcut_backward(a, r_shortcut) 213 | return r_mainstream + r_shortcut 214 | 215 | 216 | class RelevancePropagationBottleneckSimple(RelevancePropagationBottleneck): 217 | """ Relevance propagation for Bottleneck Proto A 218 | Divide relevance score for plain shortcuts 219 | """ 220 | def __init__(self, layer: Bottleneck, eps: float = 1.0e-05, top_k: float = 0.0) -> None: 221 | super().__init__(layer=layer, eps=eps, top_k=top_k) 222 | 223 | def _calc_mainstream_flow_ratio(self, input_: torch.Tensor) -> torch.Tensor: 224 | if self.downsample is None: 225 | return torch.ones_like(input_) 226 | return torch.full_like(self.downsample(input_), 0.5) 227 | 228 | 229 | class RelevancePropagationBottleneckFlowsPureSkip(RelevancePropagationBottleneck): 230 | """ Relevance propagation for Bottleneck Proto A 231 | Divide relevance score for plain shortcuts 232 | """ 233 | def __init__(self, layer: Bottleneck, eps: float = 1.0e-05, top_k: float = 0.0) -> None: 234 | super().__init__(layer=layer, eps=eps, top_k=top_k) 235 | 236 | def _calc_mainstream_flow_ratio(self, input_: torch.Tensor) -> torch.Tensor: 237 | mainstream = input_ 238 | shortcut = input_ 239 | for layer in self.layers: 240 | mainstream = layer(mainstream) 241 | if self.downsample is not None: 242 | shortcut = self.downsample(shortcut) 243 | assert mainstream.shape == shortcut.shape 244 | return mainstream.abs() / (shortcut.abs() + mainstream.abs()) 245 | 246 | 247 | class RelevancePropagationBottleneckSimpleFlowsPureSkip(RelevancePropagationBottleneck): 248 | """ Relevance propagation for Bottleneck Proto A 249 | Divide relevance score for plain shortcuts 250 | """ 251 | def __init__(self, layer: Bottleneck, eps: float = 1.0e-05, top_k: float = 0.0) -> None: 252 | super().__init__(layer=layer, eps=eps, top_k=top_k) 253 | 254 | def _calc_mainstream_flow_ratio(self, input_: torch.Tensor) -> torch.Tensor: 255 | if self.downsample is None: 256 | return torch.full_like(input_, 0.5) 257 | return torch.full_like(self.downsample(input_), 0.5) 258 | 259 | 260 | class RelevancePropagationAdaptiveAvgPool2d(nn.Module): 261 | """Layer-wise relevance propagation for 2D adaptive average pooling. 262 | 263 | Attributes: 264 | layer: 2D adaptive average pooling layer. 265 | eps: A value added to the denominator for numerical stability. 266 | 267 | """ 268 | 269 | def __init__( 270 | self, 271 | layer: torch.nn.AdaptiveAvgPool2d, 272 | eps: float = 1.0e-05, 273 | top_k: float = 0.0, 274 | ) -> None: 275 | super().__init__() 276 | self.layer = layer 277 | self.eps = eps 278 | self.top_k = top_k 279 | 280 | def forward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor: 281 | if self.top_k: 282 | r = relevance_filter(r, top_k_percent=self.top_k) 283 | z = self.layer.forward(a) + self.eps 284 | s = (r / z).data 285 | (z * s).sum().backward() 286 | c = a.grad 287 | r = (a * c).data 288 | return r 289 | 290 | 291 | class RelevancePropagationAvgPool2d(nn.Module): 292 | """Layer-wise relevance propagation for 2D average pooling. 293 | 294 | Attributes: 295 | layer: 2D average pooling layer. 296 | eps: A value added to the denominator for numerical stability. 297 | 298 | """ 299 | 300 | def __init__( 301 | self, layer: torch.nn.AvgPool2d, eps: float = 1.0e-05, top_k: float = 0.0 302 | ) -> None: 303 | super().__init__() 304 | self.layer = layer 305 | self.eps = eps 306 | self.top_k = top_k 307 | 308 | def forward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor: 309 | if self.top_k: 310 | r = relevance_filter(r, top_k_percent=self.top_k) 311 | z = self.layer.forward(a) + self.eps 312 | s = (r / z).data 313 | (z * s).sum().backward() 314 | c = a.grad 315 | r = (a * c).data 316 | return r 317 | 318 | 319 | class RelevancePropagationMaxPool2d(nn.Module): 320 | """Layer-wise relevance propagation for 2D max pooling. 321 | 322 | Optionally substitutes max pooling by average pooling layers. 323 | 324 | Attributes: 325 | layer: 2D max pooling layer. 326 | eps: a value added to the denominator for numerical stability. 327 | 328 | """ 329 | 330 | def __init__( 331 | self, 332 | layer: torch.nn.MaxPool2d, 333 | mode: str = "avg", 334 | eps: float = 1.0e-05, 335 | top_k: float = 0.0, 336 | ) -> None: 337 | super().__init__() 338 | 339 | if mode == "avg": 340 | self.layer = torch.nn.AvgPool2d(kernel_size=(2, 2)) 341 | elif mode == "max": 342 | self.layer = layer 343 | 344 | self.eps = eps 345 | self.top_k = top_k 346 | 347 | def forward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor: 348 | if self.top_k: 349 | r = relevance_filter(r, top_k_percent=self.top_k) 350 | z = self.layer.forward(a) + self.eps 351 | s = (r / z).data 352 | (z * s).sum().backward() 353 | c = a.grad 354 | r = (a * c).data 355 | # print(f"maxpool2d {r.min()}, {r.max()}") 356 | return r 357 | 358 | 359 | class RelevancePropagationConv2d(nn.Module): 360 | """Layer-wise relevance propagation for 2D convolution. 361 | 362 | Optionally modifies layer weights according to propagation rule. Here z^+-rule 363 | 364 | Attributes: 365 | layer: 2D convolutional layer. 366 | eps: a value added to the denominator for numerical stability. 367 | 368 | """ 369 | 370 | def __init__( 371 | self, 372 | layer: torch.nn.Conv2d, 373 | mode: str = "z_plus", 374 | eps: float = 1.0e-05, 375 | top_k: float = 0.0, 376 | ) -> None: 377 | super().__init__() 378 | 379 | self.layer = layer 380 | 381 | if mode == "z_plus": 382 | self.layer.weight = torch.nn.Parameter(self.layer.weight.clamp(min=0.0)) 383 | 384 | self.eps = eps 385 | self.top_k = top_k 386 | 387 | def forward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor: 388 | if self.top_k: 389 | r = relevance_filter(r, top_k_percent=self.top_k) 390 | z = self.layer.forward(a) + self.eps 391 | s = (r / z).data 392 | (z * s).sum().backward() 393 | c = a.grad 394 | r = (a * c).data 395 | # print(f"before norm: {r.sum()}") 396 | # r = (r - r.min()) / (r.max() - r.min()) 397 | # print(f"after norm: {r.sum()}\n") 398 | if r.shape != a.shape: 399 | raise RuntimeError("r.shape != a.shape") 400 | return r 401 | 402 | 403 | class RelevancePropagationLinear(nn.Module): 404 | """Layer-wise relevance propagation for linear transformation. 405 | 406 | Optionally modifies layer weights according to propagation rule. Here z^+-rule 407 | 408 | Attributes: 409 | layer: linear transformation layer. 410 | eps: a value added to the denominator for numerical stability. 411 | 412 | """ 413 | 414 | def __init__( 415 | self, 416 | layer: torch.nn.Linear, 417 | mode: str = "z_plus", 418 | eps: float = 1.0e-05, 419 | top_k: float = 0.0, 420 | ) -> None: 421 | super().__init__() 422 | 423 | self.layer = layer 424 | 425 | if mode == "z_plus": 426 | self.layer.weight = torch.nn.Parameter(self.layer.weight.clamp(min=0.0)) 427 | self.layer.bias = torch.nn.Parameter(torch.zeros_like(self.layer.bias)) 428 | 429 | self.eps = eps 430 | self.top_k = top_k 431 | 432 | @torch.no_grad() 433 | def forward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor: 434 | if self.top_k: 435 | r = relevance_filter(r, top_k_percent=self.top_k) 436 | z = self.layer.forward(a) + self.eps 437 | s = r / z 438 | c = torch.mm(s, self.layer.weight) 439 | r = (a * c).data 440 | # print(f"Linear {r.min()}, {r.max()}") 441 | return r 442 | 443 | 444 | class RelevancePropagationFlatten(nn.Module): 445 | """Layer-wise relevance propagation for flatten operation. 446 | 447 | Attributes: 448 | layer: flatten layer. 449 | 450 | """ 451 | 452 | def __init__(self, layer: torch.nn.Flatten, top_k: float = 0.0) -> None: 453 | super().__init__() 454 | self.layer = layer 455 | 456 | @torch.no_grad() 457 | def forward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor: 458 | r = r.view(size=a.shape) 459 | return r 460 | 461 | 462 | class RelevancePropagationReLU(nn.Module): 463 | """Layer-wise relevance propagation for ReLU activation. 464 | 465 | Passes the relevance scores without modification. Might be of use later. 466 | 467 | """ 468 | 469 | def __init__(self, layer: torch.nn.ReLU, top_k: float = 0.0) -> None: 470 | super().__init__() 471 | 472 | @torch.no_grad() 473 | def forward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor: 474 | return r 475 | 476 | 477 | class RelevancePropagationBatchNorm2d(nn.Module): 478 | """Layer-wise relevance propagation for ReLU activation. 479 | 480 | Passes the relevance scores without modification. Might be of use later. 481 | 482 | """ 483 | 484 | def __init__(self, layer: torch.nn.BatchNorm2d, top_k: float = 0.0) -> None: 485 | super().__init__() 486 | 487 | @torch.no_grad() 488 | def forward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor: 489 | return r 490 | 491 | 492 | class RelevancePropagationDropout(nn.Module): 493 | """Layer-wise relevance propagation for dropout layer. 494 | 495 | Passes the relevance scores without modification. Might be of use later. 496 | 497 | """ 498 | 499 | def __init__(self, layer: torch.nn.Dropout, top_k: float = 0.0) -> None: 500 | super().__init__() 501 | 502 | @torch.no_grad() 503 | def forward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor: 504 | return r 505 | 506 | 507 | class RelevancePropagationIdentity(nn.Module): 508 | """Identity layer for relevance propagation. 509 | 510 | Passes relevance scores without modifying them. 511 | 512 | """ 513 | 514 | def __init__(self, layer: nn.Module, top_k: float = 0.0) -> None: 515 | super().__init__() 516 | 517 | @torch.no_grad() 518 | def forward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor: 519 | return r 520 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | """Script with helper function.""" 2 | from typing import Literal 3 | 4 | from models.lrp import * 5 | from src.lrp_layers import * 6 | 7 | SkipConnectionPropType = Literal["simple", "flows_skip", "flows_skip_simple", "latest"] 8 | 9 | def layers_lookup(version: SkipConnectionPropType = "latest") -> dict: 10 | """Lookup table to map network layer to associated LRP operation. 11 | 12 | Returns: 13 | Dictionary holding class mappings. 14 | """ 15 | 16 | # For the purpose of the ablation study on relevance propagation for skip connections 17 | if version == "simple": 18 | return layers_lookup_simple() 19 | elif version == "flows_skip": 20 | return layers_lookup_flows_pure_skip() 21 | elif version == "flows_skip_simple": 22 | return layers_lookup_simple_flows_pure_skip() 23 | elif version == "latest": 24 | return layers_lookup_latest() 25 | else: 26 | raise ValueError("Invalid version was specified.") 27 | 28 | 29 | def layers_lookup_latest() -> dict: 30 | lookup_table = { 31 | torch.nn.modules.linear.Linear: RelevancePropagationLinear, 32 | torch.nn.modules.conv.Conv2d: RelevancePropagationConv2d, 33 | torch.nn.modules.activation.ReLU: RelevancePropagationReLU, 34 | torch.nn.modules.dropout.Dropout: RelevancePropagationDropout, 35 | torch.nn.modules.flatten.Flatten: RelevancePropagationFlatten, 36 | torch.nn.modules.pooling.AvgPool2d: RelevancePropagationAvgPool2d, 37 | torch.nn.modules.pooling.MaxPool2d: RelevancePropagationMaxPool2d, 38 | torch.nn.modules.pooling.AdaptiveAvgPool2d: RelevancePropagationAdaptiveAvgPool2d, 39 | AdaptiveAvgPool2dWithActivation: RelevancePropagationAdaptiveAvgPool2d, 40 | torch.nn.BatchNorm2d: RelevancePropagationBatchNorm2d, 41 | BatchNorm2dWithActivation: RelevancePropagationBatchNorm2d, 42 | BasicBlock: RelevancePropagationBasicBlock, 43 | Bottleneck: RelevancePropagationBottleneck, 44 | } 45 | return lookup_table 46 | 47 | 48 | def layers_lookup_simple() -> dict: 49 | lookup_table = { 50 | torch.nn.modules.linear.Linear: RelevancePropagationLinear, 51 | torch.nn.modules.conv.Conv2d: RelevancePropagationConv2d, 52 | torch.nn.modules.activation.ReLU: RelevancePropagationReLU, 53 | torch.nn.modules.dropout.Dropout: RelevancePropagationDropout, 54 | torch.nn.modules.flatten.Flatten: RelevancePropagationFlatten, 55 | torch.nn.modules.pooling.AvgPool2d: RelevancePropagationAvgPool2d, 56 | torch.nn.modules.pooling.MaxPool2d: RelevancePropagationMaxPool2d, 57 | torch.nn.modules.pooling.AdaptiveAvgPool2d: RelevancePropagationAdaptiveAvgPool2d, 58 | AdaptiveAvgPool2dWithActivation: RelevancePropagationAdaptiveAvgPool2d, 59 | torch.nn.BatchNorm2d: RelevancePropagationBatchNorm2d, 60 | BatchNorm2dWithActivation: RelevancePropagationBatchNorm2d, 61 | BasicBlock: RelevancePropagationBasicBlockSimple, 62 | Bottleneck: RelevancePropagationBottleneckSimple, 63 | } 64 | return lookup_table 65 | 66 | 67 | def layers_lookup_flows_pure_skip() -> dict: 68 | lookup_table = { 69 | torch.nn.modules.linear.Linear: RelevancePropagationLinear, 70 | torch.nn.modules.conv.Conv2d: RelevancePropagationConv2d, 71 | torch.nn.modules.activation.ReLU: RelevancePropagationReLU, 72 | torch.nn.modules.dropout.Dropout: RelevancePropagationDropout, 73 | torch.nn.modules.flatten.Flatten: RelevancePropagationFlatten, 74 | torch.nn.modules.pooling.AvgPool2d: RelevancePropagationAvgPool2d, 75 | torch.nn.modules.pooling.MaxPool2d: RelevancePropagationMaxPool2d, 76 | torch.nn.modules.pooling.AdaptiveAvgPool2d: RelevancePropagationAdaptiveAvgPool2d, 77 | AdaptiveAvgPool2dWithActivation: RelevancePropagationAdaptiveAvgPool2d, 78 | torch.nn.BatchNorm2d: RelevancePropagationBatchNorm2d, 79 | BatchNorm2dWithActivation: RelevancePropagationBatchNorm2d, 80 | BasicBlock: RelevancePropagationBasicBlockFlowsPureSkip, 81 | Bottleneck: RelevancePropagationBottleneckFlowsPureSkip, 82 | } 83 | return lookup_table 84 | 85 | 86 | def layers_lookup_simple_flows_pure_skip() -> dict: 87 | lookup_table = { 88 | torch.nn.modules.linear.Linear: RelevancePropagationLinear, 89 | torch.nn.modules.conv.Conv2d: RelevancePropagationConv2d, 90 | torch.nn.modules.activation.ReLU: RelevancePropagationReLU, 91 | torch.nn.modules.dropout.Dropout: RelevancePropagationDropout, 92 | torch.nn.modules.flatten.Flatten: RelevancePropagationFlatten, 93 | torch.nn.modules.pooling.AvgPool2d: RelevancePropagationAvgPool2d, 94 | torch.nn.modules.pooling.MaxPool2d: RelevancePropagationMaxPool2d, 95 | torch.nn.modules.pooling.AdaptiveAvgPool2d: RelevancePropagationAdaptiveAvgPool2d, 96 | AdaptiveAvgPool2dWithActivation: RelevancePropagationAdaptiveAvgPool2d, 97 | torch.nn.BatchNorm2d: RelevancePropagationBatchNorm2d, 98 | BatchNorm2dWithActivation: RelevancePropagationBatchNorm2d, 99 | BasicBlock: RelevancePropagationBasicBlockSimpleFlowsPureSkip, 100 | Bottleneck: RelevancePropagationBottleneckSimpleFlowsPureSkip, 101 | } 102 | return lookup_table 103 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os 4 | import random 5 | from typing import Dict, Iterable, List, Optional, Tuple 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | from torch.utils.data import DataLoader 13 | from torchinfo import summary 14 | from tqdm import tqdm 15 | 16 | import wandb 17 | from data import ALL_DATASETS, create_dataloader_dict, get_parameter_depend_in_data_set 18 | from evaluate import test 19 | from metrics.base import Metric 20 | from models import ALL_MODELS, create_model 21 | from models.attention_branch import AttentionBranchModel 22 | from models.lrp import BottleneckWithActivation 23 | from optim import ALL_OPTIM, create_optimizer 24 | from optim.sam import SAM 25 | from utils.loss import calculate_loss 26 | from utils.utils import fix_seed, module_generator, parse_with_config, save_json 27 | from utils.visualize import save_attention_map 28 | 29 | 30 | class EarlyStopping: 31 | """ 32 | Attributes: 33 | patience(int): How long to wait after last time validation loss improved. 34 | delta(float) : Minimum change in the monitored quantity to qualify as an improvement. 35 | save_dir(str): Directory to save a model when improvement is found. 36 | """ 37 | 38 | def __init__( 39 | self, patience: int = 7, delta: float = 0, save_dir: str = "." 40 | ) -> None: 41 | self.patience = patience 42 | self.delta = delta 43 | self.save_dir = save_dir 44 | 45 | self.counter: int = 0 46 | self.early_stop: bool = False 47 | self.best_val_loss: float = np.Inf 48 | 49 | def __call__(self, val_loss: float, net: nn.Module) -> str: 50 | if val_loss + self.delta < self.best_val_loss: 51 | log = f"({self.best_val_loss:.5f} --> {val_loss:.5f})" 52 | self._save_checkpoint(net) 53 | self.best_val_loss = val_loss 54 | self.counter = 0 55 | return log 56 | 57 | self.counter += 1 58 | log = f"(> {self.best_val_loss:.5f} {self.counter}/{self.patience})" 59 | if self.counter >= self.patience: 60 | self.early_stop = True 61 | return log 62 | 63 | def _save_checkpoint(self, net: nn.Module) -> None: 64 | save_path = os.path.join(self.save_dir, "checkpoint.pt") 65 | torch.save(net.state_dict(), save_path) 66 | 67 | 68 | def set_parameter_trainable(module: nn.Module, is_trainable: bool = True) -> None: 69 | """ 70 | Set all parameters of the module to is_trainable(bool) 71 | 72 | Args: 73 | module(nn.Module): Target module 74 | is_trainable(bool): Whether to train the parameters 75 | """ 76 | for param in module.parameters(): 77 | param.requires_grad = is_trainable 78 | 79 | 80 | def set_trainable_bottlenecks(model, num_trainable): 81 | # 2. Set the last num_trainable Bottleneck/BottleneckWithActivation layers as trainable 82 | count = 0 83 | for child in reversed(list(model.children())): 84 | # Go through each layer in the current sequential block 85 | for layer in reversed(list(child.children())): 86 | if isinstance(layer, (BottleneckWithActivation)): 87 | for param in layer.parameters(): 88 | param.requires_grad = True 89 | count += 1 90 | if count == num_trainable: 91 | return 92 | 93 | 94 | def freeze_model( 95 | model: nn.Module, 96 | num_trainable_module: int = 0, 97 | fe_trainable: bool = False, 98 | ab_trainable: bool = False, 99 | perception_trainable: bool = False, 100 | final_trainable: bool = True, 101 | ) -> None: 102 | """ 103 | Freeze the model 104 | After freezing the whole model, only the final layer is trainable 105 | Then, the last num_trainable_module are trainable from the back 106 | 107 | Args: 108 | num_trainable_module(int): Number of trainable modules 109 | fe_trainable(bool): Whether to train the Feature Extractor 110 | ab_trainable(bool): Whether to train the Attention Branch 111 | perception_trainable(bool): Whether to train the Perception Branch 112 | 113 | Note: 114 | (fe|ab|perception)_trainable is only used when AttentionBranchModel 115 | num_trainable_module takes precedence over the above 116 | """ 117 | if isinstance(model, AttentionBranchModel): 118 | set_parameter_trainable(model.feature_extractor, fe_trainable) 119 | set_parameter_trainable(model.attention_branch, ab_trainable) 120 | set_parameter_trainable(model.perception_branch, perception_trainable) 121 | modules = module_generator(model.perception_branch, reverse=True) 122 | else: 123 | if num_trainable_module < 0: 124 | set_parameter_trainable(model) 125 | return 126 | set_parameter_trainable(model, is_trainable=False) 127 | modules = module_generator(model, reverse=True) 128 | 129 | final_layer = modules.__next__() 130 | set_parameter_trainable(final_layer, final_trainable) 131 | # set_parameter_trainable(model.perception_branch[0], False) 132 | # set_parameter_trainable(model.feature_extractor[0], True) 133 | # set_parameter_trainable(model.feature_extractor[1], True) 134 | 135 | # for i, module in enumerate(modules): 136 | # if num_trainable_module <= i: 137 | # break 138 | 139 | # set_parameter_trainable(module) 140 | set_trainable_bottlenecks(model, num_trainable_module) 141 | 142 | 143 | def setting_learning_rate( 144 | model: nn.Module, lr: float, lr_linear: float, lr_ab: Optional[float] = None 145 | ) -> Iterable: 146 | """ 147 | Set learning rate for each layer 148 | 149 | Args: 150 | model (nn.Module): Model to set learning rate 151 | lr(float) : Learning rate for the last layer/Attention Branch 152 | lr_linear(float): Learning rate for the last layer 153 | lr_ab(float) : Learning rate for Attention Branch 154 | 155 | Returns: 156 | Iterable with learning rate 157 | It is given to the argument of optim.Optimizer 158 | """ 159 | if isinstance(model, AttentionBranchModel): 160 | if lr_ab is None: 161 | lr_ab = lr_linear 162 | params = [ 163 | {"params": model.attention_branch.parameters(), "lr": lr_ab}, 164 | {"params": model.perception_branch[:-1].parameters(), "lr": lr}, 165 | {"params": model.perception_branch[-1].parameters(), "lr": lr_linear}, 166 | ] 167 | else: 168 | try: 169 | params = [ 170 | {"params": model[:-1].parameters(), "lr": lr}, 171 | {"params": model[-1].parameters(), "lr": lr_linear}, 172 | ] 173 | except TypeError: 174 | params = [{"params": model.parameters(), "lr": lr}] 175 | 176 | return params 177 | 178 | 179 | def wandb_log(loss: float, metrics: Metric, phase: str) -> None: 180 | """ 181 | Output logs to wandb 182 | Add phase to each metric for easy understanding 183 | (e.g. Acc -> Train_Acc) 184 | 185 | Args: 186 | loss(float) : Loss value 187 | metircs(Metric): Evaluation metrics 188 | phase(str) : train / val / test 189 | """ 190 | log_items = {f"{phase}_loss": loss} 191 | 192 | for metric, value in metrics.score().items(): 193 | log_items[f"{phase}_{metric}"] = value 194 | 195 | wandb.log(log_items) 196 | 197 | 198 | def train_insdel( 199 | model: nn.Module, 200 | images: torch.Tensor, 201 | labels: torch.Tensor, 202 | criterion: nn.modules.loss._Loss, 203 | mode: str, 204 | theta_dist: List[float] = [0.3, 0.5, 0.7], 205 | ): 206 | assert isinstance(model, AttentionBranchModel) 207 | attention_map = model.attention_branch.attention 208 | attention_map = F.interpolate(attention_map, images.shape[2:]) 209 | att_base = attention_map.max() 210 | 211 | theta = random.choice(theta_dist) 212 | assert mode in ["insertion", "deletion"] 213 | if mode == "insertion": 214 | labels = torch.ones_like(labels) 215 | attention_map = torch.where(attention_map > att_base * theta, 1.0, 0.0) 216 | if mode == "deletion": 217 | labels = torch.zeros_like(labels) 218 | attention_map = torch.where(attention_map > att_base * theta, 0.0, 1.0) 219 | 220 | inputs = images * attention_map 221 | output = model(inputs.float()) 222 | loss = criterion(output, labels) 223 | 224 | return loss 225 | 226 | 227 | def train( 228 | dataloader: DataLoader, 229 | model: nn.Module, 230 | criterion: nn.modules.loss._Loss, 231 | optimizer: optim.Optimizer, 232 | metric: Metric, 233 | lambdas: Optional[Dict[str, float]] = None, 234 | saliency: bool = False, 235 | ) -> Tuple[float, Metric]: 236 | total = 0 237 | total_loss: float = 0 238 | torch.autograd.set_detect_anomaly(True) 239 | 240 | model.train() 241 | for data_ in tqdm(dataloader, desc="Train: ", dynamic_ncols=True): 242 | inputs, labels = ( 243 | data_[0].to(device), 244 | data_[1].to(device), 245 | ) 246 | optimizer.zero_grad() 247 | outputs = model(inputs) 248 | 249 | loss = calculate_loss(criterion, outputs, labels, model, lambdas) 250 | loss.backward() 251 | total_loss += loss.item() 252 | 253 | metric.evaluate(outputs, labels) 254 | 255 | # When the optimizer is SAM, backward twice 256 | if isinstance(optimizer, SAM): 257 | optimizer.first_step(zero_grad=True) 258 | loss_sam = calculate_loss(criterion, model(inputs), labels, model, lambdas) 259 | loss_sam.backward() 260 | optimizer.second_step(zero_grad=True) 261 | else: 262 | optimizer.step() 263 | 264 | total += labels.size(0) 265 | 266 | train_loss = total_loss / total 267 | return train_loss, metric 268 | 269 | 270 | def main(args: argparse.Namespace): 271 | now = datetime.datetime.now() 272 | now_str = now.strftime("%Y-%m-%d_%H%M%S") 273 | 274 | fix_seed(args.seed, args.no_deterministic) 275 | 276 | # Create dataloaders 277 | dataloader_dict = create_dataloader_dict( 278 | args.dataset, 279 | args.batch_size, 280 | args.image_size, 281 | train_ratio=args.train_ratio, 282 | ) 283 | data_params = get_parameter_depend_in_data_set( 284 | args.dataset, pos_weight=torch.Tensor(args.loss_weights).to(device) 285 | ) 286 | 287 | # Create a model 288 | model = create_model( 289 | args.model, 290 | num_classes=len(data_params["classes"]), 291 | num_channel=data_params["num_channel"], 292 | base_pretrained=args.base_pretrained, 293 | base_pretrained2=args.base_pretrained2, 294 | pretrained_path=args.pretrained, 295 | attention_branch=args.add_attention_branch, 296 | division_layer=args.div, 297 | theta_attention=args.theta_att, 298 | ) 299 | assert model is not None, "Model name is invalid" 300 | freeze_model( 301 | model, 302 | args.trainable_module, 303 | "fe" not in args.freeze, 304 | "ab" not in args.freeze, 305 | "pb" not in args.freeze, 306 | "linear" not in args.freeze, 307 | ) 308 | 309 | # Setup optimizer and scheduler 310 | params = setting_learning_rate(model, args.lr, args.lr_linear, args.lr_ab) 311 | optimizer = create_optimizer( 312 | args.optimizer, params, args.lr, args.weight_decay, args.momentum 313 | ) 314 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) 315 | # scheduler = CosineLRScheduler(optimizer, t_initial=100, lr_min=args.min_lr, warmup_t=10, warmup_prefix=True) 316 | 317 | if args.saliency_guided: 318 | set_parameter_trainable(model.perception_branch[0], False) 319 | 320 | criterion = data_params["criterion"] 321 | metric = data_params["metric"] 322 | 323 | # Create run_name (for save_dir / wandb) 324 | if args.model is not None: 325 | config_file = os.path.basename(args.config) 326 | run_name = os.path.splitext(config_file)[0] 327 | else: 328 | run_name = args.model 329 | run_name += ["", f"_div{args.div}"][args.add_attention_branch] 330 | run_name = f"{run_name}_{now_str}" 331 | if args.run_name is not None: 332 | run_name = args.run_name 333 | 334 | save_dir = os.path.join(args.save_dir, run_name) 335 | assert not os.path.isdir(save_dir) 336 | os.makedirs(save_dir) 337 | best_path = os.path.join(save_dir, "best.pt") 338 | 339 | configs = vars(args) 340 | configs.pop("config") # To prevent the old config from being included in the new config applied with the model 341 | 342 | early_stopping = EarlyStopping( 343 | patience=args.early_stopping_patience, save_dir=save_dir 344 | ) 345 | 346 | wandb.init(project=args.dataset, name=run_name, notes=args.notes) 347 | wandb.config.update(configs) 348 | configs["pretrained"] = best_path 349 | save_json(configs, os.path.join(save_dir, "config.json")) 350 | 351 | # Model details display (torchsummary) 352 | summary( 353 | model, 354 | (args.batch_size, data_params["num_channel"], args.image_size, args.image_size), 355 | ) 356 | 357 | lambdas = {"att": args.lambda_att} 358 | 359 | save_test_acc = 0 360 | model.to(device) 361 | for epoch in range(args.epochs): 362 | print(f"\n[Epoch {epoch+1}]") 363 | for phase, dataloader in dataloader_dict.items(): 364 | if phase == "Train": 365 | loss, metric = train( 366 | dataloader, 367 | model, 368 | criterion, 369 | optimizer, 370 | metric, 371 | lambdas=lambdas, 372 | ) 373 | else: 374 | loss, metric = test( 375 | dataloader, 376 | model, 377 | criterion, 378 | metric, 379 | device, 380 | phase, 381 | lambdas=lambdas, 382 | ) 383 | 384 | metric_log = metric.log() 385 | log = f"{phase}\t| {metric_log} Loss: {loss:.5f} " 386 | 387 | wandb_log(loss, metric, phase) 388 | 389 | if phase == "Val": 390 | early_stopping_log = early_stopping(loss, model) 391 | log += early_stopping_log 392 | scheduler.step(loss) 393 | 394 | print(log) 395 | if phase == "Test" and not early_stopping.early_stop: 396 | save_test_acc = metric.acc() 397 | 398 | metric.clear() 399 | if args.add_attention_branch: 400 | save_attention_map( 401 | model.attention_branch.attention[0][0], "attention.png" 402 | ) 403 | 404 | if early_stopping.early_stop: 405 | print("Early Stopping") 406 | model.load_state_dict(torch.load(os.path.join(save_dir, "checkpoint.pt"))) 407 | break 408 | 409 | torch.save(model.state_dict(), os.path.join(save_dir, "best.pt")) 410 | configs["test_acc"] = save_test_acc.item() 411 | save_json(configs, os.path.join(save_dir, "config.json")) 412 | wandb.log({"final_test_acc": save_test_acc}) 413 | print("Training Finished") 414 | 415 | 416 | def parse_args(): 417 | parser = argparse.ArgumentParser() 418 | 419 | parser.add_argument("-c", "--config", type=str, help="path to config file (json)") 420 | 421 | parser.add_argument("--seed", type=int, default=42) 422 | parser.add_argument("--no_deterministic", action="store_false") 423 | 424 | parser.add_argument("-n", "--notes", type=str, default="") 425 | 426 | # Model 427 | parser.add_argument("-m", "--model", choices=ALL_MODELS, help="model name") 428 | parser.add_argument( 429 | "-add_ab", 430 | "--add_attention_branch", 431 | action="store_true", 432 | help="add Attention Branch", 433 | ) 434 | parser.add_argument( 435 | "--div", 436 | type=str, 437 | choices=["layer1", "layer2", "layer3"], 438 | default="layer2", 439 | help="place to attention branch", 440 | ) 441 | parser.add_argument("--base_pretrained", type=str, help="path to base pretrained") 442 | parser.add_argument( 443 | "--base_pretrained2", 444 | type=str, 445 | help="path to base pretrained2 ( after change_num_classes() )", 446 | ) 447 | parser.add_argument("--pretrained", type=str, help="path to pretrained") 448 | parser.add_argument( 449 | "--theta_att", type=float, default=0, help="threthold of attention branch" 450 | ) 451 | 452 | # Freeze 453 | parser.add_argument( 454 | "--freeze", 455 | type=str, 456 | nargs="*", 457 | choices=["fe", "ab", "pb", "linear"], 458 | default=[], 459 | help="freezing layer", 460 | ) 461 | parser.add_argument( 462 | "--trainable_module", 463 | type=int, 464 | default=-1, 465 | help="number of trainable modules, -1: all trainable", 466 | ) 467 | 468 | # Dataset 469 | parser.add_argument("--dataset", type=str, default="IDRiD", choices=ALL_DATASETS) 470 | parser.add_argument("--image_size", type=int, default=224) 471 | parser.add_argument("--batch_size", type=int, default=32) 472 | parser.add_argument( 473 | "--train_ratio", type=float, default=0.8, help="ratio for train val split" 474 | ) 475 | parser.add_argument( 476 | "--loss_weights", 477 | type=float, 478 | nargs="*", 479 | default=[1.0, 1.0], 480 | help="weights for label by class", 481 | ) 482 | 483 | # Optimizer 484 | parser.add_argument("--epochs", type=int, default=200) 485 | parser.add_argument( 486 | "-optim", "--optimizer", type=str, default="AdamW", choices=ALL_OPTIM 487 | ) 488 | parser.add_argument( 489 | "--lr", 490 | "--learning_rate", 491 | type=float, 492 | default=1e-4, 493 | ) 494 | parser.add_argument( 495 | "--lr_linear", 496 | type=float, 497 | default=1e-3, 498 | ) 499 | parser.add_argument( 500 | "--lr_ab", 501 | "--lr_attention_branch", 502 | type=float, 503 | default=1e-3, 504 | ) 505 | parser.add_argument( 506 | "--min_lr", 507 | type=float, 508 | default=1e-6, 509 | ) 510 | parser.add_argument( 511 | "--momentum", 512 | type=float, 513 | default=0.9, 514 | ) 515 | parser.add_argument( 516 | "--weight_decay", 517 | type=float, 518 | default=0.01, 519 | ) 520 | parser.add_argument( 521 | "--factor", type=float, default=0.3333, help="new_lr = lr * factor" 522 | ) 523 | parser.add_argument( 524 | "--scheduler_patience", 525 | type=int, 526 | default=2, 527 | help="Number of epochs with no improvement after which learning rate will be reduced", 528 | ) 529 | 530 | parser.add_argument( 531 | "--lambda_att", type=float, default=0.1, help="weights for attention loss" 532 | ) 533 | 534 | parser.add_argument( 535 | "--early_stopping_patience", type=int, default=6, help="Early Stopping patience" 536 | ) 537 | parser.add_argument( 538 | "--save_dir", type=str, default="checkpoints", help="path to save checkpoints" 539 | ) 540 | 541 | parser.add_argument( 542 | "--run_name", type=str, help="save in save_dir/run_name and wandb name" 543 | ) 544 | 545 | return parse_with_config(parser) 546 | 547 | 548 | if __name__ == "__main__": 549 | # import pdb; pdb.set_trace() 550 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 551 | 552 | main(parse_args()) 553 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | 7 | from models.attention_branch import AttentionBranchModel 8 | 9 | 10 | def criterion_with_cast_targets( 11 | criterion: nn.modules.loss._Loss, preds: torch.Tensor, targets: torch.Tensor 12 | ) -> torch.Tensor: 13 | """ 14 | Calculate the loss after changing the type 15 | 16 | Args: 17 | criterion(Loss): loss function 18 | preds(Tensor) : prediction 19 | targets(Tensor): label 20 | 21 | Returns: 22 | torch.Tensor: loss value 23 | 24 | Note: 25 | The type required by the loss function is different, so we convert it 26 | """ 27 | if isinstance(criterion, nn.CrossEntropyLoss): 28 | # targets = F.one_hot(targets, num_classes=2) 29 | targets = targets.long() 30 | 31 | if isinstance(criterion, nn.BCEWithLogitsLoss): 32 | targets = F.one_hot(targets, num_classes=2) 33 | targets = targets.to(preds.dtype) 34 | 35 | return criterion(preds, targets) 36 | 37 | 38 | def calculate_loss( 39 | criterion: nn.modules.loss._Loss, 40 | outputs: torch.Tensor, 41 | targets: torch.Tensor, 42 | model: nn.Module, 43 | lambdas: Optional[Dict[str, float]] = None, 44 | ) -> torch.Tensor: 45 | """ 46 | Calculate the loss 47 | Add the attention loss when AttentionBranchModel 48 | 49 | Args: 50 | criterion(Loss) : Loss function 51 | preds(Tensor) : Prediction 52 | targets(Tensor) : Label 53 | model(nn.Module) : Model that made the prediction 54 | lambdas(Dict[str, float]): Weight of each term of the loss 55 | 56 | Returns: 57 | torch.Tensor: Loss value 58 | """ 59 | loss = criterion_with_cast_targets(criterion, outputs, targets) 60 | 61 | # Attention Loss 62 | if isinstance(model, AttentionBranchModel): 63 | keys = ["att", "var"] 64 | if lambdas is None: 65 | lambdas = {key: 1 for key in keys} 66 | for key in keys: 67 | if key not in lambdas: 68 | lambdas[key] = 1 69 | 70 | attention_loss = criterion_with_cast_targets( 71 | criterion, model.attention_pred, targets 72 | ) 73 | # loss = loss + attention_loss 74 | # attention = model.attention_branch.attention 75 | # _, _, W, H = attention.size() 76 | 77 | # att_sum = torch.sum(attention, dim=(-1, -2)) 78 | # attention_loss = torch.mean(att_sum / (W * H)) 79 | loss = loss + lambdas["att"] * attention_loss 80 | # attention = model.attention_branch.attention 81 | # attention_varmean = attention.var(dim=(1, 2, 3)).mean() 82 | # loss = loss - lambdas["var"] * attention_varmean 83 | 84 | return loss 85 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import random 4 | import sys 5 | from typing import Dict, List, Tuple, Union 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | 12 | def fix_seed(seed: int, deterministic: bool = False) -> None: 13 | # random 14 | random.seed(seed) 15 | # numpy 16 | np.random.seed(seed) 17 | # pytorch 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed_all(seed) 20 | # cudnn 21 | torch.backends.cudnn.deterministic = deterministic 22 | torch.backends.cudnn.benchmark = not deterministic 23 | 24 | 25 | def reverse_normalize( 26 | x: np.ndarray, 27 | mean: Union[Tuple[float], Tuple[float, float, float]], 28 | std: Union[Tuple[float], Tuple[float, float, float]], 29 | ): 30 | """ 31 | Restore normalization 32 | 33 | Args: 34 | x(ndarray) : Matrix that has been normalized 35 | mean(Tuple): Mean specified at the time of normalization 36 | std(Tuple) : Standard deviation specified at the time of normalization 37 | """ 38 | if x.shape[0] == 1: 39 | x = x * std + mean 40 | return x 41 | x[0, :, :] = x[0, :, :] * std[0] + mean[0] 42 | x[1, :, :] = x[1, :, :] * std[1] + mean[1] 43 | x[2, :, :] = x[2, :, :] * std[2] + mean[2] 44 | 45 | return x 46 | 47 | 48 | def module_generator(model: nn.Module, reverse: bool = False): 49 | """ 50 | Generator for nested Module, can handle nested Sequential in one layer 51 | Note that you cannot get layers by index 52 | 53 | Args: 54 | model (nn.Module): Model 55 | reverse(bool) : Whether to reverse 56 | 57 | Yields: 58 | Each layer of the model 59 | """ 60 | modules = list(model.children()) 61 | if reverse: 62 | modules = modules[::-1] 63 | 64 | for module in modules: 65 | if list(module.children()): 66 | yield from module_generator(module, reverse) 67 | continue 68 | yield module 69 | 70 | 71 | def save_json(data: Union[List, Dict], save_path: str) -> None: 72 | """ 73 | Save list/dict to json 74 | 75 | Args: 76 | data (List/Dict): Data to save 77 | save_path(str) : Path to save (including extension) 78 | """ 79 | with open(save_path, "w") as f: 80 | json.dump(data, f, indent=4) 81 | 82 | 83 | def softmax_image(image: torch.Tensor) -> torch.Tensor: 84 | image_size = image.size() 85 | if len(image_size) == 4: 86 | B, C, W, H = image_size 87 | elif len(image_size) == 3: 88 | B = 1 89 | C, W, H = image_size 90 | else: 91 | raise ValueError 92 | 93 | image = image.view(B, C, W * H) 94 | image = torch.softmax(image, dim=-1) 95 | 96 | image = image.view(B, C, W, H) 97 | if len(image_size) == 3: 98 | image = image[0] 99 | 100 | return image 101 | 102 | 103 | def tensor_to_numpy(tensor: Union[torch.Tensor, np.ndarray]) -> np.ndarray: 104 | if isinstance(tensor, torch.Tensor): 105 | result: np.ndarray = tensor.cpu().detach().numpy() 106 | else: 107 | result = tensor 108 | 109 | return result 110 | 111 | 112 | def parse_with_config(parser: argparse.ArgumentParser) -> argparse.Namespace: 113 | """ 114 | Coexistence of argparse and json file config 115 | 116 | Args: 117 | parser(ArgumentParser) 118 | 119 | Returns: 120 | Namespace: Can be used in the same way as argparser 121 | 122 | Note: 123 | Values specified in the arguments take precedence over the config 124 | """ 125 | args, _unknown = parser.parse_known_args() 126 | if args.config is not None: 127 | config_args = json.load(open(args.config)) 128 | override_keys = { 129 | arg[2:].split("=")[0] for arg in sys.argv[1:] if arg.startswith("--") 130 | } 131 | for k, v in config_args.items(): 132 | if k not in override_keys: 133 | setattr(args, k, v) 134 | return args 135 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Literal, Optional, Tuple, Union 3 | 4 | import cv2 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import plotly.graph_objects as go 8 | import torch 9 | 10 | from utils.utils import reverse_normalize, tensor_to_numpy 11 | 12 | 13 | def save_attention_map(attention: Union[np.ndarray, torch.Tensor], fname: str) -> None: 14 | attention = tensor_to_numpy(attention) 15 | 16 | fig, ax = plt.subplots() 17 | 18 | min_att = attention.min() 19 | max_att = attention.max() 20 | 21 | im = ax.imshow( 22 | attention, interpolation="nearest", cmap="jet", vmin=min_att, vmax=max_att 23 | ) 24 | fig.colorbar(im) 25 | plt.savefig(fname) 26 | plt.clf() 27 | plt.close() 28 | 29 | 30 | def save_image_with_attention_map( 31 | image: np.ndarray, 32 | attention: np.ndarray, 33 | fname: str, 34 | mean: Tuple[float, float, float], 35 | std: Tuple[float, float, float], 36 | only_img: bool = False, 37 | normalize: bool = False, 38 | sign: Literal["all", "positive", "negative", "absolute_value"] = "all", 39 | ) -> None: 40 | if len(attention.shape) == 3: 41 | attention = attention[0] 42 | 43 | image = image[:3] 44 | mean = mean[:3] 45 | std = std[:3] 46 | 47 | # attention = (attention - attention.min()) / (attention.max() - attention.min()) 48 | if normalize: 49 | attention = normalize_attr(attention, sign) 50 | 51 | # image : (C, W, H) 52 | attention = cv2.resize(attention, dsize=(image.shape[1], image.shape[2])) 53 | image = reverse_normalize(image.copy(), mean, std) 54 | image = np.transpose(image, (1, 2, 0)) 55 | image = np.clip(image, 0, 1) 56 | 57 | fig, ax = plt.subplots() 58 | if only_img: 59 | ax.axis('off') # No axes for a cleaner look 60 | if image.shape[2] == 1: 61 | ax.imshow(image, cmap="gray", vmin=0, vmax=1) 62 | else: 63 | ax.imshow(image, vmin=0, vmax=1) 64 | 65 | im = ax.imshow(attention, cmap="jet", alpha=0.4, vmin=attention.min(), vmax=attention.max()) 66 | if not only_img: 67 | fig.colorbar(im) 68 | 69 | if only_img: 70 | plt.savefig(fname, bbox_inches='tight', pad_inches=0) 71 | else: 72 | plt.savefig(fname) 73 | 74 | plt.clf() 75 | plt.close() 76 | 77 | 78 | def save_image(image: np.ndarray, fname: str, mean: Tuple[float, float, float], std: Tuple[float, float, float]) -> None: 79 | # image : (C, W, H) 80 | image = reverse_normalize(image.copy(), mean, std) 81 | image = image.clip(0, 1) 82 | image = np.transpose(image, (1, 2, 0)) 83 | 84 | # Convert image from [0, 1] float to [0, 255] uint8 for saving 85 | image = (image * 255).astype(np.uint8) 86 | # image = image.astype(np.uint8) 87 | 88 | if image.shape[2] == 1: # if grayscale 89 | image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) 90 | else: 91 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 92 | 93 | cv2.imwrite(fname, image) 94 | 95 | 96 | def simple_plot_and_save(x, y, save_path): 97 | fig = go.Figure() 98 | fig.add_trace(go.Scatter(x=x, y=x, mode='lines', line=dict(color='#999999', width=2))) 99 | fig.add_trace(go.Scatter(x=x, y=y, mode='markers', marker=dict(color='#82B366', size=12))) 100 | fig.update_layout( 101 | plot_bgcolor='white', 102 | showlegend=False, 103 | # xaxis_title=r"$\Huge {\sum_i R^{(\textrm{First block})}_i}$", 104 | # yaxis_title=r"$\Huge {p(\hat y_c)}$", 105 | xaxis=dict( 106 | title_standoff=28, 107 | tickfont=dict(family="Times New Roman", size=30, color="black"), 108 | linecolor='black', 109 | showgrid=False, 110 | ticks='outside', 111 | tickcolor='black', 112 | ), 113 | yaxis=dict( 114 | title_standoff=26, 115 | tickfont=dict(family="Times New Roman", size=30, color="black"), 116 | linecolor='black', 117 | showgrid=False, 118 | ticks='outside', 119 | tickcolor='black', 120 | scaleanchor="x", 121 | scaleratio=1, 122 | ), 123 | autosize=False, 124 | width=600, 125 | height=600, 126 | # margin=dict(l=115, r=5, b=115, t=5), 127 | margin=dict(l=5, r=5, b=5, t=5), 128 | ) 129 | fig.write_image(save_path) 130 | 131 | 132 | def simple_plot_and_save_legacy(x, y, save_path): 133 | plt.figure(figsize=(6, 6)) 134 | plt.scatter(x, y, color='orange') 135 | plt.plot(x, x, color='#999999') # y=x line 136 | plt.gca().spines['top'].set_visible(False) 137 | plt.gca().spines['right'].set_visible(False) 138 | plt.gca().set_aspect('equal', adjustable='box') 139 | plt.savefig(save_path) 140 | 141 | 142 | def save_data_as_plot( 143 | data: np.ndarray, 144 | fname: str, 145 | x: Optional[np.ndarray] = None, 146 | label: Optional[str] = None, 147 | xlim: Optional[Union[int, float]] = None, 148 | ) -> None: 149 | """ 150 | Save data as plot 151 | 152 | Args: 153 | data(ndarray): Data to save 154 | fname(str) : File name to save 155 | x(ndarray) : X-axis 156 | label(str) : Label of legend 157 | """ 158 | fig, ax = plt.subplots() 159 | 160 | if x is None: 161 | x = range(len(data)) 162 | 163 | ax.plot(x, data, label=label) 164 | 165 | xmax = len(data) if xlim is None else xlim 166 | ax.set_xlim(0, xmax) 167 | ax.set_ylim(-0.05, 1.05) 168 | 169 | plt.legend() 170 | plt.savefig(fname, bbox_inches="tight", pad_inches=0.05) 171 | plt.clf() 172 | plt.close() 173 | 174 | 175 | """followings are borrowed & modified from captum.attr._utils""" 176 | 177 | def _normalize_scale(attr: np.ndarray, scale_factor: float): 178 | assert scale_factor != 0, "Cannot normalize by scale factor = 0" 179 | if abs(scale_factor) < 1e-5: 180 | warnings.warn( 181 | "Attempting to normalize by value approximately 0, visualized results" 182 | "may be misleading. This likely means that attribution values are all" 183 | "close to 0." 184 | ) 185 | attr_norm = attr / scale_factor 186 | return np.clip(attr_norm, -1, 1) 187 | 188 | 189 | def _cumulative_sum_threshold(values: np.ndarray, percentile: Union[int, float]): 190 | # given values should be non-negative 191 | assert percentile >= 0 and percentile <= 100, ( 192 | "Percentile for thresholding must be " "between 0 and 100 inclusive." 193 | ) 194 | sorted_vals = np.sort(values.flatten()) 195 | cum_sums = np.cumsum(sorted_vals) 196 | threshold_id = np.where(cum_sums >= cum_sums[-1] * 0.01 * percentile)[0][0] 197 | return sorted_vals[threshold_id] 198 | 199 | 200 | def normalize_attr( 201 | attr: np.ndarray, 202 | sign: str, 203 | outlier_perc: Union[int, float] = 2, 204 | reduction_axis: Optional[int] = None, 205 | ): 206 | attr_combined = attr 207 | if reduction_axis is not None: 208 | attr_combined = np.sum(attr, axis=reduction_axis) 209 | 210 | # Choose appropriate signed values and rescale, removing given outlier percentage. 211 | if sign == "all": 212 | threshold = _cumulative_sum_threshold(np.abs(attr_combined), 100 - outlier_perc) 213 | elif sign == "positive": 214 | attr_combined = (attr_combined > 0) * attr_combined 215 | threshold = _cumulative_sum_threshold(attr_combined, 100 - outlier_perc) 216 | elif sign == "negative": 217 | attr_combined = (attr_combined < 0) * attr_combined 218 | threshold = -1 * _cumulative_sum_threshold( 219 | np.abs(attr_combined), 100 - outlier_perc 220 | ) 221 | elif sign == "absolute_value": 222 | attr_combined = np.abs(attr_combined) 223 | threshold = _cumulative_sum_threshold(attr_combined, 100 - outlier_perc) 224 | else: 225 | raise AssertionError("Visualize Sign type is not valid.") 226 | return _normalize_scale(attr_combined, threshold) 227 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from typing import Any, Dict, Optional, Tuple, Union 4 | 5 | import captum 6 | import cv2 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.utils.data as data 11 | from pytorch_grad_cam import GradCAM 12 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget 13 | from skimage.transform import resize 14 | from torchinfo import summary 15 | from torchvision.models.resnet import ResNet 16 | from tqdm import tqdm 17 | 18 | import wandb 19 | from data import ALL_DATASETS, create_dataloader_dict, get_parameter_depend_in_data_set 20 | from metrics.base import Metric 21 | from metrics.patch_insdel import PatchInsertionDeletion 22 | from models import ALL_MODELS, OneWayResNet, create_model 23 | from models.attention_branch import AttentionBranchModel 24 | from models.lrp import * 25 | from models.rise import RISE 26 | from src.lrp import abn_lrp, basic_lrp, resnet_lrp 27 | from src.utils import SkipConnectionPropType 28 | from utils.utils import fix_seed, parse_with_config 29 | from utils.visualize import ( 30 | save_image, 31 | save_image_with_attention_map, 32 | simple_plot_and_save, 33 | ) 34 | 35 | 36 | def calculate_attention( 37 | model: nn.Module, 38 | image: torch.Tensor, 39 | label: torch.Tensor, 40 | method: str, 41 | rise_params: Optional[Dict[str, Any]], 42 | fname: Optional[str], 43 | skip_connection_prop_type: SkipConnectionPropType = "latest", 44 | ) -> Tuple[np.ndarray, Optional[float]]: 45 | relevance_out = None 46 | if method.lower() == "abn": 47 | assert isinstance(model, AttentionBranchModel) 48 | model(image) 49 | attentions = model.attention_branch.attention # (1, W, W) 50 | attention = attentions[0] 51 | attention: np.ndarray = attention[0].detach().cpu().numpy() 52 | elif method.lower() == "rise": 53 | assert rise_params is not None 54 | rise_model = RISE( 55 | model, 56 | n_masks=rise_params["n_masks"], 57 | p1=rise_params["p1"], 58 | input_size=rise_params["input_size"], 59 | initial_mask_size=rise_params["initial_mask_size"], 60 | n_batch=rise_params["n_batch"], 61 | mask_path=rise_params["mask_path"], 62 | ) 63 | attentions = rise_model(image) # (N_class, W, H) 64 | attention = attentions[label] 65 | attention: np.ndarray = attention.cpu().numpy() 66 | elif method.lower() == "npy": 67 | assert fname is not None 68 | attention: np.ndarray = np.load(fname) 69 | elif method.lower() == "gradcam": 70 | tpl = torch.softmax(model(image), dim=-1).max(dim=-1) 71 | pred_label = tpl.indices[0].item() 72 | relevance_out = tpl.values[0].item() 73 | target_layers = [model.layer4[-1]] 74 | cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True) 75 | targets = [ClassifierOutputTarget(pred_label)] 76 | grayscale_cam = cam(input_tensor=image, targets=targets) 77 | attention = grayscale_cam[0, :] 78 | elif method.lower() == "scorecam": 79 | from src import scorecam as scam 80 | resnet_model_dict = dict(type='resnet50', arch=model, layer_name='layer4', input_size=(224, 224)) 81 | cam = scam.ScoreCAM(resnet_model_dict) 82 | attention = cam(image) 83 | attention = attention[0].detach().cpu().numpy() 84 | elif method.lower() == "lrp": # ours 85 | if isinstance(model, ResNet): 86 | model = OneWayResNet(model) 87 | tpl = torch.softmax(model(image), dim=-1).max(dim=-1) 88 | pred_label = tpl.indices[0].item() 89 | relevance_out = tpl.values[0].item() 90 | attention: np.ndarray = basic_lrp( 91 | model, image, rel_pass_ratio=1.0, topk=1, skip_connection_prop=skip_connection_prop_type 92 | ).detach().cpu().numpy() 93 | elif method.lower() == "captumlrp": 94 | tpl = torch.softmax(model(image), dim=-1).max(dim=-1) 95 | pred_label = tpl.indices[0].item() 96 | relevance_out = tpl.values[0].item() 97 | attention = captum.attr.LRP(model).attribute(image, target=pred_label) 98 | attention = attention[0].detach().cpu().numpy() 99 | elif method.lower() == "captumlrp-positive": 100 | tpl = torch.softmax(model(image), dim=-1).max(dim=-1) 101 | pred_label = tpl.indices[0].item() 102 | relevance_out = tpl.values[0].item() 103 | attention = captum.attr.LRP(model).attribute(image, target=pred_label) 104 | attention = attention[0].detach().cpu().numpy().clip(0, None) 105 | elif method.lower() in ["captumgradxinput", "gradxinput"]: 106 | tpl = torch.softmax(model(image), dim=-1).max(dim=-1) 107 | pred_label = tpl.indices[0].item() 108 | relevance_out = tpl.values[0].item() 109 | attention = captum.attr.InputXGradient(model).attribute(image, target=pred_label) 110 | attention = attention[0].detach().cpu().numpy() 111 | elif method.lower() in ["captumig", "ig"]: 112 | tpl = torch.softmax(model(image), dim=-1).max(dim=-1) 113 | pred_label = tpl.indices[0].item() 114 | relevance_out = tpl.values[0].item() 115 | attention = captum.attr.IntegratedGradients(model).attribute(image, target=pred_label) 116 | attention = attention[0].detach().cpu().numpy() 117 | elif method.lower() in ["captumig-positive", "ig-positive"]: 118 | tpl = torch.softmax(model(image), dim=-1).max(dim=-1) 119 | pred_label = tpl.indices[0].item() 120 | relevance_out = tpl.values[0].item() 121 | attention = captum.attr.IntegratedGradients(model).attribute(image, target=pred_label) 122 | attention = attention[0].detach().cpu().numpy().clip(0, None) 123 | elif method.lower() in ["captumguidedbackprop", "guidedbackprop", "gbp"]: 124 | tpl = torch.softmax(model(image), dim=-1).max(dim=-1) 125 | pred_label = tpl.indices[0].item() 126 | relevance_out = tpl.values[0].item() 127 | attention = captum.attr.GuidedBackprop(model).attribute(image, target=pred_label) 128 | attention = attention[0].detach().cpu().numpy() 129 | elif method.lower() in ["captumguidedbackprop-positive", "guidedbackprop-positive", "gbp-positive"]: 130 | tpl = torch.softmax(model(image), dim=-1).max(dim=-1) 131 | pred_label = tpl.indices[0].item() 132 | relevance_out = tpl.values[0].item() 133 | attention = captum.attr.GuidedBackprop(model).attribute(image, target=pred_label) 134 | attention = attention[0].detach().cpu().numpy().clip(0, None) 135 | # elif method.lower() in ["smoothgrad", "vargrad"]: # OOM... 136 | # tpl = torch.softmax(model(image), dim=-1).max(dim=-1) 137 | # pred_label = tpl.indices[0].item() 138 | # relevance_out = tpl.values[0].item() 139 | # ig = captum.attr.IntegratedGradients(model) 140 | # attention = captum.attr.NoiseTunnel(ig).attribute(image, nt_type=method.lower(), nt_samples=5, target=pred_label) 141 | # attention = attention[0].detach().cpu().numpy() 142 | elif method.lower() == "lime": 143 | tpl = torch.softmax(model(image), dim=-1).max(dim=-1) 144 | pred_label = tpl.indices[0].item() 145 | relevance_out = tpl.values[0].item() 146 | attention = captum.attr.Lime(model).attribute(image, target=pred_label) 147 | attention = attention[0].detach().cpu().numpy() 148 | elif method.lower() == "captumgradcam": 149 | tpl = torch.softmax(model(image), dim=-1).max(dim=-1) 150 | pred_label = tpl.indices[0].item() 151 | relevance_out = tpl.values[0].item() 152 | attention = captum.attr.LayerGradCam(model, model.layer4[-1]).attribute(image, target=pred_label) 153 | attention = attention[0].detach().cpu().numpy() 154 | else: 155 | raise ValueError(f"Invalid method was requested: {method}") 156 | 157 | return attention, relevance_out 158 | 159 | 160 | def apply_gaussian_to_max_pixel(mask, kernel_size=5, sigma=1): 161 | # Find max value's indices 162 | max_val_indices = np.where(mask == np.amax(mask)) 163 | 164 | # Initialize a new mask with the same size of the original mask 165 | new_mask = np.zeros(mask.shape) 166 | 167 | # Assign the max value of the original mask to the corresponding location of the new mask 168 | new_mask[max_val_indices] = mask[max_val_indices] 169 | 170 | # Apply Gaussian blur 171 | blurred_mask = cv2.GaussianBlur(new_mask, (kernel_size, kernel_size), sigma) 172 | 173 | return blurred_mask 174 | 175 | 176 | def remove_other_components(mask, threshold=0.5): 177 | # Binarize the mask 178 | binary_mask = (mask > threshold).astype(np.uint8) 179 | 180 | # Detect connected components 181 | num_labels, labels = cv2.connectedComponents(binary_mask) 182 | 183 | # Calculate the maximum value in the original mask for each component 184 | max_values = [np.max(mask[labels == i]) for i in range(num_labels)] 185 | 186 | # Find the component with the largest maximum value 187 | largest_max_value_component = np.argmax(max_values) 188 | # second_index = np.argsort(max_values)[-2] 189 | # third_index = np.argsort(max_values)[-3] 190 | # print(max_values) 191 | # print(largest_max_value_component) 192 | # print(second_index) 193 | 194 | # Create a new mask where all components other than the one with the largest max value are removed 195 | first_mask = np.where(labels == largest_max_value_component, mask*3, 0) 196 | # second_mask = np.where(labels == second_index, mask, 0) 197 | # third_mask = np.where(labels == third_index, mask / 3, 0) 198 | new_mask = first_mask 199 | # new_mask = first_mask + second_mask + third_mask 200 | 201 | return new_mask 202 | 203 | 204 | def apply_heat_quantization(attention, q_level: int = 8): 205 | max_ = attention.max() 206 | min_ = attention.min() 207 | 208 | # quantization 209 | bin = np.linspace(min_, max_, q_level) 210 | 211 | # apply quantization 212 | for i in range(q_level - 1): 213 | attention[(attention >= bin[i]) & (attention < bin[i + 1])] = bin[i] 214 | 215 | return attention 216 | 217 | 218 | # @torch.no_grad() 219 | def visualize( 220 | dataloader: data.DataLoader, 221 | model: nn.Module, 222 | method: str, 223 | batch_size: int, 224 | patch_size: int, 225 | step: int, 226 | save_dir: str, 227 | all_class: bool, 228 | params: Dict[str, Any], 229 | device: torch.device, 230 | evaluate: bool = False, 231 | attention_dir: Optional[str] = None, 232 | use_c1c: bool = False, 233 | heat_quantization: bool = False, 234 | hq_level: int = 8, 235 | skip_connection_prop_type: SkipConnectionPropType = "latest", 236 | data_limit: int = -1, 237 | ) -> Union[None, Metric]: 238 | if evaluate: 239 | metrics = PatchInsertionDeletion( 240 | model, batch_size, patch_size, step, params["name"], device 241 | ) 242 | insdel_save_dir = os.path.join(save_dir, "insdel") 243 | if not os.path.isdir(insdel_save_dir): 244 | os.makedirs(insdel_save_dir) 245 | 246 | rel_ins = [] 247 | rel_outs = [] 248 | 249 | counter = {} 250 | 251 | model.eval() 252 | # # print accuracy on test set 253 | # total = 0 254 | # correct = 0 255 | # for i, data_ in enumerate( 256 | # tqdm(dataloader, desc="Count failures", dynamic_ncols=True) 257 | # ): 258 | # inputs, labels = ( 259 | # data_[0].to(device), 260 | # data_[1].to(device), 261 | # ) 262 | # outputs = model(inputs) 263 | # _, predicted = torch.max(outputs.data, 1) 264 | # total += labels.size(0) 265 | # correct += (predicted == labels).sum().item() 266 | # print(f"Total: {total}") 267 | # print(f"Correct: {correct}") 268 | # print(f"Accuracy: {correct / total:.4f}") 269 | 270 | # Inference time estimation 271 | elapsed_time_sum = 0.0 272 | start_event = torch.cuda.Event(enable_timing=True) 273 | end_event = torch.cuda.Event(enable_timing=True) 274 | torch.cuda.empty_cache() 275 | batch = torch.randn(1, 3, 224, 224).to(device) 276 | for _ in range(100): 277 | start_event.record() 278 | model.forward(batch) 279 | end_event.record() 280 | torch.cuda.synchronize() 281 | elapsed_time_sum += start_event.elapsed_time(end_event) 282 | print(f"Average inference time: {elapsed_time_sum / 100}") 283 | 284 | # Elapsed time to inference + generate explanation 285 | elapsed_time_sum = 0.0 286 | start_event = torch.cuda.Event(enable_timing=True) 287 | end_event = torch.cuda.Event(enable_timing=True) 288 | torch.cuda.empty_cache() 289 | image = torch.randn(1, 3, 224, 224).to(device) 290 | if isinstance(model, ResNet): 291 | oneway_model = OneWayResNet(model) 292 | for _ in range(100): 293 | start_event.record() 294 | basic_lrp( 295 | oneway_model, image, rel_pass_ratio=1.0, topk=1, skip_connection_prop=skip_connection_prop_type 296 | ) 297 | end_event.record() 298 | torch.cuda.synchronize() 299 | elapsed_time_sum += start_event.elapsed_time(end_event) 300 | print(f"Average time to inference and generate an attribution map: {elapsed_time_sum / 100}") 301 | 302 | for i, data_ in enumerate( 303 | tqdm(dataloader, desc="Visualizing: ", dynamic_ncols=True) 304 | ): 305 | if data_limit > 0 and i >= data_limit: 306 | break 307 | torch.cuda.memory_summary(device=device) 308 | inputs, labels = ( 309 | data_[0].to(device), 310 | data_[1].to(device), 311 | ) 312 | image: torch.Tensor = inputs[0].cpu().numpy() 313 | label: torch.Tensor = labels[0] 314 | 315 | if label != 1 and not all_class: 316 | continue 317 | 318 | counter.setdefault(label.item(), 0) 319 | counter[label.item()] += 1 320 | n_eval_per_class = 1000 / len(params["classes"]) 321 | # n_eval_per_class = 5 # TODO: try 322 | if counter[label.item()] > n_eval_per_class: 323 | continue 324 | 325 | base_fname = f"{i+1}_{params['classes'][label]}" 326 | 327 | attention_fname = None 328 | if attention_dir is not None: 329 | attention_fname = os.path.join(attention_dir, f"{base_fname}.npy") 330 | 331 | attention, rel_ = calculate_attention( 332 | model, inputs, label, method, params, attention_fname, skip_connection_prop_type 333 | ) 334 | if rel_ is not None: 335 | rel_ins.append(attention.sum().item()) 336 | rel_outs.append(rel_) 337 | if use_c1c: 338 | attention = resize(attention, (28, 28)) 339 | attention = remove_other_components(attention, threshold=attention.mean()) 340 | 341 | if heat_quantization: 342 | attention = apply_heat_quantization(attention, hq_level) 343 | 344 | if attention is None: 345 | continue 346 | if method == "RISE": 347 | np.save(f"{save_dir}/{base_fname}.npy", attention) 348 | 349 | if evaluate: 350 | metrics.evaluate( 351 | image.copy(), 352 | attention, 353 | label, 354 | ) 355 | metrics.save_roc_curve(insdel_save_dir) 356 | base_fname = f"{base_fname}_{metrics.ins_auc - metrics.del_auc:.4f}" 357 | if i % 50 == 0: 358 | print(metrics.log()) 359 | 360 | save_fname = os.path.join(save_dir, f"{base_fname}.png") 361 | save_image_with_attention_map( 362 | image, attention, save_fname, params["mean"], params["std"] 363 | ) 364 | 365 | save_image(image, save_fname[:-4]+".original.png", params["mean"], params["std"]) 366 | 367 | # Plot conservations 368 | conservation_dir = os.path.join(save_dir, "conservation") 369 | if not os.path.isdir(conservation_dir): 370 | os.makedirs(conservation_dir) 371 | simple_plot_and_save(rel_ins, rel_outs, os.path.join(conservation_dir, "plot.png")) 372 | 373 | if evaluate: 374 | return metrics 375 | 376 | 377 | def main(args: argparse.Namespace) -> None: 378 | fix_seed(args.seed, args.no_deterministic) 379 | 380 | # データセットの作成 381 | dataloader_dict = create_dataloader_dict( 382 | args.dataset, 1, args.image_size, only_test=True, shuffle_val=True, dataloader_seed=args.seed 383 | ) 384 | dataloader = dataloader_dict["Test"] 385 | assert isinstance(dataloader, data.DataLoader) 386 | 387 | params = get_parameter_depend_in_data_set(args.dataset) 388 | 389 | mask_path = os.path.join(args.root_dir, "masks.npy") 390 | if not os.path.isfile(mask_path): 391 | mask_path = None 392 | rise_params = { 393 | "n_masks": args.num_masks, 394 | "p1": args.p1, 395 | "input_size": (args.image_size, args.image_size), 396 | "initial_mask_size": (args.rise_scale, args.rise_scale), 397 | "n_batch": args.batch_size, 398 | "mask_path": mask_path, 399 | } 400 | params.update(rise_params) 401 | 402 | # モデルの作成 403 | model = create_model( 404 | args.model, 405 | num_classes=len(params["classes"]), 406 | num_channel=params["num_channel"], 407 | base_pretrained=args.base_pretrained, 408 | base_pretrained2=args.base_pretrained2, 409 | pretrained_path=args.pretrained, 410 | attention_branch=args.add_attention_branch, 411 | division_layer=args.div, 412 | theta_attention=args.theta_att, 413 | init_classifier=args.dataset != "ImageNet", # Use pretrained classifier in ImageNet 414 | ) 415 | assert model is not None, "Model name is invalid" 416 | 417 | # run_nameをpretrained pathから取得 418 | # checkpoints/run_name/checkpoint.pt -> run_name 419 | run_name = args.pretrained.split(os.sep)[-2] if args.pretrained is not None else "pretrained" 420 | save_dir = os.path.join( 421 | "outputs", 422 | f"{run_name}_{args.notes}_{args.method}{args.block_size}", 423 | ) 424 | if not os.path.isdir(save_dir): 425 | os.makedirs(save_dir) 426 | 427 | summary( 428 | model, 429 | (args.batch_size, params["num_channel"], args.image_size, args.image_size), 430 | ) 431 | 432 | model.to(device) 433 | 434 | wandb.init(project=args.dataset, name=run_name, notes=args.notes) 435 | wandb.config.update(vars(args)) 436 | 437 | metrics = visualize( 438 | dataloader, 439 | model, 440 | args.method, 441 | args.batch_size, 442 | args.block_size, 443 | args.insdel_step, 444 | save_dir, 445 | args.all_class, 446 | params, 447 | device, 448 | args.visualize_only, 449 | attention_dir=args.attention_dir, 450 | use_c1c=args.use_c1c, 451 | heat_quantization=args.heat_quantization, 452 | hq_level=args.hq_level, 453 | skip_connection_prop_type=args.skip_connection_prop_type, 454 | data_limit=args.data_limit, 455 | ) 456 | 457 | if hasattr(args, "test_acc"): 458 | print(f"Test Acc: {args.test_acc}") 459 | 460 | if metrics is not None: 461 | print(metrics.log()) 462 | for key, value in metrics.score().items(): 463 | wandb.run.summary[key] = value 464 | 465 | 466 | def parse_args() -> argparse.Namespace: 467 | parser = argparse.ArgumentParser() 468 | 469 | parser.add_argument("-c", "--config", type=str, help="path to config file (json)") 470 | 471 | parser.add_argument("--seed", type=int, default=42) 472 | parser.add_argument("--no_deterministic", action="store_false") 473 | 474 | parser.add_argument("-n", "--notes", type=str, default="") 475 | 476 | # Model 477 | parser.add_argument("-m", "--model", choices=ALL_MODELS, help="model name") 478 | parser.add_argument( 479 | "-add_ab", 480 | "--add_attention_branch", 481 | action="store_true", 482 | help="add Attention Branch", 483 | ) 484 | parser.add_argument( 485 | "--div", 486 | type=str, 487 | choices=["layer1", "layer2", "layer3"], 488 | default="layer2", 489 | help="place to attention branch", 490 | ) 491 | parser.add_argument("--base_pretrained", type=str, help="path to base pretrained") 492 | parser.add_argument( 493 | "--base_pretrained2", 494 | type=str, 495 | help="path to base pretrained2 ( after change_num_classes() )", 496 | ) 497 | parser.add_argument("--pretrained", type=str, help="path to pretrained") 498 | parser.add_argument( 499 | "--orig_model", 500 | action="store_true", 501 | help="calc insdel score by using original model", 502 | ) 503 | parser.add_argument( 504 | "--theta_att", type=float, default=0, help="threthold of attention branch" 505 | ) 506 | 507 | # Dataset 508 | parser.add_argument("--dataset", type=str, default="IDRiD", choices=ALL_DATASETS) 509 | parser.add_argument("--data_limit", type=int, default=-1) 510 | parser.add_argument("--image_size", type=int, default=224) 511 | parser.add_argument("--batch_size", type=int, default=16) 512 | parser.add_argument( 513 | "--loss_weights", 514 | type=float, 515 | nargs="*", 516 | default=[1.0, 1.0], 517 | help="weights for label by class", 518 | ) 519 | 520 | parser.add_argument("--root_dir", type=str, default="./outputs/") 521 | parser.add_argument("--visualize_only", action="store_false") 522 | parser.add_argument("--all_class", action="store_true") 523 | # recommend (step, size) in 512x512 = (1, 10000), (2, 2500), (4, 500), (8, 100), (16, 20), (32, 10), (64, 5), (128, 1) 524 | # recommend (step, size) in 224x224 = (1, 500), (2, 100), (4, 20), (8, 10), (16, 5), (32, 1) 525 | parser.add_argument("--insdel_step", type=int, default=500) 526 | parser.add_argument("--block_size", type=int, default=1) 527 | 528 | parser.add_argument( 529 | "--method", 530 | type=str, 531 | default="gradcam", 532 | ) 533 | parser.add_argument("--normalize", action="store_true") 534 | parser.add_argument("--sign", type=str, default="all") 535 | 536 | parser.add_argument("--num_masks", type=int, default=5000) 537 | parser.add_argument("--rise_scale", type=int, default=9) 538 | parser.add_argument( 539 | "--p1", type=float, default=0.3, help="percentage of mask [pixel = (0, 0, 0)]" 540 | ) 541 | 542 | parser.add_argument("--attention_dir", type=str, help="path to attention npy file") 543 | 544 | parser.add_argument("--use-c1c", action="store_true", help="use C1C technique") 545 | 546 | parser.add_argument("--heat-quantization", action="store_true", help="use heat quantization technique") 547 | parser.add_argument("--hq-level", type=int, default=8, help="number of quantization level") 548 | 549 | parser.add_argument("--skip-connection-prop-type", type=str, default="latest", help="type of skip connection propagation") 550 | 551 | return parse_with_config(parser) 552 | 553 | 554 | if __name__ == "__main__": 555 | # import pdb; pdb.set_trace() 556 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 557 | 558 | main(parse_args()) 559 | --------------------------------------------------------------------------------