├── .gitattributes
├── .github
└── workflows
│ └── python-app.yml
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.md
├── cam.py
├── examples
├── augsmooth.jpg
├── both.png
├── both_detection.png
├── cam_gb_dog.jpg
├── cars_segmentation.png
├── cat.jpg
├── clip_cat.jpg
├── clip_dog.jpg
├── dff1.png
├── dff2.png
├── dog.jpg
├── dog_cat.jfif
├── dogs.png
├── dogs_gradcam++_resnet50.jpg
├── dogs_gradcam++_vgg16.jpg
├── dogs_gradcam_resnet50.jpg
├── dogs_gradcam_vgg16.jpg
├── dogs_scorecam_resnet50.jpg
├── dogs_scorecam_vgg16.jpg
├── eigenaug.jpg
├── eigensmooth.jpg
├── embeddings.png
├── horses.jpg
├── metrics.png
├── multiorgan_segmentation.gif
├── nosmooth.jpg
├── resnet50_cat_ablationcam_cam.jpg
├── resnet50_cat_gradcam_cam.jpg
├── resnet50_cat_scorecam_cam.jpg
├── resnet50_dog_ablationcam_cam.jpg
├── resnet50_dog_gradcam_cam.jpg
├── resnet50_dog_scorecam_cam.jpg
├── resnet_horses_ablationcam_cam.jpg
├── resnet_horses_gradcam++_cam.jpg
├── resnet_horses_gradcam_cam.jpg
├── resnet_horses_horses_eigencam_cam.jpg
├── resnet_horses_scorecam_cam.jpg
├── road.png
├── swinT_cat_ablationcam_cam.jpg
├── swinT_cat_gradcam_cam.jpg
├── swinT_cat_scorecam_cam.jpg
├── swinT_dog_ablationcam_cam.jpg
├── swinT_dog_gradcam_cam.jpg
├── swinT_dog_scorecam_cam.jpg
├── vgg_horses_ablationcam_cam.jpg
├── vgg_horses_eigencam_cam.jpg
├── vgg_horses_gradcam++_cam.jpg
├── vgg_horses_gradcam_cam.jpg
├── vgg_horses_scorecam_cam.jpg
├── vit_cat_ablationcam_cam.jpg
├── vit_cat_gradcam_cam.jpg
├── vit_cat_scorecam_cam.jpg
├── vit_dog_ablationcam_cam.jpg
├── vit_dog_gradcam_cam.jpg
├── vit_dog_scorecam_cam.jpg
└── yolo_eigencam.png
├── pyproject.toml
├── pytorch_grad_cam
├── __init__.py
├── ablation_cam.py
├── ablation_cam_multilayer.py
├── ablation_layer.py
├── activations_and_gradients.py
├── base_cam.py
├── eigen_cam.py
├── eigen_grad_cam.py
├── feature_factorization
│ ├── __init__.py
│ └── deep_feature_factorization.py
├── fem.py
├── finer_cam.py
├── fullgrad_cam.py
├── grad_cam.py
├── grad_cam_elementwise.py
├── grad_cam_plusplus.py
├── guided_backprop.py
├── hirescam.py
├── kpca_cam.py
├── layer_cam.py
├── metrics
│ ├── __init__.py
│ ├── cam_mult_image.py
│ ├── perturbation_confidence.py
│ └── road.py
├── random_cam.py
├── score_cam.py
├── shapley_cam.py
├── sobel_cam.py
├── utils
│ ├── __init__.py
│ ├── find_layers.py
│ ├── image.py
│ ├── model_targets.py
│ ├── reshape_transforms.py
│ └── svd_on_activations.py
└── xgrad_cam.py
├── requirements.txt
├── setup.cfg
├── setup.py
├── tests
├── test_context_release.py
├── test_one_channel.py
└── test_run_all_models.py
├── tutorials
├── CAM Metrics And Tuning Tutorial.ipynb
├── Class Activation Maps for Object Detection With Faster RCNN.ipynb
├── Class Activation Maps for Semantic Segmentation.ipynb
├── Deep Feature Factorizations.ipynb
├── EigenCAM for YOLO5.ipynb
├── HuggingFace.ipynb
├── Pixel Attribution for embeddings.ipynb
├── bbox.png
├── huggingface_dff.png
├── multimage.png
├── puppies.jpg
└── vision_transformers.md
└── usage_examples
├── clip_example.py
├── swinT_example.py
└── vit_example.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | tutorials/*.ipynb linguist-documentation=true
--------------------------------------------------------------------------------
/.github/workflows/python-app.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Tests
5 |
6 | on: [push, pull_request]
7 |
8 | jobs:
9 | build:
10 |
11 | runs-on: ubuntu-latest
12 |
13 | steps:
14 | - uses: actions/checkout@v2
15 | - name: Set up Python 3.8
16 | uses: actions/setup-python@v2
17 | with:
18 | python-version: 3.8
19 | - name: Install dependencies
20 | run: |
21 | python -m pip install --upgrade pip
22 | pip install flake8 pytest==5.3.2 pytest-faulthandler psutil
23 | pip install -e .
24 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
25 | - name: Lint with flake8
26 | run: |
27 | # stop the build if there are Python syntax errors or undefined names
28 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
29 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
30 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
31 | - name: Test with pytest
32 | run: |
33 | pytest
34 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | build
2 | dist
3 | *.egg-info
4 | **/*.pyc
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Jacob Gildenblat
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | prune tutorials
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://opensource.org/licenses/MIT)
2 | 
3 | [](https://pepy.tech/project/grad-cam)
4 | [](https://pepy.tech/project/grad-cam)
5 |
6 | # Advanced AI explainability for PyTorch
7 |
8 | `pip install grad-cam`
9 |
10 | Documentation with advanced tutorials: [https://jacobgil.github.io/pytorch-gradcam-book](https://jacobgil.github.io/pytorch-gradcam-book)
11 |
12 |
13 | This is a package with state of the art methods for Explainable AI for computer vision.
14 | This can be used for diagnosing model predictions, either in production or while
15 | developing models.
16 | The aim is also to serve as a benchmark of algorithms and metrics for research of new explainability methods.
17 |
18 | ⭐ Comprehensive collection of Pixel Attribution methods for Computer Vision.
19 |
20 | ⭐ Tested on many Common CNN Networks and Vision Transformers.
21 |
22 | ⭐ Advanced use cases: Works with Classification, Object Detection, Semantic Segmentation, Embedding-similarity and more.
23 |
24 | ⭐ Includes smoothing methods to make the CAMs look nice.
25 |
26 | ⭐ High performance: full support for batches of images in all methods.
27 |
28 | ⭐ Includes metrics for checking if you can trust the explanations, and tuning them for best performance.
29 |
30 |
31 | 
33 |
34 | | Method | What it does |
35 | |---------------------|-----------------------------------------------------------------------------------------------------------------------------|
36 | | GradCAM | Weight the 2D activations by the average gradient |
37 | | HiResCAM | Like GradCAM but element-wise multiply the activations with the gradients; provably guaranteed faithfulness for certain models |
38 | | GradCAMElementWise | Like GradCAM but element-wise multiply the activations with the gradients then apply a ReLU operation before summing |
39 | | GradCAM++ | Like GradCAM but uses second order gradients |
40 | | XGradCAM | Like GradCAM but scale the gradients by the normalized activations |
41 | | AblationCAM | Zero out activations and measure how the output drops (this repository includes a fast batched implementation) |
42 | | ScoreCAM | Perbutate the image by the scaled activations and measure how the output drops |
43 | | EigenCAM | Takes the first principle component of the 2D Activations (no class discrimination, but seems to give great results) |
44 | | EigenGradCAM | Like EigenCAM but with class discrimination: First principle component of Activations*Grad. Looks like GradCAM, but cleaner |
45 | | LayerCAM | Spatially weight the activations by positive gradients. Works better especially in lower layers |
46 | | FullGrad | Computes the gradients of the biases from all over the network, and then sums them |
47 | | Deep Feature Factorizations | Non Negative Matrix Factorization on the 2D activations |
48 | | KPCA-CAM | Like EigenCAM but with Kernel PCA instead of PCA |
49 | | FEM | A gradient free method that binarizes activations by an activation > mean + k * std rule. |
50 | | ShapleyCAM | Weight the activations using the gradient and Hessian-vector product.|
51 | | FinerCAM | Improves fine-grained classification by comparing similar classes, suppressing shared features and highlighting discriminative details. |
52 | ## Visual Examples
53 |
54 | | What makes the network think the image label is 'pug, pug-dog' | What makes the network think the image label is 'tabby, tabby cat' | Combining Grad-CAM with Guided Backpropagation for the 'pug, pug-dog' class |
55 | | ---------------------------------------------------------------|--------------------|-----------------------------------------------------------------------------|
56 |
|
|
|
57 |
58 | ## Object Detection and Semantic Segmentation
59 | | Object Detection | Semantic Segmentation |
60 | | -----------------|-----------------------|
61 | |
|
|
62 |
63 | | 3D Medical Semantic Segmentation |
64 | | -------------------------- |
65 | |
|
66 |
67 | ## Explaining similarity to other images / embeddings
68 |
69 |
70 | ## Deep Feature Factorization
71 |
72 |
73 |
74 | ## CLIP
75 | | Explaining the text prompt "a dog" | Explaining the text prompt "a cat" |
76 | | -----------------------------------|------------------------------------|
77 |
|
|
78 |
79 | ## Classification
80 |
81 | #### Resnet50:
82 | | Category | Image | GradCAM | AblationCAM | ScoreCAM |
83 | | ---------|-------|----------|------------|------------|
84 | | Dog |  |  |  | |
85 | | Cat |  |  |  | |
86 |
87 | #### Vision Transfomer (Deit Tiny):
88 | | Category | Image | GradCAM | AblationCAM | ScoreCAM |
89 | | ---------|-------|----------|------------|------------|
90 | | Dog |  |  |  | |
91 | | Cat |  |  |  | |
92 |
93 | #### Swin Transfomer (Tiny window:7 patch:4 input-size:224):
94 | | Category | Image | GradCAM | AblationCAM | ScoreCAM |
95 | | ---------|-------|----------|------------|------------|
96 | | Dog |  |  |  | |
97 | | Cat |  |  |  | |
98 |
99 |
100 | # Metrics and Evaluation for XAI
101 |
102 |
103 |
104 |
105 | ----------
106 |
107 | # Usage examples
108 |
109 | ```python
110 | from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
111 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
112 | from pytorch_grad_cam.utils.image import show_cam_on_image
113 | from torchvision.models import resnet50
114 |
115 | model = resnet50(pretrained=True)
116 | target_layers = [model.layer4[-1]]
117 | input_tensor = # Create an input tensor image for your model..
118 | # Note: input_tensor can be a batch tensor with several images!
119 |
120 | # We have to specify the target we want to generate the CAM for.
121 | targets = [ClassifierOutputTarget(281)]
122 |
123 | # Construct the CAM object once, and then re-use it on many images.
124 | with GradCAM(model=model, target_layers=target_layers) as cam:
125 | # You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
126 | grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
127 | # In this example grayscale_cam has only one image in the batch:
128 | grayscale_cam = grayscale_cam[0, :]
129 | visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
130 | # You can also get the model outputs without having to redo inference
131 | model_outputs = cam.outputs
132 | ```
133 |
134 | [cam.py](https://github.com/jacobgil/pytorch-grad-cam/blob/master/cam.py) has a more detailed usage example.
135 |
136 | ----------
137 | # Choosing the layer(s) to extract activations from
138 | You need to choose the target layer to compute the CAM for.
139 | Some common choices are:
140 | - FasterRCNN: model.backbone
141 | - Resnet18 and 50: model.layer4[-1]
142 | - VGG, densenet161 and mobilenet: model.features[-1]
143 | - mnasnet1_0: model.layers[-1]
144 | - ViT: model.blocks[-1].norm1
145 | - SwinT: model.layers[-1].blocks[-1].norm1
146 |
147 |
148 | If you pass a list with several layers, the CAM will be averaged accross them.
149 | This can be useful if you're not sure what layer will perform best.
150 |
151 | ----------
152 |
153 | # Adapting for new architectures and tasks
154 |
155 | Methods like GradCAM were designed for and were originally mostly applied on classification models,
156 | and specifically CNN classification models.
157 | However you can also use this package on new architectures like Vision Transformers, and on non classification tasks like Object Detection or Semantic Segmentation.
158 |
159 | The be able to adapt to non standard cases, we have two concepts.
160 | - The reshape transform - how do we convert activations to represent spatial images ?
161 | - The model targets - What exactly should the explainability method try to explain ?
162 |
163 | ## The reshape_transform argument
164 | In a CNN the intermediate activations in the model are a mult-channel image that have the dimensions channel x rows x cols,
165 | and the various explainabiltiy methods work with these to produce a new image.
166 |
167 | In case of another architecture, like the Vision Transformer, the shape might be different, like (rows x cols + 1) x channels, or something else.
168 | The reshape transform converts the activations back into a multi-channel image, for example by removing the class token in a vision transformer.
169 | For examples, check [here](https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/reshape_transforms.py)
170 |
171 | ## The model_target argument
172 | The model target is just a callable that is able to get the model output, and filter it out for the specific scalar output we want to explain.
173 |
174 | For classification tasks, the model target will typically be the output from a specific category.
175 | The `targets` parameter passed to the CAM method can then use `ClassifierOutputTarget`:
176 | ```python
177 | targets = [ClassifierOutputTarget(281)]
178 | ```
179 |
180 | However for more advanced cases, you might want a different behaviour.
181 | Check [here](https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/model_targets.py) for more examples.
182 |
183 | ----------
184 |
185 | # Tutorials
186 | Here you can find detailed examples of how to use this for various custom use cases like object detection:
187 |
188 | These point to the new documentation jupter-book for fast rendering.
189 | The jupyter notebooks themselves can be found under the tutorials folder in the git repository.
190 |
191 | - [Notebook tutorial: XAI Recipes for the HuggingFace 🤗 Image Classification Models]()
192 |
193 | - [Notebook tutorial: Deep Feature Factorizations for better model explainability]()
194 |
195 | - [Notebook tutorial: Class Activation Maps for Object Detection with Faster-RCNN]()
196 |
197 | - [Notebook tutorial: Class Activation Maps for YOLO5]()
198 |
199 | - [Notebook tutorial: Class Activation Maps for Semantic Segmentation]()
200 |
201 | - [Notebook tutorial: Adapting pixel attribution methods for embedding outputs from models]()
202 |
203 | - [Notebook tutorial: May the best explanation win. CAM Metrics and Tuning]()
204 |
205 | - [How it works with Vision/SwinT transformers](tutorials/vision_transformers.md)
206 |
207 |
208 | ----------
209 |
210 | # Guided backpropagation
211 |
212 | ```python
213 | from pytorch_grad_cam import GuidedBackpropReLUModel
214 | from pytorch_grad_cam.utils.image import (
215 | show_cam_on_image, deprocess_image, preprocess_image
216 | )
217 | gb_model = GuidedBackpropReLUModel(model=model, device=model.device())
218 | gb = gb_model(input_tensor, target_category=None)
219 |
220 | cam_mask = cv2.merge([grayscale_cam, grayscale_cam, grayscale_cam])
221 | cam_gb = deprocess_image(cam_mask * gb)
222 | result = deprocess_image(gb)
223 | ```
224 |
225 | ----------
226 |
227 | # Metrics and evaluating the explanations
228 |
229 | ```python
230 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputSoftmaxTarget
231 | from pytorch_grad_cam.metrics.cam_mult_image import CamMultImageConfidenceChange
232 | # Create the metric target, often the confidence drop in a score of some category
233 | metric_target = ClassifierOutputSoftmaxTarget(281)
234 | scores, batch_visualizations = CamMultImageConfidenceChange()(input_tensor,
235 | inverse_cams, targets, model, return_visualization=True)
236 | visualization = deprocess_image(batch_visualizations[0, :])
237 |
238 | # State of the art metric: Remove and Debias
239 | from pytorch_grad_cam.metrics.road import ROADMostRelevantFirst, ROADLeastRelevantFirst
240 | cam_metric = ROADMostRelevantFirst(percentile=75)
241 | scores, perturbation_visualizations = cam_metric(input_tensor,
242 | grayscale_cams, targets, model, return_visualization=True)
243 |
244 | # You can also average across different percentiles, and combine
245 | # (LeastRelevantFirst - MostRelevantFirst) / 2
246 | from pytorch_grad_cam.metrics.road import ROADMostRelevantFirstAverage,
247 | ROADLeastRelevantFirstAverage,
248 | ROADCombined
249 | cam_metric = ROADCombined(percentiles=[20, 40, 60, 80])
250 | scores = cam_metric(input_tensor, grayscale_cams, targets, model)
251 | ```
252 |
253 |
254 | # Smoothing to get nice looking CAMs
255 |
256 | To reduce noise in the CAMs, and make it fit better on the objects,
257 | two smoothing methods are supported:
258 |
259 | - `aug_smooth=True`
260 |
261 | Test time augmentation: increases the run time by x6.
262 |
263 | Applies a combination of horizontal flips, and mutiplying the image
264 | by [1.0, 1.1, 0.9].
265 |
266 | This has the effect of better centering the CAM around the objects.
267 |
268 |
269 | - `eigen_smooth=True`
270 |
271 | First principle component of `activations*weights`
272 |
273 | This has the effect of removing a lot of noise.
274 |
275 |
276 | |AblationCAM | aug smooth | eigen smooth | aug+eigen smooth|
277 | |------------|------------|--------------|--------------------|
278 |  |  |  |  |
279 |
280 | ----------
281 |
282 | # Running the example script:
283 |
284 | Usage: `python cam.py --image-path --method --output-dir `
285 |
286 |
287 | To use with a specific device, like cpu, cuda, cuda:0, mps or hpu:
288 | `python cam.py --image-path --device cuda --output-dir `
289 |
290 | ----------
291 |
292 | You can choose between:
293 |
294 | `GradCAM` , `HiResCAM`, `ScoreCAM`, `GradCAMPlusPlus`, `AblationCAM`, `XGradCAM` , `LayerCAM`, `FullGrad`, `EigenCAM`, `ShapleyCAM`, and `FinerCAM`.
295 |
296 | Some methods like ScoreCAM and AblationCAM require a large number of forward passes,
297 | and have a batched implementation.
298 |
299 | You can control the batch size with
300 | `cam.batch_size = `
301 |
302 | ----------
303 |
304 | ## Citation
305 | If you use this for research, please cite. Here is an example BibTeX entry:
306 |
307 | ```
308 | @misc{jacobgilpytorchcam,
309 | title={PyTorch library for CAM methods},
310 | author={Jacob Gildenblat and contributors},
311 | year={2021},
312 | publisher={GitHub},
313 | howpublished={\url{https://github.com/jacobgil/pytorch-grad-cam}},
314 | }
315 | ```
316 |
317 | ----------
318 |
319 | # References
320 | https://arxiv.org/abs/1610.02391
321 | `Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization
322 | Ramprasaath R. Selvaraju, Michael Cogswell, Abhishek Das, Ramakrishna Vedantam, Devi Parikh, Dhruv Batra`
323 |
324 | https://arxiv.org/abs/2011.08891
325 | `Use HiResCAM instead of Grad-CAM for faithful explanations of convolutional neural networks
326 | Rachel L. Draelos, Lawrence Carin`
327 |
328 | https://arxiv.org/abs/1710.11063
329 | `Grad-CAM++: Improved Visual Explanations for Deep Convolutional Networks
330 | Aditya Chattopadhyay, Anirban Sarkar, Prantik Howlader, Vineeth N Balasubramanian`
331 |
332 | https://arxiv.org/abs/1910.01279
333 | `Score-CAM: Score-Weighted Visual Explanations for Convolutional Neural Networks
334 | Haofan Wang, Zifan Wang, Mengnan Du, Fan Yang, Zijian Zhang, Sirui Ding, Piotr Mardziel, Xia Hu`
335 |
336 | https://ieeexplore.ieee.org/abstract/document/9093360/
337 | `Ablation-cam: Visual explanations for deep convolutional network via gradient-free localization.
338 | Saurabh Desai and Harish G Ramaswamy. In WACV, pages 972–980, 2020`
339 |
340 | https://arxiv.org/abs/2008.02312
341 | `Axiom-based Grad-CAM: Towards Accurate Visualization and Explanation of CNNs
342 | Ruigang Fu, Qingyong Hu, Xiaohu Dong, Yulan Guo, Yinghui Gao, Biao Li`
343 |
344 | https://arxiv.org/abs/2008.00299
345 | `Eigen-CAM: Class Activation Map using Principal Components
346 | Mohammed Bany Muhammad, Mohammed Yeasin`
347 |
348 | http://mftp.mmcheng.net/Papers/21TIP_LayerCAM.pdf
349 | `LayerCAM: Exploring Hierarchical Class Activation Maps for Localization
350 | Peng-Tao Jiang; Chang-Bin Zhang; Qibin Hou; Ming-Ming Cheng; Yunchao Wei`
351 |
352 | https://arxiv.org/abs/1905.00780
353 | `Full-Gradient Representation for Neural Network Visualization
354 | Suraj Srinivas, Francois Fleuret`
355 |
356 | https://arxiv.org/abs/1806.10206
357 | `Deep Feature Factorization For Concept Discovery
358 | Edo Collins, Radhakrishna Achanta, Sabine Süsstrunk`
359 |
360 | https://arxiv.org/abs/2410.00267
361 | `KPCA-CAM: Visual Explainability of Deep Computer Vision Models using Kernel PCA
362 | Sachin Karmani, Thanushon Sivakaran, Gaurav Prasad, Mehmet Ali, Wenbo Yang, Sheyang Tang`
363 |
364 | https://hal.science/hal-02963298/document
365 | `Features Understanding in 3D CNNs for Actions Recognition in Video
366 | Kazi Ahmed Asif Fuad, Pierre-Etienne Martin, Romain Giot, Romain
367 | Bourqui, Jenny Benois-Pineau, Akka Zemmar`
368 |
369 | https://arxiv.org/abs/2501.06261
370 | `CAMs as Shapley Value-based Explainers
371 | Huaiguang Cai`
372 |
373 |
374 | https://arxiv.org/pdf/2501.11309
375 | `Finer-CAM : Spotting the Difference Reveals Finer Details for Visual Explanation`
376 | `Ziheng Zhang*, Jianyang Gu*, Arpita Chowdhury, Zheda Mai, David Carlyn,Tanya Berger-Wolf, Yu Su, Wei-Lun Chao`
377 |
--------------------------------------------------------------------------------
/cam.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import cv2
4 | import numpy as np
5 | import torch
6 | from torchvision import models
7 | from pytorch_grad_cam import (
8 | GradCAM, FEM, HiResCAM, ScoreCAM, GradCAMPlusPlus,
9 | AblationCAM, XGradCAM, EigenCAM, EigenGradCAM,
10 | LayerCAM, FullGrad, GradCAMElementWise, KPCA_CAM, ShapleyCAM,
11 | FinerCAM
12 | )
13 | from pytorch_grad_cam import GuidedBackpropReLUModel
14 | from pytorch_grad_cam.utils.image import (
15 | show_cam_on_image, deprocess_image, preprocess_image
16 | )
17 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget, ClassifierOutputReST
18 |
19 |
20 | def get_args():
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument('--device', type=str, default='cpu',
23 | help='Torch device to use')
24 | parser.add_argument(
25 | '--image-path',
26 | type=str,
27 | default='./examples/both.png',
28 | help='Input image path')
29 | parser.add_argument('--aug-smooth', action='store_true',
30 | help='Apply test time augmentation to smooth the CAM')
31 | parser.add_argument(
32 | '--eigen-smooth',
33 | action='store_true',
34 | help='Reduce noise by taking the first principle component'
35 | 'of cam_weights*activations')
36 | parser.add_argument('--method', type=str, default='gradcam',
37 | choices=[
38 | 'gradcam', 'fem', 'hirescam', 'gradcam++',
39 | 'scorecam', 'xgradcam', 'ablationcam',
40 | 'eigencam', 'eigengradcam', 'layercam',
41 | 'fullgrad', 'gradcamelementwise', 'kpcacam', 'shapleycam',
42 | 'finercam'
43 | ],
44 | help='CAM method')
45 |
46 | parser.add_argument('--output-dir', type=str, default='output',
47 | help='Output directory to save the images')
48 | args = parser.parse_args()
49 |
50 | if args.device:
51 | print(f'Using device "{args.device}" for acceleration')
52 | else:
53 | print('Using CPU for computation')
54 |
55 | return args
56 |
57 |
58 | if __name__ == '__main__':
59 | """ python cam.py -image-path
60 | Example usage of loading an image and computing:
61 | 1. CAM
62 | 2. Guided Back Propagation
63 | 3. Combining both
64 | """
65 |
66 | args = get_args()
67 | methods = {
68 | "gradcam": GradCAM,
69 | "hirescam": HiResCAM,
70 | "scorecam": ScoreCAM,
71 | "gradcam++": GradCAMPlusPlus,
72 | "ablationcam": AblationCAM,
73 | "xgradcam": XGradCAM,
74 | "eigencam": EigenCAM,
75 | "eigengradcam": EigenGradCAM,
76 | "layercam": LayerCAM,
77 | "fullgrad": FullGrad,
78 | "fem": FEM,
79 | "gradcamelementwise": GradCAMElementWise,
80 | 'kpcacam': KPCA_CAM,
81 | 'shapleycam': ShapleyCAM,
82 | 'finercam': FinerCAM
83 | }
84 |
85 | if args.device=='hpu':
86 | import habana_frameworks.torch.core as htcore
87 |
88 | model = models.resnet50(pretrained=True).to(torch.device(args.device)).eval()
89 |
90 | # Choose the target layer you want to compute the visualization for.
91 | # Usually this will be the last convolutional layer in the model.
92 | # Some common choices can be:
93 | # Resnet18 and 50: model.layer4
94 | # VGG, densenet161: model.features[-1]
95 | # mnasnet1_0: model.layers[-1]
96 | # You can print the model to help chose the layer
97 | # You can pass a list with several target layers,
98 | # in that case the CAMs will be computed per layer and then aggregated.
99 | # You can also try selecting all layers of a certain type, with e.g:
100 | # from pytorch_grad_cam.utils.find_layers import find_layer_types_recursive
101 | # find_layer_types_recursive(model, [torch.nn.ReLU])
102 |
103 | target_layers = [model.layer4]
104 |
105 | rgb_img = cv2.imread(args.image_path, 1)[:, :, ::-1]
106 | rgb_img = np.float32(rgb_img) / 255
107 | input_tensor = preprocess_image(rgb_img,
108 | mean=[0.485, 0.456, 0.406],
109 | std=[0.229, 0.224, 0.225]).to(args.device)
110 |
111 | # We have to specify the target we want to generate
112 | # the Class Activation Maps for.
113 | # If targets is None, the highest scoring category (for every member in the batch) will be used.
114 | # You can target specific categories by
115 | # targets = [ClassifierOutputTarget(281)]
116 | # targets = [ClassifierOutputReST(281)]
117 | targets = None
118 |
119 | # Using the with statement ensures the context is freed, and you can
120 | # recreate different CAM objects in a loop.
121 | cam_algorithm = methods[args.method]
122 | with cam_algorithm(model=model,
123 | target_layers=target_layers) as cam:
124 |
125 | # AblationCAM and ScoreCAM have batched implementations.
126 | # You can override the internal batch size for faster computation.
127 | cam.batch_size = 32
128 | grayscale_cam = cam(input_tensor=input_tensor,
129 | targets=targets,
130 | aug_smooth=args.aug_smooth,
131 | eigen_smooth=args.eigen_smooth)
132 |
133 | grayscale_cam = grayscale_cam[0, :]
134 |
135 | cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
136 | cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)
137 |
138 | gb_model = GuidedBackpropReLUModel(model=model, device=args.device)
139 | gb = gb_model(input_tensor, target_category=None)
140 |
141 | cam_mask = cv2.merge([grayscale_cam, grayscale_cam, grayscale_cam])
142 | cam_gb = deprocess_image(cam_mask * gb)
143 | gb = deprocess_image(gb)
144 |
145 | os.makedirs(args.output_dir, exist_ok=True)
146 |
147 | cam_output_path = os.path.join(args.output_dir, f'{args.method}_cam.jpg')
148 | gb_output_path = os.path.join(args.output_dir, f'{args.method}_gb.jpg')
149 | cam_gb_output_path = os.path.join(args.output_dir, f'{args.method}_cam_gb.jpg')
150 |
151 | cv2.imwrite(cam_output_path, cam_image)
152 | cv2.imwrite(gb_output_path, gb)
153 | cv2.imwrite(cam_gb_output_path, cam_gb)
--------------------------------------------------------------------------------
/examples/augsmooth.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/augsmooth.jpg
--------------------------------------------------------------------------------
/examples/both.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/both.png
--------------------------------------------------------------------------------
/examples/both_detection.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/both_detection.png
--------------------------------------------------------------------------------
/examples/cam_gb_dog.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/cam_gb_dog.jpg
--------------------------------------------------------------------------------
/examples/cars_segmentation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/cars_segmentation.png
--------------------------------------------------------------------------------
/examples/cat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/cat.jpg
--------------------------------------------------------------------------------
/examples/clip_cat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/clip_cat.jpg
--------------------------------------------------------------------------------
/examples/clip_dog.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/clip_dog.jpg
--------------------------------------------------------------------------------
/examples/dff1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/dff1.png
--------------------------------------------------------------------------------
/examples/dff2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/dff2.png
--------------------------------------------------------------------------------
/examples/dog.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/dog.jpg
--------------------------------------------------------------------------------
/examples/dog_cat.jfif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/dog_cat.jfif
--------------------------------------------------------------------------------
/examples/dogs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/dogs.png
--------------------------------------------------------------------------------
/examples/dogs_gradcam++_resnet50.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/dogs_gradcam++_resnet50.jpg
--------------------------------------------------------------------------------
/examples/dogs_gradcam++_vgg16.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/dogs_gradcam++_vgg16.jpg
--------------------------------------------------------------------------------
/examples/dogs_gradcam_resnet50.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/dogs_gradcam_resnet50.jpg
--------------------------------------------------------------------------------
/examples/dogs_gradcam_vgg16.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/dogs_gradcam_vgg16.jpg
--------------------------------------------------------------------------------
/examples/dogs_scorecam_resnet50.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/dogs_scorecam_resnet50.jpg
--------------------------------------------------------------------------------
/examples/dogs_scorecam_vgg16.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/dogs_scorecam_vgg16.jpg
--------------------------------------------------------------------------------
/examples/eigenaug.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/eigenaug.jpg
--------------------------------------------------------------------------------
/examples/eigensmooth.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/eigensmooth.jpg
--------------------------------------------------------------------------------
/examples/embeddings.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/embeddings.png
--------------------------------------------------------------------------------
/examples/horses.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/horses.jpg
--------------------------------------------------------------------------------
/examples/metrics.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/metrics.png
--------------------------------------------------------------------------------
/examples/multiorgan_segmentation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/multiorgan_segmentation.gif
--------------------------------------------------------------------------------
/examples/nosmooth.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/nosmooth.jpg
--------------------------------------------------------------------------------
/examples/resnet50_cat_ablationcam_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/resnet50_cat_ablationcam_cam.jpg
--------------------------------------------------------------------------------
/examples/resnet50_cat_gradcam_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/resnet50_cat_gradcam_cam.jpg
--------------------------------------------------------------------------------
/examples/resnet50_cat_scorecam_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/resnet50_cat_scorecam_cam.jpg
--------------------------------------------------------------------------------
/examples/resnet50_dog_ablationcam_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/resnet50_dog_ablationcam_cam.jpg
--------------------------------------------------------------------------------
/examples/resnet50_dog_gradcam_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/resnet50_dog_gradcam_cam.jpg
--------------------------------------------------------------------------------
/examples/resnet50_dog_scorecam_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/resnet50_dog_scorecam_cam.jpg
--------------------------------------------------------------------------------
/examples/resnet_horses_ablationcam_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/resnet_horses_ablationcam_cam.jpg
--------------------------------------------------------------------------------
/examples/resnet_horses_gradcam++_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/resnet_horses_gradcam++_cam.jpg
--------------------------------------------------------------------------------
/examples/resnet_horses_gradcam_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/resnet_horses_gradcam_cam.jpg
--------------------------------------------------------------------------------
/examples/resnet_horses_horses_eigencam_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/resnet_horses_horses_eigencam_cam.jpg
--------------------------------------------------------------------------------
/examples/resnet_horses_scorecam_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/resnet_horses_scorecam_cam.jpg
--------------------------------------------------------------------------------
/examples/road.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/road.png
--------------------------------------------------------------------------------
/examples/swinT_cat_ablationcam_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/swinT_cat_ablationcam_cam.jpg
--------------------------------------------------------------------------------
/examples/swinT_cat_gradcam_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/swinT_cat_gradcam_cam.jpg
--------------------------------------------------------------------------------
/examples/swinT_cat_scorecam_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/swinT_cat_scorecam_cam.jpg
--------------------------------------------------------------------------------
/examples/swinT_dog_ablationcam_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/swinT_dog_ablationcam_cam.jpg
--------------------------------------------------------------------------------
/examples/swinT_dog_gradcam_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/swinT_dog_gradcam_cam.jpg
--------------------------------------------------------------------------------
/examples/swinT_dog_scorecam_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/swinT_dog_scorecam_cam.jpg
--------------------------------------------------------------------------------
/examples/vgg_horses_ablationcam_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/vgg_horses_ablationcam_cam.jpg
--------------------------------------------------------------------------------
/examples/vgg_horses_eigencam_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/vgg_horses_eigencam_cam.jpg
--------------------------------------------------------------------------------
/examples/vgg_horses_gradcam++_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/vgg_horses_gradcam++_cam.jpg
--------------------------------------------------------------------------------
/examples/vgg_horses_gradcam_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/vgg_horses_gradcam_cam.jpg
--------------------------------------------------------------------------------
/examples/vgg_horses_scorecam_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/vgg_horses_scorecam_cam.jpg
--------------------------------------------------------------------------------
/examples/vit_cat_ablationcam_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/vit_cat_ablationcam_cam.jpg
--------------------------------------------------------------------------------
/examples/vit_cat_gradcam_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/vit_cat_gradcam_cam.jpg
--------------------------------------------------------------------------------
/examples/vit_cat_scorecam_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/vit_cat_scorecam_cam.jpg
--------------------------------------------------------------------------------
/examples/vit_dog_ablationcam_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/vit_dog_ablationcam_cam.jpg
--------------------------------------------------------------------------------
/examples/vit_dog_gradcam_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/vit_dog_gradcam_cam.jpg
--------------------------------------------------------------------------------
/examples/vit_dog_scorecam_cam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/vit_dog_scorecam_cam.jpg
--------------------------------------------------------------------------------
/examples/yolo_eigencam.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/yolo_eigencam.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "setuptools>=42",
4 | "wheel"
5 | ]
6 | build-backend = "setuptools.build_meta"
--------------------------------------------------------------------------------
/pytorch_grad_cam/__init__.py:
--------------------------------------------------------------------------------
1 | from pytorch_grad_cam.grad_cam import GradCAM
2 | from pytorch_grad_cam.finer_cam import FinerCAM
3 | from pytorch_grad_cam.shapley_cam import ShapleyCAM
4 | from pytorch_grad_cam.fem import FEM
5 | from pytorch_grad_cam.hirescam import HiResCAM
6 | from pytorch_grad_cam.grad_cam_elementwise import GradCAMElementWise
7 | from pytorch_grad_cam.ablation_layer import AblationLayer, AblationLayerVit, AblationLayerFasterRCNN
8 | from pytorch_grad_cam.ablation_cam import AblationCAM
9 | from pytorch_grad_cam.xgrad_cam import XGradCAM
10 | from pytorch_grad_cam.grad_cam_plusplus import GradCAMPlusPlus
11 | from pytorch_grad_cam.score_cam import ScoreCAM
12 | from pytorch_grad_cam.layer_cam import LayerCAM
13 | from pytorch_grad_cam.eigen_cam import EigenCAM
14 | from pytorch_grad_cam.eigen_grad_cam import EigenGradCAM
15 | from pytorch_grad_cam.kpca_cam import KPCA_CAM
16 | from pytorch_grad_cam.random_cam import RandomCAM
17 | from pytorch_grad_cam.fullgrad_cam import FullGrad
18 | from pytorch_grad_cam.guided_backprop import GuidedBackpropReLUModel
19 | from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients
20 | from pytorch_grad_cam.feature_factorization.deep_feature_factorization import DeepFeatureFactorization, run_dff_on_image
21 | import pytorch_grad_cam.utils.model_targets
22 | import pytorch_grad_cam.utils.reshape_transforms
23 | import pytorch_grad_cam.metrics.cam_mult_image
24 | import pytorch_grad_cam.metrics.road
25 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/ablation_cam.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import tqdm
4 | from typing import Callable, List
5 | from pytorch_grad_cam.base_cam import BaseCAM
6 | from pytorch_grad_cam.utils.find_layers import replace_layer_recursive
7 | from pytorch_grad_cam.ablation_layer import AblationLayer
8 |
9 |
10 | """ Implementation of AblationCAM
11 | https://openaccess.thecvf.com/content_WACV_2020/papers/Desai_Ablation-CAM_Visual_Explanations_for_Deep_Convolutional_Network_via_Gradient-free_Localization_WACV_2020_paper.pdf
12 |
13 | Ablate individual activations, and then measure the drop in the target scores.
14 |
15 | In the current implementation, the target layer activations is cached, so it won't be re-computed.
16 | However layers before it, if any, will not be cached.
17 | This means that if the target layer is a large block, for example model.featuers (in vgg), there will
18 | be a large save in run time.
19 |
20 | Since we have to go over many channels and ablate them, and every channel ablation requires a forward pass,
21 | it would be nice if we could avoid doing that for channels that won't contribute anwyay, making it much faster.
22 | The parameter ratio_channels_to_ablate controls how many channels should be ablated, using an experimental method
23 | (to be improved). The default 1.0 value means that all channels will be ablated.
24 | """
25 |
26 |
27 | class AblationCAM(BaseCAM):
28 | def __init__(self,
29 | model: torch.nn.Module,
30 | target_layers: List[torch.nn.Module],
31 | reshape_transform: Callable = None,
32 | ablation_layer: torch.nn.Module = AblationLayer(),
33 | batch_size: int = 32,
34 | ratio_channels_to_ablate: float = 1.0) -> None:
35 |
36 | super(AblationCAM, self).__init__(model,
37 | target_layers,
38 | reshape_transform,
39 | uses_gradients=False)
40 | self.batch_size = batch_size
41 | self.ablation_layer = ablation_layer
42 | self.ratio_channels_to_ablate = ratio_channels_to_ablate
43 |
44 | def save_activation(self, module, input, output) -> None:
45 | """ Helper function to save the raw activations from the target layer """
46 | self.activations = output
47 |
48 | def assemble_ablation_scores(self,
49 | new_scores: list,
50 | original_score: float,
51 | ablated_channels: np.ndarray,
52 | number_of_channels: int) -> np.ndarray:
53 | """ Take the value from the channels that were ablated,
54 | and just set the original score for the channels that were skipped """
55 |
56 | index = 0
57 | result = []
58 | sorted_indices = np.argsort(ablated_channels)
59 | ablated_channels = ablated_channels[sorted_indices]
60 | new_scores = np.float32(new_scores)[sorted_indices]
61 |
62 | for i in range(number_of_channels):
63 | if index < len(ablated_channels) and ablated_channels[index] == i:
64 | weight = new_scores[index]
65 | index = index + 1
66 | else:
67 | weight = original_score
68 | result.append(weight)
69 |
70 | return result
71 |
72 | def get_cam_weights(self,
73 | input_tensor: torch.Tensor,
74 | target_layer: torch.nn.Module,
75 | targets: List[Callable],
76 | activations: torch.Tensor,
77 | grads: torch.Tensor) -> np.ndarray:
78 |
79 | # Do a forward pass, compute the target scores, and cache the
80 | # activations
81 | handle = target_layer.register_forward_hook(self.save_activation)
82 | with torch.no_grad():
83 | outputs = self.model(input_tensor)
84 | handle.remove()
85 | original_scores = np.float32(
86 | [target(output).cpu().item() for target, output in zip(targets, outputs)])
87 |
88 | # Replace the layer with the ablation layer.
89 | # When we finish, we will replace it back, so the
90 | # original model is unchanged.
91 | ablation_layer = self.ablation_layer
92 | replace_layer_recursive(self.model, target_layer, ablation_layer)
93 |
94 | number_of_channels = activations.shape[1]
95 | weights = []
96 | # This is a "gradient free" method, so we don't need gradients here.
97 | with torch.no_grad():
98 | # Loop over each of the batch images and ablate activations for it.
99 | for batch_index, (target, tensor) in enumerate(
100 | zip(targets, input_tensor)):
101 | new_scores = []
102 | batch_tensor = tensor.repeat(self.batch_size, 1, 1, 1)
103 |
104 | # Check which channels should be ablated. Normally this will be all channels,
105 | # But we can also try to speed this up by using a low
106 | # ratio_channels_to_ablate.
107 | channels_to_ablate = ablation_layer.activations_to_be_ablated(
108 | activations[batch_index, :], self.ratio_channels_to_ablate)
109 | number_channels_to_ablate = len(channels_to_ablate)
110 |
111 | for i in tqdm.tqdm(
112 | range(
113 | 0,
114 | number_channels_to_ablate,
115 | self.batch_size)):
116 | if i + self.batch_size > number_channels_to_ablate:
117 | batch_tensor = batch_tensor[:(
118 | number_channels_to_ablate - i)]
119 |
120 | # Change the state of the ablation layer so it ablates the next channels.
121 | # TBD: Move this into the ablation layer forward pass.
122 | ablation_layer.set_next_batch(
123 | input_batch_index = batch_index,
124 | activations = self.activations,
125 | num_channels_to_ablate = batch_tensor.size(0))
126 | score = [target(o).cpu().item()
127 | for o in self.model(batch_tensor)]
128 | new_scores.extend(score)
129 | ablation_layer.indices = ablation_layer.indices[batch_tensor.size(
130 | 0):]
131 |
132 | new_scores = self.assemble_ablation_scores(
133 | new_scores,
134 | original_scores[batch_index],
135 | channels_to_ablate,
136 | number_of_channels)
137 | weights.extend(new_scores)
138 |
139 | weights = np.float32(weights)
140 | weights = weights.reshape(activations.shape[:2])
141 | original_scores = original_scores[:, None]
142 | weights = (original_scores - weights) / original_scores
143 |
144 | # Replace the model back to the original state
145 | replace_layer_recursive(self.model, ablation_layer, target_layer)
146 | # Returning the weights from new_scores
147 | return weights
148 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/ablation_cam_multilayer.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | import tqdm
5 | from pytorch_grad_cam.base_cam import BaseCAM
6 |
7 |
8 | class AblationLayer(torch.nn.Module):
9 | def __init__(self, layer, reshape_transform, indices):
10 | super(AblationLayer, self).__init__()
11 |
12 | self.layer = layer
13 | self.reshape_transform = reshape_transform
14 | # The channels to zero out:
15 | self.indices = indices
16 |
17 | def forward(self, x):
18 | self.__call__(x)
19 |
20 | def __call__(self, x):
21 | output = self.layer(x)
22 |
23 | # Hack to work with ViT,
24 | # Since the activation channels are last and not first like in CNNs
25 | # Probably should remove it?
26 | if self.reshape_transform is not None:
27 | output = output.transpose(1, 2)
28 |
29 | for i in range(output.size(0)):
30 |
31 | # Commonly the minimum activation will be 0,
32 | # And then it makes sense to zero it out.
33 | # However depending on the architecture,
34 | # If the values can be negative, we use very negative values
35 | # to perform the ablation, deviating from the paper.
36 | if torch.min(output) == 0:
37 | output[i, self.indices[i], :] = 0
38 | else:
39 | ABLATION_VALUE = 1e5
40 | output[i, self.indices[i], :] = torch.min(
41 | output) - ABLATION_VALUE
42 |
43 | if self.reshape_transform is not None:
44 | output = output.transpose(2, 1)
45 |
46 | return output
47 |
48 |
49 | def replace_layer_recursive(model, old_layer, new_layer):
50 | for name, layer in model._modules.items():
51 | if layer == old_layer:
52 | model._modules[name] = new_layer
53 | return True
54 | elif replace_layer_recursive(layer, old_layer, new_layer):
55 | return True
56 | return False
57 |
58 |
59 | class AblationCAM(BaseCAM):
60 | def __init__(self, model, target_layers,
61 | reshape_transform=None):
62 | super(AblationCAM, self).__init__(model, target_layers,
63 | reshape_transform)
64 |
65 | if len(target_layers) > 1:
66 | print(
67 | "Warning. You are usign Ablation CAM with more than 1 layers. "
68 | "This is supported only if all layers have the same output shape")
69 |
70 | def set_ablation_layers(self):
71 | self.ablation_layers = []
72 | for target_layer in self.target_layers:
73 | ablation_layer = AblationLayer(target_layer,
74 | self.reshape_transform, indices=[])
75 | self.ablation_layers.append(ablation_layer)
76 | replace_layer_recursive(self.model, target_layer, ablation_layer)
77 |
78 | def unset_ablation_layers(self):
79 | # replace the model back to the original state
80 | for ablation_layer, target_layer in zip(
81 | self.ablation_layers, self.target_layers):
82 | replace_layer_recursive(self.model, ablation_layer, target_layer)
83 |
84 | def set_ablation_layer_batch_indices(self, indices):
85 | for ablation_layer in self.ablation_layers:
86 | ablation_layer.indices = indices
87 |
88 | def trim_ablation_layer_batch_indices(self, keep):
89 | for ablation_layer in self.ablation_layers:
90 | ablation_layer.indices = ablation_layer.indices[:keep]
91 |
92 | def get_cam_weights(self,
93 | input_tensor,
94 | target_category,
95 | activations,
96 | grads):
97 | with torch.no_grad():
98 | outputs = self.model(input_tensor).cpu().numpy()
99 | original_scores = []
100 | for i in range(input_tensor.size(0)):
101 | original_scores.append(outputs[i, target_category[i]])
102 | original_scores = np.float32(original_scores)
103 |
104 | self.set_ablation_layers()
105 |
106 | if hasattr(self, "batch_size"):
107 | BATCH_SIZE = self.batch_size
108 | else:
109 | BATCH_SIZE = 32
110 |
111 | number_of_channels = activations.shape[1]
112 | weights = []
113 |
114 | with torch.no_grad():
115 | # Iterate over the input batch
116 | for tensor, category in zip(input_tensor, target_category):
117 | batch_tensor = tensor.repeat(BATCH_SIZE, 1, 1, 1)
118 | for i in tqdm.tqdm(range(0, number_of_channels, BATCH_SIZE)):
119 | self.set_ablation_layer_batch_indices(
120 | list(range(i, i + BATCH_SIZE)))
121 |
122 | if i + BATCH_SIZE > number_of_channels:
123 | keep = number_of_channels - i
124 | batch_tensor = batch_tensor[:keep]
125 | self.trim_ablation_layer_batch_indices(self, keep)
126 | score = self.model(batch_tensor)[:, category].cpu().numpy()
127 | weights.extend(score)
128 |
129 | weights = np.float32(weights)
130 | weights = weights.reshape(activations.shape[:2])
131 | original_scores = original_scores[:, None]
132 | weights = (original_scores - weights) / original_scores
133 |
134 | # replace the model back to the original state
135 | self.unset_ablation_layers()
136 | return weights
137 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/ablation_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from collections import OrderedDict
3 | import numpy as np
4 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
5 |
6 |
7 | class AblationLayer(torch.nn.Module):
8 | def __init__(self):
9 | super(AblationLayer, self).__init__()
10 |
11 | def objectiveness_mask_from_svd(self, activations, threshold=0.01):
12 | """ Experimental method to get a binary mask to compare if the activation is worth ablating.
13 | The idea is to apply the EigenCAM method by doing PCA on the activations.
14 | Then we create a binary mask by comparing to a low threshold.
15 | Areas that are masked out, are probably not interesting anyway.
16 | """
17 |
18 | projection = get_2d_projection(activations[None, :])[0, :]
19 | projection = np.abs(projection)
20 | projection = projection - projection.min()
21 | projection = projection / projection.max()
22 | projection = projection > threshold
23 | return projection
24 |
25 | def activations_to_be_ablated(
26 | self,
27 | activations,
28 | ratio_channels_to_ablate=1.0):
29 | """ Experimental method to get a binary mask to compare if the activation is worth ablating.
30 | Create a binary CAM mask with objectiveness_mask_from_svd.
31 | Score each Activation channel, by seeing how much of its values are inside the mask.
32 | Then keep the top channels.
33 |
34 | """
35 | if ratio_channels_to_ablate == 1.0:
36 | self.indices = np.int32(range(activations.shape[0]))
37 | return self.indices
38 |
39 | projection = self.objectiveness_mask_from_svd(activations)
40 |
41 | scores = []
42 | for channel in activations:
43 | normalized = np.abs(channel)
44 | normalized = normalized - normalized.min()
45 | normalized = normalized / np.max(normalized)
46 | score = (projection * normalized).sum() / normalized.sum()
47 | scores.append(score)
48 | scores = np.float32(scores)
49 |
50 | indices = list(np.argsort(scores))
51 | high_score_indices = indices[::-
52 | 1][: int(len(indices) *
53 | ratio_channels_to_ablate)]
54 | low_score_indices = indices[: int(
55 | len(indices) * ratio_channels_to_ablate)]
56 | self.indices = np.int32(high_score_indices + low_score_indices)
57 | return self.indices
58 |
59 | def set_next_batch(
60 | self,
61 | input_batch_index,
62 | activations,
63 | num_channels_to_ablate):
64 | """ This creates the next batch of activations from the layer.
65 | Just take corresponding batch member from activations, and repeat it num_channels_to_ablate times.
66 | """
67 | self.activations = activations[input_batch_index, :, :, :].clone(
68 | ).unsqueeze(0).repeat(num_channels_to_ablate, 1, 1, 1)
69 |
70 | def __call__(self, x):
71 | output = self.activations
72 | for i in range(output.size(0)):
73 | # Commonly the minimum activation will be 0,
74 | # And then it makes sense to zero it out.
75 | # However depending on the architecture,
76 | # If the values can be negative, we use very negative values
77 | # to perform the ablation, deviating from the paper.
78 | if torch.min(output) == 0:
79 | output[i, self.indices[i], :] = 0
80 | else:
81 | ABLATION_VALUE = 1e7
82 | output[i, self.indices[i], :] = torch.min(
83 | output) - ABLATION_VALUE
84 |
85 | return output
86 |
87 |
88 | class AblationLayerVit(AblationLayer):
89 | def __init__(self):
90 | super(AblationLayerVit, self).__init__()
91 |
92 | def __call__(self, x):
93 | output = self.activations
94 | output = output.transpose(1, len(output.shape) - 1)
95 | for i in range(output.size(0)):
96 |
97 | # Commonly the minimum activation will be 0,
98 | # And then it makes sense to zero it out.
99 | # However depending on the architecture,
100 | # If the values can be negative, we use very negative values
101 | # to perform the ablation, deviating from the paper.
102 | if torch.min(output) == 0:
103 | output[i, self.indices[i], :] = 0
104 | else:
105 | ABLATION_VALUE = 1e7
106 | output[i, self.indices[i], :] = torch.min(
107 | output) - ABLATION_VALUE
108 |
109 | output = output.transpose(len(output.shape) - 1, 1)
110 |
111 | return output
112 |
113 | def set_next_batch(
114 | self,
115 | input_batch_index,
116 | activations,
117 | num_channels_to_ablate):
118 | """ This creates the next batch of activations from the layer.
119 | Just take corresponding batch member from activations, and repeat it num_channels_to_ablate times.
120 | """
121 | repeat_params = [num_channels_to_ablate] + \
122 | len(activations.shape[:-1]) * [1]
123 | self.activations = activations[input_batch_index, :, :].clone(
124 | ).unsqueeze(0).repeat(*repeat_params)
125 |
126 |
127 | class AblationLayerFasterRCNN(AblationLayer):
128 | def __init__(self):
129 | super(AblationLayerFasterRCNN, self).__init__()
130 |
131 | def set_next_batch(
132 | self,
133 | input_batch_index,
134 | activations,
135 | num_channels_to_ablate):
136 | """ Extract the next batch member from activations,
137 | and repeat it num_channels_to_ablate times.
138 | """
139 | self.activations = OrderedDict()
140 | for key, value in activations.items():
141 | fpn_activation = value[input_batch_index,
142 | :, :, :].clone().unsqueeze(0)
143 | self.activations[key] = fpn_activation.repeat(
144 | num_channels_to_ablate, 1, 1, 1)
145 |
146 | def __call__(self, x):
147 | result = self.activations
148 | layers = {0: '0', 1: '1', 2: '2', 3: '3', 4: 'pool'}
149 | num_channels_to_ablate = result['pool'].size(0)
150 | for i in range(num_channels_to_ablate):
151 | pyramid_layer = int(self.indices[i] / 256)
152 | index_in_pyramid_layer = int(self.indices[i] % 256)
153 | result[layers[pyramid_layer]][i,
154 | index_in_pyramid_layer, :, :] = -1000
155 | return result
156 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/activations_and_gradients.py:
--------------------------------------------------------------------------------
1 | class ActivationsAndGradients:
2 | """ Class for extracting activations and
3 | registering gradients from targetted intermediate layers """
4 |
5 | def __init__(self, model, target_layers, reshape_transform, detach=True):
6 | self.model = model
7 | self.gradients = []
8 | self.activations = []
9 | self.reshape_transform = reshape_transform
10 | self.detach = detach
11 | self.handles = []
12 | for target_layer in target_layers:
13 | self.handles.append(
14 | target_layer.register_forward_hook(self.save_activation))
15 | # Because of https://github.com/pytorch/pytorch/issues/61519,
16 | # we don't use backward hook to record gradients.
17 | self.handles.append(
18 | target_layer.register_forward_hook(self.save_gradient))
19 |
20 | def save_activation(self, module, input, output):
21 | activation = output
22 | if self.detach:
23 | if self.reshape_transform is not None:
24 | activation = self.reshape_transform(activation)
25 | self.activations.append(activation.cpu().detach())
26 | else:
27 | self.activations.append(activation)
28 |
29 | def save_gradient(self, module, input, output):
30 | if not hasattr(output, "requires_grad") or not output.requires_grad:
31 | # You can only register hooks on tensor requires grad.
32 | return
33 |
34 | # Gradients are computed in reverse order
35 | def _store_grad(grad):
36 | if self.detach:
37 | if self.reshape_transform is not None:
38 | grad = self.reshape_transform(grad)
39 | self.gradients = [grad.cpu().detach()] + self.gradients
40 | else:
41 | self.gradients = [grad] + self.gradients
42 |
43 | output.register_hook(_store_grad)
44 |
45 | def __call__(self, x):
46 | self.gradients = []
47 | self.activations = []
48 | return self.model(x)
49 |
50 | def release(self):
51 | for handle in self.handles:
52 | handle.remove()
53 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/base_cam.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, List, Optional, Tuple
2 |
3 | import numpy as np
4 | import torch
5 | import ttach as tta
6 |
7 | from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients
8 | from pytorch_grad_cam.utils.image import scale_cam_image
9 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
10 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
11 |
12 |
13 | class BaseCAM:
14 | def __init__(
15 | self,
16 | model: torch.nn.Module,
17 | target_layers: List[torch.nn.Module],
18 | reshape_transform: Callable = None,
19 | compute_input_gradient: bool = False,
20 | uses_gradients: bool = True,
21 | tta_transforms: Optional[tta.Compose] = None,
22 | detach: bool = True,
23 | ) -> None:
24 | self.model = model.eval()
25 | self.target_layers = target_layers
26 |
27 | # Use the same device as the model.
28 | self.device = next(self.model.parameters()).device
29 | if 'hpu' in str(self.device):
30 | try:
31 | import habana_frameworks.torch.core as htcore
32 | except ImportError as error:
33 | error.msg = f"Could not import habana_frameworks.torch.core. {error.msg}."
34 | raise error
35 | self.__htcore = htcore
36 | self.reshape_transform = reshape_transform
37 | self.compute_input_gradient = compute_input_gradient
38 | self.uses_gradients = uses_gradients
39 | if tta_transforms is None:
40 | self.tta_transforms = tta.Compose(
41 | [
42 | tta.HorizontalFlip(),
43 | tta.Multiply(factors=[0.9, 1, 1.1]),
44 | ]
45 | )
46 | else:
47 | self.tta_transforms = tta_transforms
48 |
49 | self.detach = detach
50 | self.activations_and_grads = ActivationsAndGradients(self.model, target_layers, reshape_transform, self.detach)
51 |
52 | """ Get a vector of weights for every channel in the target layer.
53 | Methods that return weights channels,
54 | will typically need to only implement this function. """
55 |
56 | def get_cam_weights(
57 | self,
58 | input_tensor: torch.Tensor,
59 | target_layers: List[torch.nn.Module],
60 | targets: List[torch.nn.Module],
61 | activations: torch.Tensor,
62 | grads: torch.Tensor,
63 | ) -> np.ndarray:
64 | raise Exception("Not Implemented")
65 |
66 | def get_cam_image(
67 | self,
68 | input_tensor: torch.Tensor,
69 | target_layer: torch.nn.Module,
70 | targets: List[torch.nn.Module],
71 | activations: torch.Tensor,
72 | grads: torch.Tensor,
73 | eigen_smooth: bool = False,
74 | ) -> np.ndarray:
75 | weights = self.get_cam_weights(input_tensor, target_layer, targets, activations, grads)
76 | if isinstance(activations, torch.Tensor):
77 | activations = activations.cpu().detach().numpy()
78 | # 2D conv
79 | if len(activations.shape) == 4:
80 | weighted_activations = weights[:, :, None, None] * activations
81 | # 3D conv
82 | elif len(activations.shape) == 5:
83 | weighted_activations = weights[:, :, None, None, None] * activations
84 | else:
85 | raise ValueError(f"Invalid activation shape. Get {len(activations.shape)}.")
86 |
87 | if eigen_smooth:
88 | cam = get_2d_projection(weighted_activations)
89 | else:
90 | cam = weighted_activations.sum(axis=1)
91 | return cam
92 |
93 | def forward(
94 | self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool = False
95 | ) -> np.ndarray:
96 | input_tensor = input_tensor.to(self.device)
97 |
98 | if self.compute_input_gradient:
99 | input_tensor = torch.autograd.Variable(input_tensor, requires_grad=True)
100 |
101 | self.outputs = outputs = self.activations_and_grads(input_tensor)
102 |
103 | if targets is None:
104 | target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1)
105 | targets = [ClassifierOutputTarget(category) for category in target_categories]
106 |
107 | if self.uses_gradients:
108 | self.model.zero_grad()
109 | loss = sum([target(output) for target, output in zip(targets, outputs)])
110 | if self.detach:
111 | loss.backward(retain_graph=True)
112 | else:
113 | # keep the computational graph, create_graph = True is needed for hvp
114 | torch.autograd.grad(loss, input_tensor, retain_graph = True, create_graph = True)
115 | # When using the following loss.backward() method, a warning is raised: "UserWarning: Using backward() with create_graph=True will create a reference cycle"
116 | # loss.backward(retain_graph=True, create_graph=True)
117 | if 'hpu' in str(self.device):
118 | self.__htcore.mark_step()
119 |
120 | # In most of the saliency attribution papers, the saliency is
121 | # computed with a single target layer.
122 | # Commonly it is the last convolutional layer.
123 | # Here we support passing a list with multiple target layers.
124 | # It will compute the saliency image for every image,
125 | # and then aggregate them (with a default mean aggregation).
126 | # This gives you more flexibility in case you just want to
127 | # use all conv layers for example, all Batchnorm layers,
128 | # or something else.
129 | cam_per_layer = self.compute_cam_per_layer(input_tensor, targets, eigen_smooth)
130 | return self.aggregate_multi_layers(cam_per_layer)
131 |
132 | def get_target_width_height(self, input_tensor: torch.Tensor) -> Tuple[int, int]:
133 | if len(input_tensor.shape) == 4:
134 | width, height = input_tensor.size(-1), input_tensor.size(-2)
135 | return width, height
136 | elif len(input_tensor.shape) == 5:
137 | depth, width, height = input_tensor.size(-1), input_tensor.size(-2), input_tensor.size(-3)
138 | return depth, width, height
139 | else:
140 | raise ValueError("Invalid input_tensor shape. Only 2D or 3D images are supported.")
141 |
142 | def compute_cam_per_layer(
143 | self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool
144 | ) -> np.ndarray:
145 | if self.detach:
146 | activations_list = [a.cpu().data.numpy() for a in self.activations_and_grads.activations]
147 | grads_list = [g.cpu().data.numpy() for g in self.activations_and_grads.gradients]
148 | else:
149 | activations_list = [a for a in self.activations_and_grads.activations]
150 | grads_list = [g for g in self.activations_and_grads.gradients]
151 | target_size = self.get_target_width_height(input_tensor)
152 |
153 | cam_per_target_layer = []
154 | # Loop over the saliency image from every layer
155 | for i in range(len(self.target_layers)):
156 | target_layer = self.target_layers[i]
157 | layer_activations = None
158 | layer_grads = None
159 | if i < len(activations_list):
160 | layer_activations = activations_list[i]
161 | if i < len(grads_list):
162 | layer_grads = grads_list[i]
163 |
164 | cam = self.get_cam_image(input_tensor, target_layer, targets, layer_activations, layer_grads, eigen_smooth)
165 | cam = np.maximum(cam, 0)
166 | scaled = scale_cam_image(cam, target_size)
167 | cam_per_target_layer.append(scaled[:, None, :])
168 |
169 | return cam_per_target_layer
170 |
171 | def aggregate_multi_layers(self, cam_per_target_layer: np.ndarray) -> np.ndarray:
172 | cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1)
173 | cam_per_target_layer = np.maximum(cam_per_target_layer, 0)
174 | result = np.mean(cam_per_target_layer, axis=1)
175 | return scale_cam_image(result)
176 |
177 | def forward_augmentation_smoothing(
178 | self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool = False
179 | ) -> np.ndarray:
180 | cams = []
181 | for transform in self.tta_transforms:
182 | augmented_tensor = transform.augment_image(input_tensor)
183 | cam = self.forward(augmented_tensor, targets, eigen_smooth)
184 |
185 | # The ttach library expects a tensor of size BxCxHxW
186 | cam = cam[:, None, :, :]
187 | cam = torch.from_numpy(cam)
188 | cam = transform.deaugment_mask(cam)
189 |
190 | # Back to numpy float32, HxW
191 | cam = cam.numpy()
192 | cam = cam[:, 0, :, :]
193 | cams.append(cam)
194 |
195 | cam = np.mean(np.float32(cams), axis=0)
196 | return cam
197 |
198 | def __call__(
199 | self,
200 | input_tensor: torch.Tensor,
201 | targets: List[torch.nn.Module] = None,
202 | aug_smooth: bool = False,
203 | eigen_smooth: bool = False,
204 | ) -> np.ndarray:
205 | # Smooth the CAM result with test time augmentation
206 | if aug_smooth is True:
207 | return self.forward_augmentation_smoothing(input_tensor, targets, eigen_smooth)
208 |
209 | return self.forward(input_tensor, targets, eigen_smooth)
210 |
211 | def __del__(self):
212 | self.activations_and_grads.release()
213 |
214 | def __enter__(self):
215 | return self
216 |
217 | def __exit__(self, exc_type, exc_value, exc_tb):
218 | self.activations_and_grads.release()
219 | if isinstance(exc_value, IndexError):
220 | # Handle IndexError here...
221 | print(f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}")
222 | return True
223 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/eigen_cam.py:
--------------------------------------------------------------------------------
1 | from pytorch_grad_cam.base_cam import BaseCAM
2 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
3 |
4 | # https://arxiv.org/abs/2008.00299
5 |
6 |
7 | class EigenCAM(BaseCAM):
8 | def __init__(self, model, target_layers,
9 | reshape_transform=None):
10 | super(EigenCAM, self).__init__(model,
11 | target_layers,
12 | reshape_transform,
13 | uses_gradients=False)
14 |
15 | def get_cam_image(self,
16 | input_tensor,
17 | target_layer,
18 | target_category,
19 | activations,
20 | grads,
21 | eigen_smooth):
22 | return get_2d_projection(activations)
23 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/eigen_grad_cam.py:
--------------------------------------------------------------------------------
1 | from pytorch_grad_cam.base_cam import BaseCAM
2 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
3 |
4 | # Like Eigen CAM: https://arxiv.org/abs/2008.00299
5 | # But multiply the activations x gradients
6 |
7 |
8 | class EigenGradCAM(BaseCAM):
9 | def __init__(self, model, target_layers,
10 | reshape_transform=None):
11 | super(EigenGradCAM, self).__init__(model, target_layers,
12 | reshape_transform)
13 |
14 | def get_cam_image(self,
15 | input_tensor,
16 | target_layer,
17 | target_category,
18 | activations,
19 | grads,
20 | eigen_smooth):
21 | return get_2d_projection(grads * activations)
22 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/feature_factorization/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/pytorch_grad_cam/feature_factorization/__init__.py
--------------------------------------------------------------------------------
/pytorch_grad_cam/feature_factorization/deep_feature_factorization.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | import torch
4 | from typing import Callable, List, Tuple, Optional
5 | from sklearn.decomposition import NMF
6 | from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients
7 | from pytorch_grad_cam.utils.image import scale_cam_image, create_labels_legend, show_factorization_on_image
8 |
9 |
10 | def dff(activations: np.ndarray, n_components: int = 5):
11 | """ Compute Deep Feature Factorization on a 2d Activations tensor.
12 |
13 | :param activations: A numpy array of shape batch x channels x height x width
14 | :param n_components: The number of components for the non negative matrix factorization
15 | :returns: A tuple of the concepts (a numpy array with shape channels x components),
16 | and the explanation heatmaps (a numpy arary with shape batch x height x width)
17 | """
18 |
19 | batch_size, channels, h, w = activations.shape
20 | reshaped_activations = activations.transpose((1, 0, 2, 3))
21 | reshaped_activations[np.isnan(reshaped_activations)] = 0
22 | reshaped_activations = reshaped_activations.reshape(
23 | reshaped_activations.shape[0], -1)
24 | offset = reshaped_activations.min(axis=-1)
25 | reshaped_activations = reshaped_activations - offset[:, None]
26 |
27 | model = NMF(n_components=n_components, init='random', random_state=0)
28 | W = model.fit_transform(reshaped_activations)
29 | H = model.components_
30 | concepts = W + offset[:, None]
31 | explanations = H.reshape(n_components, batch_size, h, w)
32 | explanations = explanations.transpose((1, 0, 2, 3))
33 | return concepts, explanations
34 |
35 |
36 | class DeepFeatureFactorization:
37 | """ Deep Feature Factorization: https://arxiv.org/abs/1806.10206
38 | This gets a model andcomputes the 2D activations for a target layer,
39 | and computes Non Negative Matrix Factorization on the activations.
40 |
41 | Optionally it runs a computation on the concept embeddings,
42 | like running a classifier on them.
43 |
44 | The explanation heatmaps are scalled to the range [0, 1]
45 | and to the input tensor width and height.
46 | """
47 |
48 | def __init__(self,
49 | model: torch.nn.Module,
50 | target_layer: torch.nn.Module,
51 | reshape_transform: Callable = None,
52 | computation_on_concepts=None
53 | ):
54 | self.model = model
55 | self.computation_on_concepts = computation_on_concepts
56 | self.activations_and_grads = ActivationsAndGradients(
57 | self.model, [target_layer], reshape_transform)
58 |
59 | def __call__(self,
60 | input_tensor: torch.Tensor,
61 | n_components: int = 16):
62 | batch_size, channels, h, w = input_tensor.size()
63 | _ = self.activations_and_grads(input_tensor)
64 |
65 | with torch.no_grad():
66 | activations = self.activations_and_grads.activations[0].cpu(
67 | ).numpy()
68 |
69 | concepts, explanations = dff(activations, n_components=n_components)
70 |
71 | processed_explanations = []
72 |
73 | for batch in explanations:
74 | processed_explanations.append(scale_cam_image(batch, (w, h)))
75 |
76 | if self.computation_on_concepts:
77 | with torch.no_grad():
78 | concept_tensors = torch.from_numpy(
79 | np.float32(concepts).transpose((1, 0)))
80 | concept_outputs = self.computation_on_concepts(
81 | concept_tensors).cpu().numpy()
82 | return concepts, processed_explanations, concept_outputs
83 | else:
84 | return concepts, processed_explanations
85 |
86 | def __del__(self):
87 | self.activations_and_grads.release()
88 |
89 | def __exit__(self, exc_type, exc_value, exc_tb):
90 | self.activations_and_grads.release()
91 | if isinstance(exc_value, IndexError):
92 | # Handle IndexError here...
93 | print(
94 | f"An exception occurred in ActivationSummary with block: {exc_type}. Message: {exc_value}")
95 | return True
96 |
97 |
98 | def run_dff_on_image(model: torch.nn.Module,
99 | target_layer: torch.nn.Module,
100 | classifier: torch.nn.Module,
101 | img_pil: Image,
102 | img_tensor: torch.Tensor,
103 | reshape_transform=Optional[Callable],
104 | n_components: int = 5,
105 | top_k: int = 2) -> np.ndarray:
106 | """ Helper function to create a Deep Feature Factorization visualization for a single image.
107 | TBD: Run this on a batch with several images.
108 | """
109 | rgb_img_float = np.array(img_pil) / 255
110 | dff = DeepFeatureFactorization(model=model,
111 | reshape_transform=reshape_transform,
112 | target_layer=target_layer,
113 | computation_on_concepts=classifier)
114 |
115 | concepts, batch_explanations, concept_outputs = dff(
116 | img_tensor[None, :], n_components)
117 |
118 | concept_outputs = torch.softmax(
119 | torch.from_numpy(concept_outputs),
120 | axis=-1).numpy()
121 | concept_label_strings = create_labels_legend(concept_outputs,
122 | labels=model.config.id2label,
123 | top_k=top_k)
124 | visualization = show_factorization_on_image(
125 | rgb_img_float,
126 | batch_explanations[0],
127 | image_weight=0.3,
128 | concept_labels=concept_label_strings)
129 |
130 | result = np.hstack((np.array(img_pil), visualization))
131 | return result
132 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/fem.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pytorch_grad_cam.base_cam import BaseCAM
3 |
4 | """Feature Explanation Method.
5 | Fuad, K. A. A., Martin, P. E., Giot, R., Bourqui, R., Benois-Pineau, J., & Zemmari, A. (2020, November). Features understanding in 3D CNNS for actions recognition in video. In 2020 Tenth International Conference on Image Processing Theory, Tools and Applications (IPTA) (pp. 1-6). IEEE.
6 | https://hal.science/hal-02963298/document
7 | """
8 |
9 | class FEM(BaseCAM):
10 | def __init__(self, model, target_layers,
11 | reshape_transform=None, k=2):
12 | super(FEM, self).__init__(model,
13 | target_layers,
14 | reshape_transform,
15 | uses_gradients=False)
16 | self.k = k
17 |
18 | def get_cam_image(self,
19 | input_tensor,
20 | target_layer,
21 | target_category,
22 | activations,
23 | grads,
24 | eigen_smooth):
25 |
26 |
27 | # 2D image
28 | if len(activations.shape) == 4:
29 | axis = (2, 3)
30 | # 3D image
31 | elif len(activations.shape) == 5:
32 | axis = (2, 3, 4)
33 | else:
34 | raise ValueError("Invalid activations shape."
35 | "Shape of activations should be 4 (2D image) or 5 (3D image).")
36 | means = np.mean(activations, axis=axis)
37 | stds = np.std(activations, axis=axis)
38 | # k sigma rule:
39 | # Add extra dimensions to match activations shape
40 | th = means + self.k * stds
41 | weights_shape = list(means.shape) + [1] * len(axis)
42 | th = th.reshape(weights_shape)
43 | binary_mask = activations > th
44 | weights = binary_mask.mean(axis=axis)
45 | return (weights.reshape(weights_shape) * activations).sum(axis=1)
--------------------------------------------------------------------------------
/pytorch_grad_cam/finer_cam.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from typing import List, Callable
4 | from pytorch_grad_cam.base_cam import BaseCAM
5 | from pytorch_grad_cam import GradCAM
6 | from pytorch_grad_cam.utils.model_targets import FinerWeightedTarget
7 |
8 | class FinerCAM:
9 | def __init__(self, model: torch.nn.Module, target_layers: List[torch.nn.Module], reshape_transform: Callable = None, base_method=GradCAM):
10 | self.base_cam = base_method(model, target_layers, reshape_transform)
11 | self.compute_input_gradient = self.base_cam.compute_input_gradient
12 | self.uses_gradients = self.base_cam.uses_gradients
13 |
14 | def __call__(self, *args, **kwargs):
15 | return self.forward(*args, **kwargs)
16 |
17 | def forward(self, input_tensor: torch.Tensor, targets: List[torch.nn.Module] = None, eigen_smooth: bool = False,
18 | alpha: float = 1, comparison_categories: List[int] = [1, 2, 3], target_idx: int = None
19 | ) -> np.ndarray:
20 | input_tensor = input_tensor.to(self.base_cam.device)
21 |
22 | if self.compute_input_gradient:
23 | input_tensor = torch.autograd.Variable(input_tensor, requires_grad=True)
24 |
25 | outputs = self.base_cam.activations_and_grads(input_tensor)
26 |
27 | if targets is None:
28 | output_data = outputs.detach().cpu().numpy()
29 | target_logits = np.max(output_data, axis=-1) if target_idx is None else output_data[:, target_idx]
30 | # Sort class indices for each sample based on the absolute difference
31 | # between the class scores and the target logit, in ascending order.
32 | # The most similar classes (smallest difference) appear first.
33 | sorted_indices = np.argsort(np.abs(output_data - target_logits[:, None]), axis=-1)
34 | targets = [FinerWeightedTarget(int(sorted_indices[i, 0]),
35 | [int(sorted_indices[i, idx]) for idx in comparison_categories],
36 | alpha)
37 | for i in range(output_data.shape[0])]
38 |
39 | if self.uses_gradients:
40 | self.base_cam.model.zero_grad()
41 | loss = sum([target(output) for target, output in zip(targets, outputs)])
42 | if self.base_cam.detach:
43 | loss.backward(retain_graph=True)
44 | else:
45 | # keep the computational graph, create_graph = True is needed for hvp
46 | torch.autograd.grad(loss, input_tensor, retain_graph = True, create_graph = True)
47 | # When using the following loss.backward() method, a warning is raised: "UserWarning: Using backward() with create_graph=True will create a reference cycle"
48 | # loss.backward(retain_graph=True, create_graph=True)
49 | if 'hpu' in str(self.base_cam.device):
50 | self.base_cam.__htcore.mark_step()
51 |
52 | cam_per_layer = self.base_cam.compute_cam_per_layer(input_tensor, targets, eigen_smooth)
53 | return self.base_cam.aggregate_multi_layers(cam_per_layer)
--------------------------------------------------------------------------------
/pytorch_grad_cam/fullgrad_cam.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from pytorch_grad_cam.base_cam import BaseCAM
4 | from pytorch_grad_cam.utils.find_layers import find_layer_predicate_recursive
5 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
6 | from pytorch_grad_cam.utils.image import scale_accross_batch_and_channels, scale_cam_image
7 |
8 | # https://arxiv.org/abs/1905.00780
9 |
10 |
11 | class FullGrad(BaseCAM):
12 | def __init__(self, model, target_layers,
13 | reshape_transform=None):
14 | if len(target_layers) > 0:
15 | print(
16 | "Warning: target_layers is ignored in FullGrad. All bias layers will be used instead")
17 |
18 | def layer_with_2D_bias(layer):
19 | bias_target_layers = [torch.nn.Conv2d, torch.nn.BatchNorm2d]
20 | if type(layer) in bias_target_layers and layer.bias is not None:
21 | return True
22 | return False
23 | target_layers = find_layer_predicate_recursive(
24 | model, layer_with_2D_bias)
25 | super(
26 | FullGrad,
27 | self).__init__(
28 | model,
29 | target_layers,
30 | reshape_transform,
31 | compute_input_gradient=True)
32 | self.bias_data = [self.get_bias_data(
33 | layer).cpu().numpy() for layer in target_layers]
34 |
35 | def get_bias_data(self, layer):
36 | # Borrowed from official paper impl:
37 | # https://github.com/idiap/fullgrad-saliency/blob/master/saliency/tensor_extractor.py#L47
38 | if isinstance(layer, torch.nn.BatchNorm2d):
39 | bias = - (layer.running_mean * layer.weight
40 | / torch.sqrt(layer.running_var + layer.eps)) + layer.bias
41 | return bias.data
42 | else:
43 | return layer.bias.data
44 |
45 | def compute_cam_per_layer(
46 | self,
47 | input_tensor,
48 | target_category,
49 | eigen_smooth):
50 | input_grad = input_tensor.grad.data.cpu().numpy()
51 | grads_list = [g.cpu().data.numpy() for g in
52 | self.activations_and_grads.gradients]
53 | cam_per_target_layer = []
54 | target_size = self.get_target_width_height(input_tensor)
55 |
56 | gradient_multiplied_input = input_grad * input_tensor.data.cpu().numpy()
57 | gradient_multiplied_input = np.abs(gradient_multiplied_input)
58 | gradient_multiplied_input = scale_accross_batch_and_channels(
59 | gradient_multiplied_input,
60 | target_size)
61 | cam_per_target_layer.append(gradient_multiplied_input)
62 |
63 | # Loop over the saliency image from every layer
64 | assert(len(self.bias_data) == len(grads_list))
65 | for bias, grads in zip(self.bias_data, grads_list):
66 | bias = bias[None, :, None, None]
67 | # In the paper they take the absolute value,
68 | # but possibily taking only the positive gradients will work
69 | # better.
70 | bias_grad = np.abs(bias * grads)
71 | result = scale_accross_batch_and_channels(
72 | bias_grad, target_size)
73 | result = np.sum(result, axis=1)
74 | cam_per_target_layer.append(result[:, None, :])
75 | cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1)
76 | if eigen_smooth:
77 | # Resize to a smaller image, since this method typically has a very large number of channels,
78 | # and then consumes a lot of memory
79 | cam_per_target_layer = scale_accross_batch_and_channels(
80 | cam_per_target_layer, (target_size[0] // 8, target_size[1] // 8))
81 | cam_per_target_layer = get_2d_projection(cam_per_target_layer)
82 | cam_per_target_layer = cam_per_target_layer[:, None, :, :]
83 | cam_per_target_layer = scale_accross_batch_and_channels(
84 | cam_per_target_layer,
85 | target_size)
86 | else:
87 | cam_per_target_layer = np.sum(
88 | cam_per_target_layer, axis=1)[:, None, :]
89 |
90 | return cam_per_target_layer
91 |
92 | def aggregate_multi_layers(self, cam_per_target_layer):
93 | result = np.sum(cam_per_target_layer, axis=1)
94 | return scale_cam_image(result)
95 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/grad_cam.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from pytorch_grad_cam.base_cam import BaseCAM
4 |
5 |
6 | class GradCAM(BaseCAM):
7 | def __init__(self, model, target_layers,
8 | reshape_transform=None):
9 | super(
10 | GradCAM,
11 | self).__init__(
12 | model,
13 | target_layers,
14 | reshape_transform)
15 |
16 | def get_cam_weights(self,
17 | input_tensor,
18 | target_layer,
19 | target_category,
20 | activations,
21 | grads):
22 | # 2D image
23 | if len(grads.shape) == 4:
24 | return np.mean(grads, axis=(2, 3))
25 |
26 | # 3D image
27 | elif len(grads.shape) == 5:
28 | return np.mean(grads, axis=(2, 3, 4))
29 |
30 | else:
31 | raise ValueError("Invalid grads shape."
32 | "Shape of grads should be 4 (2D image) or 5 (3D image).")
--------------------------------------------------------------------------------
/pytorch_grad_cam/grad_cam_elementwise.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pytorch_grad_cam.base_cam import BaseCAM
3 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
4 |
5 |
6 | class GradCAMElementWise(BaseCAM):
7 | def __init__(self, model, target_layers,
8 | reshape_transform=None):
9 | super(
10 | GradCAMElementWise,
11 | self).__init__(
12 | model,
13 | target_layers,
14 | reshape_transform)
15 |
16 | def get_cam_image(self,
17 | input_tensor,
18 | target_layer,
19 | target_category,
20 | activations,
21 | grads,
22 | eigen_smooth):
23 | elementwise_activations = np.maximum(grads * activations, 0)
24 |
25 | if eigen_smooth:
26 | cam = get_2d_projection(elementwise_activations)
27 | else:
28 | cam = elementwise_activations.sum(axis=1)
29 | return cam
30 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/grad_cam_plusplus.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pytorch_grad_cam.base_cam import BaseCAM
3 |
4 | # https://arxiv.org/abs/1710.11063
5 |
6 |
7 | class GradCAMPlusPlus(BaseCAM):
8 | def __init__(self, model, target_layers,
9 | reshape_transform=None):
10 | super(GradCAMPlusPlus, self).__init__(model, target_layers,
11 | reshape_transform)
12 |
13 | def get_cam_weights(self,
14 | input_tensor,
15 | target_layers,
16 | target_category,
17 | activations,
18 | grads):
19 | grads_power_2 = grads**2
20 | grads_power_3 = grads_power_2 * grads
21 | # Equation 19 in https://arxiv.org/abs/1710.11063
22 | sum_activations = np.sum(activations, axis=(2, 3))
23 | eps = 0.000001
24 | aij = grads_power_2 / (2 * grads_power_2 +
25 | sum_activations[:, :, None, None] * grads_power_3 + eps)
26 | # Now bring back the ReLU from eq.7 in the paper,
27 | # And zero out aijs where the activations are 0
28 | aij = np.where(grads != 0, aij, 0)
29 |
30 | weights = np.maximum(grads, 0) * aij
31 | weights = np.sum(weights, axis=(2, 3))
32 | return weights
33 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/guided_backprop.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.autograd import Function
4 | from pytorch_grad_cam.utils.find_layers import replace_all_layer_type_recursive
5 |
6 |
7 | class GuidedBackpropReLU(Function):
8 | @staticmethod
9 | def forward(self, input_img):
10 | positive_mask = (input_img > 0).type_as(input_img)
11 | output = torch.addcmul(
12 | torch.zeros(
13 | input_img.size()).type_as(input_img),
14 | input_img,
15 | positive_mask)
16 | self.save_for_backward(input_img, output)
17 | return output
18 |
19 | @staticmethod
20 | def backward(self, grad_output):
21 | input_img, output = self.saved_tensors
22 | grad_input = None
23 |
24 | positive_mask_1 = (input_img > 0).type_as(grad_output)
25 | positive_mask_2 = (grad_output > 0).type_as(grad_output)
26 | grad_input = torch.addcmul(
27 | torch.zeros(
28 | input_img.size()).type_as(input_img),
29 | torch.addcmul(
30 | torch.zeros(
31 | input_img.size()).type_as(input_img),
32 | grad_output,
33 | positive_mask_1),
34 | positive_mask_2)
35 | return grad_input
36 |
37 |
38 | class GuidedBackpropReLUasModule(torch.nn.Module):
39 | def __init__(self):
40 | super(GuidedBackpropReLUasModule, self).__init__()
41 |
42 | def forward(self, input_img):
43 | return GuidedBackpropReLU.apply(input_img)
44 |
45 |
46 | class GuidedBackpropReLUModel:
47 | def __init__(self, model, device):
48 | self.model = model
49 | self.model.eval()
50 | self.device = next(self.model.parameters()).device
51 |
52 | def forward(self, input_img):
53 | return self.model(input_img)
54 |
55 | def recursive_replace_relu_with_guidedrelu(self, module_top):
56 |
57 | for idx, module in module_top._modules.items():
58 | self.recursive_replace_relu_with_guidedrelu(module)
59 | if module.__class__.__name__ == 'ReLU':
60 | module_top._modules[idx] = GuidedBackpropReLU.apply
61 | print("b")
62 |
63 | def recursive_replace_guidedrelu_with_relu(self, module_top):
64 | try:
65 | for idx, module in module_top._modules.items():
66 | self.recursive_replace_guidedrelu_with_relu(module)
67 | if module == GuidedBackpropReLU.apply:
68 | module_top._modules[idx] = torch.nn.ReLU()
69 | except BaseException:
70 | pass
71 |
72 | def __call__(self, input_img, target_category=None):
73 | replace_all_layer_type_recursive(self.model,
74 | torch.nn.ReLU,
75 | GuidedBackpropReLUasModule())
76 |
77 |
78 | input_img = input_img.to(self.device)
79 |
80 | input_img = input_img.requires_grad_(True)
81 |
82 | output = self.forward(input_img)
83 |
84 | if target_category is None:
85 | target_category = np.argmax(output.cpu().data.numpy())
86 |
87 | loss = output[0, target_category]
88 | loss.backward(retain_graph=True)
89 |
90 | output = input_img.grad.cpu().data.numpy()
91 | output = output[0, :, :, :]
92 | output = output.transpose((1, 2, 0))
93 |
94 | replace_all_layer_type_recursive(self.model,
95 | GuidedBackpropReLUasModule,
96 | torch.nn.ReLU())
97 |
98 | return output
99 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/hirescam.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pytorch_grad_cam.base_cam import BaseCAM
3 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
4 |
5 |
6 | class HiResCAM(BaseCAM):
7 | def __init__(self, model, target_layers,
8 | reshape_transform=None):
9 | super(
10 | HiResCAM,
11 | self).__init__(
12 | model,
13 | target_layers,
14 | reshape_transform)
15 |
16 | def get_cam_image(self,
17 | input_tensor,
18 | target_layer,
19 | target_category,
20 | activations,
21 | grads,
22 | eigen_smooth):
23 | elementwise_activations = grads * activations
24 |
25 | if eigen_smooth:
26 | print(
27 | "Warning: HiResCAM's faithfulness guarantees do not hold if smoothing is applied")
28 | cam = get_2d_projection(elementwise_activations)
29 | else:
30 | cam = elementwise_activations.sum(axis=1)
31 | return cam
32 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/kpca_cam.py:
--------------------------------------------------------------------------------
1 | from pytorch_grad_cam.base_cam import BaseCAM
2 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection_kernel
3 |
4 | class KPCA_CAM(BaseCAM):
5 | def __init__(self, model, target_layers,
6 | reshape_transform=None, kernel='sigmoid', gamma=None):
7 | super(KPCA_CAM, self).__init__(model,
8 | target_layers,
9 | reshape_transform,
10 | uses_gradients=False)
11 | self.kernel=kernel
12 | self.gamma=gamma
13 |
14 | def get_cam_image(self,
15 | input_tensor,
16 | target_layer,
17 | target_category,
18 | activations,
19 | grads,
20 | eigen_smooth):
21 | return get_2d_projection_kernel(activations, self.kernel, self.gamma)
22 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/layer_cam.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pytorch_grad_cam.base_cam import BaseCAM
3 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
4 |
5 | # https://ieeexplore.ieee.org/document/9462463
6 |
7 |
8 | class LayerCAM(BaseCAM):
9 | def __init__(
10 | self,
11 | model,
12 | target_layers,
13 | reshape_transform=None):
14 | super(
15 | LayerCAM,
16 | self).__init__(
17 | model,
18 | target_layers,
19 | reshape_transform)
20 |
21 | def get_cam_image(self,
22 | input_tensor,
23 | target_layer,
24 | target_category,
25 | activations,
26 | grads,
27 | eigen_smooth):
28 | spatial_weighted_activations = np.maximum(grads, 0) * activations
29 |
30 | if eigen_smooth:
31 | cam = get_2d_projection(spatial_weighted_activations)
32 | else:
33 | cam = spatial_weighted_activations.sum(axis=1)
34 | return cam
35 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/metrics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/pytorch_grad_cam/metrics/__init__.py
--------------------------------------------------------------------------------
/pytorch_grad_cam/metrics/cam_mult_image.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from typing import List, Callable
4 | from pytorch_grad_cam.metrics.perturbation_confidence import PerturbationConfidenceMetric
5 |
6 |
7 | def multiply_tensor_with_cam(input_tensor: torch.Tensor,
8 | cam: torch.Tensor):
9 | """ Multiply an input tensor (after normalization)
10 | with a pixel attribution map
11 | """
12 | return input_tensor * cam
13 |
14 |
15 | class CamMultImageConfidenceChange(PerturbationConfidenceMetric):
16 | def __init__(self):
17 | super(CamMultImageConfidenceChange,
18 | self).__init__(multiply_tensor_with_cam)
19 |
20 |
21 | class DropInConfidence(CamMultImageConfidenceChange):
22 | def __init__(self):
23 | super(DropInConfidence, self).__init__()
24 |
25 | def __call__(self, *args, **kwargs):
26 | scores = super(DropInConfidence, self).__call__(*args, **kwargs)
27 | scores = -scores
28 | return np.maximum(scores, 0)
29 |
30 |
31 | class IncreaseInConfidence(CamMultImageConfidenceChange):
32 | def __init__(self):
33 | super(IncreaseInConfidence, self).__init__()
34 |
35 | def __call__(self, *args, **kwargs):
36 | scores = super(IncreaseInConfidence, self).__call__(*args, **kwargs)
37 | return np.float32(scores > 0)
38 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/metrics/perturbation_confidence.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from typing import List, Callable
4 |
5 | import numpy as np
6 | import cv2
7 |
8 |
9 | class PerturbationConfidenceMetric:
10 | def __init__(self, perturbation):
11 | self.perturbation = perturbation
12 |
13 | def __call__(self, input_tensor: torch.Tensor,
14 | cams: np.ndarray,
15 | targets: List[Callable],
16 | model: torch.nn.Module,
17 | return_visualization=False,
18 | return_diff=True):
19 |
20 | if return_diff:
21 | with torch.no_grad():
22 | outputs = model(input_tensor)
23 | scores = [target(output).cpu().numpy()
24 | for target, output in zip(targets, outputs)]
25 | scores = np.float32(scores)
26 |
27 | batch_size = input_tensor.size(0)
28 | perturbated_tensors = []
29 | for i in range(batch_size):
30 | cam = cams[i]
31 | tensor = self.perturbation(input_tensor[i, ...].cpu(),
32 | torch.from_numpy(cam))
33 | tensor = tensor.to(input_tensor.device)
34 | perturbated_tensors.append(tensor.unsqueeze(0))
35 | perturbated_tensors = torch.cat(perturbated_tensors)
36 |
37 | with torch.no_grad():
38 | outputs_after_imputation = model(perturbated_tensors)
39 | scores_after_imputation = [
40 | target(output).cpu().numpy() for target, output in zip(
41 | targets, outputs_after_imputation)]
42 | scores_after_imputation = np.float32(scores_after_imputation)
43 |
44 | if return_diff:
45 | result = scores_after_imputation - scores
46 | else:
47 | result = scores_after_imputation
48 |
49 | if return_visualization:
50 | return result, perturbated_tensors
51 | else:
52 | return result
53 |
54 |
55 | class RemoveMostRelevantFirst:
56 | def __init__(self, percentile, imputer):
57 | self.percentile = percentile
58 | self.imputer = imputer
59 |
60 | def __call__(self, input_tensor, mask):
61 | imputer = self.imputer
62 | if self.percentile != 'auto':
63 | threshold = np.percentile(mask.cpu().numpy(), self.percentile)
64 | binary_mask = np.float32(mask < threshold)
65 | else:
66 | _, binary_mask = cv2.threshold(
67 | np.uint8(mask * 255), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
68 |
69 | binary_mask = torch.from_numpy(binary_mask)
70 | binary_mask = binary_mask.to(mask.device)
71 | return imputer(input_tensor, binary_mask)
72 |
73 |
74 | class RemoveLeastRelevantFirst(RemoveMostRelevantFirst):
75 | def __init__(self, percentile, imputer):
76 | super(RemoveLeastRelevantFirst, self).__init__(percentile, imputer)
77 |
78 | def __call__(self, input_tensor, mask):
79 | return super(RemoveLeastRelevantFirst, self).__call__(
80 | input_tensor, 1 - mask)
81 |
82 |
83 | class AveragerAcrossThresholds:
84 | def __init__(
85 | self,
86 | imputer,
87 | percentiles=[
88 | 10,
89 | 20,
90 | 30,
91 | 40,
92 | 50,
93 | 60,
94 | 70,
95 | 80,
96 | 90]):
97 | self.imputer = imputer
98 | self.percentiles = percentiles
99 |
100 | def __call__(self,
101 | input_tensor: torch.Tensor,
102 | cams: np.ndarray,
103 | targets: List[Callable],
104 | model: torch.nn.Module):
105 | scores = []
106 | for percentile in self.percentiles:
107 | imputer = self.imputer(percentile)
108 | scores.append(imputer(input_tensor, cams, targets, model))
109 | return np.mean(np.float32(scores), axis=0)
110 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/metrics/road.py:
--------------------------------------------------------------------------------
1 | # A Consistent and Efficient Evaluation Strategy for Attribution Methods
2 | # https://arxiv.org/abs/2202.00449
3 | # Taken from https://raw.githubusercontent.com/tleemann/road_evaluation/main/imputations.py
4 | # MIT License
5 |
6 | # Copyright (c) 2022 Tobias Leemann
7 |
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to the following conditions:
14 |
15 | # The above copyright notice and this permission notice shall be included in all
16 | # copies or substantial portions of the Software.
17 |
18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | # SOFTWARE.
25 |
26 |
27 | # Implementations of our imputation models.
28 | import torch
29 | import numpy as np
30 | from scipy.sparse import lil_matrix, csc_matrix
31 | from scipy.sparse.linalg import spsolve
32 | from typing import List, Callable
33 | from pytorch_grad_cam.metrics.perturbation_confidence import PerturbationConfidenceMetric, \
34 | AveragerAcrossThresholds, \
35 | RemoveMostRelevantFirst, \
36 | RemoveLeastRelevantFirst
37 |
38 | # The weights of the surrounding pixels
39 | neighbors_weights = [((1, 1), 1 / 12),
40 | ((0, 1), 1 / 6),
41 | ((-1, 1), 1 / 12),
42 | ((1, -1), 1 / 12),
43 | ((0, -1), 1 / 6),
44 | ((-1, -1), 1 / 12),
45 | ((1, 0), 1 / 6),
46 | ((-1, 0), 1 / 6)]
47 |
48 |
49 | class NoisyLinearImputer:
50 | def __init__(self,
51 | noise: float = 0.01,
52 | weighting: List[float] = neighbors_weights):
53 | """
54 | Noisy linear imputation.
55 | noise: magnitude of noise to add (absolute, set to 0 for no noise)
56 | weighting: Weights of the neighboring pixels in the computation.
57 | List of tuples of (offset, weight)
58 | """
59 | self.noise = noise
60 | self.weighting = neighbors_weights
61 |
62 | @staticmethod
63 | def add_offset_to_indices(indices, offset, mask_shape):
64 | """ Add the corresponding offset to the indices.
65 | Return new indices plus a valid bit-vector. """
66 | cord1 = indices % mask_shape[1]
67 | cord0 = indices // mask_shape[1]
68 | cord0 += offset[0]
69 | cord1 += offset[1]
70 | valid = ((cord0 < 0) | (cord1 < 0) |
71 | (cord0 >= mask_shape[0]) |
72 | (cord1 >= mask_shape[1]))
73 | return ~valid, indices + offset[0] * mask_shape[1] + offset[1]
74 |
75 | @staticmethod
76 | def setup_sparse_system(mask, img, neighbors_weights):
77 | """ Vectorized version to set up the equation system.
78 | mask: (H, W)-tensor of missing pixels.
79 | Image: (H, W, C)-tensor of all values.
80 | Return (N,N)-System matrix, (N,C)-Right hand side for each of the C channels.
81 | """
82 | maskflt = mask.flatten()
83 | imgflat = img.reshape((img.shape[0], -1))
84 | # Indices that are imputed in the flattened mask:
85 | indices = np.argwhere(maskflt == 0).flatten()
86 | coords_to_vidx = np.zeros(len(maskflt), dtype=int)
87 | coords_to_vidx[indices] = np.arange(len(indices))
88 | numEquations = len(indices)
89 | # System matrix:
90 | A = lil_matrix((numEquations, numEquations))
91 | b = np.zeros((numEquations, img.shape[0]))
92 | # Sum of weights assigned:
93 | sum_neighbors = np.ones(numEquations)
94 | for n in neighbors_weights:
95 | offset, weight = n[0], n[1]
96 | # Take out outliers
97 | valid, new_coords = NoisyLinearImputer.add_offset_to_indices(
98 | indices, offset, mask.shape)
99 | valid_coords = new_coords[valid]
100 | valid_ids = np.argwhere(valid == 1).flatten()
101 | # Add values to the right hand-side
102 | has_values_coords = valid_coords[maskflt[valid_coords] > 0.5]
103 | has_values_ids = valid_ids[maskflt[valid_coords] > 0.5]
104 | b[has_values_ids, :] -= weight * imgflat[:, has_values_coords].T
105 | # Add weights to the system (left hand side)
106 | # Find coordinates in the system.
107 | has_no_values = valid_coords[maskflt[valid_coords] < 0.5]
108 | variable_ids = coords_to_vidx[has_no_values]
109 | has_no_values_ids = valid_ids[maskflt[valid_coords] < 0.5]
110 | A[has_no_values_ids, variable_ids] = weight
111 | # Reduce weight for invalid
112 | sum_neighbors[np.argwhere(valid == 0).flatten()] = \
113 | sum_neighbors[np.argwhere(valid == 0).flatten()] - weight
114 |
115 | A[np.arange(numEquations), np.arange(numEquations)] = -sum_neighbors
116 | return A, b
117 |
118 | def __call__(self, img: torch.Tensor, mask: torch.Tensor):
119 | """ Our linear inputation scheme. """
120 | """
121 | This is the function to do the linear infilling
122 | img: original image (C,H,W)-tensor;
123 | mask: mask; (H,W)-tensor
124 |
125 | """
126 | imgflt = img.reshape(img.shape[0], -1)
127 | maskflt = mask.reshape(-1)
128 | # Indices that need to be imputed.
129 | indices_linear = np.argwhere(maskflt == 0).flatten()
130 | # Set up sparse equation system, solve system.
131 | A, b = NoisyLinearImputer.setup_sparse_system(
132 | mask.numpy(), img.numpy(), neighbors_weights)
133 | res = torch.tensor(spsolve(csc_matrix(A), b), dtype=torch.float)
134 |
135 | # Fill the values with the solution of the system.
136 | img_infill = imgflt.clone()
137 | img_infill[:, indices_linear] = res.t() + self.noise * \
138 | torch.randn_like(res.t())
139 |
140 | return img_infill.reshape_as(img)
141 |
142 |
143 | class ROADMostRelevantFirst(PerturbationConfidenceMetric):
144 | def __init__(self, percentile=80):
145 | super(ROADMostRelevantFirst, self).__init__(
146 | RemoveMostRelevantFirst(percentile, NoisyLinearImputer()))
147 |
148 |
149 | class ROADLeastRelevantFirst(PerturbationConfidenceMetric):
150 | def __init__(self, percentile=20):
151 | super(ROADLeastRelevantFirst, self).__init__(
152 | RemoveLeastRelevantFirst(percentile, NoisyLinearImputer()))
153 |
154 |
155 | class ROADMostRelevantFirstAverage(AveragerAcrossThresholds):
156 | def __init__(self, percentiles=[10, 20, 30, 40, 50, 60, 70, 80, 90]):
157 | super(ROADMostRelevantFirstAverage, self).__init__(
158 | ROADMostRelevantFirst, percentiles)
159 |
160 |
161 | class ROADLeastRelevantFirstAverage(AveragerAcrossThresholds):
162 | def __init__(self, percentiles=[10, 20, 30, 40, 50, 60, 70, 80, 90]):
163 | super(ROADLeastRelevantFirstAverage, self).__init__(
164 | ROADLeastRelevantFirst, percentiles)
165 |
166 |
167 | class ROADCombined:
168 | def __init__(self, percentiles=[10, 20, 30, 40, 50, 60, 70, 80, 90]):
169 | self.percentiles = percentiles
170 | self.morf_averager = ROADMostRelevantFirstAverage(percentiles)
171 | self.lerf_averager = ROADLeastRelevantFirstAverage(percentiles)
172 |
173 | def __call__(self,
174 | input_tensor: torch.Tensor,
175 | cams: np.ndarray,
176 | targets: List[Callable],
177 | model: torch.nn.Module):
178 |
179 | scores_lerf = self.lerf_averager(input_tensor, cams, targets, model)
180 | scores_morf = self.morf_averager(input_tensor, cams, targets, model)
181 | return (scores_lerf - scores_morf) / 2
182 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/random_cam.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pytorch_grad_cam.base_cam import BaseCAM
3 |
4 |
5 | class RandomCAM(BaseCAM):
6 | def __init__(self, model, target_layers,
7 | reshape_transform=None):
8 | super(
9 | RandomCAM,
10 | self).__init__(
11 | model,
12 | target_layers,
13 | reshape_transform)
14 |
15 | def get_cam_weights(self,
16 | input_tensor,
17 | target_layer,
18 | target_category,
19 | activations,
20 | grads):
21 | return np.random.uniform(-1, 1, size=(grads.shape[0], grads.shape[1]))
22 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/score_cam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import tqdm
3 | from pytorch_grad_cam.base_cam import BaseCAM
4 |
5 |
6 | class ScoreCAM(BaseCAM):
7 | def __init__(
8 | self,
9 | model,
10 | target_layers,
11 | reshape_transform=None):
12 | super(ScoreCAM, self).__init__(model,
13 | target_layers,
14 | reshape_transform=reshape_transform,
15 | uses_gradients=False)
16 |
17 | def get_cam_weights(self,
18 | input_tensor,
19 | target_layer,
20 | targets,
21 | activations,
22 | grads):
23 | with torch.no_grad():
24 | upsample = torch.nn.UpsamplingBilinear2d(
25 | size=input_tensor.shape[-2:])
26 | activation_tensor = torch.from_numpy(activations)
27 | activation_tensor = activation_tensor.to(self.device)
28 |
29 | upsampled = upsample(activation_tensor)
30 |
31 | maxs = upsampled.view(upsampled.size(0),
32 | upsampled.size(1), -1).max(dim=-1)[0]
33 | mins = upsampled.view(upsampled.size(0),
34 | upsampled.size(1), -1).min(dim=-1)[0]
35 |
36 | maxs, mins = maxs[:, :, None, None], mins[:, :, None, None]
37 | upsampled = (upsampled - mins) / (maxs - mins + 1e-8)
38 |
39 | input_tensors = input_tensor[:, None,
40 | :, :] * upsampled[:, :, None, :, :]
41 |
42 | if hasattr(self, "batch_size"):
43 | BATCH_SIZE = self.batch_size
44 | else:
45 | BATCH_SIZE = 16
46 |
47 | scores = []
48 | for target, tensor in zip(targets, input_tensors):
49 | for i in tqdm.tqdm(range(0, tensor.size(0), BATCH_SIZE)):
50 | batch = tensor[i: i + BATCH_SIZE, :]
51 | outputs = [target(o).cpu().item()
52 | for o in self.model(batch)]
53 | scores.extend(outputs)
54 | scores = torch.Tensor(scores)
55 | scores = scores.view(activations.shape[0], activations.shape[1])
56 | weights = torch.nn.Softmax(dim=-1)(scores).numpy()
57 | return weights
58 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/shapley_cam.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, List, Optional, Tuple
2 | from pytorch_grad_cam.base_cam import BaseCAM
3 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
4 | import torch
5 | import numpy as np
6 |
7 | """
8 | Weights the activation maps using the gradient and Hessian-Vector product.
9 | This method (https://arxiv.org/abs/2501.06261) reinterpret CAM methods (include GradCAM, HiResCAM and the original CAM) from a Shapley value perspective.
10 | """
11 | class ShapleyCAM(BaseCAM):
12 | def __init__(self, model, target_layers,
13 | reshape_transform=None):
14 | super(
15 | ShapleyCAM,
16 | self).__init__(
17 | model = model,
18 | target_layers = target_layers,
19 | reshape_transform = reshape_transform,
20 | compute_input_gradient = True,
21 | uses_gradients = True,
22 | detach = False)
23 |
24 | def get_cam_weights(self,
25 | input_tensor,
26 | target_layer,
27 | target_category,
28 | activations,
29 | grads):
30 |
31 | hvp = torch.autograd.grad(
32 | outputs=grads,
33 | inputs=activations,
34 | grad_outputs=activations,
35 | retain_graph=False,
36 | allow_unused=True
37 | )[0]
38 | # print(torch.max(hvp[0]).item()) # check if hvp is not all zeros
39 | if hvp is None:
40 | hvp = torch.tensor(0).to(self.device)
41 | else:
42 | if self.activations_and_grads.reshape_transform is not None:
43 | hvp = self.activations_and_grads.reshape_transform(hvp)
44 |
45 | if self.activations_and_grads.reshape_transform is not None:
46 | activations = self.activations_and_grads.reshape_transform(activations)
47 | grads = self.activations_and_grads.reshape_transform(grads)
48 |
49 | weight = (grads - 0.5 * hvp).detach().cpu().numpy()
50 | # 2D image
51 | if len(activations.shape) == 4:
52 | weight = np.mean(weight, axis=(2, 3))
53 | return weight
54 | # 3D image
55 | elif len(activations.shape) == 5:
56 | weight = np.mean(weight, axis=(2, 3, 4))
57 | return weight
58 | else:
59 | raise ValueError("Invalid grads shape."
60 | "Shape of grads should be 4 (2D image) or 5 (3D image).")
61 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/sobel_cam.py:
--------------------------------------------------------------------------------
1 | import cv2
2 |
3 |
4 | def sobel_cam(img):
5 | gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
6 | grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
7 | grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
8 | abs_grad_x = cv2.convertScaleAbs(grad_x)
9 | abs_grad_y = cv2.convertScaleAbs(grad_y)
10 | grad = cv2.addWeighted(abs_grad_x, 0.5, abs_grad_y, 0.5, 0)
11 | return grad
12 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from pytorch_grad_cam.utils.image import deprocess_image
2 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
3 | from pytorch_grad_cam.utils import model_targets
4 | from pytorch_grad_cam.utils import reshape_transforms
5 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/utils/find_layers.py:
--------------------------------------------------------------------------------
1 | def replace_layer_recursive(model, old_layer, new_layer):
2 | for name, layer in model._modules.items():
3 | if layer == old_layer:
4 | model._modules[name] = new_layer
5 | return True
6 | elif replace_layer_recursive(layer, old_layer, new_layer):
7 | return True
8 | return False
9 |
10 |
11 | def replace_all_layer_type_recursive(model, old_layer_type, new_layer):
12 | for name, layer in model._modules.items():
13 | if isinstance(layer, old_layer_type):
14 | model._modules[name] = new_layer
15 | replace_all_layer_type_recursive(layer, old_layer_type, new_layer)
16 |
17 |
18 | def find_layer_types_recursive(model, layer_types):
19 | def predicate(layer):
20 | return type(layer) in layer_types
21 | return find_layer_predicate_recursive(model, predicate)
22 |
23 |
24 | def find_layer_predicate_recursive(model, predicate):
25 | result = []
26 | for name, layer in model._modules.items():
27 | if predicate(layer):
28 | result.append(layer)
29 | result.extend(find_layer_predicate_recursive(layer, predicate))
30 | return result
31 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/utils/image.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Dict, List
3 |
4 | import cv2
5 | import matplotlib
6 | import numpy as np
7 | import torch
8 | from matplotlib import pyplot as plt
9 | from matplotlib.lines import Line2D
10 | from scipy.ndimage import zoom
11 | from torchvision.transforms import Compose, Normalize, ToTensor
12 |
13 |
14 | def preprocess_image(
15 | img: np.ndarray, mean=[
16 | 0.5, 0.5, 0.5], std=[
17 | 0.5, 0.5, 0.5]) -> torch.Tensor:
18 | preprocessing = Compose([
19 | ToTensor(),
20 | Normalize(mean=mean, std=std)
21 | ])
22 | return preprocessing(img.copy()).unsqueeze(0)
23 |
24 |
25 | def deprocess_image(img):
26 | """ see https://github.com/jacobgil/keras-grad-cam/blob/master/grad-cam.py#L65 """
27 | img = img - np.mean(img)
28 | img = img / (np.std(img) + 1e-5)
29 | img = img * 0.1
30 | img = img + 0.5
31 | img = np.clip(img, 0, 1)
32 | return np.uint8(img * 255)
33 |
34 |
35 | def show_cam_on_image(img: np.ndarray,
36 | mask: np.ndarray,
37 | use_rgb: bool = False,
38 | colormap: int = cv2.COLORMAP_JET,
39 | image_weight: float = 0.5) -> np.ndarray:
40 | """ This function overlays the cam mask on the image as an heatmap.
41 | By default the heatmap is in BGR format.
42 |
43 | :param img: The base image in RGB or BGR format.
44 | :param mask: The cam mask.
45 | :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
46 | :param colormap: The OpenCV colormap to be used.
47 | :param image_weight: The final result is image_weight * img + (1-image_weight) * mask.
48 | :returns: The default image with the cam overlay.
49 | """
50 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
51 | if use_rgb:
52 | heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
53 | heatmap = np.float32(heatmap) / 255
54 |
55 | if np.max(img) > 1:
56 | raise Exception(
57 | "The input image should np.float32 in the range [0, 1]")
58 |
59 | if image_weight < 0 or image_weight > 1:
60 | raise Exception(
61 | f"image_weight should be in the range [0, 1].\
62 | Got: {image_weight}")
63 |
64 | cam = (1 - image_weight) * heatmap + image_weight * img
65 | cam = cam / np.max(cam)
66 | return np.uint8(255 * cam)
67 |
68 |
69 | def create_labels_legend(concept_scores: np.ndarray,
70 | labels: Dict[int, str],
71 | top_k=2):
72 | concept_categories = np.argsort(concept_scores, axis=1)[:, ::-1][:, :top_k]
73 | concept_labels_topk = []
74 | for concept_index in range(concept_categories.shape[0]):
75 | categories = concept_categories[concept_index, :]
76 | concept_labels = []
77 | for category in categories:
78 | score = concept_scores[concept_index, category]
79 | label = f"{','.join(labels[category].split(',')[:3])}:{score:.2f}"
80 | concept_labels.append(label)
81 | concept_labels_topk.append("\n".join(concept_labels))
82 | return concept_labels_topk
83 |
84 |
85 | def show_factorization_on_image(img: np.ndarray,
86 | explanations: np.ndarray,
87 | colors: List[np.ndarray] = None,
88 | image_weight: float = 0.5,
89 | concept_labels: List = None) -> np.ndarray:
90 | """ Color code the different component heatmaps on top of the image.
91 | Every component color code will be magnified according to the heatmap itensity
92 | (by modifying the V channel in the HSV color space),
93 | and optionally create a lagend that shows the labels.
94 |
95 | Since different factorization component heatmaps can overlap in principle,
96 | we need a strategy to decide how to deal with the overlaps.
97 | This keeps the component that has a higher value in it's heatmap.
98 |
99 | :param img: The base image RGB format.
100 | :param explanations: A tensor of shape num_componetns x height x width, with the component visualizations.
101 | :param colors: List of R, G, B colors to be used for the components.
102 | If None, will use the gist_rainbow cmap as a default.
103 | :param image_weight: The final result is image_weight * img + (1-image_weight) * visualization.
104 | :concept_labels: A list of strings for every component. If this is paseed, a legend that shows
105 | the labels and their colors will be added to the image.
106 | :returns: The visualized image.
107 | """
108 | n_components = explanations.shape[0]
109 | if colors is None:
110 | # taken from https://github.com/edocollins/DFF/blob/master/utils.py
111 | _cmap = plt.cm.get_cmap('gist_rainbow')
112 | colors = [
113 | np.array(
114 | _cmap(i)) for i in np.arange(
115 | 0,
116 | 1,
117 | 1.0 /
118 | n_components)]
119 | concept_per_pixel = explanations.argmax(axis=0)
120 | masks = []
121 | for i in range(n_components):
122 | mask = np.zeros(shape=(img.shape[0], img.shape[1], 3))
123 | mask[:, :, :] = colors[i][:3]
124 | explanation = explanations[i]
125 | explanation[concept_per_pixel != i] = 0
126 | mask = np.uint8(mask * 255)
127 | mask = cv2.cvtColor(mask, cv2.COLOR_RGB2HSV)
128 | mask[:, :, 2] = np.uint8(255 * explanation)
129 | mask = cv2.cvtColor(mask, cv2.COLOR_HSV2RGB)
130 | mask = np.float32(mask) / 255
131 | masks.append(mask)
132 |
133 | mask = np.sum(np.float32(masks), axis=0)
134 | result = img * image_weight + mask * (1 - image_weight)
135 | result = np.uint8(result * 255)
136 |
137 | if concept_labels is not None:
138 | px = 1 / plt.rcParams['figure.dpi'] # pixel in inches
139 | fig = plt.figure(figsize=(result.shape[1] * px, result.shape[0] * px))
140 | plt.rcParams['legend.fontsize'] = int(
141 | 14 * result.shape[0] / 256 / max(1, n_components / 6))
142 | lw = 5 * result.shape[0] / 256
143 | lines = [Line2D([0], [0], color=colors[i], lw=lw)
144 | for i in range(n_components)]
145 | plt.legend(lines,
146 | concept_labels,
147 | mode="expand",
148 | fancybox=True,
149 | shadow=True)
150 |
151 | plt.tight_layout(pad=0, w_pad=0, h_pad=0)
152 | plt.axis('off')
153 | fig.canvas.draw()
154 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
155 | plt.close(fig=fig)
156 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
157 | data = cv2.resize(data, (result.shape[1], result.shape[0]))
158 | result = np.hstack((result, data))
159 | return result
160 |
161 |
162 | def scale_cam_image(cam, target_size=None):
163 | result = []
164 | for img in cam:
165 | img = img - np.min(img)
166 | img = img / (1e-7 + np.max(img))
167 | if target_size is not None:
168 | if len(img.shape) > 2:
169 | img = zoom(np.float32(img), [
170 | (t_s / i_s) for i_s, t_s in zip(img.shape, target_size[::-1])])
171 | else:
172 | img = cv2.resize(np.float32(img), target_size)
173 |
174 | result.append(img)
175 | result = np.float32(result)
176 |
177 | return result
178 |
179 |
180 | def scale_accross_batch_and_channels(tensor, target_size):
181 | batch_size, channel_size = tensor.shape[:2]
182 | reshaped_tensor = tensor.reshape(
183 | batch_size * channel_size, *tensor.shape[2:])
184 | result = scale_cam_image(reshaped_tensor, target_size)
185 | result = result.reshape(
186 | batch_size,
187 | channel_size,
188 | target_size[1],
189 | target_size[0])
190 | return result
191 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/utils/model_targets.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torchvision
4 |
5 |
6 | class ClassifierOutputTarget:
7 | def __init__(self, category):
8 | self.category = category
9 |
10 | def __call__(self, model_output):
11 | if len(model_output.shape) == 1:
12 | return model_output[self.category]
13 | return model_output[:, self.category]
14 |
15 |
16 | class ClassifierOutputSoftmaxTarget:
17 | def __init__(self, category):
18 | self.category = category
19 |
20 | def __call__(self, model_output):
21 | if len(model_output.shape) == 1:
22 | return torch.softmax(model_output, dim=-1)[self.category]
23 | return torch.softmax(model_output, dim=-1)[:, self.category]
24 |
25 |
26 | class ClassifierOutputReST:
27 | """
28 | Using both pre-softmax and post-softmax, proposed in https://arxiv.org/abs/2501.06261
29 | """
30 | def __init__(self, category):
31 | self.category = category
32 | def __call__(self, model_output):
33 | if len(model_output.shape) == 1:
34 | target = torch.tensor([self.category], device=model_output.device)
35 | model_output = model_output.unsqueeze(0)
36 | return model_output[0][self.category] - torch.nn.functional.cross_entropy(model_output, target)
37 | else:
38 | target = torch.tensor([self.category] * model_output.shape[0], device=model_output.device)
39 | return model_output[:,self.category] - torch.nn.functional.cross_entropy(model_output, target)
40 |
41 |
42 | class BinaryClassifierOutputTarget:
43 | def __init__(self, category):
44 | self.category = category
45 |
46 | def __call__(self, model_output):
47 | if self.category == 1:
48 | sign = 1
49 | else:
50 | sign = -1
51 | return model_output * sign
52 |
53 |
54 | class SoftmaxOutputTarget:
55 | def __init__(self):
56 | pass
57 |
58 | def __call__(self, model_output):
59 | return torch.softmax(model_output, dim=-1)
60 |
61 |
62 | class RawScoresOutputTarget:
63 | def __init__(self):
64 | pass
65 |
66 | def __call__(self, model_output):
67 | return model_output
68 |
69 |
70 | class SemanticSegmentationTarget:
71 | """ Gets a binary spatial mask and a category,
72 | And return the sum of the category scores,
73 | of the pixels in the mask. """
74 |
75 | def __init__(self, category, mask):
76 | self.category = category
77 | self.mask = torch.from_numpy(mask)
78 |
79 | def __call__(self, model_output):
80 | return (model_output[self.category, :, :] * self.mask.to(model_output.device)).sum()
81 |
82 |
83 | class FasterRCNNBoxScoreTarget:
84 | """ For every original detected bounding box specified in "bounding boxes",
85 | assign a score on how the current bounding boxes match it,
86 | 1. In IOU
87 | 2. In the classification score.
88 | If there is not a large enough overlap, or the category changed,
89 | assign a score of 0.
90 |
91 | The total score is the sum of all the box scores.
92 | """
93 |
94 | def __init__(self, labels, bounding_boxes, iou_threshold=0.5):
95 | self.labels = labels
96 | self.bounding_boxes = bounding_boxes
97 | self.iou_threshold = iou_threshold
98 |
99 | def __call__(self, model_outputs):
100 | output = torch.Tensor([0])
101 | if torch.cuda.is_available():
102 | output = output.cuda()
103 | elif torch.backends.mps.is_available():
104 | output = output.to("mps")
105 |
106 | if len(model_outputs["boxes"]) == 0:
107 | return output
108 |
109 | for box, label in zip(self.bounding_boxes, self.labels):
110 | box = torch.Tensor(box[None, :])
111 | if torch.cuda.is_available():
112 | box = box.cuda()
113 | elif torch.backends.mps.is_available():
114 | box = box.to("mps")
115 |
116 | ious = torchvision.ops.box_iou(box, model_outputs["boxes"])
117 | index = ious.argmax()
118 | if ious[0, index] > self.iou_threshold and model_outputs["labels"][index] == label:
119 | score = ious[0, index] + model_outputs["scores"][index]
120 | output = output + score
121 | return output
122 |
123 | class FinerWeightedTarget:
124 | """
125 | Computes a weighted difference between a primary category and a set of comparison categories.
126 |
127 | This target calculates the difference between the score for the main category and each of the comparison categories.
128 | It obtains a weight for each comparison category from the softmax probabilities of the model output and computes a
129 | weighted difference scaled by a comparison strength factor alpha.
130 | """
131 | def __init__(self, main_category, comparison_categories, alpha):
132 | self.main_category = main_category
133 | self.comparison_categories = comparison_categories
134 | self.alpha = alpha
135 |
136 | def __call__(self, model_output):
137 | select = lambda idx: model_output[idx] if model_output.ndim == 1 else model_output[..., idx]
138 |
139 | wn = select(self.main_category)
140 |
141 | prob = torch.softmax(model_output, dim=-1)
142 |
143 | weights = [prob[idx] if model_output.ndim == 1 else prob[..., idx] for idx in self.comparison_categories]
144 | numerator = sum(w * (wn - self.alpha * select(idx)) for w, idx in zip(weights, self.comparison_categories))
145 | denominator = sum(weights)
146 |
147 | return numerator / (denominator + 1e-9)
--------------------------------------------------------------------------------
/pytorch_grad_cam/utils/reshape_transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def fasterrcnn_reshape_transform(x):
5 | target_size = x['pool'].size()[-2:]
6 | activations = []
7 | for key, value in x.items():
8 | activations.append(
9 | torch.nn.functional.interpolate(
10 | torch.abs(value),
11 | target_size,
12 | mode='bilinear'))
13 | activations = torch.cat(activations, axis=1)
14 | return activations
15 |
16 |
17 | def swinT_reshape_transform(tensor, height=7, width=7):
18 | result = tensor.reshape(tensor.size(0),
19 | height, width, tensor.size(2))
20 |
21 | # Bring the channels to the first dimension,
22 | # like in CNNs.
23 | result = result.transpose(2, 3).transpose(1, 2)
24 | return result
25 |
26 |
27 | def vit_reshape_transform(tensor, height=14, width=14):
28 | result = tensor[:, 1:, :].reshape(tensor.size(0),
29 | height, width, tensor.size(2))
30 |
31 | # Bring the channels to the first dimension,
32 | # like in CNNs.
33 | result = result.transpose(2, 3).transpose(1, 2)
34 | return result
35 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/utils/svd_on_activations.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.decomposition import KernelPCA
3 |
4 |
5 | def get_2d_projection(activation_batch):
6 | # TBD: use pytorch batch svd implementation
7 | activation_batch[np.isnan(activation_batch)] = 0
8 | projections = []
9 | for activations in activation_batch:
10 | reshaped_activations = (activations).reshape(
11 | activations.shape[0], -1).transpose()
12 | # Centering before the SVD seems to be important here,
13 | # Otherwise the image returned is negative
14 | reshaped_activations = reshaped_activations - \
15 | reshaped_activations.mean(axis=0)
16 | U, S, VT = np.linalg.svd(reshaped_activations, full_matrices=True)
17 | projection = reshaped_activations @ VT[0, :]
18 | projection = projection.reshape(activations.shape[1:])
19 | projections.append(projection)
20 | return np.float32(projections)
21 |
22 |
23 |
24 | def get_2d_projection_kernel(activation_batch, kernel='sigmoid', gamma=None):
25 | activation_batch[np.isnan(activation_batch)] = 0
26 | projections = []
27 | for activations in activation_batch:
28 | reshaped_activations = activations.reshape(activations.shape[0], -1).transpose()
29 | reshaped_activations = reshaped_activations - reshaped_activations.mean(axis=0)
30 | # Apply Kernel PCA
31 | kpca = KernelPCA(n_components=1, kernel=kernel, gamma=gamma)
32 | projection = kpca.fit_transform(reshaped_activations)
33 | projection = projection.reshape(activations.shape[1:])
34 | projections.append(projection)
35 | return np.float32(projections)
36 |
--------------------------------------------------------------------------------
/pytorch_grad_cam/xgrad_cam.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pytorch_grad_cam.base_cam import BaseCAM
3 |
4 |
5 | class XGradCAM(BaseCAM):
6 | def __init__(
7 | self,
8 | model,
9 | target_layers,
10 | reshape_transform=None):
11 | super(
12 | XGradCAM,
13 | self).__init__(
14 | model,
15 | target_layers,
16 | reshape_transform)
17 |
18 | def get_cam_weights(self,
19 | input_tensor,
20 | target_layer,
21 | target_category,
22 | activations,
23 | grads):
24 | sum_activations = np.sum(activations, axis=(2, 3))
25 | eps = 1e-7
26 | weights = grads * activations / \
27 | (sum_activations[:, :, None, None] + eps)
28 | weights = weights.sum(axis=(2, 3))
29 | return weights
30 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | Pillow
3 | torch>=1.7.1
4 | torchvision>=0.8.2
5 | ttach
6 | tqdm
7 | opencv-python
8 | matplotlib
9 | scikit-learn
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = grad-cam
3 | version = 1.1.0
4 | author = Jacob Gildenblat
5 | author_email = jacob.gildenblat@gmail.com
6 | description = Many Class Activation Map methods implemented in Pytorch. Including Grad-CAM, Grad-CAM++, Score-CAM, Ablation-CAM and XGrad-CAM
7 | long_description = file: README.md
8 | long_description_content_type = text/markdown
9 | url = https://github.com/jacobgil/pytorch-grad-cam
10 | project_urls =
11 | Bug Tracker = https://github.com/jacobgil/pytorch-grad-cam/issues
12 | classifiers =
13 | Programming Language :: Python :: 3
14 | License :: OSI Approved :: MIT License
15 | Operating System :: OS Independent
16 |
17 | [options]
18 | packages = find:
19 | python_requires = >=3.6
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | with open('README.md', mode='r', encoding='utf-8') as fh:
4 | long_description = fh.read()
5 |
6 | with open("requirements.txt", "r") as f:
7 | requirements = f.readlines()
8 |
9 | setuptools.setup(
10 | name='grad-cam',
11 | version='1.5.5',
12 | author='Jacob Gildenblat',
13 | author_email='jacob.gildenblat@gmail.com',
14 | description='Many Class Activation Map methods implemented in Pytorch for classification, segmentation, object detection and more',
15 | long_description=long_description,
16 | long_description_content_type='text/markdown',
17 | url='https://github.com/jacobgil/pytorch-grad-cam',
18 | project_urls={
19 | 'Bug Tracker': 'https://github.com/jacobgil/pytorch-grad-cam/issues',
20 | },
21 | classifiers=[
22 | 'Programming Language :: Python :: 3',
23 | 'License :: OSI Approved :: MIT License',
24 | 'Operating System :: OS Independent',
25 | ],
26 | packages=setuptools.find_packages(
27 | exclude=["*tutorials*"]),
28 | python_requires='>=3.8',
29 | install_requires=requirements)
30 |
--------------------------------------------------------------------------------
/tests/test_context_release.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torchvision
3 | import torch
4 | import cv2
5 | import psutil
6 | from pytorch_grad_cam import GradCAM, \
7 | ScoreCAM, \
8 | GradCAMPlusPlus, \
9 | AblationCAM, \
10 | XGradCAM, \
11 | EigenCAM, \
12 | EigenGradCAM, \
13 | LayerCAM, \
14 | FullGrad
15 | from pytorch_grad_cam.utils.image import show_cam_on_image, \
16 | preprocess_image
17 |
18 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
19 |
20 | torch.manual_seed(0)
21 |
22 |
23 | @pytest.fixture
24 | def numpy_image():
25 | return cv2.imread("examples/both.png")
26 |
27 |
28 | @pytest.mark.parametrize("cnn_model,target_layer_names", [
29 | (torchvision.models.resnet18, ["layer4[-1]"])
30 | ])
31 | @pytest.mark.parametrize("batch_size,width,height", [
32 | (1, 224, 224)
33 | ])
34 | @pytest.mark.parametrize("target_category", [
35 | 100
36 | ])
37 | @pytest.mark.parametrize("aug_smooth", [
38 | False
39 | ])
40 | @pytest.mark.parametrize("eigen_smooth", [
41 | False
42 | ])
43 | @pytest.mark.parametrize("cam_method",
44 | [GradCAM])
45 | def test_memory_usage_in_loop(numpy_image, batch_size, width, height,
46 | cnn_model, target_layer_names, cam_method,
47 | target_category, aug_smooth, eigen_smooth):
48 | img = cv2.resize(numpy_image, (width, height))
49 | input_tensor = preprocess_image(img)
50 | input_tensor = input_tensor.repeat(batch_size, 1, 1, 1)
51 | model = cnn_model(pretrained=True)
52 | target_layers = []
53 | for layer in target_layer_names:
54 | target_layers.append(eval(f"model.{layer}"))
55 | targets = [ClassifierOutputTarget(target_category)
56 | for _ in range(batch_size)]
57 | initial_memory = 0
58 | for i in range(100):
59 | with cam_method(model=model,
60 | target_layers=target_layers) as cam:
61 | grayscale_cam = cam(input_tensor=input_tensor,
62 | targets=targets,
63 | aug_smooth=aug_smooth,
64 | eigen_smooth=eigen_smooth)
65 |
66 | if i == 0:
67 | initial_memory = psutil.virtual_memory()[2]
68 | assert(psutil.virtual_memory()[2] <= initial_memory * 1.5)
69 |
--------------------------------------------------------------------------------
/tests/test_one_channel.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torchvision
3 | import torch
4 | import cv2
5 | import numpy as np
6 | from pytorch_grad_cam import GradCAM, \
7 | ScoreCAM, \
8 | GradCAMPlusPlus, \
9 | AblationCAM, \
10 | XGradCAM, \
11 | EigenCAM, \
12 | EigenGradCAM, \
13 | LayerCAM, \
14 | FullGrad
15 | from pytorch_grad_cam.utils.image import show_cam_on_image, \
16 | preprocess_image
17 |
18 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
19 |
20 | torch.manual_seed(0)
21 |
22 |
23 | @pytest.fixture
24 | def numpy_image():
25 | return cv2.imread("examples/both.png")
26 |
27 |
28 | @pytest.mark.parametrize("cam_method",
29 | [GradCAM])
30 | def test_memory_usage_in_loop(numpy_image, cam_method):
31 | model = torchvision.models.resnet18(pretrained=False)
32 | model.conv1 = torch.nn.Conv2d(
33 | 1, 64, kernel_size=(
34 | 7, 7), stride=(
35 | 2, 2), padding=(
36 | 3, 3), bias=False)
37 | target_layers = [model.layer4]
38 | gray_img = numpy_image[:, :, 0]
39 | input_tensor = torch.from_numpy(
40 | np.float32(gray_img)).unsqueeze(0).unsqueeze(0)
41 | input_tensor = input_tensor.repeat(16, 1, 1, 1)
42 | targets = None
43 | with cam_method(model=model,
44 | target_layers=target_layers) as cam:
45 | grayscale_cam = cam(input_tensor=input_tensor,
46 | targets=targets)
47 | print(grayscale_cam.shape)
48 |
--------------------------------------------------------------------------------
/tests/test_run_all_models.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torchvision
3 | import torch
4 | import cv2
5 | from pytorch_grad_cam import GradCAM, \
6 | ScoreCAM, \
7 | GradCAMPlusPlus, \
8 | AblationCAM, \
9 | XGradCAM, \
10 | EigenCAM, \
11 | EigenGradCAM, \
12 | LayerCAM, \
13 | FullGrad, \
14 | KPCA_CAM
15 | from pytorch_grad_cam.utils.image import show_cam_on_image, \
16 | preprocess_image
17 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
18 |
19 |
20 | @pytest.fixture
21 | def numpy_image():
22 | return cv2.imread("examples/both.png")
23 |
24 |
25 | @pytest.mark.parametrize("cnn_model,target_layer_names", [
26 | (torchvision.models.resnet18, ["layer4[-1]", "layer4[-2]"]),
27 | (torchvision.models.vgg11, ["features[-1]"])
28 | ])
29 | @pytest.mark.parametrize("batch_size,width,height", [
30 | (2, 32, 32),
31 | (1, 32, 40)
32 | ])
33 | @pytest.mark.parametrize("target_category", [
34 | None,
35 | 100
36 | ])
37 | @pytest.mark.parametrize("aug_smooth", [
38 | False
39 | ])
40 | @pytest.mark.parametrize("eigen_smooth", [
41 | True,
42 | False
43 | ])
44 | @pytest.mark.parametrize("cam_method",
45 | [ScoreCAM,
46 | AblationCAM,
47 | GradCAM,
48 | ScoreCAM,
49 | GradCAMPlusPlus,
50 | XGradCAM,
51 | EigenCAM,
52 | EigenGradCAM,
53 | LayerCAM,
54 | FullGrad,
55 | KPCA_CAM])
56 |
57 | def test_all_cam_models_can_run(numpy_image, batch_size, width, height,
58 | cnn_model, target_layer_names, cam_method,
59 | target_category, aug_smooth, eigen_smooth):
60 | img = cv2.resize(numpy_image, (width, height))
61 | input_tensor = preprocess_image(img)
62 | input_tensor = input_tensor.repeat(batch_size, 1, 1, 1)
63 |
64 | model = cnn_model(pretrained=True)
65 | target_layers = []
66 | for layer in target_layer_names:
67 | target_layers.append(eval(f"model.{layer}"))
68 |
69 | cam = cam_method(model=model,
70 | target_layers=target_layers)
71 | cam.batch_size = 4
72 | if target_category is None:
73 | targets = None
74 | else:
75 | targets = [ClassifierOutputTarget(target_category)
76 | for _ in range(batch_size)]
77 |
78 | grayscale_cam = cam(input_tensor=input_tensor,
79 | targets=targets,
80 | aug_smooth=aug_smooth,
81 | eigen_smooth=eigen_smooth)
82 | assert(grayscale_cam.shape[0] == input_tensor.shape[0])
83 | assert(grayscale_cam.shape[1:] == input_tensor.shape[2:])
84 |
--------------------------------------------------------------------------------
/tutorials/bbox.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/tutorials/bbox.png
--------------------------------------------------------------------------------
/tutorials/huggingface_dff.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/tutorials/huggingface_dff.png
--------------------------------------------------------------------------------
/tutorials/multimage.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/tutorials/multimage.png
--------------------------------------------------------------------------------
/tutorials/puppies.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/tutorials/puppies.jpg
--------------------------------------------------------------------------------
/tutorials/vision_transformers.md:
--------------------------------------------------------------------------------
1 | # How does it work with Vision Transformers
2 |
3 | *See [usage_examples/vit_example.py](../usage_examples/vit_example.py)*
4 |
5 | In ViT the output of the layers are typically BATCH x 197 x 192.
6 | In the dimension with 197, the first element represents the class token, and the rest represent the 14x14 patches in the image.
7 | We can treat the last 196 elements as a 14x14 spatial image, with 192 channels.
8 |
9 | To reshape the activations and gradients to 2D spatial images,
10 | we can pass the CAM constructor a reshape_transform function.
11 |
12 | This can also be a starting point for other architectures that will come in the future.
13 |
14 | ```python
15 |
16 | GradCAM(model=model, target_layers=target_layers, reshape_transform=reshape_transform)
17 |
18 | def reshape_transform(tensor, height=14, width=14):
19 | result = tensor[:, 1 : , :].reshape(tensor.size(0),
20 | height, width, tensor.size(2))
21 |
22 | # Bring the channels to the first dimension,
23 | # like in CNNs.
24 | result = result.transpose(2, 3).transpose(1, 2)
25 | return result
26 | ```
27 |
28 | ### Which target_layer should we chose for Vision Transformers?
29 |
30 | Since the final classification is done on the class token computed in the last attention block,
31 | the output will not be affected by the 14x14 channels in the last layer.
32 | The gradient of the output with respect to them, will be 0!
33 |
34 | We should chose any layer before the final attention block, for example:
35 | ```python
36 | target_layers = [model.blocks[-1].norm1]
37 | ```
38 |
39 | ----------
40 |
41 | # How does it work with Swin Transformers
42 |
43 | *See [usage_examples/swinT_example.py](../usage_examples/swinT_example.py)*
44 |
45 | In Swin transformer base the output of the layers are typically BATCH x 49 x 1024.
46 | We can treat the last 49 elements as a 7x7 spatial image, with 1024 channels.
47 |
48 | To reshape the activations and gradients to 2D spatial images,
49 | we can pass the CAM constructor a reshape_transform function.
50 |
51 | This can also be a starting point for other architectures that will come in the future.
52 |
53 | ```python
54 |
55 | GradCAM(model=model, target_layers=target_layers, reshape_transform=reshape_transform)
56 |
57 | def reshape_transform(tensor, height=7, width=7):
58 | result = tensor.reshape(tensor.size(0),
59 | height, width, tensor.size(2))
60 |
61 | # Bring the channels to the first dimension,
62 | # like in CNNs.
63 | result = result.transpose(2, 3).transpose(1, 2)
64 | return result
65 | ```
66 |
67 | ### Which target_layer should we chose for Swin Transformers?
68 |
69 | Since the swin transformer is different from ViT, it does not contains `cls_token` as present in ViT,
70 | therefore we will use all the 7x7 images we get from the last block of the last layer.
71 |
72 | We should chose any layer before the final attention block, for example:
73 | ```python
74 | target_layers = [model.layers[-1].blocks[-1].norm1]
75 | ```
76 |
--------------------------------------------------------------------------------
/usage_examples/clip_example.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import cv2
4 | import numpy as np
5 | import torch
6 | from torch import nn
7 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
8 | try:
9 | from transformers import CLIPProcessor, CLIPModel
10 | except ImportError:
11 | print("The transformers package is not installed. Please install it to use CLIP.")
12 | exit(1)
13 |
14 |
15 | from pytorch_grad_cam import GradCAM, \
16 | ScoreCAM, \
17 | GradCAMPlusPlus, \
18 | AblationCAM, \
19 | XGradCAM, \
20 | EigenCAM, \
21 | EigenGradCAM, \
22 | LayerCAM, \
23 | FullGrad
24 |
25 | from pytorch_grad_cam.utils.image import show_cam_on_image, \
26 | preprocess_image
27 | from pytorch_grad_cam.ablation_layer import AblationLayerVit
28 |
29 |
30 | def get_args():
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument('--device', type=str, default='cpu',
33 | help='Torch device to use')
34 | parser.add_argument(
35 | '--image-path',
36 | type=str,
37 | default='./examples/both.png',
38 | help='Input image path')
39 | parser.add_argument(
40 | '--labels',
41 | type=str,
42 | nargs='+',
43 | default=["a cat", "a dog", "a car", "a person", "a shoe"],
44 | help='need recognition labels'
45 | )
46 |
47 | parser.add_argument('--aug_smooth', action='store_true',
48 | help='Apply test time augmentation to smooth the CAM')
49 | parser.add_argument(
50 | '--eigen_smooth',
51 | action='store_true',
52 | help='Reduce noise by taking the first principle componenet'
53 | 'of cam_weights*activations')
54 |
55 | parser.add_argument(
56 | '--method',
57 | type=str,
58 | default='gradcam',
59 | help='Can be gradcam/gradcam++/scorecam/xgradcam/ablationcam')
60 |
61 | args = parser.parse_args()
62 | if args.device:
63 | print(f'Using device "{args.device}" for acceleration')
64 | else:
65 | print('Using CPU for computation')
66 |
67 | return args
68 |
69 |
70 | def reshape_transform(tensor, height=16, width=16):
71 | result = tensor[:, 1:, :].reshape(tensor.size(0),
72 | height, width, tensor.size(2))
73 |
74 | # Bring the channels to the first dimension,
75 | # like in CNNs.
76 | result = result.transpose(2, 3).transpose(1, 2)
77 | return result
78 |
79 |
80 | class ImageClassifier(nn.Module):
81 | def __init__(self, labels):
82 | super(ImageClassifier, self).__init__()
83 | self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
84 | self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
85 | self.labels = labels
86 |
87 | def forward(self, x):
88 | text_inputs = self.processor(text=self.labels, return_tensors="pt", padding=True)
89 |
90 | outputs = self.clip(pixel_values=x, input_ids=text_inputs['input_ids'].to(self.clip.device),
91 | attention_mask=text_inputs['attention_mask'].to(self.clip.device))
92 |
93 | logits_per_image = outputs.logits_per_image
94 | probs = logits_per_image.softmax(dim=1)
95 |
96 | for label, prob in zip(self.labels, probs[0]):
97 | print(f"{label}: {prob:.4f}")
98 | return probs
99 |
100 |
101 | if __name__ == '__main__':
102 | """ python vit_gradcam.py --image-path
103 | Example usage of using cam-methods on a VIT network.
104 |
105 | """
106 |
107 | args = get_args()
108 | methods = \
109 | {"gradcam": GradCAM,
110 | "scorecam": ScoreCAM,
111 | "gradcam++": GradCAMPlusPlus,
112 | "ablationcam": AblationCAM,
113 | "xgradcam": XGradCAM,
114 | "eigencam": EigenCAM,
115 | "eigengradcam": EigenGradCAM,
116 | "layercam": LayerCAM,
117 | "fullgrad": FullGrad}
118 |
119 | if args.method not in list(methods.keys()):
120 | raise Exception(f"method should be one of {list(methods.keys())}")
121 |
122 | labels = args.labels
123 | model = ImageClassifier(labels).to(torch.device(args.device)).eval()
124 | print(model)
125 |
126 | target_layers = [model.clip.vision_model.encoder.layers[-1].layer_norm1]
127 |
128 | if args.method not in methods:
129 | raise Exception(f"Method {args.method} not implemented")
130 |
131 | rgb_img = cv2.imread(args.image_path, 1)[:, :, ::-1]
132 | rgb_img = cv2.resize(rgb_img, (224, 224))
133 | rgb_img = np.float32(rgb_img) / 255
134 | input_tensor = preprocess_image(rgb_img, mean=[0.5, 0.5, 0.5],
135 | std=[0.5, 0.5, 0.5]).to(args.device)
136 |
137 | if args.method == "ablationcam":
138 | cam = methods[args.method](model=model,
139 | target_layers=target_layers,
140 | reshape_transform=reshape_transform,
141 | ablation_layer=AblationLayerVit())
142 | else:
143 | cam = methods[args.method](model=model,
144 | target_layers=target_layers,
145 | reshape_transform=reshape_transform)
146 |
147 |
148 |
149 | # If None, returns the map for the highest scoring category.
150 | # Otherwise, targets the requested category.
151 | #targets = [ClassifierOutputTarget(1)]
152 | targets = None
153 |
154 | # AblationCAM and ScoreCAM have batched implementations.
155 | # You can override the internal batch size for faster computation.
156 | cam.batch_size = 32
157 |
158 | grayscale_cam = cam(input_tensor=input_tensor,
159 | targets=targets,
160 | eigen_smooth=args.eigen_smooth,
161 | aug_smooth=args.aug_smooth)
162 |
163 | # Here grayscale_cam has only one image in the batch
164 | grayscale_cam = grayscale_cam[0, :]
165 |
166 | cam_image = show_cam_on_image(rgb_img, grayscale_cam)
167 | cv2.imwrite(f'{args.method}_cam.jpg', cam_image)
168 |
--------------------------------------------------------------------------------
/usage_examples/swinT_example.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import numpy as np
4 | import torch
5 | import timm
6 |
7 | from pytorch_grad_cam import GradCAM, \
8 | ScoreCAM, \
9 | GradCAMPlusPlus, \
10 | AblationCAM, \
11 | XGradCAM, \
12 | EigenCAM, \
13 | EigenGradCAM, \
14 | LayerCAM, \
15 | FullGrad
16 |
17 | from pytorch_grad_cam.utils.image import show_cam_on_image, \
18 | preprocess_image
19 | from pytorch_grad_cam.ablation_layer import AblationLayerVit
20 |
21 |
22 | def get_args():
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument('--use-cuda', action='store_true', default=False,
25 | help='Use NVIDIA GPU acceleration')
26 | parser.add_argument(
27 | '--image-path',
28 | type=str,
29 | default='./examples/both.png',
30 | help='Input image path')
31 | parser.add_argument('--aug_smooth', action='store_true',
32 | help='Apply test time augmentation to smooth the CAM')
33 | parser.add_argument(
34 | '--eigen_smooth',
35 | action='store_true',
36 | help='Reduce noise by taking the first principle componenet'
37 | 'of cam_weights*activations')
38 |
39 | parser.add_argument(
40 | '--method',
41 | type=str,
42 | default='scorecam',
43 | help='Can be gradcam/gradcam++/scorecam/xgradcam/ablationcam')
44 |
45 | args = parser.parse_args()
46 | args.use_cuda = args.use_cuda and torch.cuda.is_available()
47 | if args.use_cuda:
48 | print('Using GPU for acceleration')
49 | else:
50 | print('Using CPU for computation')
51 |
52 | return args
53 |
54 |
55 | def reshape_transform(tensor, height=7, width=7):
56 | result = tensor.reshape(tensor.size(0),
57 | height, width, tensor.size(2))
58 |
59 | # Bring the channels to the first dimension,
60 | # like in CNNs.
61 | result = result.transpose(2, 3).transpose(1, 2)
62 | return result
63 |
64 |
65 | if __name__ == '__main__':
66 | """ python swinT_example.py -image-path
67 | Example usage of using cam-methods on a SwinTransformers network.
68 |
69 | """
70 |
71 | args = get_args()
72 | methods = \
73 | {"gradcam": GradCAM,
74 | "scorecam": ScoreCAM,
75 | "gradcam++": GradCAMPlusPlus,
76 | "ablationcam": AblationCAM,
77 | "xgradcam": XGradCAM,
78 | "eigencam": EigenCAM,
79 | "eigengradcam": EigenGradCAM,
80 | "layercam": LayerCAM,
81 | "fullgrad": FullGrad}
82 |
83 | if args.method not in list(methods.keys()):
84 | raise Exception(f"method should be one of {list(methods.keys())}")
85 |
86 | model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)
87 | model.eval()
88 |
89 | if args.use_cuda:
90 | model = model.cuda()
91 |
92 | target_layers = [model.layers[-1].blocks[-1].norm2]
93 |
94 | if args.method not in methods:
95 | raise Exception(f"Method {args.method} not implemented")
96 |
97 | if args.method == "ablationcam":
98 | cam = methods[args.method](model=model,
99 | target_layers=target_layers,
100 | reshape_transform=reshape_transform,
101 | ablation_layer=AblationLayerVit())
102 | else:
103 | cam = methods[args.method](model=model,
104 | target_layers=target_layers,
105 | reshape_transform=reshape_transform)
106 |
107 | rgb_img = cv2.imread(args.image_path, 1)[:, :, ::-1]
108 | rgb_img = cv2.resize(rgb_img, (224, 224))
109 | rgb_img = np.float32(rgb_img) / 255
110 | input_tensor = preprocess_image(rgb_img, mean=[0.5, 0.5, 0.5],
111 | std=[0.5, 0.5, 0.5])
112 |
113 | # AblationCAM and ScoreCAM have batched implementations.
114 | # You can override the internal batch size for faster computation.
115 | cam.batch_size = 32
116 |
117 | grayscale_cam = cam(input_tensor=input_tensor,
118 | targets=None,
119 | eigen_smooth=args.eigen_smooth,
120 | aug_smooth=args.aug_smooth)
121 |
122 | # Here grayscale_cam has only one image in the batch
123 | grayscale_cam = grayscale_cam[0, :]
124 |
125 | cam_image = show_cam_on_image(rgb_img, grayscale_cam)
126 | cv2.imwrite(f'{args.method}_cam.jpg', cam_image)
127 |
--------------------------------------------------------------------------------
/usage_examples/vit_example.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import numpy as np
4 | import torch
5 |
6 | from pytorch_grad_cam import GradCAM, \
7 | ScoreCAM, \
8 | GradCAMPlusPlus, \
9 | AblationCAM, \
10 | XGradCAM, \
11 | EigenCAM, \
12 | EigenGradCAM, \
13 | LayerCAM, \
14 | FullGrad
15 |
16 | from pytorch_grad_cam import GuidedBackpropReLUModel
17 | from pytorch_grad_cam.utils.image import show_cam_on_image, \
18 | preprocess_image
19 | from pytorch_grad_cam.ablation_layer import AblationLayerVit
20 |
21 |
22 | def get_args():
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument('--device', type=str, default='cpu',
25 | help='Torch device to use')
26 |
27 | parser.add_argument(
28 | '--image-path',
29 | type=str,
30 | default='./examples/both.png',
31 | help='Input image path')
32 | parser.add_argument('--aug_smooth', action='store_true',
33 | help='Apply test time augmentation to smooth the CAM')
34 | parser.add_argument(
35 | '--eigen_smooth',
36 | action='store_true',
37 | help='Reduce noise by taking the first principle componenet'
38 | 'of cam_weights*activations')
39 |
40 | parser.add_argument(
41 | '--method',
42 | type=str,
43 | default='gradcam',
44 | help='Can be gradcam/gradcam++/scorecam/xgradcam/ablationcam')
45 |
46 | args = parser.parse_args()
47 | if args.device:
48 | print(f'Using device "{args.device}" for acceleration')
49 | else:
50 | print('Using CPU for computation')
51 |
52 | return args
53 |
54 |
55 | def reshape_transform(tensor, height=14, width=14):
56 | result = tensor[:, 1:, :].reshape(tensor.size(0),
57 | height, width, tensor.size(2))
58 |
59 | # Bring the channels to the first dimension,
60 | # like in CNNs.
61 | result = result.transpose(2, 3).transpose(1, 2)
62 | return result
63 |
64 |
65 | if __name__ == '__main__':
66 | """ python vit_gradcam.py --image-path
67 | Example usage of using cam-methods on a VIT network.
68 |
69 | """
70 |
71 | args = get_args()
72 | methods = \
73 | {"gradcam": GradCAM,
74 | "scorecam": ScoreCAM,
75 | "gradcam++": GradCAMPlusPlus,
76 | "ablationcam": AblationCAM,
77 | "xgradcam": XGradCAM,
78 | "eigencam": EigenCAM,
79 | "eigengradcam": EigenGradCAM,
80 | "layercam": LayerCAM,
81 | "fullgrad": FullGrad}
82 |
83 | if args.method not in list(methods.keys()):
84 | raise Exception(f"method should be one of {list(methods.keys())}")
85 |
86 | model = torch.hub.load('facebookresearch/deit:main',
87 | 'deit_tiny_patch16_224', pretrained=True).to(torch.device(args.device)).eval()
88 |
89 |
90 | target_layers = [model.blocks[-1].norm1]
91 |
92 | if args.method not in methods:
93 | raise Exception(f"Method {args.method} not implemented")
94 |
95 | if args.method == "ablationcam":
96 | cam = methods[args.method](model=model,
97 | target_layers=target_layers,
98 | reshape_transform=reshape_transform,
99 | ablation_layer=AblationLayerVit())
100 | else:
101 | cam = methods[args.method](model=model,
102 | target_layers=target_layers,
103 | reshape_transform=reshape_transform)
104 |
105 | rgb_img = cv2.imread(args.image_path, 1)[:, :, ::-1]
106 | rgb_img = cv2.resize(rgb_img, (224, 224))
107 | rgb_img = np.float32(rgb_img) / 255
108 | input_tensor = preprocess_image(rgb_img, mean=[0.5, 0.5, 0.5],
109 | std=[0.5, 0.5, 0.5]).to(args.device)
110 |
111 | # If None, returns the map for the highest scoring category.
112 | # Otherwise, targets the requested category.
113 | targets = None
114 |
115 | # AblationCAM and ScoreCAM have batched implementations.
116 | # You can override the internal batch size for faster computation.
117 | cam.batch_size = 32
118 |
119 | grayscale_cam = cam(input_tensor=input_tensor,
120 | targets=targets,
121 | eigen_smooth=args.eigen_smooth,
122 | aug_smooth=args.aug_smooth)
123 |
124 | # Here grayscale_cam has only one image in the batch
125 | grayscale_cam = grayscale_cam[0, :]
126 |
127 | cam_image = show_cam_on_image(rgb_img, grayscale_cam)
128 | cv2.imwrite(f'{args.method}_cam.jpg', cam_image)
129 |
--------------------------------------------------------------------------------