├── .gitignore
├── LICENSE
├── README.md
├── checkpoints
├── .gitkeep
└── README.md
├── configs
├── CUB_resnet18.json
├── CUB_resnet34.json
├── CUB_resnet50.json
├── CUB_resnet50v2.json
└── ImageNet_resnet50.json
├── data
├── CUB.py
├── RDD.py
├── RDD_4.py
├── RDD_bbox.py
├── __init__.py
├── imagenet.py
└── sampler.py
├── datasets
├── .gitkeep
└── README.md
├── evaluate.py
├── exp.sh
├── metrics
├── __init__.py
├── accuracy.py
├── base.py
└── patch_insdel.py
├── models
├── __init__.py
├── attention_branch.py
├── lrp.py
└── rise.py
├── oneshot.py
├── optim
├── __init__.py
└── sam.py
├── outputs
└── .gitkeep
├── poetry.lock
├── pyproject.toml
├── qual
└── original
│ ├── Arabian_camel.png
│ ├── Brandt_Cormorant.png
│ ├── Geococcyx.png
│ ├── Rock_Wren.png
│ ├── Savannah_Sparrow.png
│ ├── bee.png
│ ├── bubble.png
│ ├── bustard.png
│ ├── drumstick.png
│ ├── oboe.png
│ ├── ram.png
│ ├── sock.png
│ ├── solar_dish.png
│ ├── water_ouzel.png
│ └── wombat.png
├── scripts
├── calc_dataset_info.py
└── visualize_transforms.py
├── src
├── __init__.py
├── data.py
├── data_processing.py
├── lrp.py
├── lrp_filter.py
├── lrp_layers.py
├── scorecam.py
└── utils.py
├── train.py
├── utils
├── loss.py
├── utils.py
└── visualize.py
└── visualize.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 | .ruff_cache
3 | /wandb
4 | /checkpoints
5 | /datasets
6 | /outputs
7 | /oneshot_images
8 | /qual*
9 |
10 | # Byte-compiled / optimized / DLL files
11 | __pycache__/
12 | *.py[cod]
13 | *$py.class
14 |
15 | # C extensions
16 | *.so
17 |
18 | # Distribution / packaging
19 | .Python
20 | build/
21 | develop-eggs/
22 | dist/
23 | downloads/
24 | eggs/
25 | .eggs/
26 | lib/
27 | lib64/
28 | parts/
29 | sdist/
30 | var/
31 | wheels/
32 | share/python-wheels/
33 | *.egg-info/
34 | .installed.cfg
35 | *.egg
36 | MANIFEST
37 |
38 | # PyInstaller
39 | # Usually these files are written by a python script from a template
40 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
41 | *.manifest
42 | *.spec
43 |
44 | # Installer logs
45 | pip-log.txt
46 | pip-delete-this-directory.txt
47 |
48 | # Unit test / coverage reports
49 | htmlcov/
50 | .tox/
51 | .nox/
52 | .coverage
53 | .coverage.*
54 | .cache
55 | nosetests.xml
56 | coverage.xml
57 | *.cover
58 | *.py,cover
59 | .hypothesis/
60 | .pytest_cache/
61 | cover/
62 |
63 | # Translations
64 | *.mo
65 | *.pot
66 |
67 | # Django stuff:
68 | *.log
69 | local_settings.py
70 | db.sqlite3
71 | db.sqlite3-journal
72 |
73 | # Flask stuff:
74 | instance/
75 | .webassets-cache
76 |
77 | # Scrapy stuff:
78 | .scrapy
79 |
80 | # Sphinx documentation
81 | docs/_build/
82 |
83 | # PyBuilder
84 | .pybuilder/
85 | target/
86 |
87 | # Jupyter Notebook
88 | .ipynb_checkpoints
89 |
90 | # IPython
91 | profile_default/
92 | ipython_config.py
93 |
94 | # pyenv
95 | # For a library or package, you might want to ignore these files since the code is
96 | # intended to run in multiple environments; otherwise, check them in:
97 | # .python-version
98 |
99 | # pipenv
100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
103 | # install all needed dependencies.
104 | #Pipfile.lock
105 |
106 | # poetry
107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
108 | # This is especially recommended for binary packages to ensure reproducibility, and is more
109 | # commonly ignored for libraries.
110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
111 | #poetry.lock
112 |
113 | # pdm
114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
115 | #pdm.lock
116 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
117 | # in version control.
118 | # https://pdm.fming.dev/#use-with-ide
119 | .pdm.toml
120 |
121 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
122 | __pypackages__/
123 |
124 | # Celery stuff
125 | celerybeat-schedule
126 | celerybeat.pid
127 |
128 | # SageMath parsed files
129 | *.sage.py
130 |
131 | # Environments
132 | .env
133 | .venv
134 | env/
135 | venv/
136 | ENV/
137 | env.bak/
138 | venv.bak/
139 |
140 | # Spyder project settings
141 | .spyderproject
142 | .spyproject
143 |
144 | # Rope project settings
145 | .ropeproject
146 |
147 | # mkdocs documentation
148 | /site
149 |
150 | # mypy
151 | .mypy_cache/
152 | .dmypy.json
153 | dmypy.json
154 |
155 | # Pyre type checker
156 | .pyre/
157 |
158 | # pytype static type analyzer
159 | .pytype/
160 |
161 | # Cython debug symbols
162 | cython_debug/
163 |
164 | # PyCharm
165 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
166 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
167 | # and can be added to the global gitignore or merged into this file. For a more nuclear
168 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
169 | #.idea/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The Clear BSD License
2 |
3 | Copyright (c) 2024 the authors of "Layer-Wise Relevance Propagation with Conservation Property for ResNet" (Seitaro Otsuki, Tsumugi Iida, Félix Doublet,
4 | Tsubasa Hirakawa, Takayoshi Yamashita, Hironobu Fujiyoshi and Komei Sugiura)
5 |
6 | All rights reserved.
7 |
8 | Redistribution and use in source and binary forms, with or without
9 | modification, are permitted (subject to the limitations in the disclaimer
10 | below) provided that the following conditions are met:
11 |
12 | * Redistributions of source code must retain the above copyright notice,
13 | this list of conditions and the following disclaimer.
14 |
15 | * Redistributions in binary form must reproduce the above copyright
16 | notice, this list of conditions and the following disclaimer in the
17 | documentation and/or other materials provided with the distribution.
18 |
19 | * Neither the name of the copyright holder nor the names of its
20 | contributors may be used to endorse or promote products derived from this
21 | software without specific prior written permission.
22 |
23 | NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
24 | THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
25 | CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
26 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
27 | PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
28 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
29 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
30 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
31 | BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
32 | IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
33 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34 | POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [ECCV24] Layer-Wise Relevance Propagation with Conservation Property for ResNet
2 |
3 | - Accepted at ECCV 2024
4 | - [Project page](https://5ei74r0.github.io/lrp-for-resnet.page/)
5 | - [ArXiv](https://arxiv.org/abs/2407.09115)
6 |
7 |
8 | The transparent formulation of explanation methods is essential for elucidating the predictions of neural networks, which are typically black-box models. Layer-wise Relevance Propagation (LRP) is a well-established method that transparently traces the flow of a model's prediction backward through its architecture by backpropagating relevance scores. However, the conventional LRP does not fully consider the existence of skip connections, and thus its application to the widely used ResNet architecture has not been thoroughly explored. In this study, we extend LRP to ResNet models by introducing Relevance Splitting at points where the output from a skip connection converges with that from a residual block. Our formulation guarantees the conservation property throughout the process, thereby preserving the integrity of the generated explanations. To evaluate the effectiveness of our approach, we conduct experiments on ImageNet and the Caltech-UCSD Birds-200-2011 dataset. Our method achieves superior performance to that of baseline methods on standard evaluation metrics such as the Insertion-Deletion score while maintaining its conservation property. We will release our code for further research at this https URL
9 |
10 |
11 |
12 | ## Getting started
13 | Clone this repository and get in it. Then run `poetry install --no-root`.
14 |
15 | We used the following env.
16 | - Python 3.9.15
17 | - Poetry 1.7.1
18 | - cuda 11.7
19 |
20 | See [pyproject.toml](pyproject.toml) to check python dependencies.
21 |
22 | ### Datasets
23 | Follow the instructions [here](datasets/README.md).
24 |
25 | ### Get models
26 | If you want to test the method on the CUB, follow the instructions [here](checkpoints/README.md).
27 | You do not have to prepare models for ImageNet.
28 |
29 |
30 | ## Quantitative Experiments
31 | E.g.: Run ours on ImageNet.
32 | ```bash
33 | poetry run python visualize.py -c configs/ImageNet_resnet50.json --method "lrp" --heat-quantization --skip-connection-prop-type "flows_skip" --notes "imagenet--type:flows_skip--viz:norm+positive" --all_class --seed 42 --normalize --sign "positive"
34 | ```
35 |
36 |
37 | ## Visualize attribution maps for specific images
38 | E.g.: Visualize attribution maps for water ouzel in ImageNet by our method.
39 | ```bash
40 | poetry run python oneshot.py -c configs/ImageNet_resnet50.json --method "lrp" --skip-connection-prop-type "flows_skip" --heat-quantization --image-path ./qual/original/water_ouzel.png --label 20 --save-path ./qual/ours/water_ouzel.png --normalize --sign "positive"
41 | ```
42 | See `exp.sh` for more examples
43 |
44 |
45 | ## Bibtex
46 |
47 | ```
48 | @article{otsuki2024layer,
49 | title={{Layer-Wise Relevance Propagation with Conservation Property for ResNet}},
50 | author={Seitaro Otsuki, Tsumugi Iida, F\'elix Doublet, Tsubasa Hirakawa, Takayoshi Yamashita, Hironobu Fujiyoshi, Komei Sugiura},
51 | journal={arXiv preprint arXiv:2407.09115},
52 | year={2024},
53 | }
54 | ```
55 |
56 | ## License
57 | This work is licensed under the BSD-3-Clause-Clear license.
58 | To view a copy of this license, see [LICENSE](LICENSE).
59 |
--------------------------------------------------------------------------------
/checkpoints/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/checkpoints/.gitkeep
--------------------------------------------------------------------------------
/checkpoints/README.md:
--------------------------------------------------------------------------------
1 | We use ResNet50 trained on CUB for CUB experiments. Download them from [here](https://github.com/keio-smilab24/LRP-for-ResNet/releases) and put them into this (checkpoints/) directory.
2 |
3 | ### Structure
4 | Download models & place them as follows:
5 | ```
6 | checkpoints
7 | ├── .gitkeep
8 | ├── CUB_resnet50_Seed40
9 | │ ├── best.pt
10 | │ ├── checkpoint.pt
11 | │ └── config.json
12 | ├── CUB_resnet50_Seed41
13 | │ ├── best.pt
14 | │ ├── checkpoint.pt
15 | │ └── config.json
16 | ├── CUB_resnet50_Seed42
17 | │ ├── best.pt
18 | │ ├── checkpoint.pt
19 | │ └── config.json
20 | ├── CUB_resnet50_Seed43
21 | │ ├── best.pt
22 | │ ├── checkpoint.pt
23 | │ └── config.json
24 | ├── CUB_resnet50_Seed44
25 | │ ├── best.pt
26 | │ ├── checkpoint.pt
27 | │ └── config.json
28 | ```
--------------------------------------------------------------------------------
/configs/CUB_resnet18.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed": 574,
3 | "model": "resnet18",
4 | "add_attention_branch": false,
5 | "base_pretrained": "",
6 | "trainable_module": -1,
7 | "dataset": "CUB",
8 | "image_size": 224,
9 | "batch_size": 32,
10 | "train_ratio": 0.9,
11 | "epochs": 300,
12 | "optimizer": "SGD",
13 | "lr": 1e-3,
14 | "lr_linear": 1e-3,
15 | "min_lr": 1e-6,
16 | "weight_decay": 1e-4,
17 | "factor": 0.333,
18 | "scheduler_patience": 1,
19 | "early_stopping_patience": 6
20 | }
--------------------------------------------------------------------------------
/configs/CUB_resnet34.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed": 574,
3 | "model": "resnet34",
4 | "add_attention_branch": false,
5 | "base_pretrained": "",
6 | "trainable_module": -1,
7 | "dataset": "CUB",
8 | "image_size": 224,
9 | "batch_size": 32,
10 | "train_ratio": 0.9,
11 | "epochs": 300,
12 | "optimizer": "SGD",
13 | "lr": 1e-3,
14 | "lr_linear": 1e-3,
15 | "min_lr": 1e-6,
16 | "weight_decay": 1e-4,
17 | "factor": 0.333,
18 | "scheduler_patience": 1,
19 | "early_stopping_patience": 6
20 | }
--------------------------------------------------------------------------------
/configs/CUB_resnet50.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed": 574,
3 | "model": "resnet50",
4 | "add_attention_branch": false,
5 | "base_pretrained": "",
6 | "trainable_module": -1,
7 | "dataset": "CUB",
8 | "image_size": 224,
9 | "batch_size": 32,
10 | "train_ratio": 0.9,
11 | "epochs": 300,
12 | "optimizer": "SGD",
13 | "lr": 1e-3,
14 | "lr_linear": 1e-3,
15 | "min_lr": 1e-6,
16 | "weight_decay": 1e-4,
17 | "factor": 0.333,
18 | "scheduler_patience": 1,
19 | "early_stopping_patience": 6
20 | }
--------------------------------------------------------------------------------
/configs/CUB_resnet50v2.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed": 574,
3 | "model": "resnet50-v2",
4 | "add_attention_branch": false,
5 | "base_pretrained": "",
6 | "trainable_module": -1,
7 | "dataset": "CUB",
8 | "image_size": 224,
9 | "batch_size": 32,
10 | "train_ratio": 0.9,
11 | "epochs": 300,
12 | "optimizer": "SGD",
13 | "lr": 1e-3,
14 | "lr_linear": 1e-3,
15 | "min_lr": 1e-6,
16 | "weight_decay": 1e-4,
17 | "factor": 0.333,
18 | "scheduler_patience": 1,
19 | "early_stopping_patience": 6
20 | }
--------------------------------------------------------------------------------
/configs/ImageNet_resnet50.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed": 574,
3 | "model": "resnet50",
4 | "add_attention_branch": false,
5 | "base_pretrained": "",
6 | "trainable_module": -1,
7 | "dataset": "ImageNet",
8 | "image_size": 224,
9 | "batch_size": 32,
10 | "train_ratio": 0.9,
11 | "epochs": 300,
12 | "optimizer": "SGD",
13 | "lr": 1e-3,
14 | "lr_linear": 1e-3,
15 | "min_lr": 1e-6,
16 | "weight_decay": 1e-4,
17 | "factor": 0.333,
18 | "scheduler_patience": 1,
19 | "early_stopping_patience": 6
20 | }
--------------------------------------------------------------------------------
/data/CUB.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Any, Callable, Optional, Tuple
3 |
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 |
7 |
8 | class CUBDataset(Dataset):
9 |
10 | def __init__(
11 | self, root: str, image_set: str = "train", transform: Optional[Callable] = None
12 | ) -> None:
13 | super().__init__()
14 |
15 | self.root = root
16 | self.image_set = image_set
17 | self.transform = transform
18 |
19 | self.base_dir = os.path.join(self.root, "CUB_200_2011")
20 |
21 | # Images and their corresponding paths
22 | image_list_path = os.path.join(self.base_dir, "images.txt")
23 | with open(image_list_path, "r") as f:
24 | lines = f.readlines()
25 | img_ids = [int(line.split()[0]) for line in lines]
26 | img_paths = [line.split()[1].strip() for line in lines]
27 |
28 | # Splitting the dataset into train/test based on train_test_split.txt
29 | train_test_split_path = os.path.join(self.base_dir, "train_test_split.txt")
30 | with open(train_test_split_path, "r") as f:
31 | lines = f.readlines()
32 | splits = {int(line.split()[0]): int(line.split()[1]) for line in lines}
33 |
34 | # Filtering images based on the desired set (train/test)
35 | if image_set == "train":
36 | self.images = [img_paths[i] for i, img_id in enumerate(img_ids) if splits[img_id] == 1]
37 | else:
38 | self.images = [img_paths[i] for i, img_id in enumerate(img_ids) if splits[img_id] == 0]
39 |
40 | # Targets and segments
41 | self.targets = [int(img_path.split('/')[0].split('.')[0]) - 1 for img_path in self.images]
42 | self.segments = [img_path.replace("images", "segmentations").replace(".jpg", ".png") for img_path in self.images]
43 |
44 | def __getitem__(self, index) -> Tuple[Any, Any]:
45 | image_path = os.path.join(self.base_dir, "images", self.images[index])
46 | image = Image.open(image_path).convert("RGB")
47 | target = self.targets[index]
48 |
49 | if self.transform:
50 | image = self.transform(image)
51 |
52 | return image, target
53 |
54 | def __len__(self) -> int:
55 | return len(self.images)
56 |
57 |
58 | CUB_CLASSES = [
59 | 'Black_footed_Albatross',
60 | 'Laysan_Albatross',
61 | 'Sooty_Albatross',
62 | 'Groove_billed_Ani',
63 | 'Crested_Auklet',
64 | 'Least_Auklet',
65 | 'Parakeet_Auklet',
66 | 'Rhinoceros_Auklet',
67 | 'Brewer_Blackbird',
68 | 'Red_winged_Blackbird',
69 | 'Rusty_Blackbird',
70 | 'Yellow_headed_Blackbird',
71 | 'Bobolink',
72 | 'Indigo_Bunting',
73 | 'Lazuli_Bunting',
74 | 'Painted_Bunting',
75 | 'Cardinal',
76 | 'Spotted_Catbird',
77 | 'Gray_Catbird',
78 | 'Yellow_breasted_Chat',
79 | 'Eastern_Towhee',
80 | 'Chuck_will_Widow',
81 | 'Brandt_Cormorant',
82 | 'Red_faced_Cormorant',
83 | 'Pelagic_Cormorant',
84 | 'Bronzed_Cowbird',
85 | 'Shiny_Cowbird',
86 | 'Brown_Creeper',
87 | 'American_Crow',
88 | 'Fish_Crow',
89 | 'Black_billed_Cuckoo',
90 | 'Mangrove_Cuckoo',
91 | 'Yellow_billed_Cuckoo',
92 | 'Gray_crowned_Rosy_Finch',
93 | 'Purple_Finch',
94 | 'Northern_Flicker',
95 | 'Acadian_Flycatcher',
96 | 'Great_Crested_Flycatcher',
97 | 'Least_Flycatcher',
98 | 'Olive_sided_Flycatcher',
99 | 'Scissor_tailed_Flycatcher',
100 | 'Vermilion_Flycatcher',
101 | 'Yellow_bellied_Flycatcher',
102 | 'Frigatebird',
103 | 'Northern_Fulmar',
104 | 'Gadwall',
105 | 'American_Goldfinch',
106 | 'European_Goldfinch',
107 | 'Boat_tailed_Grackle',
108 | 'Eared_Grebe',
109 | 'Horned_Grebe',
110 | 'Pied_billed_Grebe',
111 | 'Western_Grebe',
112 | 'Blue_Grosbeak',
113 | 'Evening_Grosbeak',
114 | 'Pine_Grosbeak',
115 | 'Rose_breasted_Grosbeak',
116 | 'Pigeon_Guillemot',
117 | 'California_Gull',
118 | 'Glaucous_winged_Gull',
119 | 'Heermann_Gull',
120 | 'Herring_Gull',
121 | 'Ivory_Gull',
122 | 'Ring_billed_Gull',
123 | 'Slaty_backed_Gull',
124 | 'Western_Gull',
125 | 'Anna_Hummingbird',
126 | 'Ruby_throated_Hummingbird',
127 | 'Rufous_Hummingbird',
128 | 'Green_Violetear',
129 | 'Long_tailed_Jaeger',
130 | 'Pomarine_Jaeger',
131 | 'Blue_Jay',
132 | 'Florida_Jay',
133 | 'Green_Jay',
134 | 'Dark_eyed_Junco',
135 | 'Tropical_Kingbird',
136 | 'Gray_Kingbird',
137 | 'Belted_Kingfisher',
138 | 'Green_Kingfisher',
139 | 'Pied_Kingfisher',
140 | 'Ringed_Kingfisher',
141 | 'White_breasted_Kingfisher',
142 | 'Red_legged_Kittiwake',
143 | 'Horned_Lark',
144 | 'Pacific_Loon',
145 | 'Mallard',
146 | 'Western_Meadowlark',
147 | 'Hooded_Merganser',
148 | 'Red_breasted_Merganser',
149 | 'Mockingbird',
150 | 'Nighthawk',
151 | 'Clark_Nutcracker',
152 | 'White_breasted_Nuthatch',
153 | 'Baltimore_Oriole',
154 | 'Hooded_Oriole',
155 | 'Orchard_Oriole',
156 | 'Scott_Oriole',
157 | 'Ovenbird',
158 | 'Brown_Pelican',
159 | 'White_Pelican',
160 | 'Western_Wood_Pewee',
161 | 'Sayornis',
162 | 'American_Pipit',
163 | 'Whip_poor_Will',
164 | 'Horned_Puffin',
165 | 'Common_Raven',
166 | 'White_necked_Raven',
167 | 'American_Redstart',
168 | 'Geococcyx',
169 | 'Loggerhead_Shrike',
170 | 'Great_Grey_Shrike',
171 | 'Baird_Sparrow',
172 | 'Black_throated_Sparrow',
173 | 'Brewer_Sparrow',
174 | 'Chipping_Sparrow',
175 | 'Clay_colored_Sparrow',
176 | 'House_Sparrow',
177 | 'Field_Sparrow',
178 | 'Fox_Sparrow',
179 | 'Grasshopper_Sparrow',
180 | 'Harris_Sparrow',
181 | 'Henslow_Sparrow',
182 | 'Le_Conte_Sparrow',
183 | 'Lincoln_Sparrow',
184 | 'Nelson_Sharp_tailed_Sparrow',
185 | 'Savannah_Sparrow',
186 | 'Seaside_Sparrow',
187 | 'Song_Sparrow',
188 | 'Tree_Sparrow',
189 | 'Vesper_Sparrow',
190 | 'White_crowned_Sparrow',
191 | 'White_throated_Sparrow',
192 | 'Cape_Glossy_Starling',
193 | 'Bank_Swallow',
194 | 'Barn_Swallow',
195 | 'Cliff_Swallow',
196 | 'Tree_Swallow',
197 | 'Scarlet_Tanager',
198 | 'Summer_Tanager',
199 | 'Artic_Tern',
200 | 'Black_Tern',
201 | 'Caspian_Tern',
202 | 'Common_Tern',
203 | 'Elegant_Tern',
204 | 'Forsters_Tern',
205 | 'Least_Tern',
206 | 'Green_tailed_Towhee',
207 | 'Brown_Thrasher',
208 | 'Sage_Thrasher',
209 | 'Black_capped_Vireo',
210 | 'Blue_headed_Vireo',
211 | 'Philadelphia_Vireo',
212 | 'Red_eyed_Vireo',
213 | 'Warbling_Vireo',
214 | 'White_eyed_Vireo',
215 | 'Yellow_throated_Vireo',
216 | 'Bay_breasted_Warbler',
217 | 'Black_and_white_Warbler',
218 | 'Black_throated_Blue_Warbler',
219 | 'Blue_winged_Warbler',
220 | 'Canada_Warbler',
221 | 'Cape_May_Warbler',
222 | 'Cerulean_Warbler',
223 | 'Chestnut_sided_Warbler',
224 | 'Golden_winged_Warbler',
225 | 'Hooded_Warbler',
226 | 'Kentucky_Warbler',
227 | 'Magnolia_Warbler',
228 | 'Mourning_Warbler',
229 | 'Myrtle_Warbler',
230 | 'Nashville_Warbler',
231 | 'Orange_crowned_Warbler',
232 | 'Palm_Warbler',
233 | 'Pine_Warbler',
234 | 'Prairie_Warbler',
235 | 'Prothonotary_Warbler',
236 | 'Swainson_Warbler',
237 | 'Tennessee_Warbler',
238 | 'Wilson_Warbler',
239 | 'Worm_eating_Warbler',
240 | 'Yellow_Warbler',
241 | 'Northern_Waterthrush',
242 | 'Louisiana_Waterthrush',
243 | 'Bohemian_Waxwing',
244 | 'Cedar_Waxwing',
245 | 'American_Three_toed_Woodpecker',
246 | 'Pileated_Woodpecker',
247 | 'Red_bellied_Woodpecker',
248 | 'Red_cockaded_Woodpecker',
249 | 'Red_headed_Woodpecker',
250 | 'Downy_Woodpecker',
251 | 'Bewick_Wren',
252 | 'Cactus_Wren',
253 | 'Carolina_Wren',
254 | 'House_Wren',
255 | 'Marsh_Wren',
256 | 'Rock_Wren',
257 | 'Winter_Wren',
258 | 'Common_Yellowthroat',
259 | ]
--------------------------------------------------------------------------------
/data/RDD.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import os
3 | from typing import Any, Callable, Optional, Tuple
4 |
5 | import cv2
6 | import numpy as np
7 | import torch
8 | from PIL import Image
9 | from torch.utils.data import Dataset
10 |
11 |
12 | class RDDDataset(Dataset):
13 |
14 | def __init__(
15 | self, root: str, image_set: str = "train", transform: Optional[Callable] = None
16 | ) -> None:
17 | super().__init__()
18 |
19 | self.root = root
20 | self.image_set = image_set
21 | self.transform = transform
22 |
23 | self.base_dir = os.path.join(self.root, "RDD")
24 |
25 | image_dir = "JPEGImages"
26 | annot_dir = "Annotations"
27 | sam_dir = "SAM-mask"
28 | annotation_file = f"{image_set}.csv"
29 |
30 | self.image_dir = os.path.join(self.base_dir, image_dir)
31 | self.sam_dir = os.path.join(self.base_dir, sam_dir)
32 | self.annotation_file = os.path.join(self.base_dir, annot_dir, annotation_file)
33 |
34 | with open(self.annotation_file) as f:
35 | reader = csv.reader(f)
36 | # skip header
37 | reader.__next__()
38 | annotations = [row for row in reader]
39 |
40 | # Transpose [Image_fname, Label]
41 | annotations = [list(x) for x in zip(*annotations)]
42 |
43 | image_fnames = annotations[0]
44 | self.images = list(
45 | map(lambda x: os.path.join(self.image_dir, x + ".jpg"), image_fnames)
46 | )
47 | self.sam_segments = list(
48 | map(lambda x: os.path.join(self.sam_dir, x + ".png"), image_fnames)
49 | )
50 | self.targets = self.targets = list(
51 | map(lambda x: int(1 <= int(x)), annotations[1])
52 | )
53 |
54 | def __getitem__(self, index) -> Tuple[Any, Any, Any, Any]:
55 | orig_image = Image.open(self.images[index]).convert("RGB")
56 |
57 | sam = cv2.imread(self.sam_segments[index], cv2.IMREAD_GRAYSCALE)
58 | sam_orig = np.expand_dims(cv2.resize(sam, (orig_image.size[:2])), -1)
59 | sam_mask = 1 * (sam_orig > 0)
60 | if sam_mask.sum() < sam_mask.size * 0.1:
61 | sam_mask = np.ones_like(sam_mask)
62 |
63 | # Calculate the bounding box coordinates of the mask
64 | mask_coords = np.argwhere(sam_mask.squeeze())
65 | min_y, min_x = np.min(mask_coords, axis=0)
66 | max_y, max_x = np.max(mask_coords, axis=0)
67 |
68 | # Crop the image using the bounding box coordinates
69 | image = orig_image.crop((min_x, min_y, max_x + 1, max_y + 1))
70 |
71 | # image = np.array(orig_image) * sam_mask
72 | # image = Image.fromarray(image.astype(np.uint8))
73 |
74 | target = self.targets[index]
75 |
76 | if self.transform is not None:
77 | image = self.transform(image)
78 | orig_image = self.transform(orig_image)
79 |
80 | sam_resized = np.expand_dims(cv2.resize(sam, (orig_image.shape[1:])), -1)
81 | sam_mask = 1 * (sam_resized > 0)
82 | sam_mask = sam_mask.reshape(1, orig_image.shape[1], orig_image.shape[2])
83 | sam_mask = torch.from_numpy(sam_mask.astype(np.uint8)).clone()
84 |
85 | return image, target, sam_mask, orig_image
86 |
87 | def __len__(self) -> int:
88 | return len(self.images)
89 |
--------------------------------------------------------------------------------
/data/RDD_4.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import os
3 | from typing import Any, Callable, Optional, Tuple
4 |
5 | import cv2
6 | import numpy as np
7 | import torch
8 | import torchvision
9 | from PIL import Image
10 | from torch.utils.data import Dataset
11 |
12 |
13 | class RDDDataset(Dataset):
14 |
15 | def __init__(
16 | self, root: str, image_set: str = "train", transform: Optional[Callable] = None
17 | ) -> None:
18 | super().__init__()
19 |
20 | self.root = root
21 | self.image_set = image_set
22 | self.transform = transform
23 |
24 | self.base_dir = os.path.join(self.root, "RDD")
25 |
26 | image_dir = "JPEGImages"
27 | annot_dir = "Annotations"
28 | sam_dir = "SAM-all"
29 | annotation_file = f"{image_set}.csv"
30 |
31 | self.image_dir = os.path.join(self.base_dir, image_dir)
32 | self.sam_dir = os.path.join(self.base_dir, sam_dir)
33 | self.annotation_file = os.path.join(self.base_dir, annot_dir, annotation_file)
34 |
35 | with open(self.annotation_file) as f:
36 | reader = csv.reader(f)
37 | # skip header
38 | reader.__next__()
39 | annotations = [row for row in reader]
40 |
41 | # Transpose [Image_fname, Label]
42 | annotations = [list(x) for x in zip(*annotations)]
43 |
44 | image_fnames = annotations[0]
45 | self.images = list(
46 | map(lambda x: os.path.join(self.image_dir, x + ".jpg"), image_fnames)
47 | )
48 | self.sam_segments = list(
49 | map(lambda x: os.path.join(self.sam_dir, x + ".png"), image_fnames)
50 | )
51 | self.targets = self.targets = list(
52 | map(lambda x: int(1 <= int(x)), annotations[1])
53 | )
54 |
55 | def __getitem__(self, index) -> Tuple[Any, Any]:
56 | image = Image.open(self.images[index]).convert("RGB")
57 |
58 | target = self.targets[index]
59 |
60 | if self.transform is not None:
61 | image = self.transform(image)
62 |
63 | sam = cv2.imread(self.sam_segments[index], cv2.IMREAD_GRAYSCALE)
64 | # sam = np.expand_dims(cv2.resize(sam, (image.size[:2])), -1)
65 | sam = cv2.resize(sam, (image.shape[1:]))
66 | sam_mask = 1 * (sam > 0)
67 | sam_mask_pil = Image.fromarray(sam_mask.astype(np.uint8), mode="L")
68 |
69 | mask = torchvision.transforms.functional.to_tensor(sam_mask_pil)
70 | # print(image.shape, mask.shape)
71 | image = torch.cat((image, mask), 0)
72 |
73 | return image, target
74 |
75 | def __len__(self) -> int:
76 | return len(self.images)
77 |
--------------------------------------------------------------------------------
/data/RDD_bbox.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import os
3 | from typing import Any, Callable, Optional, Tuple
4 |
5 | from PIL import Image
6 | from torch.utils.data import Dataset
7 |
8 |
9 | class RDDBboxDataset(Dataset):
10 |
11 | def __init__(
12 | self, root: str, image_set: str = "train", transform: Optional[Callable] = None
13 | ) -> None:
14 | super().__init__()
15 |
16 | self.root = root
17 | self.image_set = image_set
18 | self.transform = transform
19 |
20 | self.base_dir = os.path.join(self.root, "RDD_bbox")
21 | annot_dir = "Annotations"
22 | annotation_file = f"{image_set}.csv"
23 | self.annotation_file = os.path.join(self.base_dir, annot_dir, annotation_file)
24 |
25 | with open(self.annotation_file) as f:
26 | reader = csv.reader(f)
27 | # skip header
28 | reader.__next__()
29 | annotations = [row for row in reader]
30 |
31 | # Transpose [Image_fname, Label]
32 | annotations = [list(x) for x in zip(*annotations)]
33 |
34 | self.images = list(
35 | map(lambda x: os.path.join(self.base_dir, x), annotations[0])
36 | )
37 | self.targets = list(map(int, annotations[1]))
38 |
39 | def __getitem__(self, index) -> Tuple[Any, Any]:
40 | image = Image.open(self.images[index]).convert("RGB")
41 | target = self.targets[index]
42 |
43 | if self.transform is not None:
44 | image = self.transform(image)
45 |
46 | return image, target
47 |
48 | def __len__(self) -> int:
49 | return len(self.images)
50 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | import random
2 | from typing import Any, Callable, Dict, Optional
3 |
4 | import numpy
5 | import torch
6 | import torch.utils.data as data
7 | import torchvision.transforms as transforms
8 | from torch import nn
9 |
10 | from data.CUB import CUB_CLASSES, CUBDataset
11 | from data.imagenet import IMAGENET_CLASSES, ImageNetDataset
12 | from data.RDD import RDDDataset
13 | from data.RDD_bbox import RDDBboxDataset
14 | from data.sampler import BalancedBatchSampler
15 | from metrics.accuracy import Accuracy, MultiClassAccuracy
16 |
17 | ALL_DATASETS = ["RDD", "RDD_bbox", "CUB", "ImageNet"]
18 |
19 |
20 | def seed_worker(worker_id):
21 | worker_seed = torch.initial_seed() % 2**32
22 | numpy.random.seed(worker_seed)
23 | random.seed(worker_seed)
24 |
25 |
26 | def get_generator(seed=0):
27 | g = torch.Generator()
28 | g.manual_seed(seed)
29 | return g
30 |
31 |
32 | def create_dataloader_dict(
33 | dataset_name: str,
34 | batch_size: int,
35 | image_size: int = 224,
36 | only_test: bool = False,
37 | train_ratio: float = 0.9,
38 | shuffle_val: bool = False,
39 | dataloader_seed: int = 0,
40 | ) -> Dict[str, data.DataLoader]:
41 | """
42 | Create dataloader dictionary
43 |
44 | Args:
45 | dataset_name(str) : Dataset name
46 | batch_size (int) : Batch size
47 | image_size (int) : Image size
48 | only_test (bool): Create only test dataset
49 | train_ratio(float): Train / val split ratio when there is no val
50 |
51 | Returns:
52 | dataloader_dict : Dataloader dictionary
53 | """
54 |
55 | test_dataset = create_dataset(dataset_name, "test", image_size)
56 | test_dataloader = data.DataLoader(
57 | test_dataset, batch_size=batch_size, shuffle=shuffle_val, worker_init_fn=seed_worker, generator=get_generator(dataloader_seed),
58 | )
59 |
60 | if only_test:
61 | return {"Test": test_dataloader}
62 |
63 | train_dataset = create_dataset(
64 | dataset_name,
65 | "train",
66 | image_size,
67 | )
68 |
69 | dataset_params = get_parameter_depend_in_data_set(dataset_name)
70 |
71 | # Create val or split
72 | if dataset_params["has_val"]:
73 | val_dataset = create_dataset(
74 | dataset_name,
75 | "val",
76 | image_size,
77 | )
78 | else:
79 | train_size = int(train_ratio * len(train_dataset))
80 | val_size = len(train_dataset) - train_size
81 |
82 | train_dataset, val_dataset = data.random_split(
83 | train_dataset, [train_size, val_size]
84 | )
85 | val_dataset.transform = test_dataset.transform
86 |
87 | if dataset_params["sampler"]:
88 | train_dataloader = data.DataLoader(
89 | train_dataset,
90 | batch_sampler=BalancedBatchSampler(train_dataset, 2, batch_size // 2),
91 | worker_init_fn=seed_worker,
92 | generator=get_generator(dataloader_seed),
93 | )
94 | else:
95 | train_dataloader = data.DataLoader(
96 | train_dataset, batch_size=batch_size, shuffle=True, worker_init_fn=seed_worker, generator=get_generator(dataloader_seed)
97 | )
98 | val_dataloader = data.DataLoader(
99 | val_dataset, batch_size=batch_size, shuffle=shuffle_val, worker_init_fn=seed_worker, generator=get_generator(dataloader_seed)
100 | )
101 |
102 | dataloader_dict = {
103 | "Train": train_dataloader,
104 | "Val": val_dataloader,
105 | "Test": test_dataloader,
106 | }
107 |
108 | print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
109 | return dataloader_dict
110 |
111 |
112 | def create_dataset(
113 | dataset_name: str,
114 | image_set: str = "train",
115 | image_size: int = 224,
116 | transform: Optional[Callable] = None,
117 | ) -> data.Dataset:
118 | """
119 | Create dataset
120 | Normalization parameters are created for each dataset
121 |
122 | Args:
123 | dataset_name(str) : Dataset name
124 | image_set(str) : Choose from train / val / test
125 | image_size(int) : Image size
126 | transform(Callable): transform
127 |
128 | Returns:
129 | data.Dataset : PyTorch dataset
130 | """
131 | assert dataset_name in ALL_DATASETS
132 | params = get_parameter_depend_in_data_set(dataset_name)
133 |
134 | if transform is None:
135 | if image_set == "train":
136 | transform = transforms.Compose(
137 | [
138 | transforms.Resize(int(image_size/0.875)),
139 | transforms.RandomResizedCrop(image_size),
140 | transforms.RandomHorizontalFlip(),
141 | transforms.ToTensor(),
142 | transforms.Normalize(params["mean"], params["std"]),
143 | ]
144 | )
145 | else:
146 | transform = transforms.Compose(
147 | [
148 | transforms.Resize(int(image_size/0.875)),
149 | transforms.CenterCrop(image_size),
150 | transforms.ToTensor(),
151 | transforms.Normalize(params["mean"], params["std"]),
152 | ]
153 | )
154 |
155 | if params["has_params"]:
156 | dataset = params["dataset"](
157 | root="./datasets",
158 | image_set=image_set,
159 | params=params,
160 | transform=transform,
161 | )
162 | else:
163 | dataset = params["dataset"](
164 | root="./datasets", image_set=image_set, transform=transform
165 | )
166 |
167 | return dataset
168 |
169 |
170 | def get_parameter_depend_in_data_set(
171 | dataset_name: str,
172 | pos_weight: torch.Tensor = torch.Tensor([1]),
173 | dataset_root: str = "./datasets",
174 | ) -> Dict[str, Any]:
175 | """
176 | Get parameters of the dataset
177 |
178 | Args:
179 | dataset_name(str): Dataset name
180 |
181 | Returns:
182 | dict[str, Any]: Parameters such as mean, variance, class name, etc.
183 | """
184 | params = dict()
185 | params["name"] = dataset_name
186 | params["root"] = dataset_root
187 | # Whether to pass params to the dataset class
188 | params["has_params"] = False
189 | params["num_channel"] = 3
190 | params["sampler"] = True
191 | # ImageNet
192 | params["mean"] = (0.485, 0.456, 0.406)
193 | params["std"] = (0.229, 0.224, 0.225)
194 |
195 | if dataset_name == "RDD":
196 | params["dataset"] = RDDDataset
197 | params["num_channel"] = 3
198 | params["mean"] = (0.4770, 0.5026, 0.5094)
199 | params["std"] = (0.2619, 0.2684, 0.3001)
200 | params["classes"] = ("no crack", "crack")
201 | params["has_val"] = False
202 | params["has_params"] = False
203 | params["sampler"] = False
204 |
205 | params["metric"] = Accuracy()
206 | params["criterion"] = nn.CrossEntropyLoss()
207 | elif dataset_name == "RDD_bbox":
208 | params["dataset"] = RDDBboxDataset
209 | params["num_channel"] = 3
210 | params["mean"] = (0.4401, 0.4347, 0.4137)
211 | params["std"] = (0.2016, 0.1871, 0.1787)
212 | params["classes"] = ("no crack", "crack")
213 | params["has_val"] = False
214 | params["has_params"] = False
215 | params["sampler"] = False
216 | params["metric"] = Accuracy()
217 | params["criterion"] = nn.CrossEntropyLoss()
218 | elif dataset_name == "CUB":
219 | params["dataset"] = CUBDataset
220 | params["num_channel"] = 3
221 | # params["mean"] = (0.4859, 0.4996, 0.4318)
222 | # params["std"] = (0.2266, 0.2218, 0.2609)
223 | params["mean"] = (0.485, 0.456, 0.406) # trace ImageNet
224 | params["std"] = (0.229, 0.224, 0.225) # trace ImageNet
225 | params["classes"] = CUB_CLASSES
226 | params["has_val"] = False
227 | params["has_params"] = False
228 | params["sampler"] = False
229 | params["metric"] = MultiClassAccuracy()
230 | params["criterion"] = nn.CrossEntropyLoss()
231 | elif dataset_name == "ImageNet":
232 | params["dataset"] = ImageNetDataset
233 | params["num_channel"] = 3
234 | params["mean"] = (0.485, 0.456, 0.406)
235 | params["std"] = (0.229, 0.224, 0.225)
236 | params["classes"] = IMAGENET_CLASSES
237 | params["has_val"] = False # Use val. set as the test set
238 | params["has_params"] = False
239 | params["sampler"] = False
240 | params["metric"] = MultiClassAccuracy()
241 | params["criterion"] = nn.CrossEntropyLoss()
242 |
243 | return params
244 |
--------------------------------------------------------------------------------
/data/sampler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.data import DataLoader
4 | from torch.utils.data.sampler import BatchSampler
5 |
6 |
7 | class BalancedBatchSampler(BatchSampler):
8 | """
9 | BatchSampler - from a MNIST-like dataset, samples n_classes and within these classes samples n_samples.
10 | Returns batches of size n_classes * n_samples
11 | """
12 |
13 | def __init__(self, dataset, n_classes, n_samples):
14 | loader = DataLoader(dataset)
15 | self.labels_list = []
16 | for _, label in loader:
17 | self.labels_list.append(label)
18 | self.labels = torch.LongTensor(self.labels_list)
19 | self.labels_set = list(set(self.labels.numpy()))
20 | self.label_to_indices = {
21 | label: np.where(self.labels.numpy() == label)[0]
22 | for label in self.labels_set
23 | }
24 | for l in self.labels_set:
25 | np.random.shuffle(self.label_to_indices[l])
26 | self.used_label_indices_count = {label: 0 for label in self.labels_set}
27 | self.count = 0
28 | self.n_classes = n_classes
29 | self.n_samples = n_samples
30 | self.dataset = dataset
31 | self.batch_size = self.n_samples * self.n_classes
32 |
33 | def __iter__(self):
34 | self.count = 0
35 | while self.count + self.batch_size < len(self.dataset):
36 | classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
37 | indices = []
38 | for class_ in classes:
39 | indices.extend(
40 | self.label_to_indices[class_][
41 | self.used_label_indices_count[
42 | class_
43 | ] : self.used_label_indices_count[class_]
44 | + self.n_samples
45 | ]
46 | )
47 | self.used_label_indices_count[class_] += self.n_samples
48 | if self.used_label_indices_count[class_] + self.n_samples > len(
49 | self.label_to_indices[class_]
50 | ):
51 | np.random.shuffle(self.label_to_indices[class_])
52 | self.used_label_indices_count[class_] = 0
53 | yield indices
54 | self.count += self.n_classes * self.n_samples
55 |
56 | def __len__(self):
57 | return len(self.dataset) // self.batch_size
58 |
--------------------------------------------------------------------------------
/datasets/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/datasets/.gitkeep
--------------------------------------------------------------------------------
/datasets/README.md:
--------------------------------------------------------------------------------
1 | ### Structure
2 | Download datasets (ImageNet and Caltech-UCSD Birds-200-2011) and place them as follows:
3 | ```
4 | .
5 | ├── .gitkeep
6 | ├── CUB_200_2011
7 | │ ├── attributes
8 | │ ├── bounding_boxes.txt
9 | │ ├── classes.txt
10 | │ ├── image_class_labels.txt
11 | │ ├── images
12 | │ ├── images.txt
13 | │ ├── parts
14 | │ ├── README
15 | │ └── train_test_split.txt
16 | ├── imagenet
17 | │ ├── ILSVRC2012_devkit_t12.tar.gz
18 | │ ├── ILSVRC2012_img_val.tar
19 | │ ├── meta.bin
20 | │ └── val
21 | └── README.md
22 | ```
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from typing import Dict, Optional, Tuple
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.utils.data as data
8 | from torchinfo import summary
9 | from tqdm import tqdm
10 |
11 | from data import (ALL_DATASETS, create_dataloader_dict,
12 | get_parameter_depend_in_data_set)
13 | from metrics.base import Metric
14 | from models import ALL_MODELS, create_model
15 | from utils.loss import calculate_loss
16 | from utils.utils import fix_seed, parse_with_config
17 |
18 |
19 | @torch.no_grad()
20 | def test(
21 | dataloader: data.DataLoader,
22 | model: nn.Module,
23 | criterion: nn.modules.loss._Loss,
24 | metrics: Metric,
25 | device: torch.device,
26 | phase: str = "Test",
27 | lambdas: Optional[Dict[str, float]] = None,
28 | ) -> Tuple[float, Metric]:
29 | total = 0
30 | total_loss: float = 0
31 |
32 | model.eval()
33 | for data in tqdm(dataloader, desc=f"{phase}: ", dynamic_ncols=True):
34 | inputs, labels = data[0].to(device), data[1].to(device)
35 | outputs = model(inputs)
36 |
37 | total_loss += calculate_loss(criterion, outputs, labels, model, lambdas).item()
38 | metrics.evaluate(outputs, labels)
39 | total += labels.size(0)
40 |
41 | test_loss = total_loss / total
42 | return test_loss, metrics
43 |
44 |
45 | import torchvision.transforms as transforms
46 | from PIL import Image
47 |
48 | from utils.visualize import save_image
49 |
50 |
51 | @torch.no_grad()
52 | def test_with_save(
53 | dataloader: data.DataLoader,
54 | model: nn.Module,
55 | criterion: nn.modules.loss._Loss,
56 | metrics: Metric,
57 | device: torch.device,
58 | phase: str = "Test",
59 | lambdas: Optional[Dict[str, float]] = None,
60 | ) -> Tuple[float, Metric]:
61 | total = 0
62 | total_loss: float = 0
63 | transform = transforms.ToPILImage()
64 |
65 | model.eval()
66 | for data in tqdm(dataloader, desc=f"{phase}: ", dynamic_ncols=True):
67 | inputs, labels = data[0].to(device), data[1].to(device)
68 | outputs = model(inputs)
69 | outputs = torch.softmax(outputs, dim=1)
70 | pred = torch.argmax(outputs, dim=1)
71 | misclassified = (pred != labels).nonzero().squeeze()
72 | if misclassified.numel() == 1:
73 | misclassified = [misclassified.item()]
74 |
75 | for index in misclassified:
76 | image = inputs[index, :3].cpu()
77 | # image = transform(image)
78 |
79 | true_label = labels[index].item()
80 | # predicted_label = pred[index].item()
81 | params = get_parameter_depend_in_data_set("RDD")
82 | save_image(
83 | image.detach().cpu().numpy(),
84 | f"outputs/error_anal/{true_label}/{outputs[index, true_label]:.4f}.png",
85 | params["mean"],
86 | params["std"],
87 | )
88 |
89 | # image.save(
90 | # f"outputs/error_anal/{true_label}/{outputs[index, true_label]}.png"
91 | # )
92 |
93 | total_loss += calculate_loss(criterion, outputs, labels, model, lambdas).item()
94 | metrics.evaluate(outputs, labels)
95 | total += labels.size(0)
96 |
97 | test_loss = total_loss / total
98 | return test_loss, metrics
99 |
100 |
101 | def main(args: argparse.Namespace) -> None:
102 | fix_seed(args.seed, args.no_deterministic)
103 |
104 | # データセットの作成
105 | dataloader_dict = create_dataloader_dict(
106 | args.dataset, args.batch_size, args.image_size, only_test=True
107 | )
108 | dataloader = dataloader_dict["Test"]
109 | assert isinstance(dataloader, data.DataLoader)
110 |
111 | params = get_parameter_depend_in_data_set(args.dataset)
112 |
113 | # モデルの作成
114 | model = create_model(
115 | args.model,
116 | num_classes=len(params["classes"]),
117 | num_channel=params["num_channel"],
118 | base_pretrained=args.base_pretrained,
119 | base_pretrained2=args.base_pretrained2,
120 | pretrained_path=args.pretrained,
121 | attention_branch=args.add_attention_branch,
122 | division_layer=args.div,
123 | theta_attention=args.theta_att,
124 | )
125 | assert model is not None, "Model name is invalid"
126 |
127 | model.load_state_dict(torch.load(args.pretrained))
128 | print(f"pretrained {args.pretrained} loaded.")
129 |
130 | criterion = params["criterion"]
131 | metric = params["metric"]
132 |
133 | # run_nameをpretrained pathから取得
134 | # checkpoints/run_name/checkpoint.pt -> run_name
135 | run_name = args.pretrained.split(os.sep)[-2]
136 | save_dir = os.path.join("outputs", run_name)
137 | if not os.path.isdir(save_dir):
138 | os.makedirs(save_dir)
139 |
140 | summary(
141 | model,
142 | (args.batch_size, params["num_channel"], args.image_size, args.image_size),
143 | )
144 |
145 | model.to(device)
146 |
147 | lambdas = {"att": args.lambda_att, "var": args.lambda_var}
148 | loss, metrics = test_with_save(
149 | dataloader, model, criterion, metric, device, lambdas=lambdas
150 | )
151 | metric_log = metrics.log()
152 | print(f"Test\t| {metric_log} Loss: {loss:.5f}")
153 |
154 |
155 | def parse_args() -> argparse.Namespace:
156 | parser = argparse.ArgumentParser()
157 |
158 | parser.add_argument("-c", "--config", type=str, help="path to config file (json)")
159 |
160 | parser.add_argument("--seed", type=int, default=42)
161 | parser.add_argument("--no_deterministic", action="store_false")
162 |
163 | # Model
164 | parser.add_argument("-m", "--model", choices=ALL_MODELS, help="model name")
165 | parser.add_argument(
166 | "-add_ab",
167 | "--add_attention_branch",
168 | action="store_true",
169 | help="add Attention Branch",
170 | )
171 | parser.add_argument(
172 | "-div",
173 | "--division_layer",
174 | type=str,
175 | choices=["layer1", "layer2", "layer3"],
176 | default="layer2",
177 | help="place to attention branch",
178 | )
179 | parser.add_argument("--base_pretrained", type=str, help="path to base pretrained")
180 | parser.add_argument("--pretrained", type=str, help="path to pretrained")
181 |
182 | # Dataset
183 | parser.add_argument("--dataset", type=str, default="IDRiD", choices=ALL_DATASETS)
184 | parser.add_argument("--image_size", type=int, default=224)
185 | parser.add_argument("--batch_size", type=int, default=32)
186 | parser.add_argument(
187 | "--loss_weights",
188 | type=float,
189 | nargs="*",
190 | default=[1.0, 1.0],
191 | help="weights for label by class",
192 | )
193 |
194 | parser.add_argument(
195 | "--lambda_att", type=float, default=0.1, help="weights for attention loss"
196 | )
197 | parser.add_argument(
198 | "--lambda_var", type=float, default=0, help="weights for variance loss"
199 | )
200 |
201 | parser.add_argument("--root_dir", type=str, default="./outputs/")
202 |
203 | return parse_with_config(parser)
204 |
205 |
206 | if __name__ == "__main__":
207 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
208 |
209 | main(parse_args())
210 |
--------------------------------------------------------------------------------
/metrics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/metrics/__init__.py
--------------------------------------------------------------------------------
/metrics/accuracy.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Tuple
2 |
3 | import torch
4 |
5 | from metrics.base import Metric
6 |
7 |
8 | def num_correct_topk(
9 | output: torch.Tensor, target: torch.Tensor, topk: Tuple[int] = (1,)
10 | ) -> List[int]:
11 | """
12 | Calculate top-k Accuracy
13 |
14 | Args:
15 | output(Tensor) : Model output
16 | target(Tensor) : Label
17 | topk(Tuple[int]): How many top ranks should be correct
18 |
19 | Returns:
20 | List[int] : top-k Accuracy
21 | """
22 | maxk = max(topk)
23 |
24 | _, pred = output.topk(maxk, dim=1)
25 | pred = pred.t()
26 | correct = pred.eq(target.view(1, -1).expand_as(pred))
27 |
28 | # [[False, False, True], [F, F, F], [T, F, F]]
29 | # -> [0, 0, 1, 0, 0, 0, 1, 0, 0] -> 2
30 | result = []
31 | for k in topk:
32 | correct_k = correct[:k].reshape(-1).float().sum(0)
33 | result.append(correct_k)
34 | return result
35 |
36 |
37 | class Accuracy(Metric):
38 | def __init__(self) -> None:
39 | self.total = 0
40 | self.correct = 0
41 | self.tp = 0
42 | self.tn = 0
43 | self.fp = 0
44 | self.fn = 0
45 |
46 | def evaluate(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
47 | self.total += labels.size(0)
48 | self.correct += num_correct_topk(preds, labels)[0]
49 |
50 | preds = torch.max(preds, dim=-1)[1]
51 | tp, fp, tn, fn = confusion(preds, labels)
52 | self.tp += tp
53 | self.fp += fp
54 | self.tn += tn
55 | self.fn += fn
56 |
57 | def score(self) -> Dict[str, float]:
58 | return {
59 | "Acc": self.acc(),
60 | "TP": int(self.tp),
61 | "FP": int(self.fp),
62 | "FN": int(self.fn),
63 | "TN": int(self.tn),
64 | }
65 |
66 | def acc(self) -> float:
67 | return self.correct / self.total
68 |
69 | def clear(self) -> None:
70 | self.total = 0
71 | self.correct = 0
72 | self.tp = 0
73 | self.fp = 0
74 | self.fn = 0
75 | self.tn = 0
76 |
77 |
78 | def confusion(output, target) -> Tuple[int, int, int, int]:
79 | """
80 | Calculate Confusion Matrix
81 |
82 | Args:
83 | output(Tensor) : Model output
84 | target(Tensor) : Label
85 |
86 | Returns:
87 | true_positive(int) : Number of TP
88 | false_positive(int): Number of FP
89 | true_negative(int) : Number of TN
90 | false_negative(int): Number of FN
91 | """
92 |
93 | # TP: 1/1 = 1, FP: 1/0 -> inf, TN: 0/0 -> nan, FN: 0/1 -> 0
94 | confusion_vector = output / target
95 |
96 | true_positives = torch.sum(confusion_vector == 1).item()
97 | false_positives = torch.sum(confusion_vector == float("inf")).item()
98 | true_negatives = torch.sum(torch.isnan(confusion_vector)).item()
99 | false_negatives = torch.sum(confusion_vector == 0).item()
100 |
101 | return (
102 | int(true_positives),
103 | int(false_positives),
104 | int(true_negatives),
105 | int(false_negatives),
106 | )
107 |
108 |
109 | class MultiClassAccuracy(Metric):
110 | def __init__(self) -> None:
111 | self.total = 0
112 | self.top1_correct = 0
113 | self.top5_correct = 0
114 |
115 | def evaluate(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
116 | self.total += labels.size(0)
117 | correct_top1, correct_top5 = num_correct_topk(preds, labels, (1, 5))
118 | self.top1_correct += correct_top1
119 | self.top5_correct += correct_top5
120 |
121 | def score(self) -> Dict[str, float]:
122 | return {
123 | "Top-1 Acc": self.acc(),
124 | "Top-5 Acc": self.top5_acc(),
125 | }
126 |
127 | def acc(self) -> float:
128 | return self.top1_correct / self.total
129 |
130 | def top5_acc(self) -> float:
131 | return self.top5_correct / self.total
132 |
133 | def clear(self) -> None:
134 | self.total = 0
135 | self.top1_correct = 0
136 | self.top5_correct = 0
--------------------------------------------------------------------------------
/metrics/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta, abstractmethod
2 | from typing import Dict
3 |
4 | import torch
5 |
6 |
7 | class Metric(metaclass=ABCMeta):
8 | @abstractmethod
9 | def evaluate(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
10 | pass
11 |
12 | @abstractmethod
13 | def score(self) -> Dict[str, float]:
14 | pass
15 |
16 | @abstractmethod
17 | def clear(self) -> None:
18 | pass
19 |
20 | def log(self) -> str:
21 | result = ""
22 | scores = self.score()
23 | for name, score in scores.items():
24 | if isinstance(score, int):
25 | result += f"{name}: {score} "
26 | else:
27 | result += f"{name}: {score:.3f} "
28 |
29 | return result[:-1]
30 |
--------------------------------------------------------------------------------
/metrics/patch_insdel.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | from typing import Dict, Union
4 |
5 | import cv2
6 | import numpy as np
7 | import skimage.measure
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | from data import get_parameter_depend_in_data_set
13 | from metrics.base import Metric
14 | from utils.utils import reverse_normalize
15 | from utils.visualize import save_data_as_plot, save_image
16 |
17 |
18 | class PatchInsertionDeletion(Metric):
19 | def __init__(
20 | self,
21 | model: nn.Module,
22 | batch_size: int,
23 | patch_size: int,
24 | step: int,
25 | dataset: str,
26 | device: torch.device,
27 | ) -> None:
28 | self.total = 0
29 | self.total_insertion = 0
30 | self.total_deletion = 0
31 | self.class_insertion: Dict[int, float] = {}
32 | self.num_by_classes: Dict[int, int] = {}
33 | self.class_deletion: Dict[int, float] = {}
34 |
35 | self.model = model
36 | self.batch_size = batch_size
37 | self.step = step
38 | self.device = device
39 | self.patch_size = patch_size
40 | self.dataset = dataset
41 |
42 | def evaluate(
43 | self,
44 | image: np.ndarray,
45 | attention: np.ndarray,
46 | label: Union[np.ndarray, torch.Tensor],
47 | ) -> None:
48 | self.image = image.copy()
49 | self.label = int(label.item())
50 |
51 | # image (C, W, H), attention (1, W', H') -> attention (W, H)
52 | self.attention = attention
53 | if not (self.image.shape[1:] == attention.shape):
54 | self.attention = cv2.resize(
55 | attention[0], dsize=(self.image.shape[1], self.image.shape[2])
56 | )
57 |
58 | # Divide attention map into patches and calculate the order of patches
59 | self.divide_attention_map_into_patch()
60 | self.calculate_attention_order()
61 |
62 | # Create input for insertion and inference
63 | self.generate_insdel_images(mode="insertion")
64 | self.ins_preds = self.inference() # for plot
65 | self.ins_auc = auc(self.ins_preds)
66 | self.total_insertion += self.ins_auc
67 | del self.input
68 |
69 | # deletion
70 | self.generate_insdel_images(mode="deletion")
71 | self.del_preds = self.inference()
72 | self.del_auc = auc(self.del_preds)
73 | self.total_deletion += self.del_auc
74 | del self.input
75 |
76 | self.total += 1
77 |
78 | def divide_attention_map_into_patch(self):
79 | assert self.attention is not None
80 |
81 | self.patch_attention = skimage.measure.block_reduce(
82 | self.attention, (self.patch_size, self.patch_size), np.max
83 | )
84 |
85 | def calculate_attention_order(self):
86 | attention_flat = np.ravel(self.patch_attention)
87 | # Sort in descending order
88 | order = np.argsort(-attention_flat)
89 |
90 | W, H = self.attention.shape
91 | patch_w, _ = W // self.patch_size, H // self.patch_size
92 | self.order = np.apply_along_axis(
93 | lambda x: map_2d_indices(x, patch_w), axis=0, arr=order
94 | )
95 |
96 | def generate_insdel_images(self, mode: str):
97 | C, W, H = self.image.shape
98 | patch_w, patch_h = W // self.patch_size, H // self.patch_size
99 | num_insertion = math.ceil(patch_w * patch_h / self.step)
100 |
101 | params = get_parameter_depend_in_data_set(self.dataset)
102 | self.input = np.zeros((num_insertion, C, W, H))
103 | mean, std = params["mean"], params["std"]
104 | image = reverse_normalize(self.image.copy(), mean, std)
105 |
106 | for i in range(num_insertion):
107 | step_index = min(self.step * (i + 1), self.order.shape[1] - 1)
108 | w_indices = self.order[0, step_index]
109 | h_indices = self.order[1, step_index]
110 | threthold = self.patch_attention[w_indices, h_indices]
111 |
112 | if mode == "insertion":
113 | mask = np.where(threthold <= self.patch_attention, 1, 0)
114 | elif mode == "deletion":
115 | mask = np.where(threthold <= self.patch_attention, 0, 1)
116 |
117 | mask = cv2.resize(mask, dsize=(W, H), interpolation=cv2.INTER_NEAREST)
118 |
119 | for c in range(min(3, C)):
120 | self.input[i, c] = (image[c] * mask - mean[c]) / std[c]
121 |
122 | def inference(self):
123 | inputs = torch.Tensor(self.input)
124 |
125 | num_iter = math.ceil(inputs.size(0) / self.batch_size)
126 | result = torch.zeros(0)
127 |
128 | for iter in range(num_iter):
129 | start = self.batch_size * iter
130 | batch_inputs = inputs[start : start + self.batch_size].to(self.device)
131 |
132 | outputs = self.model(batch_inputs)
133 |
134 | outputs = F.softmax(outputs, 1)
135 | outputs = outputs[:, self.label]
136 | result = torch.cat([result, outputs.cpu().detach()], dim=0)
137 |
138 | return np.nan_to_num(result)
139 |
140 | def save_images(self):
141 | params = get_parameter_depend_in_data_set(self.dataset)
142 | for i, image in enumerate(self.input):
143 | save_image(image, f"tmp/{self.total}_{i}", params["mean"], params["std"])
144 |
145 | def score(self) -> Dict[str, float]:
146 | result = {
147 | "Insertion": self.insertion(),
148 | "Deletion": self.deletion(),
149 | "PID": self.insertion() - self.deletion(),
150 | }
151 |
152 | for class_idx in self.class_insertion.keys():
153 | self.class_insertion_score(class_idx)
154 | self.class_deletion_score(class_idx)
155 |
156 | return result
157 |
158 | def log(self) -> str:
159 | result = "Class\tPID\tIns\tDel\n"
160 |
161 | scores = self.score()
162 | result += f"All\t{scores['PID']:.3f}\t{scores['Insertion']:.3f}\t{scores['Deletion']:.3f}\n"
163 |
164 | for class_idx in self.class_insertion.keys():
165 | pid = scores[f"PID_{class_idx}"]
166 | insertion = scores[f"Insertion_{class_idx}"]
167 | deletion = scores[f"Deletion_{class_idx}"]
168 | result += f"{class_idx}\t{pid:.3f}\t{insertion:.3f}\t{deletion:.3f}\n"
169 |
170 | return result
171 |
172 | def insertion(self) -> float:
173 | return self.total_insertion / self.total
174 |
175 | def deletion(self) -> float:
176 | return self.total_deletion / self.total
177 |
178 | def class_insertion_score(self, class_idx: int) -> float:
179 | num_samples = self.num_by_classes[class_idx]
180 | inserton_score = self.class_insertion[class_idx]
181 |
182 | return inserton_score / num_samples
183 |
184 | def class_deletion_score(self, class_idx: int) -> float:
185 | num_samples = self.num_by_classes[class_idx]
186 | deletion_score = self.class_deletion[class_idx]
187 |
188 | return deletion_score / num_samples
189 |
190 | def clear(self) -> None:
191 | self.total = 0
192 | self.ins_preds = None
193 | self.del_preds = None
194 |
195 | def save_roc_curve(self, save_dir: str) -> None:
196 | ins_fname = os.path.join(save_dir, f"{self.total}_insertion.png")
197 | save_data_as_plot(self.ins_preds, ins_fname, label=f"AUC = {self.ins_auc:.4f}")
198 |
199 | del_fname = os.path.join(save_dir, f"{self.total}_deletion.png")
200 | save_data_as_plot(self.del_preds, del_fname, label=f"AUC = {self.del_auc:.4f}")
201 |
202 |
203 | def map_2d_indices(indices_1d: int, width: int):
204 | """
205 | Convert 1D index to 2D index
206 | 1D index is converted to 2D index
207 |
208 | Args:
209 | indices_1d(array): index
210 | width(int) : width
211 |
212 | Examples:
213 | [[0, 1, 2], [3, 4, 5]]
214 | -> [0, 1, 2, 3, 4, 5]
215 |
216 | map_2d_indices(1, 3)
217 | >>> [0, 1]
218 | map_ed_indices(5, 3)
219 | >>> [1, 2]
220 |
221 | Return the index of the array before flattening
222 | """
223 | return [indices_1d // width, indices_1d % width]
224 |
225 |
226 | def auc(arr):
227 | """Returns normalized Area Under Curve of the array."""
228 | return (arr.sum() - arr[0] / 2 - arr[-1] / 2) / (arr.shape[0] - 1)
229 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Optional
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torchvision
7 | from torchvision.models.resnet import (
8 | ResNet,
9 | ResNet50_Weights,
10 | resnet18,
11 | resnet34,
12 | resnet50,
13 | )
14 | from torchvision.models.vgg import VGG
15 |
16 | from models.attention_branch import add_attention_branch
17 | from models.lrp import replace_resnet_modules
18 |
19 | ALL_MODELS = [
20 | "resnet18",
21 | "resnet34",
22 | "resnet",
23 | "resnet50",
24 | "resnet50-legacy",
25 | "EfficientNet",
26 | "wide",
27 | "vgg",
28 | "vgg19",
29 | "vgg19_bn",
30 | "vgg19_skip",
31 | "swin",
32 | ]
33 |
34 |
35 | class OneWayResNet(nn.Module):
36 | def __init__(self, model: ResNet) -> None:
37 | super().__init__()
38 | self.features = nn.Sequential(
39 | model.conv1,
40 | model.bn1,
41 | model.relu,
42 | model.maxpool,
43 | model.layer1,
44 | model.layer2,
45 | model.layer3,
46 | model.layer4,
47 | )
48 | self.avgpool = model.avgpool
49 | self.classifier = nn.Sequential(model.fc)
50 |
51 | def forward(self, x: torch.Tensor) -> torch.Tensor:
52 | x = self.features(x)
53 | x = self.avgpool(x)
54 | x = torch.flatten(x, 1)
55 | x = self.classifier(x)
56 | return x
57 |
58 |
59 | class _vgg19_skip_forward:
60 | def __init__(self, model_ref: VGG) -> None:
61 | self.model_ref = model_ref
62 |
63 | def __call__(self, x: torch.Tensor) -> torch.Tensor:
64 | # [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"]
65 | x = self.model_ref.features[0:2](x) # conv 3->64, relu
66 | x = self.model_ref.features[2:4](x) + x # conv 64->64, relu
67 | x = self.model_ref.features[4:5](x) # maxpool
68 | x = self.model_ref.features[5:7](x) # conv 64->128, relu
69 | x = self.model_ref.features[7:9](x) + x # conv 128->128, relu
70 | x = self.model_ref.features[9:10](x) # maxpool
71 | x = self.model_ref.features[10:12](x) # conv 128->256, relu
72 | x = self.model_ref.features[12:14](x) + x # conv 256->256, relu
73 | x = self.model_ref.features[14:16](x) + x # conv 256->256, relu
74 | x = self.model_ref.features[16:18](x) + x # conv 256->256, relu
75 | x = self.model_ref.features[18:19](x) # maxpool
76 | x = self.model_ref.features[19:21](x) # conv 256->512, relu
77 | x = self.model_ref.features[21:23](x) + x # conv 512->512, relu
78 | x = self.model_ref.features[23:25](x) + x # conv 512->512, relu
79 | x = self.model_ref.features[25:27](x) + x # conv 512->512, relu
80 | x = self.model_ref.features[27:28](x) # maxpool
81 | x = self.model_ref.features[28:30](x) + x # conv 512->512, relu
82 | x = self.model_ref.features[30:32](x) + x # conv 512->512, relu
83 | x = self.model_ref.features[32:34](x) + x # conv 512->512, relu
84 | x = self.model_ref.features[34:36](x) + x # conv 512->512, relu
85 | x = self.model_ref.features[36:37](x) # maxpool
86 | # x = self.model_ref.features(x)
87 | x = self.model_ref.avgpool(x)
88 | x = torch.flatten(x, 1)
89 | x = self.model_ref.classifier(x)
90 | return x
91 |
92 |
93 | def create_model(
94 | base_model: str,
95 | num_classes: int = 1000,
96 | num_channel: int = 3,
97 | base_pretrained: Optional[str] = None,
98 | base_pretrained2: Optional[str] = None,
99 | pretrained_path: Optional[str] = None,
100 | attention_branch: bool = False,
101 | division_layer: Optional[str] = None,
102 | theta_attention: float = 0,
103 | init_classifier: bool = True,
104 | ) -> Optional[nn.Module]:
105 | """
106 | Create model from model name and parameters
107 |
108 | Args:
109 | base_model(str) : Base model name (resnet, etc.)
110 | num_classes(int) : Number of classes
111 | base_pretrained(str) : Base model pretrain path
112 | base_pretrained2(str) : Pretrain path after changing the final layer
113 | pretrained_path(str) : Pretrain path of the final model
114 | (After attention branch, etc.)
115 | attention_branch(bool): Whether to attention branch
116 | division_layer(str) : Which layer to introduce attention branch
117 | theta_attention(float): Threshold when entering Attention Branch
118 | Set pixels with lower attention than this value to 0 and input
119 |
120 | Returns:
121 | nn.Module: Created model
122 | """
123 | # Create base model
124 | if base_model == "resnet":
125 | model = resnet18(pretrained=(base_pretrained is not None))
126 | layer_index = {"layer1": -6, "layer2": -5, "layer3": -4}
127 | if init_classifier:
128 | model.fc = nn.Linear(model.fc.in_features, num_classes)
129 | elif base_model == "resnet18":
130 | model = resnet18(pretrained=True)
131 | if init_classifier:
132 | model.fc = nn.Linear(model.fc.in_features, num_classes)
133 | elif base_model == "resnet34":
134 | model = resnet34(pretrained=True)
135 | if init_classifier:
136 | model.fc = nn.Linear(model.fc.in_features, num_classes)
137 | elif base_model == "resnet50":
138 | model = resnet50(pretrained=True)
139 | if init_classifier:
140 | model.fc = nn.Linear(model.fc.in_features, num_classes)
141 | elif base_model == "resnet50-v2":
142 | model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
143 | if init_classifier:
144 | model.fc = nn.Linear(model.fc.in_features, num_classes)
145 | elif base_model == "vgg":
146 | model = torchvision.models.vgg11(pretrained=True)
147 | model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, num_classes)
148 | elif base_model == "vgg19":
149 | model = torchvision.models.vgg19(pretrained=True)
150 | model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, num_classes)
151 | elif base_model == "vgg19_bn":
152 | model = torchvision.models.vgg19_bn(pretrained=True)
153 | model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, num_classes)
154 | elif base_model == "vgg19_skip":
155 | model = torchvision.models.vgg19(pretrained=True)
156 | model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, num_classes)
157 | model.forward = _vgg19_skip_forward(model)
158 | else:
159 | return None
160 |
161 | # Load if base_pretrained is path
162 | if base_pretrained is not None and os.path.isfile(base_pretrained):
163 | model.load_state_dict(torch.load(base_pretrained))
164 | print(f"base pretrained {base_pretrained} loaded.")
165 |
166 | if attention_branch:
167 | assert division_layer is not None
168 | model = add_attention_branch(
169 | model,
170 | layer_index[division_layer],
171 | num_classes,
172 | theta_attention,
173 | )
174 | model.attention_branch = replace_resnet_modules(model.attention_branch)
175 |
176 | # Load if pretrained is path
177 | if pretrained_path is not None and os.path.isfile(pretrained_path):
178 | model.load_state_dict(torch.load(pretrained_path))
179 | print(f"pretrained {pretrained_path} loaded.")
180 |
181 | return model
182 |
--------------------------------------------------------------------------------
/models/attention_branch.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Type
2 |
3 | import torch
4 | import torch.nn as nn
5 | from timm.models.byobnet import ByobNet
6 |
7 | from utils.utils import module_generator
8 |
9 |
10 | class BatchNorm2dWithActivation(nn.BatchNorm2d):
11 | def __init__(self, module):
12 | super(BatchNorm2dWithActivation, self).__init__(
13 | module.num_features,
14 | eps=module.eps,
15 | momentum=module.momentum,
16 | affine=module.affine,
17 | track_running_stats=module.track_running_stats,
18 | )
19 |
20 | def forward(self, x):
21 | self.activations = x.detach().clone()
22 | return super().forward(x)
23 |
24 |
25 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
26 | """3x3 convolution with padding"""
27 | return nn.Conv2d(
28 | in_planes,
29 | out_planes,
30 | kernel_size=3,
31 | stride=stride,
32 | padding=dilation,
33 | groups=groups,
34 | bias=False,
35 | dilation=dilation,
36 | )
37 |
38 |
39 | def conv1x1(in_planes, out_planes, stride=1):
40 | """1x1 convolution"""
41 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
42 |
43 |
44 | class BasicBlock(nn.Module):
45 | expansion = 1
46 |
47 | def __init__(self, inplanes, planes, stride=1, downsample=None):
48 | super(BasicBlock, self).__init__()
49 | self.inplanes = inplanes
50 | self.conv1 = conv3x3(inplanes, planes, stride)
51 | self.bn1 = nn.BatchNorm2d(planes)
52 | self.relu = nn.ReLU(inplace=True)
53 | self.conv2 = conv3x3(planes, planes)
54 | self.bn2 = nn.BatchNorm2d(planes)
55 | self.downsample = downsample
56 | self.stride = stride
57 |
58 | def forward(self, x):
59 | residual = x
60 |
61 | out = self.conv1(x)
62 | out = self.bn1(out)
63 | out = self.relu(out)
64 |
65 | out = self.conv2(out)
66 | out = self.bn2(out)
67 |
68 | if self.downsample is not None:
69 | residual = self.downsample(x)
70 |
71 | out += residual
72 | out = self.relu(out)
73 |
74 | return out
75 |
76 |
77 | class Bottleneck(nn.Module):
78 | expansion = 4
79 |
80 | def __init__(self, inplanes, planes, stride=1, downsample=None):
81 | super(Bottleneck, self).__init__()
82 | self.inplanes = inplanes
83 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
84 | self.bn1 = nn.BatchNorm2d(planes)
85 | self.conv2 = nn.Conv2d(
86 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
87 | )
88 | self.bn2 = nn.BatchNorm2d(planes)
89 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
90 | self.bn3 = nn.BatchNorm2d(planes * 4)
91 | self.relu = nn.ReLU(inplace=True)
92 | self.downsample = downsample
93 | self.stride = stride
94 |
95 | def forward(self, x):
96 | residual = x
97 |
98 | out = self.conv1(x)
99 | out = self.bn1(out)
100 | out = self.relu(out)
101 |
102 | out = self.conv2(out)
103 | out = self.bn2(out)
104 | out = self.relu(out)
105 |
106 | out = self.conv3(out)
107 | out = self.bn3(out)
108 |
109 | if self.downsample is not None:
110 | residual = self.downsample(x)
111 |
112 | out += residual
113 | out = self.relu(out)
114 |
115 | return out
116 |
117 |
118 | class AttentionBranch(nn.Module):
119 | def __init__(
120 | self,
121 | block: Bottleneck,
122 | num_layer: int,
123 | num_classes: int = 1000,
124 | inplanes: int = 64,
125 | multi_task: bool = False,
126 | num_tasks: Optional[List[int]] = None,
127 | ) -> None:
128 | super().__init__()
129 |
130 | self.inplanes = inplanes
131 |
132 | self.layer1 = self._make_layer(
133 | block, self.inplanes, num_layer, stride=1, down_size=False
134 | )
135 |
136 | hidden_channel = 10
137 |
138 | self.bn1 = nn.BatchNorm2d(self.inplanes * block.expansion)
139 | # self.bn1 = BatchNorm2dWithActivation(bn1)
140 | self.conv1 = conv1x1(self.inplanes * block.expansion, hidden_channel)
141 | self.relu = nn.ReLU(inplace=True)
142 |
143 | if multi_task:
144 | if num_tasks is None:
145 | num_tasks = [1 for _ in range(num_classes)]
146 | assert num_classes == sum(num_tasks)
147 | self.conv2 = conv3x3(num_classes, len(num_tasks))
148 | self.bn2 = nn.BatchNorm2d(len(num_tasks))
149 | else:
150 | self.conv2 = conv1x1(hidden_channel, hidden_channel)
151 | self.bn2 = nn.BatchNorm2d(hidden_channel)
152 | self.conv3 = conv1x1(hidden_channel, 1)
153 | self.bn3 = nn.BatchNorm2d(1)
154 |
155 | self.sigmoid = nn.Sigmoid()
156 |
157 | self.conv4 = conv1x1(hidden_channel, num_classes)
158 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
159 | self.flatten = nn.Flatten(1)
160 |
161 | def _make_layer(
162 | self,
163 | block: Bottleneck,
164 | planes: int,
165 | blocks: int,
166 | stride: int = 1,
167 | down_size: bool = True,
168 | ) -> nn.Sequential:
169 | downsample = None
170 | if stride != 1 or self.inplanes != planes * block.expansion:
171 | downsample = nn.Sequential(
172 | nn.Conv2d(
173 | self.inplanes,
174 | planes * block.expansion,
175 | kernel_size=1,
176 | stride=stride,
177 | bias=False,
178 | ),
179 | nn.BatchNorm2d(planes * block.expansion),
180 | )
181 |
182 | layers = []
183 | layers.append(block(self.inplanes, planes, stride, downsample))
184 | if down_size:
185 | self.inplanes = planes * block.expansion
186 | for _ in range(1, blocks):
187 | layers.append(block(self.inplanes, planes))
188 |
189 | return nn.Sequential(*layers)
190 | else:
191 | inplanes = planes * block.expansion
192 | for _ in range(1, blocks):
193 | layers.append(block(inplanes, planes))
194 |
195 | return nn.Sequential(*layers)
196 |
197 | def forward(self, x):
198 | x = self.layer1(x)
199 |
200 | self.bn1_activation = x
201 | x = self.bn1(x)
202 | x = self.conv1(x)
203 | x = self.relu(x)
204 |
205 | self.class_attention = self.sigmoid(x)
206 |
207 | attention = self.conv2(x)
208 | attention = self.bn2(attention)
209 | attention = self.relu(attention)
210 |
211 | attention = self.conv3(attention)
212 | attention = self.bn3(attention)
213 |
214 | self.attention_order = attention
215 | self.attention = self.sigmoid(attention)
216 |
217 | x = self.conv4(x)
218 | x = self.avgpool(x)
219 | x = self.flatten(x)
220 |
221 | return x
222 |
223 |
224 | class AttentionBranchModel(nn.Module):
225 | def __init__(
226 | self,
227 | feature_extractor: nn.Module,
228 | perception_branch: nn.Module,
229 | block: Type = Bottleneck,
230 | num_layer: int = 2,
231 | num_classes: int = 1000,
232 | inplanes: int = 64,
233 | theta_attention: float = 0,
234 | ) -> None:
235 | super().__init__()
236 |
237 | self.feature_extractor = feature_extractor
238 |
239 | self.attention_branch = AttentionBranch(block, num_layer, num_classes, inplanes)
240 | # self.attention_branch = replace_resnet_modules(self.attention_branch)
241 |
242 | self.perception_branch = perception_branch
243 | self.theta_attention = theta_attention
244 |
245 | def forward(self, x):
246 | x = self.feature_extractor(x)
247 |
248 | # For Attention Loss
249 | self.attention_pred = self.attention_branch(x)
250 |
251 | attention = self.attention_branch.attention
252 | attention = torch.where(
253 | self.theta_attention < attention, attention.double(), 0.0
254 | )
255 |
256 | x = x * attention
257 | x = x.float()
258 | # x = attention_x + x
259 |
260 | return self.perception_branch(x)
261 |
262 |
263 | def add_attention_branch(
264 | base_model: nn.Module,
265 | division_index: int,
266 | num_classes: int,
267 | theta_attention: float = 0,
268 | ) -> nn.Module:
269 | modules = list(base_model.children())
270 | # [conv1, bn1, relu, maxpool, layer1, layer2, layer3, layer4, avgpool, fc]
271 |
272 | pre_model = nn.Sequential(*modules[:division_index])
273 | post_model = nn.Sequential(*modules[division_index:-1], nn.Flatten(1), modules[-1])
274 |
275 | if isinstance(base_model, ByobNet):
276 | return AttentionBranchModel(
277 | pre_model,
278 | post_model,
279 | Bottleneck,
280 | 2,
281 | num_classes,
282 | inplanes=64,
283 | theta_attention=theta_attention,
284 | )
285 |
286 | final_layer = module_generator(modules[division_index], reverse=True)
287 | for module in final_layer:
288 | if isinstance(module, nn.modules.batchnorm._NormBase):
289 | inplanes = module.num_features // 2
290 | break
291 | elif isinstance(module, nn.modules.conv._ConvNd):
292 | inplanes = module.out_channels
293 | break
294 |
295 | return AttentionBranchModel(
296 | pre_model,
297 | post_model,
298 | Bottleneck,
299 | 2,
300 | num_classes,
301 | inplanes=inplanes,
302 | theta_attention=theta_attention,
303 | )
304 |
--------------------------------------------------------------------------------
/models/lrp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchvision.models.resnet import BasicBlock as OriginalBasicBlock
5 | from torchvision.models.resnet import Bottleneck as OriginalBottleneck
6 |
7 | from models.attention_branch import BasicBlock as ABNBasicBlock
8 | from models.attention_branch import Bottleneck as ABNBottleneck
9 |
10 |
11 | class LinearWithActivation(nn.Linear):
12 | def __init__(self, module, out_features=None):
13 | in_features = module.in_features
14 | if out_features is None:
15 | out_features = module.out_features
16 | bias = module.bias is not None
17 | super(LinearWithActivation, self).__init__(in_features, out_features, bias)
18 | self.activations = None
19 |
20 | def forward(self, x):
21 | if len(x.shape) > 2:
22 | bs = x.shape[0]
23 | x = x.view(bs, -1)
24 | self.activations = x.detach().clone()
25 | return super(LinearWithActivation, self).forward(x)
26 |
27 |
28 | class Conv2dWithActivation(nn.Conv2d):
29 | def __init__(self, module, num_channel=None):
30 | in_channels = module.in_channels if num_channel is None else num_channel
31 | super(Conv2dWithActivation, self).__init__(
32 | in_channels,
33 | module.out_channels,
34 | module.kernel_size,
35 | stride=module.stride,
36 | padding=module.padding,
37 | dilation=module.dilation,
38 | groups=module.groups,
39 | padding_mode=module.padding_mode,
40 | bias=module.bias is not None,
41 | )
42 | self.activations = None
43 |
44 | def forward(self, x):
45 | self.activations = x.detach().clone()
46 | return super().forward(x)
47 |
48 |
49 | class BatchNorm2dWithActivation(nn.BatchNorm2d):
50 | def __init__(self, module):
51 | super(BatchNorm2dWithActivation, self).__init__(
52 | module.num_features,
53 | eps=module.eps,
54 | momentum=module.momentum,
55 | affine=module.affine,
56 | track_running_stats=module.track_running_stats,
57 | )
58 | self.activations = None
59 |
60 | def forward(self, x):
61 | self.activations = x.detach().clone()
62 | return super().forward(x)
63 |
64 |
65 | class ReLUWithActivation(nn.ReLU):
66 | def __init__(self, *args, **kwargs):
67 | super(ReLUWithActivation, self).__init__(*args, **kwargs)
68 | self.activations = None
69 |
70 | def forward(self, x):
71 | self.activations = x.detach().clone()
72 | return super(ReLUWithActivation, self).forward(x)
73 |
74 |
75 | class MaxPool2dWithActivation(nn.MaxPool2d):
76 | def __init__(self, module):
77 | super(MaxPool2dWithActivation, self).__init__(
78 | module.kernel_size, module.stride, module.padding, module.dilation
79 | )
80 | self.activations = None
81 |
82 | def forward(self, x):
83 | self.activations = x.detach().clone()
84 | return super(MaxPool2dWithActivation, self).forward(x)
85 |
86 |
87 | class AdaptiveAvgPool2dWithActivation(nn.AdaptiveAvgPool2d):
88 | def __init__(self, module: nn.AdaptiveAvgPool2d):
89 | super(AdaptiveAvgPool2dWithActivation, self).__init__(module.output_size)
90 | self.activations = None
91 |
92 | def forward(self, x):
93 | self.activations = x.detach().clone()
94 | return super(AdaptiveAvgPool2dWithActivation, self).forward(x)
95 |
96 |
97 | def copy_weights(target_module, source_module):
98 | if isinstance(
99 | target_module, (nn.AdaptiveAvgPool2d, nn.MaxPool2d, nn.ReLU)
100 | ) and isinstance(source_module, (nn.AdaptiveAvgPool2d, nn.MaxPool2d, nn.ReLU)):
101 | # Do nothing for layers without weights
102 | return
103 | if isinstance(target_module, nn.Linear) and isinstance(source_module, nn.Linear):
104 | target_module.weight.data.copy_(source_module.weight.data)
105 | target_module.bias.data.copy_(source_module.bias.data)
106 | elif isinstance(target_module, nn.Conv2d) and isinstance(source_module, nn.Conv2d):
107 | target_module.weight.data.copy_(source_module.weight.data)
108 | if source_module.bias is not None:
109 | target_module.bias = source_module.bias
110 | elif isinstance(target_module, nn.BatchNorm2d) and isinstance(
111 | source_module, nn.BatchNorm2d
112 | ):
113 | target_module.weight.data.copy_(source_module.weight.data)
114 | target_module.bias.data.copy_(source_module.bias.data)
115 | target_module.running_mean.data.copy_(source_module.running_mean.data)
116 | target_module.running_var.data.copy_(source_module.running_var.data)
117 | else:
118 | raise ValueError(
119 | f"Unsupported module types for copy_weights source: {source_module} and target: {target_module}"
120 | )
121 |
122 |
123 | def replace_modules(model, wrapper):
124 | for name, module in model.named_children():
125 | if isinstance(module, nn.Sequential):
126 | setattr(model, name, replace_modules(module, wrapper))
127 | else:
128 | wrapped_module = wrapper(module)
129 | wrapped_module = copy_weights(module, wrapped_module)
130 | setattr(model, name, wrapped_module)
131 | return model
132 |
133 |
134 | class BasicBlockWithActivation(OriginalBasicBlock):
135 | def __init__(self, block: OriginalBasicBlock):
136 | inplanes = block.conv1.in_channels
137 | planes = block.conv1.out_channels
138 |
139 | super(BasicBlockWithActivation, self).__init__(
140 | inplanes=inplanes,
141 | planes=planes,
142 | stride=block.stride,
143 | downsample=block.downsample,
144 | )
145 |
146 | def forward(self, x):
147 | self.activations = x.detach().clone()
148 | return super().forward(x)
149 |
150 |
151 | class BottleneckWithActivation(OriginalBottleneck):
152 | def __init__(self, block: OriginalBottleneck):
153 | inplanes = block.conv1.in_channels
154 | planes = block.conv3.out_channels // block.expansion
155 | groups = block.conv2.groups
156 | dilation = block.conv2.dilation
157 | width = block.conv1.out_channels
158 | base_width = 64 * width // (groups * planes)
159 |
160 | super(BottleneckWithActivation, self).__init__(
161 | inplanes=inplanes,
162 | planes=planes,
163 | stride=block.stride,
164 | downsample=block.downsample,
165 | groups=groups,
166 | base_width=base_width,
167 | dilation=dilation,
168 | )
169 |
170 | def forward(self, x):
171 | self.activations = x.detach().clone()
172 | return super().forward(x)
173 |
174 |
175 | def replace_resnet_modules(model):
176 | for name, module in list(model.named_children()):
177 | if isinstance(module, nn.Sequential):
178 | for i, block in enumerate(module):
179 | if isinstance(block, OriginalBasicBlock) or isinstance(
180 | block, ABNBasicBlock
181 | ):
182 | module[i] = BasicBlockWithActivation(block)
183 | elif isinstance(block, OriginalBottleneck) or isinstance(
184 | block, ABNBottleneck
185 | ):
186 | module[i] = BottleneckWithActivation(block)
187 | setattr(model, name, module)
188 | elif isinstance(module, nn.AdaptiveAvgPool2d):
189 | setattr(model, name, AdaptiveAvgPool2dWithActivation(module))
190 | return model
191 |
--------------------------------------------------------------------------------
/models/rise.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class RISE(nn.Module):
8 | def __init__(
9 | self,
10 | model,
11 | n_masks=10000,
12 | p1=0.1,
13 | input_size=(224, 224),
14 | initial_mask_size=(7, 7),
15 | n_batch=128,
16 | mask_path=None,
17 | ):
18 | super().__init__()
19 | self.model = model
20 | self.n_masks = n_masks
21 | self.p1 = p1
22 | self.input_size = input_size
23 | self.initial_mask_size = initial_mask_size
24 | self.n_batch = n_batch
25 |
26 | if mask_path is not None:
27 | self.masks = self.load_masks(mask_path)
28 | else:
29 | self.masks = self.generate_masks()
30 |
31 | def generate_masks(self):
32 | # cell size in the upsampled mask
33 | Ch = np.ceil(self.input_size[0] / self.initial_mask_size[0])
34 | Cw = np.ceil(self.input_size[1] / self.initial_mask_size[1])
35 |
36 | resize_h = int((self.initial_mask_size[0] + 1) * Ch)
37 | resize_w = int((self.initial_mask_size[1] + 1) * Cw)
38 |
39 | masks = []
40 |
41 | for _ in range(self.n_masks):
42 | # generate binary mask
43 | binary_mask = torch.randn(
44 | 1, 1, self.initial_mask_size[0], self.initial_mask_size[1]
45 | )
46 | binary_mask = (binary_mask < self.p1).float()
47 |
48 | # upsampling mask
49 | mask = F.interpolate(
50 | binary_mask, (resize_h, resize_w), mode="bilinear", align_corners=False
51 | )
52 |
53 | # random cropping
54 | i = np.random.randint(0, Ch)
55 | j = np.random.randint(0, Cw)
56 | mask = mask[:, :, i : i + self.input_size[0], j : j + self.input_size[1]]
57 |
58 | masks.append(mask)
59 |
60 | masks = torch.cat(masks, dim=0) # (N_masks, 1, H, W)
61 |
62 | return masks
63 |
64 | def load_masks(self, filepath):
65 | masks = torch.load(filepath)
66 | return masks
67 |
68 | def save_masks(self, filepath):
69 | torch.save(self.masks, filepath)
70 |
71 | def forward(self, x):
72 | # x: input image. (1, 3, H, W)
73 | device = x.device
74 |
75 | # keep probabilities of each class
76 | probs = []
77 | # shape (n_masks, 3, H, W)
78 | masked_x = torch.mul(self.masks, x.to("cpu").data)
79 |
80 | for i in range(0, self.n_masks, self.n_batch):
81 | input = masked_x[i : min(i + self.n_batch, self.n_masks)].to(device)
82 | out = self.model(input)
83 | probs.append(torch.softmax(out, dim=1).to("cpu").data)
84 |
85 | probs = torch.cat(probs) # shape => (n_masks, n_classes)
86 | n_classes = probs.shape[1]
87 |
88 | # caluculate saliency map using probability scores as weights
89 | saliency = torch.matmul(
90 | probs.data.transpose(0, 1), self.masks.view(self.n_masks, -1)
91 | )
92 | saliency = saliency.view((n_classes, self.input_size[0], self.input_size[1]))
93 | saliency = saliency / (self.n_masks * self.p1)
94 |
95 | # normalize
96 | m, _ = torch.min(saliency.view(n_classes, -1), dim=1)
97 | saliency -= m.view(n_classes, 1, 1)
98 | M, _ = torch.max(saliency.view(n_classes, -1), dim=1)
99 | saliency /= M.view(n_classes, 1, 1)
100 | return saliency.data
101 |
--------------------------------------------------------------------------------
/oneshot.py:
--------------------------------------------------------------------------------
1 | """ Visualize attention map of a model with a given image.
2 |
3 | Example:
4 | ```sh
5 | poetry run python oneshot.py -c checkpoints/CUB_resnet50_Seed42/config.json --method "scorecam" \
6 | --image-path ./qual/original/painted_bunting.png --label 15 --save-path ./qual/scorecam/painted_bunting.png
7 |
8 | poetry run python oneshot.py -c checkpoints/CUB_resnet50_Seed42/config.json --method "lrp" \
9 | --skip-connection-prop-type "flows_skip" --heat-quantization \
10 | --image-path ./qual/original/painted_bunting.png --label 15 --save-path ./qual/ours/painted_bunting.png
11 |
12 | poetry run python oneshot.py -c configs/ImageNet_resnet50.json --method "lrp" \
13 | --skip-connection-prop-type "flows_skip" --heat-quantization \
14 | --image-path ./qual/original/bee.png --label 309 --save-path ./qual/ours/bee.png
15 | ```
16 |
17 | """
18 |
19 | import argparse
20 | import os
21 | from typing import Any, Dict, Optional, Union
22 |
23 | import torch
24 | import torch.nn as nn
25 | import torchvision.transforms as transforms
26 | from PIL import Image
27 | from skimage.transform import resize
28 |
29 | from data import ALL_DATASETS, get_parameter_depend_in_data_set
30 | from metrics.base import Metric
31 | from models import ALL_MODELS, create_model
32 | from models.lrp import *
33 | from src.utils import SkipConnectionPropType
34 | from utils.utils import fix_seed, parse_with_config
35 | from utils.visualize import (
36 | save_image_with_attention_map,
37 | )
38 | from visualize import ( # TODO: move these functions to src directory
39 | apply_heat_quantization,
40 | calculate_attention,
41 | remove_other_components,
42 | )
43 |
44 |
45 | def load_image(image_path: str, image_size: int = 224) -> torch.Tensor:
46 | imagenet_mean = (0.485, 0.456, 0.406)
47 | imagenet_std = (0.229, 0.224, 0.225)
48 | transform = transforms.Compose(
49 | [
50 | transforms.ToTensor(),
51 | transforms.Normalize(imagenet_mean, imagenet_std),
52 | ]
53 | )
54 | image = Image.open(image_path).convert("RGB")
55 | image = transform(image)
56 | return image
57 |
58 |
59 | def visualize(
60 | image_path: str,
61 | label: int,
62 | model: nn.Module,
63 | method: str,
64 | save_path: str,
65 | params: Dict[str, Any],
66 | device: torch.device,
67 | attention_dir: Optional[str] = None,
68 | use_c1c: bool = False,
69 | heat_quantization: bool = False,
70 | hq_level: int = 8,
71 | skip_connection_prop_type: SkipConnectionPropType = "latest",
72 | normalize: bool = False,
73 | sign: str = "all",
74 | ) -> Union[None, Metric]:
75 | model.eval()
76 | torch.cuda.memory_summary(device=device)
77 | inputs = load_image(image_path, params["input_size"][0]).unsqueeze(0).to(device)
78 | image = inputs[0].cpu().numpy()
79 | label = torch.tensor(label).to(device)
80 |
81 | attention, _ = calculate_attention(
82 | model, inputs, label, method, params, None, skip_connection_prop_type
83 | )
84 | if use_c1c:
85 | attention = resize(attention, (28, 28))
86 | attention = remove_other_components(attention, threshold=attention.mean())
87 |
88 | if heat_quantization:
89 | attention = apply_heat_quantization(attention, hq_level)
90 |
91 | save_path = save_path + ".png" if save_path[-4:] != ".png" else save_path
92 | save_image_with_attention_map(
93 | image, attention, save_path, params["mean"], params["std"], only_img=True, normalize=normalize, sign=sign
94 | )
95 |
96 |
97 | def main(args: argparse.Namespace) -> None:
98 | fix_seed(args.seed, args.no_deterministic)
99 | params = get_parameter_depend_in_data_set(args.dataset)
100 |
101 | mask_path = os.path.join(args.root_dir, "masks.npy")
102 | if not os.path.isfile(mask_path):
103 | mask_path = None
104 | rise_params = {
105 | "n_masks": args.num_masks,
106 | "p1": args.p1,
107 | "input_size": (args.image_size, args.image_size),
108 | "initial_mask_size": (args.rise_scale, args.rise_scale),
109 | "n_batch": args.batch_size,
110 | "mask_path": mask_path,
111 | }
112 | params.update(rise_params)
113 |
114 | # Create model
115 | model = create_model(
116 | args.model,
117 | num_classes=len(params["classes"]),
118 | num_channel=params["num_channel"],
119 | base_pretrained=args.base_pretrained,
120 | base_pretrained2=args.base_pretrained2,
121 | pretrained_path=args.pretrained,
122 | attention_branch=args.add_attention_branch,
123 | division_layer=args.div,
124 | theta_attention=args.theta_att,
125 | init_classifier=args.dataset != "ImageNet", # Use pretrained classifier in ImageNet
126 | )
127 | assert model is not None, "Model name is invalid"
128 |
129 | model.to(device)
130 | visualize(
131 | args.image_path,
132 | args.label,
133 | model,
134 | args.method,
135 | args.save_path,
136 | params,
137 | device,
138 | attention_dir=args.attention_dir,
139 | use_c1c=args.use_c1c,
140 | heat_quantization=args.heat_quantization,
141 | hq_level=args.hq_level,
142 | skip_connection_prop_type=args.skip_connection_prop_type,
143 | normalize=args.normalize,
144 | sign=args.sign,
145 | )
146 |
147 |
148 | def parse_args() -> argparse.Namespace:
149 | parser = argparse.ArgumentParser()
150 |
151 | parser.add_argument("-c", "--config", type=str, help="path to config file (json)")
152 |
153 | parser.add_argument("--seed", type=int, default=42)
154 | parser.add_argument("--no_deterministic", action="store_false")
155 |
156 | # Model
157 | parser.add_argument("-m", "--model", choices=ALL_MODELS, help="model name")
158 | parser.add_argument(
159 | "-add_ab",
160 | "--add_attention_branch",
161 | action="store_true",
162 | help="add Attention Branch",
163 | )
164 | parser.add_argument(
165 | "--div",
166 | type=str,
167 | choices=["layer1", "layer2", "layer3"],
168 | default="layer2",
169 | help="place to attention branch",
170 | )
171 | parser.add_argument("--base_pretrained", type=str, help="path to base pretrained")
172 | parser.add_argument(
173 | "--base_pretrained2",
174 | type=str,
175 | help="path to base pretrained2 ( after change_num_classes() )",
176 | )
177 | parser.add_argument("--pretrained", type=str, help="path to pretrained")
178 | parser.add_argument(
179 | "--orig_model",
180 | action="store_true",
181 | help="calc insdel score by using original model",
182 | )
183 | parser.add_argument(
184 | "--theta_att", type=float, default=0, help="threthold of attention branch"
185 | )
186 |
187 | # Target Image
188 | parser.add_argument("--image-path", type=str, help="path to target image")
189 | parser.add_argument("--label", type=int, help="label of target image")
190 |
191 | # Visualize option
192 | parser.add_argument("--normalize", action="store_true", help="normalize attribution")
193 | parser.add_argument("--sign", type=str, default="all", help="sign of attribution to show")
194 |
195 | # Save
196 | parser.add_argument("--save-path", type=str, help="path to save image")
197 |
198 | # Dataset
199 | parser.add_argument("--dataset", type=str, default="IDRiD", choices=ALL_DATASETS)
200 | parser.add_argument("--image_size", type=int, default=224)
201 | parser.add_argument("--batch_size", type=int, default=16)
202 | parser.add_argument(
203 | "--loss_weights",
204 | type=float,
205 | nargs="*",
206 | default=[1.0, 1.0],
207 | help="weights for label by class",
208 | )
209 |
210 | parser.add_argument("--root_dir", type=str, default="./outputs/")
211 | parser.add_argument("--visualize_only", action="store_false")
212 | parser.add_argument("--all_class", action="store_true")
213 | # recommend (step, size) in 512x512 = (1, 10000), (2, 2500), (4, 500), (8, 100), (16, 20), (32, 10), (64, 5), (128, 1)
214 | # recommend (step, size) in 224x224 = (1, 500), (2, 100), (4, 20), (8, 10), (16, 5), (32, 1)
215 | parser.add_argument("--insdel_step", type=int, default=500)
216 | parser.add_argument("--block_size", type=int, default=1)
217 |
218 | parser.add_argument(
219 | "--method",
220 | type=str,
221 | default="gradcam",
222 | )
223 |
224 | parser.add_argument("--num_masks", type=int, default=5000)
225 | parser.add_argument("--rise_scale", type=int, default=9)
226 | parser.add_argument(
227 | "--p1", type=float, default=0.3, help="percentage of mask [pixel = (0, 0, 0)]"
228 | )
229 |
230 | parser.add_argument("--attention_dir", type=str, help="path to attention npy file")
231 |
232 | parser.add_argument("--use-c1c", action="store_true", help="use C1C technique")
233 |
234 | parser.add_argument("--heat-quantization", action="store_true", help="use heat quantization technique")
235 | parser.add_argument("--hq-level", type=int, default=8, help="number of quantization level")
236 |
237 | parser.add_argument("--skip-connection-prop-type", type=str, default="latest", help="type of skip connection propagation")
238 |
239 | return parse_with_config(parser)
240 |
241 |
242 | if __name__ == "__main__":
243 | # import pdb; pdb.set_trace()
244 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
245 |
246 | main(parse_args())
247 |
--------------------------------------------------------------------------------
/optim/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Iterable
2 |
3 | import torch.optim as optim
4 |
5 |
6 | from optim.sam import SAM
7 |
8 | ALL_OPTIM = ["SGD", "Adam", "AdamW", "SAM"]
9 |
10 |
11 | def create_optimizer(
12 | optim_name: str,
13 | params: Iterable,
14 | lr: float,
15 | weight_decay: float = 0.9,
16 | momentum: float = 0.9,
17 | ) -> optim.Optimizer:
18 | """
19 | Create an optimizer
20 |
21 | Args:
22 | optim_name(str) : Name of the optimizer
23 | params(Iterable) : params
24 | lr(float) : Learning rate
25 | weight_decay(float): weight_decay
26 | momentum(float) : momentum
27 |
28 | Returns:
29 | Optimizer
30 | """
31 | assert optim_name in ALL_OPTIM
32 |
33 | if optim_name == "SGD":
34 | optimizer = optim.SGD(
35 | params, lr=lr, momentum=momentum, weight_decay=weight_decay
36 | )
37 | elif optim_name == "Adam":
38 | optimizer = optim.Adam(params, lr=lr, weight_decay=weight_decay)
39 | elif optim_name == "AdamW":
40 | optimizer = optim.AdamW(params, lr=lr, weight_decay=weight_decay)
41 | elif optim_name == "SAM":
42 | base_optimizer = optim.SGD
43 | optimizer = SAM(
44 | params, base_optimizer, lr=lr, weight_decay=weight_decay, momentum=momentum
45 | )
46 | else:
47 | raise ValueError
48 |
49 | return optimizer
50 |
--------------------------------------------------------------------------------
/optim/sam.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class SAM(torch.optim.Optimizer):
5 | def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
6 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
7 |
8 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
9 | super(SAM, self).__init__(params, defaults)
10 |
11 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
12 | self.param_groups = self.base_optimizer.param_groups
13 |
14 | @torch.no_grad()
15 | def first_step(self, zero_grad=False):
16 | grad_norm = self._grad_norm()
17 | for group in self.param_groups:
18 | scale = group["rho"] / (grad_norm + 1e-12)
19 |
20 | for p in group["params"]:
21 | if p.grad is None:
22 | continue
23 | self.state[p]["old_p"] = p.data.clone()
24 | e_w = (
25 | (torch.pow(p, 2) if group["adaptive"] else 1.0)
26 | * p.grad
27 | * scale.to(p)
28 | )
29 | p.add_(e_w) # climb to the local maximum "w + e(w)"
30 |
31 | if zero_grad:
32 | self.zero_grad()
33 |
34 | @torch.no_grad()
35 | def second_step(self, zero_grad=False):
36 | for group in self.param_groups:
37 | for p in group["params"]:
38 | if p.grad is None:
39 | continue
40 | p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)"
41 |
42 | self.base_optimizer.step() # do the actual "sharpness-aware" update
43 |
44 | if zero_grad:
45 | self.zero_grad()
46 |
47 | @torch.no_grad()
48 | def step(self, closure=None):
49 | assert (
50 | closure is not None
51 | ), "Sharpness Aware Minimization requires closure, but it was not provided"
52 | closure = torch.enable_grad()(
53 | closure
54 | ) # the closure should do a full forward-backward pass
55 |
56 | self.first_step(zero_grad=True)
57 | closure()
58 | self.second_step()
59 |
60 | def _grad_norm(self):
61 | shared_device = self.param_groups[0]["params"][
62 | 0
63 | ].device # put everything on the same device, in case of model parallelism
64 | norm = torch.norm(
65 | torch.stack(
66 | [
67 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad)
68 | .norm(p=2)
69 | .to(shared_device)
70 | for group in self.param_groups
71 | for p in group["params"]
72 | if p.grad is not None
73 | ]
74 | ),
75 | p=2,
76 | )
77 | return norm
78 |
79 | def load_state_dict(self, state_dict):
80 | super().load_state_dict(state_dict)
81 | self.base_optimizer.param_groups = self.param_groups
82 |
--------------------------------------------------------------------------------
/outputs/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/outputs/.gitkeep
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | # line-length = 119
3 | target-version = ['py39']
4 |
5 | [tool.mypy]
6 | python_version = "3.9"
7 |
8 | [tool.ruff]
9 | # Never enforce `E501` (line length violations).
10 | ignore = ["C901", "E501", "E741", "F405", "F403", "W605"]
11 | # select = ["C", "E", "F", "I", "W"]
12 |
13 | # # Ignore import violations in all `__init__.py` files.
14 | # [tool.ruff.per-file-ignores]
15 | # "__init__.py" = ["E402", "F401", "F403", "F811"]
16 |
17 | # [tool.ruff.isort]
18 | # lines-after-imports = 2
19 |
20 | [tool.poetry]
21 | name = "lrp-for-resnet"
22 | version = "0.1.0"
23 | description = ""
24 | authors = ["Foo "]
25 | readme = "README.md"
26 |
27 | [tool.poetry.dependencies]
28 | python = ">=3.9,<3.12"
29 | torch = "2.0.0"
30 | torchvision = "0.15.1"
31 | torchaudio = "2.0.1"
32 | tqdm = "^4.65.0"
33 | matplotlib = "^3.7.1"
34 | numpy = "^1.24.3"
35 | opencv-python = "^4.7.0.72"
36 | pillow = "^9.5.0"
37 | scikit-image = "^0.20.0"
38 | scikit-learn = "^1.2.2"
39 | torchinfo = "^1.8.0"
40 | wandb = "^0.15.3"
41 | timm = "^0.9.2"
42 | grad-cam = "^1.4.6"
43 | captum = "^0.6.0"
44 | plotly = "^5.19.0"
45 | kaleido = "0.2.1"
46 |
47 | [tool.poetry.group.dev.dependencies]
48 | ruff = "^0.0.269"
49 | black = "^23.3.0"
50 | mypy = "^1.3.0"
51 |
52 | [build-system]
53 | requires = ["poetry-core"]
54 | build-backend = "poetry.core.masonry.api"
55 |
--------------------------------------------------------------------------------
/qual/original/Arabian_camel.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/Arabian_camel.png
--------------------------------------------------------------------------------
/qual/original/Brandt_Cormorant.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/Brandt_Cormorant.png
--------------------------------------------------------------------------------
/qual/original/Geococcyx.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/Geococcyx.png
--------------------------------------------------------------------------------
/qual/original/Rock_Wren.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/Rock_Wren.png
--------------------------------------------------------------------------------
/qual/original/Savannah_Sparrow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/Savannah_Sparrow.png
--------------------------------------------------------------------------------
/qual/original/bee.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/bee.png
--------------------------------------------------------------------------------
/qual/original/bubble.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/bubble.png
--------------------------------------------------------------------------------
/qual/original/bustard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/bustard.png
--------------------------------------------------------------------------------
/qual/original/drumstick.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/drumstick.png
--------------------------------------------------------------------------------
/qual/original/oboe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/oboe.png
--------------------------------------------------------------------------------
/qual/original/ram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/ram.png
--------------------------------------------------------------------------------
/qual/original/sock.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/sock.png
--------------------------------------------------------------------------------
/qual/original/solar_dish.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/solar_dish.png
--------------------------------------------------------------------------------
/qual/original/water_ouzel.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/water_ouzel.png
--------------------------------------------------------------------------------
/qual/original/wombat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/qual/original/wombat.png
--------------------------------------------------------------------------------
/scripts/calc_dataset_info.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Tuple
3 |
4 | import torch
5 | import torch.utils.data as data
6 | from torchvision import transforms
7 | from tqdm import tqdm
8 |
9 | from data import ALL_DATASETS, create_dataset, get_generator, seed_worker
10 | from utils.utils import fix_seed
11 |
12 |
13 | def calc_mean_std(
14 | dataloader: data.DataLoader, image_size
15 | ) -> Tuple[torch.Tensor, torch.Tensor]:
16 | """
17 | Calculate the mean and variance of the image dataset
18 |
19 | Args:
20 | dataloader(DataLoader): DataLoader
21 | image_size(int) : Image size
22 |
23 | Returns:
24 | Mean and variance of each channel
25 | Tuple[torch.Tensor, torch.Tensor]
26 | """
27 |
28 | sum = torch.tensor([0.0, 0.0, 0.0])
29 | sum_square = torch.tensor([0.0, 0.0, 0.0])
30 | total = 0
31 |
32 | for inputs, _ in tqdm(dataloader, dynamic_ncols=True):
33 | inputs.to(device)
34 | sum += inputs.sum(axis=[0, 2, 3])
35 | sum_square += (inputs**2).sum(axis=[0, 2, 3])
36 | total += inputs.size(0)
37 |
38 | count = total * image_size * image_size
39 |
40 | total_mean = sum / count
41 | total_var = (sum_square / count) - (total_mean**2)
42 | total_std = torch.sqrt(total_var)
43 |
44 | return total_mean, total_std
45 |
46 |
47 | def main():
48 | fix_seed(args.seed, True)
49 |
50 | transform = transforms.Compose(
51 | [
52 | transforms.Resize((args.image_size, args.image_size)),
53 | transforms.ToTensor(),
54 | ]
55 | )
56 |
57 | dataset = create_dataset(args.dataset, "train", args.image_size, transform)
58 | dataloader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, worker_init_fn=seed_worker, generator=get_generator())
59 |
60 | mean, std = calc_mean_std(dataloader, args.image_size)
61 | print(f"mean: {mean}")
62 | print(f"std: {std}")
63 |
64 |
65 | if __name__ == "__main__":
66 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
67 |
68 | parser = argparse.ArgumentParser()
69 |
70 | parser.add_argument("--seed", default=42, type=int)
71 |
72 | parser.add_argument("--dataset", type=str, choices=ALL_DATASETS)
73 | parser.add_argument("--image_size", type=int, default=224)
74 | parser.add_argument("--batch_size", default=32, type=int)
75 |
76 | args = parser.parse_args()
77 |
78 | main()
79 |
--------------------------------------------------------------------------------
/scripts/visualize_transforms.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from typing import Tuple
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import torch
8 | from data import ALL_DATASETS, create_dataloader_dict, get_parameter_depend_in_data_set
9 | from utils.utils import fix_seed, reverse_normalize
10 |
11 |
12 | def save_normalized_image(
13 | image: np.ndarray,
14 | fname: str,
15 | mean: Tuple[float, float, float],
16 | std: Tuple[float, float, float],
17 | ) -> None:
18 | image = reverse_normalize(image.copy(), mean, std)
19 | image = np.transpose(image, (1, 2, 0))
20 |
21 | fig, ax = plt.subplots()
22 | ax.imshow(image)
23 |
24 | plt.savefig(fname)
25 | plt.clf()
26 | plt.close()
27 |
28 |
29 | def main():
30 | fix_seed(args.seed, True)
31 |
32 | dataloader = create_dataloader_dict(args.dataset, 1, args.image_size)
33 | params = get_parameter_depend_in_data_set(args.dataset)
34 |
35 | save_dir = os.path.join(args.root_dir, "transforms")
36 | if not os.path.isdir(save_dir):
37 | os.makedirs(save_dir)
38 |
39 | for phase, inputs in dataloader.items():
40 | if not phase == "Train":
41 | continue
42 |
43 | for i, data in enumerate(inputs):
44 | image = data[0].cpu().numpy()[0]
45 | save_fname = os.path.join(save_dir, f"{i}.png")
46 | save_normalized_image(image, save_fname, params["mean"], params["std"])
47 |
48 |
49 | if __name__ == "__main__":
50 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
51 |
52 | parser = argparse.ArgumentParser()
53 |
54 | parser.add_argument("--seed", default=42, type=int)
55 |
56 | parser.add_argument("--dataset", type=str, default="IDRiD", choices=ALL_DATASETS)
57 | parser.add_argument("--image_size", type=int, default=224)
58 |
59 | parser.add_argument("--root_dir", type=str, default="./outputs/")
60 |
61 | args = parser.parse_args()
62 |
63 | main()
64 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keio-smilab24/LRP-for-ResNet/cadd126f7bf8b3fee32a748eca14189a088ea373/src/__init__.py
--------------------------------------------------------------------------------
/src/data.py:
--------------------------------------------------------------------------------
1 | """Script holds data loader methods.
2 | """
3 | import argparse
4 |
5 | import torch
6 | import torchvision
7 | from torch.utils import data
8 |
9 | from data import get_generator, seed_worker
10 |
11 |
12 | def get_data_loader(config: argparse.Namespace) -> torch.utils.data.DataLoader:
13 | """Creates dataloader for networks from PyTorch's Model Zoo.
14 |
15 | Data loader uses mean and standard deviation for ImageNet.
16 |
17 | Args:
18 | config: Argparse namespace object.
19 |
20 | Returns:
21 | Data loader object.
22 |
23 | """
24 | input_dir = config.input_dir
25 | batch_size = config.batch_size
26 |
27 | mean = [0.485, 0.456, 0.406]
28 | std = [0.229, 0.224, 0.225]
29 |
30 | transforms = []
31 |
32 | if config.resize:
33 | transforms += [
34 | torchvision.transforms.Resize(size=int(1.1 * config.resize)),
35 | torchvision.transforms.CenterCrop(size=config.resize),
36 | ]
37 |
38 | transforms += [
39 | torchvision.transforms.ToTensor(),
40 | torchvision.transforms.Normalize(mean, std),
41 | ]
42 |
43 | transform = torchvision.transforms.Compose(transforms=transforms)
44 | dataset = torchvision.datasets.ImageFolder(root=input_dir, transform=transform)
45 |
46 | data_loader = data.DataLoader(dataset, batch_size=batch_size, worker_init_fn=seed_worker, generator=get_generator())
47 |
48 | return data_loader
49 |
--------------------------------------------------------------------------------
/src/data_processing.py:
--------------------------------------------------------------------------------
1 | """Script with method for pre- and post-processing."""
2 | import argparse
3 |
4 | import cv2
5 | import numpy
6 | import torch
7 | import torchvision.transforms
8 |
9 |
10 | class DataProcessing:
11 | def __init__(self, config: argparse.Namespace, device: torch.device) -> None:
12 | """Initializes data processing class."""
13 |
14 | mean = [0.485, 0.456, 0.406]
15 | std = [0.229, 0.224, 0.225]
16 |
17 | transforms = [
18 | torchvision.transforms.ToPILImage(),
19 | ]
20 |
21 | if config.resize:
22 | transforms += [
23 | torchvision.transforms.Resize(size=int(1.1 * config.resize)),
24 | torchvision.transforms.CenterCrop(size=config.resize),
25 | ]
26 |
27 | transforms += [
28 | torchvision.transforms.ToTensor(),
29 | torchvision.transforms.Normalize(mean, std),
30 | ]
31 |
32 | self.transform = torchvision.transforms.Compose(transforms=transforms)
33 | self.device = device
34 |
35 | def preprocess(self, frame: numpy.ndarray) -> torch.Tensor:
36 | """Preprocesses frame captured by webcam."""
37 | return self.transform(frame).to(self.device)[None, ...]
38 |
39 | def postprocess(self, relevance_scores: torch.Tensor):
40 | """Normalizes relevance scores and applies colormap."""
41 | relevance_scores = relevance_scores.numpy()
42 | relevance_scores = cv2.normalize(
43 | relevance_scores, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8UC1
44 | )
45 | relevance_scores = cv2.applyColorMap(relevance_scores, cv2.COLORMAP_HOT)
46 | return relevance_scores
47 |
--------------------------------------------------------------------------------
/src/lrp.py:
--------------------------------------------------------------------------------
1 | """Class for layer-wise relevance propagation.
2 |
3 | Layer-wise relevance propagation for VGG-like networks from PyTorch's Model Zoo.
4 | Implementation can be adapted to work with other architectures as well by adding the corresponding operations.
5 |
6 | Typical usage example:
7 |
8 | model = torchvision.models.vgg16(pretrained=True)
9 | lrp_model = LRPModel(model)
10 | r = lrp_model.forward(x)
11 |
12 | """
13 | from copy import deepcopy
14 | from typing import Union
15 |
16 | import torch
17 | from torch import nn
18 | from torchvision.models.resnet import BasicBlock, Bottleneck
19 | from torchvision.models.vgg import VGG
20 |
21 | from models import OneWayResNet
22 | from src.utils import SkipConnectionPropType, layers_lookup
23 |
24 |
25 | class LRPModel(nn.Module):
26 | """Class wraps PyTorch model to perform layer-wise relevance propagation."""
27 |
28 | def __init__(self, model: torch.nn.Module, rel_pass_ratio: float = 0.0, skip_connection_prop="latest") -> None:
29 | super().__init__()
30 | self.model: Union[VGG, OneWayResNet] = model
31 | self.rel_pass_ratio = rel_pass_ratio
32 | self.skip_connection_prop = skip_connection_prop
33 |
34 | self.model.eval() # self.model.train() activates dropout / batch normalization etc.!
35 |
36 | # Parse network (architecture must be based on VGG...)
37 | self.layers = self._get_layer_operations()
38 |
39 | # Create LRP network
40 | self.lrp_layers = self._create_lrp_model()
41 |
42 | def _create_lrp_model(self) -> torch.nn.ModuleList:
43 | """Method builds the model for layer-wise relevance propagation.
44 |
45 | Returns:
46 | LRP-model as module list.
47 |
48 | """
49 | # Clone layers from original model. This is necessary as we might modify the weights.
50 | layers = deepcopy(self.layers)
51 | lookup_table = layers_lookup(self.skip_connection_prop)
52 |
53 | # Run backwards through layers
54 | for i, layer in enumerate(layers[::-1]):
55 | try:
56 | layers[i] = lookup_table[layer.__class__](layer=layer, top_k=self.rel_pass_ratio)
57 | except KeyError:
58 | message = (
59 | f"Layer-wise relevance propagation not implemented for "
60 | f"{layer.__class__.__name__} layer."
61 | )
62 | raise NotImplementedError(message)
63 |
64 | return layers
65 |
66 | def _get_layer_operations(self) -> torch.nn.ModuleList:
67 | """Get all network operations and store them in a list.
68 |
69 | This method is adapted to VGG networks from PyTorch's Model Zoo.
70 | Modify this method to work also for other networks.
71 |
72 | Returns:
73 | Layers of original model stored in module list.
74 |
75 | """
76 | layers = torch.nn.ModuleList()
77 |
78 | # Parse VGG, OneWayResNet
79 | for layer in self.model.features:
80 | is_resnet_tower = isinstance(layer, nn.Sequential) and (isinstance(layer[0], BasicBlock) or isinstance(layer[0], Bottleneck))
81 | if is_resnet_tower:
82 | for sub_layer in layer:
83 | assert isinstance(sub_layer, BasicBlock) or isinstance(sub_layer, Bottleneck)
84 | layers.append(sub_layer)
85 | else:
86 | layers.append(layer)
87 |
88 | layers.append(self.model.avgpool)
89 | layers.append(torch.nn.Flatten(start_dim=1))
90 |
91 | for layer in self.model.classifier:
92 | layers.append(layer)
93 |
94 | return layers
95 |
96 | def forward(self, x: torch.tensor, topk=-1) -> torch.tensor:
97 | """Forward method that first performs standard inference followed by layer-wise relevance propagation.
98 |
99 | Args:
100 | x: Input tensor representing an image / images (N, C, H, W).
101 |
102 | Returns:
103 | Tensor holding relevance scores with dimensions (N, 1, H, W).
104 |
105 | """
106 | activations = list()
107 |
108 | # Run inference and collect activations.
109 | with torch.no_grad():
110 | # Replace image with ones avoids using image information for relevance computation.
111 | activations.append(torch.ones_like(x))
112 | for layer in self.layers:
113 | x = layer.forward(x)
114 | activations.append(x)
115 |
116 | # Reverse order of activations to run backwards through model
117 | activations = activations[::-1]
118 | activations = [a.data.requires_grad_(True) for a in activations]
119 |
120 | # Initial relevance scores are the network's output activations
121 | relevance = torch.softmax(activations.pop(0), dim=-1) # Unsupervised
122 | if topk != -1:
123 | relevance_zero = torch.zeros_like(relevance)
124 | top_k_indices = torch.topk(relevance, topk).indices
125 | for index in top_k_indices:
126 | relevance_zero[..., index] = relevance[..., index]
127 | relevance = relevance_zero
128 |
129 | # Perform relevance propagation
130 | for i, layer in enumerate(self.lrp_layers):
131 | a = activations.pop(0)
132 | try:
133 | relevance = layer.forward(a, relevance)
134 | except RuntimeError:
135 | print(f"RuntimeError at layer {i}.\n"
136 | f"Layer: {layer.__class__.__name__}\n"
137 | f"Relevance shape: {relevance.shape}\n"
138 | f"Activation shape: {activations[0].shape}\n")
139 | exit(1)
140 |
141 | # # disturb
142 | # disturbance = (torch.rand(relevance.shape).to(relevance.device) - 0.5)
143 | # relevance = relevance + disturbance * 5e-4
144 |
145 | ### patch for debug and/or analysis ###
146 | # Escape to see relevance scores at critical layers
147 | # Uncomment the following lines to see relevance scores at critical layers
148 | #
149 | # if i == 3: # stop at 3, 7, 11, 15, 18 (1st, 5th, 9th, 13th, 16th last bottleneck block)
150 | # break
151 | #
152 | ### value of i is determined manually by checking each layer with the following code
153 | # ```python
154 | # print(len(self.lrp_layers)) # 23 for ResNet50
155 | # for i, layer in enumerate(self.lrp_layers):
156 | # print(f"{i}:\n {layer.__class__.__name__}")
157 | # ```
158 | # RelevancePropagationBottleneckFlowsPureSkip x16 (i = 3-18)
159 |
160 | return relevance.permute(0, 2, 3, 1).sum(dim=-1).squeeze().detach().cpu()
161 |
162 |
163 | # legacy code
164 | class LRPModules(nn.Module):
165 | """Class wraps PyTorch model to perform layer-wise relevance propagation."""
166 |
167 | def __init__(
168 | self, layers: nn.ModuleList, out_relevance: torch.Tensor, top_k: float = 0.0
169 | ) -> None:
170 | super().__init__()
171 | self.top_k = top_k
172 |
173 | # Parse network
174 | self.layers = layers
175 | self.out_relevance = out_relevance
176 |
177 | # Create LRP network
178 | self.lrp_layers = self._create_lrp_model()
179 |
180 | def _create_lrp_model(self) -> nn.ModuleList:
181 | """Method builds the model for layer-wise relevance propagation.
182 |
183 | Returns:
184 | LRP-model as module list.
185 |
186 | """
187 | # Clone layers from original model. This is necessary as we might modify the weights.
188 | layers = deepcopy(self.layers)
189 | lookup_table = layers_lookup()
190 |
191 | # Run backwards through layers
192 | for i, layer in enumerate(layers[::-1]):
193 | try:
194 | layers[i] = lookup_table[layer.__class__](layer=layer, top_k=self.top_k)
195 | except KeyError:
196 | message = (
197 | f"Layer-wise relevance propagation not implemented for "
198 | f"{layer.__class__.__name__} layer."
199 | )
200 | raise NotImplementedError(message)
201 |
202 | return layers
203 |
204 | def forward(self, x: torch.tensor) -> torch.tensor:
205 | """Forward method that first performs standard inference followed by layer-wise relevance propagation.
206 |
207 | Args:
208 | x: Input tensor representing an image / images (N, C, H, W).
209 |
210 | Returns:
211 | Tensor holding relevance scores with dimensions (N, 1, H, W).
212 |
213 | """
214 | activations = list()
215 |
216 | # Run inference and collect activations.
217 | with torch.no_grad():
218 | # Replace image with ones avoids using image information for relevance computation.
219 | activations.append(torch.ones_like(x))
220 | for layer in self.layers:
221 | x = layer.forward(x)
222 | activations.append(x)
223 |
224 | # Reverse order of activations to run backwards through model
225 | activations = activations[::-1]
226 | activations = [a.data.requires_grad_(True) for a in activations]
227 |
228 | # Initial relevance scores are the network's output activations
229 | relevance = torch.softmax(activations.pop(0), dim=-1) # Unsupervised
230 | if self.out_relevance is not None:
231 | relevance = self.out_relevance.to(relevance.device)
232 |
233 | # Perform relevance propagation
234 | for i, layer in enumerate(self.lrp_layers):
235 | relevance = layer.forward(activations.pop(0), relevance)
236 |
237 | return relevance.permute(0, 2, 3, 1).sum(dim=-1).squeeze().detach().cpu()
238 |
239 |
240 | def basic_lrp(
241 | model, image, rel_pass_ratio=1.0, topk=1, skip_connection_prop: SkipConnectionPropType = "latest"
242 | ):
243 | lrp_model = LRPModel(model, rel_pass_ratio=rel_pass_ratio, skip_connection_prop=skip_connection_prop)
244 | R = lrp_model.forward(image, topk)
245 | return R
246 |
247 |
248 | # Legacy code -----------------------
249 | def resnet_lrp(model, image, topk=0.2):
250 | output = model(image)
251 | score, class_index = torch.max(output, 1)
252 | R = torch.zeros_like(output)
253 | R[0, class_index] = score
254 |
255 | post_modules = divide_module_by_name(model, "avgpool")
256 | new_post = post_modules[:-1]
257 | new_post.append(torch.nn.Flatten(start_dim=1))
258 | new_post.append(post_modules[-1])
259 | post_modules = new_post
260 |
261 | post_lrp = LRPModules(post_modules, R, top_k=topk)
262 | R = post_lrp.forward(post_modules[0].activations)
263 |
264 | R = resnet_layer_lrp(model.layer4, R, top_k=topk)
265 | R = resnet_layer_lrp(model.layer3, R, top_k=topk)
266 | R = resnet_layer_lrp(model.layer2, R, top_k=topk)
267 | R = resnet_layer_lrp(model.layer1, R, top_k=topk)
268 |
269 | pre_modules = divide_module_by_name(model, "layer1", before_module=True)
270 | pre_lrp = LRPModules(pre_modules, R, top_k=topk)
271 | R = pre_lrp.forward(image)
272 |
273 | return R
274 |
275 |
276 | def abn_lrp(model, image, topk=0.2):
277 | output = model(image)
278 | score, class_index = torch.max(output, 1)
279 | R = torch.zeros_like(output)
280 | R[0, class_index] = score
281 |
282 | #########################
283 | ### Perception Branch ###
284 | #########################
285 | post_modules = nn.ModuleList(
286 | [
287 | model.perception_branch[2],
288 | model.perception_branch[3],
289 | model.perception_branch[4],
290 | ]
291 | )
292 | new_post = post_modules[:-1]
293 | new_post.append(torch.nn.Flatten(start_dim=1))
294 | new_post.append(post_modules[-1])
295 | post_modules = new_post
296 |
297 | post_lrp = LRPModules(post_modules, R, top_k=topk)
298 | R_pb = post_lrp.forward(post_modules[0].activations)
299 |
300 | for sequential_blocks in model.perception_branch[:2][::-1]:
301 | R_pb = resnet_layer_lrp(sequential_blocks, R_pb, topk)
302 |
303 | #########################
304 | ### Attention Branch ###
305 | #########################
306 | # h -> layer1, bn1, conv1, relu, conv4, avgpool, flatten
307 | ab_modules = nn.ModuleList(
308 | [
309 | model.attention_branch.bn1,
310 | model.attention_branch.conv1,
311 | model.attention_branch.relu,
312 | model.attention_branch.conv4,
313 | model.attention_branch.avgpool,
314 | model.attention_branch.flatten,
315 | ]
316 | )
317 | ab_lrp = LRPModules(ab_modules, R, top_k=topk)
318 | R_ab = ab_lrp.forward(model.attention_branch.bn1_activation)
319 | R_ab = resnet_layer_lrp(model.attention_branch.layer1, R_ab, topk)
320 |
321 | #########################
322 | ### Feature Extractor ###
323 | #########################
324 | R_fe_out = R_pb + R_ab
325 | R = resnet_layer_lrp(model.feature_extractor[-1], R_fe_out, topk)
326 | R = resnet_layer_lrp(model.feature_extractor[-2], R, topk)
327 |
328 | pre_modules = nn.ModuleList(
329 | [
330 | model.feature_extractor[0],
331 | model.feature_extractor[1],
332 | model.feature_extractor[2],
333 | model.feature_extractor[3],
334 | ]
335 | )
336 | pre_lrp = LRPModules(pre_modules, R, top_k=topk)
337 | R = pre_lrp.forward(image)
338 |
339 | return R
340 |
341 |
342 | def resnet_layer_lrp(
343 | layer: nn.Sequential, out_relevance: torch.Tensor, top_k: float = 0.0
344 | ):
345 | for res_block in layer[::-1]:
346 | inputs = res_block.activations
347 |
348 | identify = out_relevance
349 | if res_block.downsample is not None:
350 | downsample = nn.ModuleList(
351 | [res_block.downsample[0], res_block.downsample[1]]
352 | )
353 | skip_lrp = LRPModules(downsample, identify, top_k=top_k)
354 | skip_relevance = skip_lrp.forward(inputs)
355 | else:
356 | skip_relevance = identify
357 |
358 | main_modules = nn.ModuleList()
359 | for name, module in res_block._modules.items():
360 | if name == "downsample":
361 | continue
362 | main_modules.append(module)
363 | main_lrp = LRPModules(main_modules, identify, top_k=top_k)
364 | main_relevance = main_lrp.forward(inputs)
365 |
366 | gamma = 0.5
367 | out_relevance = gamma * main_relevance + (1 - gamma) * skip_relevance
368 | return out_relevance
369 |
370 |
371 | def divide_module_by_name(model, module_name: str, before_module: bool = False):
372 | use_module = before_module
373 | modules = nn.ModuleList()
374 | for name, module in model._modules.items():
375 | if name == module_name:
376 | use_module = not use_module
377 | if not use_module:
378 | continue
379 | modules.append(module)
380 |
381 | return modules
382 |
--------------------------------------------------------------------------------
/src/lrp_filter.py:
--------------------------------------------------------------------------------
1 | """Implements filter method for relevance scores.
2 | """
3 | import torch
4 |
5 |
6 | def relevance_filter(r: torch.tensor, top_k_percent: float = 1.0) -> torch.tensor:
7 | """Filter that allows largest k percent values to pass for each batch dimension.
8 |
9 | Filter keeps k% of the largest tensor elements. Other tensor elements are set to
10 | zero. Here, k = 1 means that all relevance scores are passed on to the next layer.
11 |
12 | Args:
13 | r: Tensor holding relevance scores of current layer.
14 | top_k_percent: Proportion of top k values that is passed on.
15 |
16 | Returns:
17 | Tensor of same shape as input tensor.
18 |
19 | """
20 | assert 0.0 < top_k_percent <= 1.0
21 |
22 | if top_k_percent < 1.0:
23 | size = r.size()
24 | r = r.flatten(start_dim=1)
25 | num_elements = r.size(-1)
26 | k = max(1, int(top_k_percent * num_elements))
27 | top_k = torch.topk(input=r, k=k, dim=-1)
28 | r = torch.zeros_like(r)
29 | r.scatter_(dim=1, index=top_k.indices, src=top_k.values)
30 | return r.view(size)
31 | else:
32 | return r
33 |
--------------------------------------------------------------------------------
/src/lrp_layers.py:
--------------------------------------------------------------------------------
1 | """Layers for layer-wise relevance propagation.
2 |
3 | Layers for layer-wise relevance propagation can be modified.
4 |
5 | """
6 | import torch
7 | from torch import nn
8 | from torchvision.models.resnet import BasicBlock, Bottleneck
9 |
10 | from src.lrp_filter import relevance_filter
11 |
12 |
13 | class RelevancePropagationBasicBlock(nn.Module):
14 | def __init__(
15 | self,
16 | layer: BasicBlock,
17 | eps: float = 1.0e-05,
18 | top_k: float = 0.0,
19 | ) -> None:
20 | super().__init__()
21 | self.layers = [
22 | layer.conv1,
23 | layer.bn1,
24 | layer.relu,
25 | layer.conv2,
26 | layer.bn2,
27 | ]
28 | self.downsample = layer.downsample
29 | self.relu = layer.relu
30 | self.eps = eps
31 | self.top_k = top_k
32 |
33 | def _calc_mainstream_flow_ratio(self, input_: torch.Tensor) -> torch.Tensor:
34 | if self.downsample is None:
35 | return torch.ones_like(input_)
36 | mainstream = input_
37 | shortcut = input_
38 | for layer in self.layers:
39 | mainstream = layer(mainstream)
40 | if self.downsample is not None:
41 | shortcut = self.downsample(shortcut)
42 | assert mainstream.shape == shortcut.shape
43 | return mainstream.abs() / (shortcut.abs() + mainstream.abs())
44 |
45 | def mainstream_backward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
46 | with torch.no_grad():
47 | activations = [a]
48 | for layer in self.layers:
49 | activations.append(layer.forward(activations[-1]))
50 |
51 | activations.pop() # ignore output of this basic block
52 | activations = [a.data.requires_grad_(True) for a in activations]
53 |
54 | # NOW, IGNORES DOWN-SAMPLING & SKIP CONNECTION
55 | r_out = r
56 | for layer in self.layers[::-1]:
57 | a = activations.pop()
58 | if self.top_k:
59 | r_out = relevance_filter(r_out, top_k_percent=self.top_k)
60 |
61 | if isinstance(layer, nn.Conv2d):
62 | r_in = RelevancePropagationConv2d(layer, eps=self.eps, top_k=self.top_k)(
63 | a, r_out
64 | )
65 | elif isinstance(layer, nn.BatchNorm2d):
66 | r_in = RelevancePropagationBatchNorm2d(layer, top_k=self.top_k)(a, r_out)
67 | elif isinstance(layer, nn.ReLU):
68 | r_in = RelevancePropagationReLU(layer, top_k=self.top_k)(a, r_out)
69 | else:
70 | raise NotImplementedError
71 | r_out = r_in
72 | return r_in
73 |
74 | def shortcut_backward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
75 | if self.downsample is None:
76 | return r
77 | a = a.data.requires_grad_(True)
78 | assert isinstance(self.downsample[0], nn.Conv2d)
79 | return RelevancePropagationConv2d(self.downsample[0], eps=self.eps, top_k=self.top_k)(a, r)
80 |
81 | def forward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
82 | ratio = self._calc_mainstream_flow_ratio(a)
83 | assert r.shape == ratio.shape
84 | r_mainstream = ratio * r
85 | r_shortcut = (1 - ratio) * r
86 | r_mainstream = self.mainstream_backward(a, r_mainstream)
87 | r_shortcut = self.shortcut_backward(a, r_shortcut)
88 | return r_mainstream + r_shortcut
89 |
90 |
91 | class RelevancePropagationBasicBlockSimple(RelevancePropagationBasicBlock):
92 | """ Relevance propagation for BasicBlock Proto A
93 | Divide relevance score for plain shortcuts
94 | """
95 | def __init__(self, layer: BasicBlock, eps: float = 1.0e-05, top_k: float = 0.0) -> None:
96 | super().__init__(layer=layer, eps=eps, top_k=top_k)
97 |
98 | def _calc_mainstream_flow_ratio(self, input_: torch.Tensor) -> torch.Tensor:
99 | if self.downsample is None:
100 | return torch.ones_like(input_)
101 | return torch.full_like(self.downsample(input_), 0.5)
102 |
103 |
104 | class RelevancePropagationBasicBlockFlowsPureSkip(RelevancePropagationBasicBlock):
105 | """ Relevance propagation for BasicBlock Proto A
106 | Divide relevance score for plain shortcuts
107 | """
108 | def __init__(self, layer: BasicBlock, eps: float = 1.0e-05, top_k: float = 0.0) -> None:
109 | super().__init__(layer=layer, eps=eps, top_k=top_k)
110 |
111 | def _calc_mainstream_flow_ratio(self, input_: torch.Tensor) -> torch.Tensor:
112 | mainstream = input_
113 | shortcut = input_
114 | for layer in self.layers:
115 | mainstream = layer(mainstream)
116 | if self.downsample is not None:
117 | shortcut = self.downsample(shortcut)
118 | assert mainstream.shape == shortcut.shape
119 | return mainstream.abs() / (shortcut.abs() + mainstream.abs())
120 |
121 |
122 | class RelevancePropagationBasicBlockSimpleFlowsPureSkip(RelevancePropagationBasicBlock):
123 | """ Relevance propagation for BasicBlock Proto A
124 | Divide relevance score for plain shortcuts
125 | """
126 | def __init__(self, layer: BasicBlock, eps: float = 1.0e-05, top_k: float = 0.0) -> None:
127 | super().__init__(layer=layer, eps=eps, top_k=top_k)
128 |
129 | def _calc_mainstream_flow_ratio(self, input_: torch.Tensor) -> torch.Tensor:
130 | if self.downsample is None:
131 | return torch.full_like(input_, 0.5)
132 | return torch.full_like(self.downsample(input_), 0.5)
133 |
134 |
135 | class RelevancePropagationBottleneck(nn.Module):
136 | def __init__(
137 | self,
138 | layer: Bottleneck,
139 | eps: float = 1.0e-05,
140 | top_k: float = 0.0,
141 | ) -> None:
142 | super().__init__()
143 | self.layers = [
144 | layer.conv1,
145 | layer.bn1,
146 | layer.relu,
147 | layer.conv2,
148 | layer.bn2,
149 | layer.relu,
150 | layer.conv3,
151 | layer.bn3,
152 | ]
153 | self.downsample = layer.downsample
154 | self.relu = layer.relu
155 | self.eps = eps
156 | self.top_k = top_k
157 |
158 | def _calc_mainstream_flow_ratio(self, input_: torch.Tensor) -> torch.Tensor:
159 | if self.downsample is None:
160 | return torch.ones_like(input_)
161 | mainstream = input_
162 | shortcut = input_
163 | for layer in self.layers:
164 | mainstream = layer(mainstream)
165 | if self.downsample is not None:
166 | shortcut = self.downsample(shortcut)
167 | assert mainstream.shape == shortcut.shape
168 | return mainstream.abs() / (shortcut.abs() + mainstream.abs())
169 |
170 | def mainstream_backward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
171 | with torch.no_grad():
172 | activations = [a]
173 | for layer in self.layers:
174 | activations.append(layer.forward(activations[-1]))
175 |
176 | activations.pop() # ignore output of this bottleneck block
177 | activations = [a.data.requires_grad_(True) for a in activations]
178 |
179 | # NOW, IGNORES DOWN-SAMPLING & SKIP CONNECTION
180 | r_out = r
181 | for layer in self.layers[::-1]:
182 | a = activations.pop()
183 | if self.top_k:
184 | r_out = relevance_filter(r_out, top_k_percent=self.top_k)
185 |
186 | if isinstance(layer, nn.Conv2d):
187 | r_in = RelevancePropagationConv2d(layer, eps=self.eps, top_k=self.top_k)(
188 | a, r_out
189 | )
190 | elif isinstance(layer, nn.BatchNorm2d):
191 | r_in = RelevancePropagationBatchNorm2d(layer, top_k=self.top_k)(a, r_out)
192 | elif isinstance(layer, nn.ReLU):
193 | r_in = RelevancePropagationReLU(layer, top_k=self.top_k)(a, r_out)
194 | else:
195 | raise NotImplementedError
196 | r_out = r_in
197 | return r_in
198 |
199 | def shortcut_backward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
200 | if self.downsample is None:
201 | return r
202 | a = a.data.requires_grad_(True)
203 | assert isinstance(self.downsample[0], nn.Conv2d)
204 | return RelevancePropagationConv2d(self.downsample[0], eps=self.eps, top_k=self.top_k)(a, r)
205 |
206 | def forward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
207 | ratio = self._calc_mainstream_flow_ratio(a)
208 | assert r.shape == ratio.shape
209 | r_mainstream = ratio * r
210 | r_shortcut = (1 - ratio) * r
211 | r_mainstream = self.mainstream_backward(a, r_mainstream)
212 | r_shortcut = self.shortcut_backward(a, r_shortcut)
213 | return r_mainstream + r_shortcut
214 |
215 |
216 | class RelevancePropagationBottleneckSimple(RelevancePropagationBottleneck):
217 | """ Relevance propagation for Bottleneck Proto A
218 | Divide relevance score for plain shortcuts
219 | """
220 | def __init__(self, layer: Bottleneck, eps: float = 1.0e-05, top_k: float = 0.0) -> None:
221 | super().__init__(layer=layer, eps=eps, top_k=top_k)
222 |
223 | def _calc_mainstream_flow_ratio(self, input_: torch.Tensor) -> torch.Tensor:
224 | if self.downsample is None:
225 | return torch.ones_like(input_)
226 | return torch.full_like(self.downsample(input_), 0.5)
227 |
228 |
229 | class RelevancePropagationBottleneckFlowsPureSkip(RelevancePropagationBottleneck):
230 | """ Relevance propagation for Bottleneck Proto A
231 | Divide relevance score for plain shortcuts
232 | """
233 | def __init__(self, layer: Bottleneck, eps: float = 1.0e-05, top_k: float = 0.0) -> None:
234 | super().__init__(layer=layer, eps=eps, top_k=top_k)
235 |
236 | def _calc_mainstream_flow_ratio(self, input_: torch.Tensor) -> torch.Tensor:
237 | mainstream = input_
238 | shortcut = input_
239 | for layer in self.layers:
240 | mainstream = layer(mainstream)
241 | if self.downsample is not None:
242 | shortcut = self.downsample(shortcut)
243 | assert mainstream.shape == shortcut.shape
244 | return mainstream.abs() / (shortcut.abs() + mainstream.abs())
245 |
246 |
247 | class RelevancePropagationBottleneckSimpleFlowsPureSkip(RelevancePropagationBottleneck):
248 | """ Relevance propagation for Bottleneck Proto A
249 | Divide relevance score for plain shortcuts
250 | """
251 | def __init__(self, layer: Bottleneck, eps: float = 1.0e-05, top_k: float = 0.0) -> None:
252 | super().__init__(layer=layer, eps=eps, top_k=top_k)
253 |
254 | def _calc_mainstream_flow_ratio(self, input_: torch.Tensor) -> torch.Tensor:
255 | if self.downsample is None:
256 | return torch.full_like(input_, 0.5)
257 | return torch.full_like(self.downsample(input_), 0.5)
258 |
259 |
260 | class RelevancePropagationAdaptiveAvgPool2d(nn.Module):
261 | """Layer-wise relevance propagation for 2D adaptive average pooling.
262 |
263 | Attributes:
264 | layer: 2D adaptive average pooling layer.
265 | eps: A value added to the denominator for numerical stability.
266 |
267 | """
268 |
269 | def __init__(
270 | self,
271 | layer: torch.nn.AdaptiveAvgPool2d,
272 | eps: float = 1.0e-05,
273 | top_k: float = 0.0,
274 | ) -> None:
275 | super().__init__()
276 | self.layer = layer
277 | self.eps = eps
278 | self.top_k = top_k
279 |
280 | def forward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
281 | if self.top_k:
282 | r = relevance_filter(r, top_k_percent=self.top_k)
283 | z = self.layer.forward(a) + self.eps
284 | s = (r / z).data
285 | (z * s).sum().backward()
286 | c = a.grad
287 | r = (a * c).data
288 | return r
289 |
290 |
291 | class RelevancePropagationAvgPool2d(nn.Module):
292 | """Layer-wise relevance propagation for 2D average pooling.
293 |
294 | Attributes:
295 | layer: 2D average pooling layer.
296 | eps: A value added to the denominator for numerical stability.
297 |
298 | """
299 |
300 | def __init__(
301 | self, layer: torch.nn.AvgPool2d, eps: float = 1.0e-05, top_k: float = 0.0
302 | ) -> None:
303 | super().__init__()
304 | self.layer = layer
305 | self.eps = eps
306 | self.top_k = top_k
307 |
308 | def forward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
309 | if self.top_k:
310 | r = relevance_filter(r, top_k_percent=self.top_k)
311 | z = self.layer.forward(a) + self.eps
312 | s = (r / z).data
313 | (z * s).sum().backward()
314 | c = a.grad
315 | r = (a * c).data
316 | return r
317 |
318 |
319 | class RelevancePropagationMaxPool2d(nn.Module):
320 | """Layer-wise relevance propagation for 2D max pooling.
321 |
322 | Optionally substitutes max pooling by average pooling layers.
323 |
324 | Attributes:
325 | layer: 2D max pooling layer.
326 | eps: a value added to the denominator for numerical stability.
327 |
328 | """
329 |
330 | def __init__(
331 | self,
332 | layer: torch.nn.MaxPool2d,
333 | mode: str = "avg",
334 | eps: float = 1.0e-05,
335 | top_k: float = 0.0,
336 | ) -> None:
337 | super().__init__()
338 |
339 | if mode == "avg":
340 | self.layer = torch.nn.AvgPool2d(kernel_size=(2, 2))
341 | elif mode == "max":
342 | self.layer = layer
343 |
344 | self.eps = eps
345 | self.top_k = top_k
346 |
347 | def forward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
348 | if self.top_k:
349 | r = relevance_filter(r, top_k_percent=self.top_k)
350 | z = self.layer.forward(a) + self.eps
351 | s = (r / z).data
352 | (z * s).sum().backward()
353 | c = a.grad
354 | r = (a * c).data
355 | # print(f"maxpool2d {r.min()}, {r.max()}")
356 | return r
357 |
358 |
359 | class RelevancePropagationConv2d(nn.Module):
360 | """Layer-wise relevance propagation for 2D convolution.
361 |
362 | Optionally modifies layer weights according to propagation rule. Here z^+-rule
363 |
364 | Attributes:
365 | layer: 2D convolutional layer.
366 | eps: a value added to the denominator for numerical stability.
367 |
368 | """
369 |
370 | def __init__(
371 | self,
372 | layer: torch.nn.Conv2d,
373 | mode: str = "z_plus",
374 | eps: float = 1.0e-05,
375 | top_k: float = 0.0,
376 | ) -> None:
377 | super().__init__()
378 |
379 | self.layer = layer
380 |
381 | if mode == "z_plus":
382 | self.layer.weight = torch.nn.Parameter(self.layer.weight.clamp(min=0.0))
383 |
384 | self.eps = eps
385 | self.top_k = top_k
386 |
387 | def forward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
388 | if self.top_k:
389 | r = relevance_filter(r, top_k_percent=self.top_k)
390 | z = self.layer.forward(a) + self.eps
391 | s = (r / z).data
392 | (z * s).sum().backward()
393 | c = a.grad
394 | r = (a * c).data
395 | # print(f"before norm: {r.sum()}")
396 | # r = (r - r.min()) / (r.max() - r.min())
397 | # print(f"after norm: {r.sum()}\n")
398 | if r.shape != a.shape:
399 | raise RuntimeError("r.shape != a.shape")
400 | return r
401 |
402 |
403 | class RelevancePropagationLinear(nn.Module):
404 | """Layer-wise relevance propagation for linear transformation.
405 |
406 | Optionally modifies layer weights according to propagation rule. Here z^+-rule
407 |
408 | Attributes:
409 | layer: linear transformation layer.
410 | eps: a value added to the denominator for numerical stability.
411 |
412 | """
413 |
414 | def __init__(
415 | self,
416 | layer: torch.nn.Linear,
417 | mode: str = "z_plus",
418 | eps: float = 1.0e-05,
419 | top_k: float = 0.0,
420 | ) -> None:
421 | super().__init__()
422 |
423 | self.layer = layer
424 |
425 | if mode == "z_plus":
426 | self.layer.weight = torch.nn.Parameter(self.layer.weight.clamp(min=0.0))
427 | self.layer.bias = torch.nn.Parameter(torch.zeros_like(self.layer.bias))
428 |
429 | self.eps = eps
430 | self.top_k = top_k
431 |
432 | @torch.no_grad()
433 | def forward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
434 | if self.top_k:
435 | r = relevance_filter(r, top_k_percent=self.top_k)
436 | z = self.layer.forward(a) + self.eps
437 | s = r / z
438 | c = torch.mm(s, self.layer.weight)
439 | r = (a * c).data
440 | # print(f"Linear {r.min()}, {r.max()}")
441 | return r
442 |
443 |
444 | class RelevancePropagationFlatten(nn.Module):
445 | """Layer-wise relevance propagation for flatten operation.
446 |
447 | Attributes:
448 | layer: flatten layer.
449 |
450 | """
451 |
452 | def __init__(self, layer: torch.nn.Flatten, top_k: float = 0.0) -> None:
453 | super().__init__()
454 | self.layer = layer
455 |
456 | @torch.no_grad()
457 | def forward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
458 | r = r.view(size=a.shape)
459 | return r
460 |
461 |
462 | class RelevancePropagationReLU(nn.Module):
463 | """Layer-wise relevance propagation for ReLU activation.
464 |
465 | Passes the relevance scores without modification. Might be of use later.
466 |
467 | """
468 |
469 | def __init__(self, layer: torch.nn.ReLU, top_k: float = 0.0) -> None:
470 | super().__init__()
471 |
472 | @torch.no_grad()
473 | def forward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
474 | return r
475 |
476 |
477 | class RelevancePropagationBatchNorm2d(nn.Module):
478 | """Layer-wise relevance propagation for ReLU activation.
479 |
480 | Passes the relevance scores without modification. Might be of use later.
481 |
482 | """
483 |
484 | def __init__(self, layer: torch.nn.BatchNorm2d, top_k: float = 0.0) -> None:
485 | super().__init__()
486 |
487 | @torch.no_grad()
488 | def forward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
489 | return r
490 |
491 |
492 | class RelevancePropagationDropout(nn.Module):
493 | """Layer-wise relevance propagation for dropout layer.
494 |
495 | Passes the relevance scores without modification. Might be of use later.
496 |
497 | """
498 |
499 | def __init__(self, layer: torch.nn.Dropout, top_k: float = 0.0) -> None:
500 | super().__init__()
501 |
502 | @torch.no_grad()
503 | def forward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
504 | return r
505 |
506 |
507 | class RelevancePropagationIdentity(nn.Module):
508 | """Identity layer for relevance propagation.
509 |
510 | Passes relevance scores without modifying them.
511 |
512 | """
513 |
514 | def __init__(self, layer: nn.Module, top_k: float = 0.0) -> None:
515 | super().__init__()
516 |
517 | @torch.no_grad()
518 | def forward(self, a: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
519 | return r
520 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | """Script with helper function."""
2 | from typing import Literal
3 |
4 | from models.lrp import *
5 | from src.lrp_layers import *
6 |
7 | SkipConnectionPropType = Literal["simple", "flows_skip", "flows_skip_simple", "latest"]
8 |
9 | def layers_lookup(version: SkipConnectionPropType = "latest") -> dict:
10 | """Lookup table to map network layer to associated LRP operation.
11 |
12 | Returns:
13 | Dictionary holding class mappings.
14 | """
15 |
16 | # For the purpose of the ablation study on relevance propagation for skip connections
17 | if version == "simple":
18 | return layers_lookup_simple()
19 | elif version == "flows_skip":
20 | return layers_lookup_flows_pure_skip()
21 | elif version == "flows_skip_simple":
22 | return layers_lookup_simple_flows_pure_skip()
23 | elif version == "latest":
24 | return layers_lookup_latest()
25 | else:
26 | raise ValueError("Invalid version was specified.")
27 |
28 |
29 | def layers_lookup_latest() -> dict:
30 | lookup_table = {
31 | torch.nn.modules.linear.Linear: RelevancePropagationLinear,
32 | torch.nn.modules.conv.Conv2d: RelevancePropagationConv2d,
33 | torch.nn.modules.activation.ReLU: RelevancePropagationReLU,
34 | torch.nn.modules.dropout.Dropout: RelevancePropagationDropout,
35 | torch.nn.modules.flatten.Flatten: RelevancePropagationFlatten,
36 | torch.nn.modules.pooling.AvgPool2d: RelevancePropagationAvgPool2d,
37 | torch.nn.modules.pooling.MaxPool2d: RelevancePropagationMaxPool2d,
38 | torch.nn.modules.pooling.AdaptiveAvgPool2d: RelevancePropagationAdaptiveAvgPool2d,
39 | AdaptiveAvgPool2dWithActivation: RelevancePropagationAdaptiveAvgPool2d,
40 | torch.nn.BatchNorm2d: RelevancePropagationBatchNorm2d,
41 | BatchNorm2dWithActivation: RelevancePropagationBatchNorm2d,
42 | BasicBlock: RelevancePropagationBasicBlock,
43 | Bottleneck: RelevancePropagationBottleneck,
44 | }
45 | return lookup_table
46 |
47 |
48 | def layers_lookup_simple() -> dict:
49 | lookup_table = {
50 | torch.nn.modules.linear.Linear: RelevancePropagationLinear,
51 | torch.nn.modules.conv.Conv2d: RelevancePropagationConv2d,
52 | torch.nn.modules.activation.ReLU: RelevancePropagationReLU,
53 | torch.nn.modules.dropout.Dropout: RelevancePropagationDropout,
54 | torch.nn.modules.flatten.Flatten: RelevancePropagationFlatten,
55 | torch.nn.modules.pooling.AvgPool2d: RelevancePropagationAvgPool2d,
56 | torch.nn.modules.pooling.MaxPool2d: RelevancePropagationMaxPool2d,
57 | torch.nn.modules.pooling.AdaptiveAvgPool2d: RelevancePropagationAdaptiveAvgPool2d,
58 | AdaptiveAvgPool2dWithActivation: RelevancePropagationAdaptiveAvgPool2d,
59 | torch.nn.BatchNorm2d: RelevancePropagationBatchNorm2d,
60 | BatchNorm2dWithActivation: RelevancePropagationBatchNorm2d,
61 | BasicBlock: RelevancePropagationBasicBlockSimple,
62 | Bottleneck: RelevancePropagationBottleneckSimple,
63 | }
64 | return lookup_table
65 |
66 |
67 | def layers_lookup_flows_pure_skip() -> dict:
68 | lookup_table = {
69 | torch.nn.modules.linear.Linear: RelevancePropagationLinear,
70 | torch.nn.modules.conv.Conv2d: RelevancePropagationConv2d,
71 | torch.nn.modules.activation.ReLU: RelevancePropagationReLU,
72 | torch.nn.modules.dropout.Dropout: RelevancePropagationDropout,
73 | torch.nn.modules.flatten.Flatten: RelevancePropagationFlatten,
74 | torch.nn.modules.pooling.AvgPool2d: RelevancePropagationAvgPool2d,
75 | torch.nn.modules.pooling.MaxPool2d: RelevancePropagationMaxPool2d,
76 | torch.nn.modules.pooling.AdaptiveAvgPool2d: RelevancePropagationAdaptiveAvgPool2d,
77 | AdaptiveAvgPool2dWithActivation: RelevancePropagationAdaptiveAvgPool2d,
78 | torch.nn.BatchNorm2d: RelevancePropagationBatchNorm2d,
79 | BatchNorm2dWithActivation: RelevancePropagationBatchNorm2d,
80 | BasicBlock: RelevancePropagationBasicBlockFlowsPureSkip,
81 | Bottleneck: RelevancePropagationBottleneckFlowsPureSkip,
82 | }
83 | return lookup_table
84 |
85 |
86 | def layers_lookup_simple_flows_pure_skip() -> dict:
87 | lookup_table = {
88 | torch.nn.modules.linear.Linear: RelevancePropagationLinear,
89 | torch.nn.modules.conv.Conv2d: RelevancePropagationConv2d,
90 | torch.nn.modules.activation.ReLU: RelevancePropagationReLU,
91 | torch.nn.modules.dropout.Dropout: RelevancePropagationDropout,
92 | torch.nn.modules.flatten.Flatten: RelevancePropagationFlatten,
93 | torch.nn.modules.pooling.AvgPool2d: RelevancePropagationAvgPool2d,
94 | torch.nn.modules.pooling.MaxPool2d: RelevancePropagationMaxPool2d,
95 | torch.nn.modules.pooling.AdaptiveAvgPool2d: RelevancePropagationAdaptiveAvgPool2d,
96 | AdaptiveAvgPool2dWithActivation: RelevancePropagationAdaptiveAvgPool2d,
97 | torch.nn.BatchNorm2d: RelevancePropagationBatchNorm2d,
98 | BatchNorm2dWithActivation: RelevancePropagationBatchNorm2d,
99 | BasicBlock: RelevancePropagationBasicBlockSimpleFlowsPureSkip,
100 | Bottleneck: RelevancePropagationBottleneckSimpleFlowsPureSkip,
101 | }
102 | return lookup_table
103 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import os
4 | import random
5 | from typing import Dict, Iterable, List, Optional, Tuple
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torch.optim as optim
12 | from torch.utils.data import DataLoader
13 | from torchinfo import summary
14 | from tqdm import tqdm
15 |
16 | import wandb
17 | from data import ALL_DATASETS, create_dataloader_dict, get_parameter_depend_in_data_set
18 | from evaluate import test
19 | from metrics.base import Metric
20 | from models import ALL_MODELS, create_model
21 | from models.attention_branch import AttentionBranchModel
22 | from models.lrp import BottleneckWithActivation
23 | from optim import ALL_OPTIM, create_optimizer
24 | from optim.sam import SAM
25 | from utils.loss import calculate_loss
26 | from utils.utils import fix_seed, module_generator, parse_with_config, save_json
27 | from utils.visualize import save_attention_map
28 |
29 |
30 | class EarlyStopping:
31 | """
32 | Attributes:
33 | patience(int): How long to wait after last time validation loss improved.
34 | delta(float) : Minimum change in the monitored quantity to qualify as an improvement.
35 | save_dir(str): Directory to save a model when improvement is found.
36 | """
37 |
38 | def __init__(
39 | self, patience: int = 7, delta: float = 0, save_dir: str = "."
40 | ) -> None:
41 | self.patience = patience
42 | self.delta = delta
43 | self.save_dir = save_dir
44 |
45 | self.counter: int = 0
46 | self.early_stop: bool = False
47 | self.best_val_loss: float = np.Inf
48 |
49 | def __call__(self, val_loss: float, net: nn.Module) -> str:
50 | if val_loss + self.delta < self.best_val_loss:
51 | log = f"({self.best_val_loss:.5f} --> {val_loss:.5f})"
52 | self._save_checkpoint(net)
53 | self.best_val_loss = val_loss
54 | self.counter = 0
55 | return log
56 |
57 | self.counter += 1
58 | log = f"(> {self.best_val_loss:.5f} {self.counter}/{self.patience})"
59 | if self.counter >= self.patience:
60 | self.early_stop = True
61 | return log
62 |
63 | def _save_checkpoint(self, net: nn.Module) -> None:
64 | save_path = os.path.join(self.save_dir, "checkpoint.pt")
65 | torch.save(net.state_dict(), save_path)
66 |
67 |
68 | def set_parameter_trainable(module: nn.Module, is_trainable: bool = True) -> None:
69 | """
70 | Set all parameters of the module to is_trainable(bool)
71 |
72 | Args:
73 | module(nn.Module): Target module
74 | is_trainable(bool): Whether to train the parameters
75 | """
76 | for param in module.parameters():
77 | param.requires_grad = is_trainable
78 |
79 |
80 | def set_trainable_bottlenecks(model, num_trainable):
81 | # 2. Set the last num_trainable Bottleneck/BottleneckWithActivation layers as trainable
82 | count = 0
83 | for child in reversed(list(model.children())):
84 | # Go through each layer in the current sequential block
85 | for layer in reversed(list(child.children())):
86 | if isinstance(layer, (BottleneckWithActivation)):
87 | for param in layer.parameters():
88 | param.requires_grad = True
89 | count += 1
90 | if count == num_trainable:
91 | return
92 |
93 |
94 | def freeze_model(
95 | model: nn.Module,
96 | num_trainable_module: int = 0,
97 | fe_trainable: bool = False,
98 | ab_trainable: bool = False,
99 | perception_trainable: bool = False,
100 | final_trainable: bool = True,
101 | ) -> None:
102 | """
103 | Freeze the model
104 | After freezing the whole model, only the final layer is trainable
105 | Then, the last num_trainable_module are trainable from the back
106 |
107 | Args:
108 | num_trainable_module(int): Number of trainable modules
109 | fe_trainable(bool): Whether to train the Feature Extractor
110 | ab_trainable(bool): Whether to train the Attention Branch
111 | perception_trainable(bool): Whether to train the Perception Branch
112 |
113 | Note:
114 | (fe|ab|perception)_trainable is only used when AttentionBranchModel
115 | num_trainable_module takes precedence over the above
116 | """
117 | if isinstance(model, AttentionBranchModel):
118 | set_parameter_trainable(model.feature_extractor, fe_trainable)
119 | set_parameter_trainable(model.attention_branch, ab_trainable)
120 | set_parameter_trainable(model.perception_branch, perception_trainable)
121 | modules = module_generator(model.perception_branch, reverse=True)
122 | else:
123 | if num_trainable_module < 0:
124 | set_parameter_trainable(model)
125 | return
126 | set_parameter_trainable(model, is_trainable=False)
127 | modules = module_generator(model, reverse=True)
128 |
129 | final_layer = modules.__next__()
130 | set_parameter_trainable(final_layer, final_trainable)
131 | # set_parameter_trainable(model.perception_branch[0], False)
132 | # set_parameter_trainable(model.feature_extractor[0], True)
133 | # set_parameter_trainable(model.feature_extractor[1], True)
134 |
135 | # for i, module in enumerate(modules):
136 | # if num_trainable_module <= i:
137 | # break
138 |
139 | # set_parameter_trainable(module)
140 | set_trainable_bottlenecks(model, num_trainable_module)
141 |
142 |
143 | def setting_learning_rate(
144 | model: nn.Module, lr: float, lr_linear: float, lr_ab: Optional[float] = None
145 | ) -> Iterable:
146 | """
147 | Set learning rate for each layer
148 |
149 | Args:
150 | model (nn.Module): Model to set learning rate
151 | lr(float) : Learning rate for the last layer/Attention Branch
152 | lr_linear(float): Learning rate for the last layer
153 | lr_ab(float) : Learning rate for Attention Branch
154 |
155 | Returns:
156 | Iterable with learning rate
157 | It is given to the argument of optim.Optimizer
158 | """
159 | if isinstance(model, AttentionBranchModel):
160 | if lr_ab is None:
161 | lr_ab = lr_linear
162 | params = [
163 | {"params": model.attention_branch.parameters(), "lr": lr_ab},
164 | {"params": model.perception_branch[:-1].parameters(), "lr": lr},
165 | {"params": model.perception_branch[-1].parameters(), "lr": lr_linear},
166 | ]
167 | else:
168 | try:
169 | params = [
170 | {"params": model[:-1].parameters(), "lr": lr},
171 | {"params": model[-1].parameters(), "lr": lr_linear},
172 | ]
173 | except TypeError:
174 | params = [{"params": model.parameters(), "lr": lr}]
175 |
176 | return params
177 |
178 |
179 | def wandb_log(loss: float, metrics: Metric, phase: str) -> None:
180 | """
181 | Output logs to wandb
182 | Add phase to each metric for easy understanding
183 | (e.g. Acc -> Train_Acc)
184 |
185 | Args:
186 | loss(float) : Loss value
187 | metircs(Metric): Evaluation metrics
188 | phase(str) : train / val / test
189 | """
190 | log_items = {f"{phase}_loss": loss}
191 |
192 | for metric, value in metrics.score().items():
193 | log_items[f"{phase}_{metric}"] = value
194 |
195 | wandb.log(log_items)
196 |
197 |
198 | def train_insdel(
199 | model: nn.Module,
200 | images: torch.Tensor,
201 | labels: torch.Tensor,
202 | criterion: nn.modules.loss._Loss,
203 | mode: str,
204 | theta_dist: List[float] = [0.3, 0.5, 0.7],
205 | ):
206 | assert isinstance(model, AttentionBranchModel)
207 | attention_map = model.attention_branch.attention
208 | attention_map = F.interpolate(attention_map, images.shape[2:])
209 | att_base = attention_map.max()
210 |
211 | theta = random.choice(theta_dist)
212 | assert mode in ["insertion", "deletion"]
213 | if mode == "insertion":
214 | labels = torch.ones_like(labels)
215 | attention_map = torch.where(attention_map > att_base * theta, 1.0, 0.0)
216 | if mode == "deletion":
217 | labels = torch.zeros_like(labels)
218 | attention_map = torch.where(attention_map > att_base * theta, 0.0, 1.0)
219 |
220 | inputs = images * attention_map
221 | output = model(inputs.float())
222 | loss = criterion(output, labels)
223 |
224 | return loss
225 |
226 |
227 | def train(
228 | dataloader: DataLoader,
229 | model: nn.Module,
230 | criterion: nn.modules.loss._Loss,
231 | optimizer: optim.Optimizer,
232 | metric: Metric,
233 | lambdas: Optional[Dict[str, float]] = None,
234 | saliency: bool = False,
235 | ) -> Tuple[float, Metric]:
236 | total = 0
237 | total_loss: float = 0
238 | torch.autograd.set_detect_anomaly(True)
239 |
240 | model.train()
241 | for data_ in tqdm(dataloader, desc="Train: ", dynamic_ncols=True):
242 | inputs, labels = (
243 | data_[0].to(device),
244 | data_[1].to(device),
245 | )
246 | optimizer.zero_grad()
247 | outputs = model(inputs)
248 |
249 | loss = calculate_loss(criterion, outputs, labels, model, lambdas)
250 | loss.backward()
251 | total_loss += loss.item()
252 |
253 | metric.evaluate(outputs, labels)
254 |
255 | # When the optimizer is SAM, backward twice
256 | if isinstance(optimizer, SAM):
257 | optimizer.first_step(zero_grad=True)
258 | loss_sam = calculate_loss(criterion, model(inputs), labels, model, lambdas)
259 | loss_sam.backward()
260 | optimizer.second_step(zero_grad=True)
261 | else:
262 | optimizer.step()
263 |
264 | total += labels.size(0)
265 |
266 | train_loss = total_loss / total
267 | return train_loss, metric
268 |
269 |
270 | def main(args: argparse.Namespace):
271 | now = datetime.datetime.now()
272 | now_str = now.strftime("%Y-%m-%d_%H%M%S")
273 |
274 | fix_seed(args.seed, args.no_deterministic)
275 |
276 | # Create dataloaders
277 | dataloader_dict = create_dataloader_dict(
278 | args.dataset,
279 | args.batch_size,
280 | args.image_size,
281 | train_ratio=args.train_ratio,
282 | )
283 | data_params = get_parameter_depend_in_data_set(
284 | args.dataset, pos_weight=torch.Tensor(args.loss_weights).to(device)
285 | )
286 |
287 | # Create a model
288 | model = create_model(
289 | args.model,
290 | num_classes=len(data_params["classes"]),
291 | num_channel=data_params["num_channel"],
292 | base_pretrained=args.base_pretrained,
293 | base_pretrained2=args.base_pretrained2,
294 | pretrained_path=args.pretrained,
295 | attention_branch=args.add_attention_branch,
296 | division_layer=args.div,
297 | theta_attention=args.theta_att,
298 | )
299 | assert model is not None, "Model name is invalid"
300 | freeze_model(
301 | model,
302 | args.trainable_module,
303 | "fe" not in args.freeze,
304 | "ab" not in args.freeze,
305 | "pb" not in args.freeze,
306 | "linear" not in args.freeze,
307 | )
308 |
309 | # Setup optimizer and scheduler
310 | params = setting_learning_rate(model, args.lr, args.lr_linear, args.lr_ab)
311 | optimizer = create_optimizer(
312 | args.optimizer, params, args.lr, args.weight_decay, args.momentum
313 | )
314 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
315 | # scheduler = CosineLRScheduler(optimizer, t_initial=100, lr_min=args.min_lr, warmup_t=10, warmup_prefix=True)
316 |
317 | if args.saliency_guided:
318 | set_parameter_trainable(model.perception_branch[0], False)
319 |
320 | criterion = data_params["criterion"]
321 | metric = data_params["metric"]
322 |
323 | # Create run_name (for save_dir / wandb)
324 | if args.model is not None:
325 | config_file = os.path.basename(args.config)
326 | run_name = os.path.splitext(config_file)[0]
327 | else:
328 | run_name = args.model
329 | run_name += ["", f"_div{args.div}"][args.add_attention_branch]
330 | run_name = f"{run_name}_{now_str}"
331 | if args.run_name is not None:
332 | run_name = args.run_name
333 |
334 | save_dir = os.path.join(args.save_dir, run_name)
335 | assert not os.path.isdir(save_dir)
336 | os.makedirs(save_dir)
337 | best_path = os.path.join(save_dir, "best.pt")
338 |
339 | configs = vars(args)
340 | configs.pop("config") # To prevent the old config from being included in the new config applied with the model
341 |
342 | early_stopping = EarlyStopping(
343 | patience=args.early_stopping_patience, save_dir=save_dir
344 | )
345 |
346 | wandb.init(project=args.dataset, name=run_name, notes=args.notes)
347 | wandb.config.update(configs)
348 | configs["pretrained"] = best_path
349 | save_json(configs, os.path.join(save_dir, "config.json"))
350 |
351 | # Model details display (torchsummary)
352 | summary(
353 | model,
354 | (args.batch_size, data_params["num_channel"], args.image_size, args.image_size),
355 | )
356 |
357 | lambdas = {"att": args.lambda_att}
358 |
359 | save_test_acc = 0
360 | model.to(device)
361 | for epoch in range(args.epochs):
362 | print(f"\n[Epoch {epoch+1}]")
363 | for phase, dataloader in dataloader_dict.items():
364 | if phase == "Train":
365 | loss, metric = train(
366 | dataloader,
367 | model,
368 | criterion,
369 | optimizer,
370 | metric,
371 | lambdas=lambdas,
372 | )
373 | else:
374 | loss, metric = test(
375 | dataloader,
376 | model,
377 | criterion,
378 | metric,
379 | device,
380 | phase,
381 | lambdas=lambdas,
382 | )
383 |
384 | metric_log = metric.log()
385 | log = f"{phase}\t| {metric_log} Loss: {loss:.5f} "
386 |
387 | wandb_log(loss, metric, phase)
388 |
389 | if phase == "Val":
390 | early_stopping_log = early_stopping(loss, model)
391 | log += early_stopping_log
392 | scheduler.step(loss)
393 |
394 | print(log)
395 | if phase == "Test" and not early_stopping.early_stop:
396 | save_test_acc = metric.acc()
397 |
398 | metric.clear()
399 | if args.add_attention_branch:
400 | save_attention_map(
401 | model.attention_branch.attention[0][0], "attention.png"
402 | )
403 |
404 | if early_stopping.early_stop:
405 | print("Early Stopping")
406 | model.load_state_dict(torch.load(os.path.join(save_dir, "checkpoint.pt")))
407 | break
408 |
409 | torch.save(model.state_dict(), os.path.join(save_dir, "best.pt"))
410 | configs["test_acc"] = save_test_acc.item()
411 | save_json(configs, os.path.join(save_dir, "config.json"))
412 | wandb.log({"final_test_acc": save_test_acc})
413 | print("Training Finished")
414 |
415 |
416 | def parse_args():
417 | parser = argparse.ArgumentParser()
418 |
419 | parser.add_argument("-c", "--config", type=str, help="path to config file (json)")
420 |
421 | parser.add_argument("--seed", type=int, default=42)
422 | parser.add_argument("--no_deterministic", action="store_false")
423 |
424 | parser.add_argument("-n", "--notes", type=str, default="")
425 |
426 | # Model
427 | parser.add_argument("-m", "--model", choices=ALL_MODELS, help="model name")
428 | parser.add_argument(
429 | "-add_ab",
430 | "--add_attention_branch",
431 | action="store_true",
432 | help="add Attention Branch",
433 | )
434 | parser.add_argument(
435 | "--div",
436 | type=str,
437 | choices=["layer1", "layer2", "layer3"],
438 | default="layer2",
439 | help="place to attention branch",
440 | )
441 | parser.add_argument("--base_pretrained", type=str, help="path to base pretrained")
442 | parser.add_argument(
443 | "--base_pretrained2",
444 | type=str,
445 | help="path to base pretrained2 ( after change_num_classes() )",
446 | )
447 | parser.add_argument("--pretrained", type=str, help="path to pretrained")
448 | parser.add_argument(
449 | "--theta_att", type=float, default=0, help="threthold of attention branch"
450 | )
451 |
452 | # Freeze
453 | parser.add_argument(
454 | "--freeze",
455 | type=str,
456 | nargs="*",
457 | choices=["fe", "ab", "pb", "linear"],
458 | default=[],
459 | help="freezing layer",
460 | )
461 | parser.add_argument(
462 | "--trainable_module",
463 | type=int,
464 | default=-1,
465 | help="number of trainable modules, -1: all trainable",
466 | )
467 |
468 | # Dataset
469 | parser.add_argument("--dataset", type=str, default="IDRiD", choices=ALL_DATASETS)
470 | parser.add_argument("--image_size", type=int, default=224)
471 | parser.add_argument("--batch_size", type=int, default=32)
472 | parser.add_argument(
473 | "--train_ratio", type=float, default=0.8, help="ratio for train val split"
474 | )
475 | parser.add_argument(
476 | "--loss_weights",
477 | type=float,
478 | nargs="*",
479 | default=[1.0, 1.0],
480 | help="weights for label by class",
481 | )
482 |
483 | # Optimizer
484 | parser.add_argument("--epochs", type=int, default=200)
485 | parser.add_argument(
486 | "-optim", "--optimizer", type=str, default="AdamW", choices=ALL_OPTIM
487 | )
488 | parser.add_argument(
489 | "--lr",
490 | "--learning_rate",
491 | type=float,
492 | default=1e-4,
493 | )
494 | parser.add_argument(
495 | "--lr_linear",
496 | type=float,
497 | default=1e-3,
498 | )
499 | parser.add_argument(
500 | "--lr_ab",
501 | "--lr_attention_branch",
502 | type=float,
503 | default=1e-3,
504 | )
505 | parser.add_argument(
506 | "--min_lr",
507 | type=float,
508 | default=1e-6,
509 | )
510 | parser.add_argument(
511 | "--momentum",
512 | type=float,
513 | default=0.9,
514 | )
515 | parser.add_argument(
516 | "--weight_decay",
517 | type=float,
518 | default=0.01,
519 | )
520 | parser.add_argument(
521 | "--factor", type=float, default=0.3333, help="new_lr = lr * factor"
522 | )
523 | parser.add_argument(
524 | "--scheduler_patience",
525 | type=int,
526 | default=2,
527 | help="Number of epochs with no improvement after which learning rate will be reduced",
528 | )
529 |
530 | parser.add_argument(
531 | "--lambda_att", type=float, default=0.1, help="weights for attention loss"
532 | )
533 |
534 | parser.add_argument(
535 | "--early_stopping_patience", type=int, default=6, help="Early Stopping patience"
536 | )
537 | parser.add_argument(
538 | "--save_dir", type=str, default="checkpoints", help="path to save checkpoints"
539 | )
540 |
541 | parser.add_argument(
542 | "--run_name", type=str, help="save in save_dir/run_name and wandb name"
543 | )
544 |
545 | return parse_with_config(parser)
546 |
547 |
548 | if __name__ == "__main__":
549 | # import pdb; pdb.set_trace()
550 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
551 |
552 | main(parse_args())
553 |
--------------------------------------------------------------------------------
/utils/loss.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn
6 |
7 | from models.attention_branch import AttentionBranchModel
8 |
9 |
10 | def criterion_with_cast_targets(
11 | criterion: nn.modules.loss._Loss, preds: torch.Tensor, targets: torch.Tensor
12 | ) -> torch.Tensor:
13 | """
14 | Calculate the loss after changing the type
15 |
16 | Args:
17 | criterion(Loss): loss function
18 | preds(Tensor) : prediction
19 | targets(Tensor): label
20 |
21 | Returns:
22 | torch.Tensor: loss value
23 |
24 | Note:
25 | The type required by the loss function is different, so we convert it
26 | """
27 | if isinstance(criterion, nn.CrossEntropyLoss):
28 | # targets = F.one_hot(targets, num_classes=2)
29 | targets = targets.long()
30 |
31 | if isinstance(criterion, nn.BCEWithLogitsLoss):
32 | targets = F.one_hot(targets, num_classes=2)
33 | targets = targets.to(preds.dtype)
34 |
35 | return criterion(preds, targets)
36 |
37 |
38 | def calculate_loss(
39 | criterion: nn.modules.loss._Loss,
40 | outputs: torch.Tensor,
41 | targets: torch.Tensor,
42 | model: nn.Module,
43 | lambdas: Optional[Dict[str, float]] = None,
44 | ) -> torch.Tensor:
45 | """
46 | Calculate the loss
47 | Add the attention loss when AttentionBranchModel
48 |
49 | Args:
50 | criterion(Loss) : Loss function
51 | preds(Tensor) : Prediction
52 | targets(Tensor) : Label
53 | model(nn.Module) : Model that made the prediction
54 | lambdas(Dict[str, float]): Weight of each term of the loss
55 |
56 | Returns:
57 | torch.Tensor: Loss value
58 | """
59 | loss = criterion_with_cast_targets(criterion, outputs, targets)
60 |
61 | # Attention Loss
62 | if isinstance(model, AttentionBranchModel):
63 | keys = ["att", "var"]
64 | if lambdas is None:
65 | lambdas = {key: 1 for key in keys}
66 | for key in keys:
67 | if key not in lambdas:
68 | lambdas[key] = 1
69 |
70 | attention_loss = criterion_with_cast_targets(
71 | criterion, model.attention_pred, targets
72 | )
73 | # loss = loss + attention_loss
74 | # attention = model.attention_branch.attention
75 | # _, _, W, H = attention.size()
76 |
77 | # att_sum = torch.sum(attention, dim=(-1, -2))
78 | # attention_loss = torch.mean(att_sum / (W * H))
79 | loss = loss + lambdas["att"] * attention_loss
80 | # attention = model.attention_branch.attention
81 | # attention_varmean = attention.var(dim=(1, 2, 3)).mean()
82 | # loss = loss - lambdas["var"] * attention_varmean
83 |
84 | return loss
85 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import random
4 | import sys
5 | from typing import Dict, List, Tuple, Union
6 |
7 | import numpy as np
8 | import torch
9 | from torch import nn
10 |
11 |
12 | def fix_seed(seed: int, deterministic: bool = False) -> None:
13 | # random
14 | random.seed(seed)
15 | # numpy
16 | np.random.seed(seed)
17 | # pytorch
18 | torch.manual_seed(seed)
19 | torch.cuda.manual_seed_all(seed)
20 | # cudnn
21 | torch.backends.cudnn.deterministic = deterministic
22 | torch.backends.cudnn.benchmark = not deterministic
23 |
24 |
25 | def reverse_normalize(
26 | x: np.ndarray,
27 | mean: Union[Tuple[float], Tuple[float, float, float]],
28 | std: Union[Tuple[float], Tuple[float, float, float]],
29 | ):
30 | """
31 | Restore normalization
32 |
33 | Args:
34 | x(ndarray) : Matrix that has been normalized
35 | mean(Tuple): Mean specified at the time of normalization
36 | std(Tuple) : Standard deviation specified at the time of normalization
37 | """
38 | if x.shape[0] == 1:
39 | x = x * std + mean
40 | return x
41 | x[0, :, :] = x[0, :, :] * std[0] + mean[0]
42 | x[1, :, :] = x[1, :, :] * std[1] + mean[1]
43 | x[2, :, :] = x[2, :, :] * std[2] + mean[2]
44 |
45 | return x
46 |
47 |
48 | def module_generator(model: nn.Module, reverse: bool = False):
49 | """
50 | Generator for nested Module, can handle nested Sequential in one layer
51 | Note that you cannot get layers by index
52 |
53 | Args:
54 | model (nn.Module): Model
55 | reverse(bool) : Whether to reverse
56 |
57 | Yields:
58 | Each layer of the model
59 | """
60 | modules = list(model.children())
61 | if reverse:
62 | modules = modules[::-1]
63 |
64 | for module in modules:
65 | if list(module.children()):
66 | yield from module_generator(module, reverse)
67 | continue
68 | yield module
69 |
70 |
71 | def save_json(data: Union[List, Dict], save_path: str) -> None:
72 | """
73 | Save list/dict to json
74 |
75 | Args:
76 | data (List/Dict): Data to save
77 | save_path(str) : Path to save (including extension)
78 | """
79 | with open(save_path, "w") as f:
80 | json.dump(data, f, indent=4)
81 |
82 |
83 | def softmax_image(image: torch.Tensor) -> torch.Tensor:
84 | image_size = image.size()
85 | if len(image_size) == 4:
86 | B, C, W, H = image_size
87 | elif len(image_size) == 3:
88 | B = 1
89 | C, W, H = image_size
90 | else:
91 | raise ValueError
92 |
93 | image = image.view(B, C, W * H)
94 | image = torch.softmax(image, dim=-1)
95 |
96 | image = image.view(B, C, W, H)
97 | if len(image_size) == 3:
98 | image = image[0]
99 |
100 | return image
101 |
102 |
103 | def tensor_to_numpy(tensor: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
104 | if isinstance(tensor, torch.Tensor):
105 | result: np.ndarray = tensor.cpu().detach().numpy()
106 | else:
107 | result = tensor
108 |
109 | return result
110 |
111 |
112 | def parse_with_config(parser: argparse.ArgumentParser) -> argparse.Namespace:
113 | """
114 | Coexistence of argparse and json file config
115 |
116 | Args:
117 | parser(ArgumentParser)
118 |
119 | Returns:
120 | Namespace: Can be used in the same way as argparser
121 |
122 | Note:
123 | Values specified in the arguments take precedence over the config
124 | """
125 | args, _unknown = parser.parse_known_args()
126 | if args.config is not None:
127 | config_args = json.load(open(args.config))
128 | override_keys = {
129 | arg[2:].split("=")[0] for arg in sys.argv[1:] if arg.startswith("--")
130 | }
131 | for k, v in config_args.items():
132 | if k not in override_keys:
133 | setattr(args, k, v)
134 | return args
135 |
--------------------------------------------------------------------------------
/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from typing import Literal, Optional, Tuple, Union
3 |
4 | import cv2
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import plotly.graph_objects as go
8 | import torch
9 |
10 | from utils.utils import reverse_normalize, tensor_to_numpy
11 |
12 |
13 | def save_attention_map(attention: Union[np.ndarray, torch.Tensor], fname: str) -> None:
14 | attention = tensor_to_numpy(attention)
15 |
16 | fig, ax = plt.subplots()
17 |
18 | min_att = attention.min()
19 | max_att = attention.max()
20 |
21 | im = ax.imshow(
22 | attention, interpolation="nearest", cmap="jet", vmin=min_att, vmax=max_att
23 | )
24 | fig.colorbar(im)
25 | plt.savefig(fname)
26 | plt.clf()
27 | plt.close()
28 |
29 |
30 | def save_image_with_attention_map(
31 | image: np.ndarray,
32 | attention: np.ndarray,
33 | fname: str,
34 | mean: Tuple[float, float, float],
35 | std: Tuple[float, float, float],
36 | only_img: bool = False,
37 | normalize: bool = False,
38 | sign: Literal["all", "positive", "negative", "absolute_value"] = "all",
39 | ) -> None:
40 | if len(attention.shape) == 3:
41 | attention = attention[0]
42 |
43 | image = image[:3]
44 | mean = mean[:3]
45 | std = std[:3]
46 |
47 | # attention = (attention - attention.min()) / (attention.max() - attention.min())
48 | if normalize:
49 | attention = normalize_attr(attention, sign)
50 |
51 | # image : (C, W, H)
52 | attention = cv2.resize(attention, dsize=(image.shape[1], image.shape[2]))
53 | image = reverse_normalize(image.copy(), mean, std)
54 | image = np.transpose(image, (1, 2, 0))
55 | image = np.clip(image, 0, 1)
56 |
57 | fig, ax = plt.subplots()
58 | if only_img:
59 | ax.axis('off') # No axes for a cleaner look
60 | if image.shape[2] == 1:
61 | ax.imshow(image, cmap="gray", vmin=0, vmax=1)
62 | else:
63 | ax.imshow(image, vmin=0, vmax=1)
64 |
65 | im = ax.imshow(attention, cmap="jet", alpha=0.4, vmin=attention.min(), vmax=attention.max())
66 | if not only_img:
67 | fig.colorbar(im)
68 |
69 | if only_img:
70 | plt.savefig(fname, bbox_inches='tight', pad_inches=0)
71 | else:
72 | plt.savefig(fname)
73 |
74 | plt.clf()
75 | plt.close()
76 |
77 |
78 | def save_image(image: np.ndarray, fname: str, mean: Tuple[float, float, float], std: Tuple[float, float, float]) -> None:
79 | # image : (C, W, H)
80 | image = reverse_normalize(image.copy(), mean, std)
81 | image = image.clip(0, 1)
82 | image = np.transpose(image, (1, 2, 0))
83 |
84 | # Convert image from [0, 1] float to [0, 255] uint8 for saving
85 | image = (image * 255).astype(np.uint8)
86 | # image = image.astype(np.uint8)
87 |
88 | if image.shape[2] == 1: # if grayscale
89 | image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
90 | else:
91 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
92 |
93 | cv2.imwrite(fname, image)
94 |
95 |
96 | def simple_plot_and_save(x, y, save_path):
97 | fig = go.Figure()
98 | fig.add_trace(go.Scatter(x=x, y=x, mode='lines', line=dict(color='#999999', width=2)))
99 | fig.add_trace(go.Scatter(x=x, y=y, mode='markers', marker=dict(color='#82B366', size=12)))
100 | fig.update_layout(
101 | plot_bgcolor='white',
102 | showlegend=False,
103 | # xaxis_title=r"$\Huge {\sum_i R^{(\textrm{First block})}_i}$",
104 | # yaxis_title=r"$\Huge {p(\hat y_c)}$",
105 | xaxis=dict(
106 | title_standoff=28,
107 | tickfont=dict(family="Times New Roman", size=30, color="black"),
108 | linecolor='black',
109 | showgrid=False,
110 | ticks='outside',
111 | tickcolor='black',
112 | ),
113 | yaxis=dict(
114 | title_standoff=26,
115 | tickfont=dict(family="Times New Roman", size=30, color="black"),
116 | linecolor='black',
117 | showgrid=False,
118 | ticks='outside',
119 | tickcolor='black',
120 | scaleanchor="x",
121 | scaleratio=1,
122 | ),
123 | autosize=False,
124 | width=600,
125 | height=600,
126 | # margin=dict(l=115, r=5, b=115, t=5),
127 | margin=dict(l=5, r=5, b=5, t=5),
128 | )
129 | fig.write_image(save_path)
130 |
131 |
132 | def simple_plot_and_save_legacy(x, y, save_path):
133 | plt.figure(figsize=(6, 6))
134 | plt.scatter(x, y, color='orange')
135 | plt.plot(x, x, color='#999999') # y=x line
136 | plt.gca().spines['top'].set_visible(False)
137 | plt.gca().spines['right'].set_visible(False)
138 | plt.gca().set_aspect('equal', adjustable='box')
139 | plt.savefig(save_path)
140 |
141 |
142 | def save_data_as_plot(
143 | data: np.ndarray,
144 | fname: str,
145 | x: Optional[np.ndarray] = None,
146 | label: Optional[str] = None,
147 | xlim: Optional[Union[int, float]] = None,
148 | ) -> None:
149 | """
150 | Save data as plot
151 |
152 | Args:
153 | data(ndarray): Data to save
154 | fname(str) : File name to save
155 | x(ndarray) : X-axis
156 | label(str) : Label of legend
157 | """
158 | fig, ax = plt.subplots()
159 |
160 | if x is None:
161 | x = range(len(data))
162 |
163 | ax.plot(x, data, label=label)
164 |
165 | xmax = len(data) if xlim is None else xlim
166 | ax.set_xlim(0, xmax)
167 | ax.set_ylim(-0.05, 1.05)
168 |
169 | plt.legend()
170 | plt.savefig(fname, bbox_inches="tight", pad_inches=0.05)
171 | plt.clf()
172 | plt.close()
173 |
174 |
175 | """followings are borrowed & modified from captum.attr._utils"""
176 |
177 | def _normalize_scale(attr: np.ndarray, scale_factor: float):
178 | assert scale_factor != 0, "Cannot normalize by scale factor = 0"
179 | if abs(scale_factor) < 1e-5:
180 | warnings.warn(
181 | "Attempting to normalize by value approximately 0, visualized results"
182 | "may be misleading. This likely means that attribution values are all"
183 | "close to 0."
184 | )
185 | attr_norm = attr / scale_factor
186 | return np.clip(attr_norm, -1, 1)
187 |
188 |
189 | def _cumulative_sum_threshold(values: np.ndarray, percentile: Union[int, float]):
190 | # given values should be non-negative
191 | assert percentile >= 0 and percentile <= 100, (
192 | "Percentile for thresholding must be " "between 0 and 100 inclusive."
193 | )
194 | sorted_vals = np.sort(values.flatten())
195 | cum_sums = np.cumsum(sorted_vals)
196 | threshold_id = np.where(cum_sums >= cum_sums[-1] * 0.01 * percentile)[0][0]
197 | return sorted_vals[threshold_id]
198 |
199 |
200 | def normalize_attr(
201 | attr: np.ndarray,
202 | sign: str,
203 | outlier_perc: Union[int, float] = 2,
204 | reduction_axis: Optional[int] = None,
205 | ):
206 | attr_combined = attr
207 | if reduction_axis is not None:
208 | attr_combined = np.sum(attr, axis=reduction_axis)
209 |
210 | # Choose appropriate signed values and rescale, removing given outlier percentage.
211 | if sign == "all":
212 | threshold = _cumulative_sum_threshold(np.abs(attr_combined), 100 - outlier_perc)
213 | elif sign == "positive":
214 | attr_combined = (attr_combined > 0) * attr_combined
215 | threshold = _cumulative_sum_threshold(attr_combined, 100 - outlier_perc)
216 | elif sign == "negative":
217 | attr_combined = (attr_combined < 0) * attr_combined
218 | threshold = -1 * _cumulative_sum_threshold(
219 | np.abs(attr_combined), 100 - outlier_perc
220 | )
221 | elif sign == "absolute_value":
222 | attr_combined = np.abs(attr_combined)
223 | threshold = _cumulative_sum_threshold(attr_combined, 100 - outlier_perc)
224 | else:
225 | raise AssertionError("Visualize Sign type is not valid.")
226 | return _normalize_scale(attr_combined, threshold)
227 |
--------------------------------------------------------------------------------
/visualize.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from typing import Any, Dict, Optional, Tuple, Union
4 |
5 | import captum
6 | import cv2
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | import torch.utils.data as data
11 | from pytorch_grad_cam import GradCAM
12 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
13 | from skimage.transform import resize
14 | from torchinfo import summary
15 | from torchvision.models.resnet import ResNet
16 | from tqdm import tqdm
17 |
18 | import wandb
19 | from data import ALL_DATASETS, create_dataloader_dict, get_parameter_depend_in_data_set
20 | from metrics.base import Metric
21 | from metrics.patch_insdel import PatchInsertionDeletion
22 | from models import ALL_MODELS, OneWayResNet, create_model
23 | from models.attention_branch import AttentionBranchModel
24 | from models.lrp import *
25 | from models.rise import RISE
26 | from src.lrp import abn_lrp, basic_lrp, resnet_lrp
27 | from src.utils import SkipConnectionPropType
28 | from utils.utils import fix_seed, parse_with_config
29 | from utils.visualize import (
30 | save_image,
31 | save_image_with_attention_map,
32 | simple_plot_and_save,
33 | )
34 |
35 |
36 | def calculate_attention(
37 | model: nn.Module,
38 | image: torch.Tensor,
39 | label: torch.Tensor,
40 | method: str,
41 | rise_params: Optional[Dict[str, Any]],
42 | fname: Optional[str],
43 | skip_connection_prop_type: SkipConnectionPropType = "latest",
44 | ) -> Tuple[np.ndarray, Optional[float]]:
45 | relevance_out = None
46 | if method.lower() == "abn":
47 | assert isinstance(model, AttentionBranchModel)
48 | model(image)
49 | attentions = model.attention_branch.attention # (1, W, W)
50 | attention = attentions[0]
51 | attention: np.ndarray = attention[0].detach().cpu().numpy()
52 | elif method.lower() == "rise":
53 | assert rise_params is not None
54 | rise_model = RISE(
55 | model,
56 | n_masks=rise_params["n_masks"],
57 | p1=rise_params["p1"],
58 | input_size=rise_params["input_size"],
59 | initial_mask_size=rise_params["initial_mask_size"],
60 | n_batch=rise_params["n_batch"],
61 | mask_path=rise_params["mask_path"],
62 | )
63 | attentions = rise_model(image) # (N_class, W, H)
64 | attention = attentions[label]
65 | attention: np.ndarray = attention.cpu().numpy()
66 | elif method.lower() == "npy":
67 | assert fname is not None
68 | attention: np.ndarray = np.load(fname)
69 | elif method.lower() == "gradcam":
70 | tpl = torch.softmax(model(image), dim=-1).max(dim=-1)
71 | pred_label = tpl.indices[0].item()
72 | relevance_out = tpl.values[0].item()
73 | target_layers = [model.layer4[-1]]
74 | cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)
75 | targets = [ClassifierOutputTarget(pred_label)]
76 | grayscale_cam = cam(input_tensor=image, targets=targets)
77 | attention = grayscale_cam[0, :]
78 | elif method.lower() == "scorecam":
79 | from src import scorecam as scam
80 | resnet_model_dict = dict(type='resnet50', arch=model, layer_name='layer4', input_size=(224, 224))
81 | cam = scam.ScoreCAM(resnet_model_dict)
82 | attention = cam(image)
83 | attention = attention[0].detach().cpu().numpy()
84 | elif method.lower() == "lrp": # ours
85 | if isinstance(model, ResNet):
86 | model = OneWayResNet(model)
87 | tpl = torch.softmax(model(image), dim=-1).max(dim=-1)
88 | pred_label = tpl.indices[0].item()
89 | relevance_out = tpl.values[0].item()
90 | attention: np.ndarray = basic_lrp(
91 | model, image, rel_pass_ratio=1.0, topk=1, skip_connection_prop=skip_connection_prop_type
92 | ).detach().cpu().numpy()
93 | elif method.lower() == "captumlrp":
94 | tpl = torch.softmax(model(image), dim=-1).max(dim=-1)
95 | pred_label = tpl.indices[0].item()
96 | relevance_out = tpl.values[0].item()
97 | attention = captum.attr.LRP(model).attribute(image, target=pred_label)
98 | attention = attention[0].detach().cpu().numpy()
99 | elif method.lower() == "captumlrp-positive":
100 | tpl = torch.softmax(model(image), dim=-1).max(dim=-1)
101 | pred_label = tpl.indices[0].item()
102 | relevance_out = tpl.values[0].item()
103 | attention = captum.attr.LRP(model).attribute(image, target=pred_label)
104 | attention = attention[0].detach().cpu().numpy().clip(0, None)
105 | elif method.lower() in ["captumgradxinput", "gradxinput"]:
106 | tpl = torch.softmax(model(image), dim=-1).max(dim=-1)
107 | pred_label = tpl.indices[0].item()
108 | relevance_out = tpl.values[0].item()
109 | attention = captum.attr.InputXGradient(model).attribute(image, target=pred_label)
110 | attention = attention[0].detach().cpu().numpy()
111 | elif method.lower() in ["captumig", "ig"]:
112 | tpl = torch.softmax(model(image), dim=-1).max(dim=-1)
113 | pred_label = tpl.indices[0].item()
114 | relevance_out = tpl.values[0].item()
115 | attention = captum.attr.IntegratedGradients(model).attribute(image, target=pred_label)
116 | attention = attention[0].detach().cpu().numpy()
117 | elif method.lower() in ["captumig-positive", "ig-positive"]:
118 | tpl = torch.softmax(model(image), dim=-1).max(dim=-1)
119 | pred_label = tpl.indices[0].item()
120 | relevance_out = tpl.values[0].item()
121 | attention = captum.attr.IntegratedGradients(model).attribute(image, target=pred_label)
122 | attention = attention[0].detach().cpu().numpy().clip(0, None)
123 | elif method.lower() in ["captumguidedbackprop", "guidedbackprop", "gbp"]:
124 | tpl = torch.softmax(model(image), dim=-1).max(dim=-1)
125 | pred_label = tpl.indices[0].item()
126 | relevance_out = tpl.values[0].item()
127 | attention = captum.attr.GuidedBackprop(model).attribute(image, target=pred_label)
128 | attention = attention[0].detach().cpu().numpy()
129 | elif method.lower() in ["captumguidedbackprop-positive", "guidedbackprop-positive", "gbp-positive"]:
130 | tpl = torch.softmax(model(image), dim=-1).max(dim=-1)
131 | pred_label = tpl.indices[0].item()
132 | relevance_out = tpl.values[0].item()
133 | attention = captum.attr.GuidedBackprop(model).attribute(image, target=pred_label)
134 | attention = attention[0].detach().cpu().numpy().clip(0, None)
135 | # elif method.lower() in ["smoothgrad", "vargrad"]: # OOM...
136 | # tpl = torch.softmax(model(image), dim=-1).max(dim=-1)
137 | # pred_label = tpl.indices[0].item()
138 | # relevance_out = tpl.values[0].item()
139 | # ig = captum.attr.IntegratedGradients(model)
140 | # attention = captum.attr.NoiseTunnel(ig).attribute(image, nt_type=method.lower(), nt_samples=5, target=pred_label)
141 | # attention = attention[0].detach().cpu().numpy()
142 | elif method.lower() == "lime":
143 | tpl = torch.softmax(model(image), dim=-1).max(dim=-1)
144 | pred_label = tpl.indices[0].item()
145 | relevance_out = tpl.values[0].item()
146 | attention = captum.attr.Lime(model).attribute(image, target=pred_label)
147 | attention = attention[0].detach().cpu().numpy()
148 | elif method.lower() == "captumgradcam":
149 | tpl = torch.softmax(model(image), dim=-1).max(dim=-1)
150 | pred_label = tpl.indices[0].item()
151 | relevance_out = tpl.values[0].item()
152 | attention = captum.attr.LayerGradCam(model, model.layer4[-1]).attribute(image, target=pred_label)
153 | attention = attention[0].detach().cpu().numpy()
154 | else:
155 | raise ValueError(f"Invalid method was requested: {method}")
156 |
157 | return attention, relevance_out
158 |
159 |
160 | def apply_gaussian_to_max_pixel(mask, kernel_size=5, sigma=1):
161 | # Find max value's indices
162 | max_val_indices = np.where(mask == np.amax(mask))
163 |
164 | # Initialize a new mask with the same size of the original mask
165 | new_mask = np.zeros(mask.shape)
166 |
167 | # Assign the max value of the original mask to the corresponding location of the new mask
168 | new_mask[max_val_indices] = mask[max_val_indices]
169 |
170 | # Apply Gaussian blur
171 | blurred_mask = cv2.GaussianBlur(new_mask, (kernel_size, kernel_size), sigma)
172 |
173 | return blurred_mask
174 |
175 |
176 | def remove_other_components(mask, threshold=0.5):
177 | # Binarize the mask
178 | binary_mask = (mask > threshold).astype(np.uint8)
179 |
180 | # Detect connected components
181 | num_labels, labels = cv2.connectedComponents(binary_mask)
182 |
183 | # Calculate the maximum value in the original mask for each component
184 | max_values = [np.max(mask[labels == i]) for i in range(num_labels)]
185 |
186 | # Find the component with the largest maximum value
187 | largest_max_value_component = np.argmax(max_values)
188 | # second_index = np.argsort(max_values)[-2]
189 | # third_index = np.argsort(max_values)[-3]
190 | # print(max_values)
191 | # print(largest_max_value_component)
192 | # print(second_index)
193 |
194 | # Create a new mask where all components other than the one with the largest max value are removed
195 | first_mask = np.where(labels == largest_max_value_component, mask*3, 0)
196 | # second_mask = np.where(labels == second_index, mask, 0)
197 | # third_mask = np.where(labels == third_index, mask / 3, 0)
198 | new_mask = first_mask
199 | # new_mask = first_mask + second_mask + third_mask
200 |
201 | return new_mask
202 |
203 |
204 | def apply_heat_quantization(attention, q_level: int = 8):
205 | max_ = attention.max()
206 | min_ = attention.min()
207 |
208 | # quantization
209 | bin = np.linspace(min_, max_, q_level)
210 |
211 | # apply quantization
212 | for i in range(q_level - 1):
213 | attention[(attention >= bin[i]) & (attention < bin[i + 1])] = bin[i]
214 |
215 | return attention
216 |
217 |
218 | # @torch.no_grad()
219 | def visualize(
220 | dataloader: data.DataLoader,
221 | model: nn.Module,
222 | method: str,
223 | batch_size: int,
224 | patch_size: int,
225 | step: int,
226 | save_dir: str,
227 | all_class: bool,
228 | params: Dict[str, Any],
229 | device: torch.device,
230 | evaluate: bool = False,
231 | attention_dir: Optional[str] = None,
232 | use_c1c: bool = False,
233 | heat_quantization: bool = False,
234 | hq_level: int = 8,
235 | skip_connection_prop_type: SkipConnectionPropType = "latest",
236 | data_limit: int = -1,
237 | ) -> Union[None, Metric]:
238 | if evaluate:
239 | metrics = PatchInsertionDeletion(
240 | model, batch_size, patch_size, step, params["name"], device
241 | )
242 | insdel_save_dir = os.path.join(save_dir, "insdel")
243 | if not os.path.isdir(insdel_save_dir):
244 | os.makedirs(insdel_save_dir)
245 |
246 | rel_ins = []
247 | rel_outs = []
248 |
249 | counter = {}
250 |
251 | model.eval()
252 | # # print accuracy on test set
253 | # total = 0
254 | # correct = 0
255 | # for i, data_ in enumerate(
256 | # tqdm(dataloader, desc="Count failures", dynamic_ncols=True)
257 | # ):
258 | # inputs, labels = (
259 | # data_[0].to(device),
260 | # data_[1].to(device),
261 | # )
262 | # outputs = model(inputs)
263 | # _, predicted = torch.max(outputs.data, 1)
264 | # total += labels.size(0)
265 | # correct += (predicted == labels).sum().item()
266 | # print(f"Total: {total}")
267 | # print(f"Correct: {correct}")
268 | # print(f"Accuracy: {correct / total:.4f}")
269 |
270 | # Inference time estimation
271 | elapsed_time_sum = 0.0
272 | start_event = torch.cuda.Event(enable_timing=True)
273 | end_event = torch.cuda.Event(enable_timing=True)
274 | torch.cuda.empty_cache()
275 | batch = torch.randn(1, 3, 224, 224).to(device)
276 | for _ in range(100):
277 | start_event.record()
278 | model.forward(batch)
279 | end_event.record()
280 | torch.cuda.synchronize()
281 | elapsed_time_sum += start_event.elapsed_time(end_event)
282 | print(f"Average inference time: {elapsed_time_sum / 100}")
283 |
284 | # Elapsed time to inference + generate explanation
285 | elapsed_time_sum = 0.0
286 | start_event = torch.cuda.Event(enable_timing=True)
287 | end_event = torch.cuda.Event(enable_timing=True)
288 | torch.cuda.empty_cache()
289 | image = torch.randn(1, 3, 224, 224).to(device)
290 | if isinstance(model, ResNet):
291 | oneway_model = OneWayResNet(model)
292 | for _ in range(100):
293 | start_event.record()
294 | basic_lrp(
295 | oneway_model, image, rel_pass_ratio=1.0, topk=1, skip_connection_prop=skip_connection_prop_type
296 | )
297 | end_event.record()
298 | torch.cuda.synchronize()
299 | elapsed_time_sum += start_event.elapsed_time(end_event)
300 | print(f"Average time to inference and generate an attribution map: {elapsed_time_sum / 100}")
301 |
302 | for i, data_ in enumerate(
303 | tqdm(dataloader, desc="Visualizing: ", dynamic_ncols=True)
304 | ):
305 | if data_limit > 0 and i >= data_limit:
306 | break
307 | torch.cuda.memory_summary(device=device)
308 | inputs, labels = (
309 | data_[0].to(device),
310 | data_[1].to(device),
311 | )
312 | image: torch.Tensor = inputs[0].cpu().numpy()
313 | label: torch.Tensor = labels[0]
314 |
315 | if label != 1 and not all_class:
316 | continue
317 |
318 | counter.setdefault(label.item(), 0)
319 | counter[label.item()] += 1
320 | n_eval_per_class = 1000 / len(params["classes"])
321 | # n_eval_per_class = 5 # TODO: try
322 | if counter[label.item()] > n_eval_per_class:
323 | continue
324 |
325 | base_fname = f"{i+1}_{params['classes'][label]}"
326 |
327 | attention_fname = None
328 | if attention_dir is not None:
329 | attention_fname = os.path.join(attention_dir, f"{base_fname}.npy")
330 |
331 | attention, rel_ = calculate_attention(
332 | model, inputs, label, method, params, attention_fname, skip_connection_prop_type
333 | )
334 | if rel_ is not None:
335 | rel_ins.append(attention.sum().item())
336 | rel_outs.append(rel_)
337 | if use_c1c:
338 | attention = resize(attention, (28, 28))
339 | attention = remove_other_components(attention, threshold=attention.mean())
340 |
341 | if heat_quantization:
342 | attention = apply_heat_quantization(attention, hq_level)
343 |
344 | if attention is None:
345 | continue
346 | if method == "RISE":
347 | np.save(f"{save_dir}/{base_fname}.npy", attention)
348 |
349 | if evaluate:
350 | metrics.evaluate(
351 | image.copy(),
352 | attention,
353 | label,
354 | )
355 | metrics.save_roc_curve(insdel_save_dir)
356 | base_fname = f"{base_fname}_{metrics.ins_auc - metrics.del_auc:.4f}"
357 | if i % 50 == 0:
358 | print(metrics.log())
359 |
360 | save_fname = os.path.join(save_dir, f"{base_fname}.png")
361 | save_image_with_attention_map(
362 | image, attention, save_fname, params["mean"], params["std"]
363 | )
364 |
365 | save_image(image, save_fname[:-4]+".original.png", params["mean"], params["std"])
366 |
367 | # Plot conservations
368 | conservation_dir = os.path.join(save_dir, "conservation")
369 | if not os.path.isdir(conservation_dir):
370 | os.makedirs(conservation_dir)
371 | simple_plot_and_save(rel_ins, rel_outs, os.path.join(conservation_dir, "plot.png"))
372 |
373 | if evaluate:
374 | return metrics
375 |
376 |
377 | def main(args: argparse.Namespace) -> None:
378 | fix_seed(args.seed, args.no_deterministic)
379 |
380 | # データセットの作成
381 | dataloader_dict = create_dataloader_dict(
382 | args.dataset, 1, args.image_size, only_test=True, shuffle_val=True, dataloader_seed=args.seed
383 | )
384 | dataloader = dataloader_dict["Test"]
385 | assert isinstance(dataloader, data.DataLoader)
386 |
387 | params = get_parameter_depend_in_data_set(args.dataset)
388 |
389 | mask_path = os.path.join(args.root_dir, "masks.npy")
390 | if not os.path.isfile(mask_path):
391 | mask_path = None
392 | rise_params = {
393 | "n_masks": args.num_masks,
394 | "p1": args.p1,
395 | "input_size": (args.image_size, args.image_size),
396 | "initial_mask_size": (args.rise_scale, args.rise_scale),
397 | "n_batch": args.batch_size,
398 | "mask_path": mask_path,
399 | }
400 | params.update(rise_params)
401 |
402 | # モデルの作成
403 | model = create_model(
404 | args.model,
405 | num_classes=len(params["classes"]),
406 | num_channel=params["num_channel"],
407 | base_pretrained=args.base_pretrained,
408 | base_pretrained2=args.base_pretrained2,
409 | pretrained_path=args.pretrained,
410 | attention_branch=args.add_attention_branch,
411 | division_layer=args.div,
412 | theta_attention=args.theta_att,
413 | init_classifier=args.dataset != "ImageNet", # Use pretrained classifier in ImageNet
414 | )
415 | assert model is not None, "Model name is invalid"
416 |
417 | # run_nameをpretrained pathから取得
418 | # checkpoints/run_name/checkpoint.pt -> run_name
419 | run_name = args.pretrained.split(os.sep)[-2] if args.pretrained is not None else "pretrained"
420 | save_dir = os.path.join(
421 | "outputs",
422 | f"{run_name}_{args.notes}_{args.method}{args.block_size}",
423 | )
424 | if not os.path.isdir(save_dir):
425 | os.makedirs(save_dir)
426 |
427 | summary(
428 | model,
429 | (args.batch_size, params["num_channel"], args.image_size, args.image_size),
430 | )
431 |
432 | model.to(device)
433 |
434 | wandb.init(project=args.dataset, name=run_name, notes=args.notes)
435 | wandb.config.update(vars(args))
436 |
437 | metrics = visualize(
438 | dataloader,
439 | model,
440 | args.method,
441 | args.batch_size,
442 | args.block_size,
443 | args.insdel_step,
444 | save_dir,
445 | args.all_class,
446 | params,
447 | device,
448 | args.visualize_only,
449 | attention_dir=args.attention_dir,
450 | use_c1c=args.use_c1c,
451 | heat_quantization=args.heat_quantization,
452 | hq_level=args.hq_level,
453 | skip_connection_prop_type=args.skip_connection_prop_type,
454 | data_limit=args.data_limit,
455 | )
456 |
457 | if hasattr(args, "test_acc"):
458 | print(f"Test Acc: {args.test_acc}")
459 |
460 | if metrics is not None:
461 | print(metrics.log())
462 | for key, value in metrics.score().items():
463 | wandb.run.summary[key] = value
464 |
465 |
466 | def parse_args() -> argparse.Namespace:
467 | parser = argparse.ArgumentParser()
468 |
469 | parser.add_argument("-c", "--config", type=str, help="path to config file (json)")
470 |
471 | parser.add_argument("--seed", type=int, default=42)
472 | parser.add_argument("--no_deterministic", action="store_false")
473 |
474 | parser.add_argument("-n", "--notes", type=str, default="")
475 |
476 | # Model
477 | parser.add_argument("-m", "--model", choices=ALL_MODELS, help="model name")
478 | parser.add_argument(
479 | "-add_ab",
480 | "--add_attention_branch",
481 | action="store_true",
482 | help="add Attention Branch",
483 | )
484 | parser.add_argument(
485 | "--div",
486 | type=str,
487 | choices=["layer1", "layer2", "layer3"],
488 | default="layer2",
489 | help="place to attention branch",
490 | )
491 | parser.add_argument("--base_pretrained", type=str, help="path to base pretrained")
492 | parser.add_argument(
493 | "--base_pretrained2",
494 | type=str,
495 | help="path to base pretrained2 ( after change_num_classes() )",
496 | )
497 | parser.add_argument("--pretrained", type=str, help="path to pretrained")
498 | parser.add_argument(
499 | "--orig_model",
500 | action="store_true",
501 | help="calc insdel score by using original model",
502 | )
503 | parser.add_argument(
504 | "--theta_att", type=float, default=0, help="threthold of attention branch"
505 | )
506 |
507 | # Dataset
508 | parser.add_argument("--dataset", type=str, default="IDRiD", choices=ALL_DATASETS)
509 | parser.add_argument("--data_limit", type=int, default=-1)
510 | parser.add_argument("--image_size", type=int, default=224)
511 | parser.add_argument("--batch_size", type=int, default=16)
512 | parser.add_argument(
513 | "--loss_weights",
514 | type=float,
515 | nargs="*",
516 | default=[1.0, 1.0],
517 | help="weights for label by class",
518 | )
519 |
520 | parser.add_argument("--root_dir", type=str, default="./outputs/")
521 | parser.add_argument("--visualize_only", action="store_false")
522 | parser.add_argument("--all_class", action="store_true")
523 | # recommend (step, size) in 512x512 = (1, 10000), (2, 2500), (4, 500), (8, 100), (16, 20), (32, 10), (64, 5), (128, 1)
524 | # recommend (step, size) in 224x224 = (1, 500), (2, 100), (4, 20), (8, 10), (16, 5), (32, 1)
525 | parser.add_argument("--insdel_step", type=int, default=500)
526 | parser.add_argument("--block_size", type=int, default=1)
527 |
528 | parser.add_argument(
529 | "--method",
530 | type=str,
531 | default="gradcam",
532 | )
533 | parser.add_argument("--normalize", action="store_true")
534 | parser.add_argument("--sign", type=str, default="all")
535 |
536 | parser.add_argument("--num_masks", type=int, default=5000)
537 | parser.add_argument("--rise_scale", type=int, default=9)
538 | parser.add_argument(
539 | "--p1", type=float, default=0.3, help="percentage of mask [pixel = (0, 0, 0)]"
540 | )
541 |
542 | parser.add_argument("--attention_dir", type=str, help="path to attention npy file")
543 |
544 | parser.add_argument("--use-c1c", action="store_true", help="use C1C technique")
545 |
546 | parser.add_argument("--heat-quantization", action="store_true", help="use heat quantization technique")
547 | parser.add_argument("--hq-level", type=int, default=8, help="number of quantization level")
548 |
549 | parser.add_argument("--skip-connection-prop-type", type=str, default="latest", help="type of skip connection propagation")
550 |
551 | return parse_with_config(parser)
552 |
553 |
554 | if __name__ == "__main__":
555 | # import pdb; pdb.set_trace()
556 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
557 |
558 | main(parse_args())
559 |
--------------------------------------------------------------------------------