├── .gitattributes ├── .github └── workflows │ └── python-app.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── cam.py ├── examples ├── augsmooth.jpg ├── both.png ├── both_detection.png ├── cam_gb_dog.jpg ├── cars_segmentation.png ├── cat.jpg ├── clip_cat.jpg ├── clip_dog.jpg ├── dff1.png ├── dff2.png ├── dog.jpg ├── dog_cat.jfif ├── dogs.png ├── dogs_gradcam++_resnet50.jpg ├── dogs_gradcam++_vgg16.jpg ├── dogs_gradcam_resnet50.jpg ├── dogs_gradcam_vgg16.jpg ├── dogs_scorecam_resnet50.jpg ├── dogs_scorecam_vgg16.jpg ├── eigenaug.jpg ├── eigensmooth.jpg ├── embeddings.png ├── horses.jpg ├── metrics.png ├── multiorgan_segmentation.gif ├── nosmooth.jpg ├── resnet50_cat_ablationcam_cam.jpg ├── resnet50_cat_gradcam_cam.jpg ├── resnet50_cat_scorecam_cam.jpg ├── resnet50_dog_ablationcam_cam.jpg ├── resnet50_dog_gradcam_cam.jpg ├── resnet50_dog_scorecam_cam.jpg ├── resnet_horses_ablationcam_cam.jpg ├── resnet_horses_gradcam++_cam.jpg ├── resnet_horses_gradcam_cam.jpg ├── resnet_horses_horses_eigencam_cam.jpg ├── resnet_horses_scorecam_cam.jpg ├── road.png ├── swinT_cat_ablationcam_cam.jpg ├── swinT_cat_gradcam_cam.jpg ├── swinT_cat_scorecam_cam.jpg ├── swinT_dog_ablationcam_cam.jpg ├── swinT_dog_gradcam_cam.jpg ├── swinT_dog_scorecam_cam.jpg ├── vgg_horses_ablationcam_cam.jpg ├── vgg_horses_eigencam_cam.jpg ├── vgg_horses_gradcam++_cam.jpg ├── vgg_horses_gradcam_cam.jpg ├── vgg_horses_scorecam_cam.jpg ├── vit_cat_ablationcam_cam.jpg ├── vit_cat_gradcam_cam.jpg ├── vit_cat_scorecam_cam.jpg ├── vit_dog_ablationcam_cam.jpg ├── vit_dog_gradcam_cam.jpg ├── vit_dog_scorecam_cam.jpg └── yolo_eigencam.png ├── pyproject.toml ├── pytorch_grad_cam ├── __init__.py ├── ablation_cam.py ├── ablation_cam_multilayer.py ├── ablation_layer.py ├── activations_and_gradients.py ├── base_cam.py ├── eigen_cam.py ├── eigen_grad_cam.py ├── feature_factorization │ ├── __init__.py │ └── deep_feature_factorization.py ├── fem.py ├── finer_cam.py ├── fullgrad_cam.py ├── grad_cam.py ├── grad_cam_elementwise.py ├── grad_cam_plusplus.py ├── guided_backprop.py ├── hirescam.py ├── kpca_cam.py ├── layer_cam.py ├── metrics │ ├── __init__.py │ ├── cam_mult_image.py │ ├── perturbation_confidence.py │ └── road.py ├── random_cam.py ├── score_cam.py ├── shapley_cam.py ├── sobel_cam.py ├── utils │ ├── __init__.py │ ├── find_layers.py │ ├── image.py │ ├── model_targets.py │ ├── reshape_transforms.py │ └── svd_on_activations.py └── xgrad_cam.py ├── requirements.txt ├── setup.cfg ├── setup.py ├── tests ├── test_context_release.py ├── test_one_channel.py └── test_run_all_models.py ├── tutorials ├── CAM Metrics And Tuning Tutorial.ipynb ├── Class Activation Maps for Object Detection With Faster RCNN.ipynb ├── Class Activation Maps for Semantic Segmentation.ipynb ├── Deep Feature Factorizations.ipynb ├── EigenCAM for YOLO5.ipynb ├── HuggingFace.ipynb ├── Pixel Attribution for embeddings.ipynb ├── bbox.png ├── huggingface_dff.png ├── multimage.png ├── puppies.jpg └── vision_transformers.md └── usage_examples ├── clip_example.py ├── swinT_example.py └── vit_example.py /.gitattributes: -------------------------------------------------------------------------------- 1 | tutorials/*.ipynb linguist-documentation=true -------------------------------------------------------------------------------- /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Tests 5 | 6 | on: [push, pull_request] 7 | 8 | jobs: 9 | build: 10 | 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python 3.8 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: 3.8 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install flake8 pytest==5.3.2 pytest-faulthandler psutil 23 | pip install -e . 24 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 25 | - name: Lint with flake8 26 | run: | 27 | # stop the build if there are Python syntax errors or undefined names 28 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 29 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 30 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 31 | - name: Test with pytest 32 | run: | 33 | pytest 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | dist 3 | *.egg-info 4 | **/*.pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Jacob Gildenblat 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | prune tutorials -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 2 | ![Build Status](https://github.com/jacobgil/pytorch-grad-cam/workflows/Tests/badge.svg) 3 | [![Downloads](https://static.pepy.tech/personalized-badge/grad-cam?period=month&units=international_system&left_color=black&right_color=brightgreen&left_text=Monthly%20Downloads)](https://pepy.tech/project/grad-cam) 4 | [![Downloads](https://static.pepy.tech/personalized-badge/grad-cam?period=total&units=international_system&left_color=black&right_color=blue&left_text=Total%20Downloads)](https://pepy.tech/project/grad-cam) 5 | 6 | # Advanced AI explainability for PyTorch 7 | 8 | `pip install grad-cam` 9 | 10 | Documentation with advanced tutorials: [https://jacobgil.github.io/pytorch-gradcam-book](https://jacobgil.github.io/pytorch-gradcam-book) 11 | 12 | 13 | This is a package with state of the art methods for Explainable AI for computer vision. 14 | This can be used for diagnosing model predictions, either in production or while 15 | developing models. 16 | The aim is also to serve as a benchmark of algorithms and metrics for research of new explainability methods. 17 | 18 | ⭐ Comprehensive collection of Pixel Attribution methods for Computer Vision. 19 | 20 | ⭐ Tested on many Common CNN Networks and Vision Transformers. 21 | 22 | ⭐ Advanced use cases: Works with Classification, Object Detection, Semantic Segmentation, Embedding-similarity and more. 23 | 24 | ⭐ Includes smoothing methods to make the CAMs look nice. 25 | 26 | ⭐ High performance: full support for batches of images in all methods. 27 | 28 | ⭐ Includes metrics for checking if you can trust the explanations, and tuning them for best performance. 29 | 30 | 31 | ![visualization](https://github.com/jacobgil/jacobgil.github.io/blob/master/assets/cam_dog.gif?raw=true 32 | ) 33 | 34 | | Method | What it does | 35 | |---------------------|-----------------------------------------------------------------------------------------------------------------------------| 36 | | GradCAM | Weight the 2D activations by the average gradient | 37 | | HiResCAM | Like GradCAM but element-wise multiply the activations with the gradients; provably guaranteed faithfulness for certain models | 38 | | GradCAMElementWise | Like GradCAM but element-wise multiply the activations with the gradients then apply a ReLU operation before summing | 39 | | GradCAM++ | Like GradCAM but uses second order gradients | 40 | | XGradCAM | Like GradCAM but scale the gradients by the normalized activations | 41 | | AblationCAM | Zero out activations and measure how the output drops (this repository includes a fast batched implementation) | 42 | | ScoreCAM | Perbutate the image by the scaled activations and measure how the output drops | 43 | | EigenCAM | Takes the first principle component of the 2D Activations (no class discrimination, but seems to give great results) | 44 | | EigenGradCAM | Like EigenCAM but with class discrimination: First principle component of Activations*Grad. Looks like GradCAM, but cleaner | 45 | | LayerCAM | Spatially weight the activations by positive gradients. Works better especially in lower layers | 46 | | FullGrad | Computes the gradients of the biases from all over the network, and then sums them | 47 | | Deep Feature Factorizations | Non Negative Matrix Factorization on the 2D activations | 48 | | KPCA-CAM | Like EigenCAM but with Kernel PCA instead of PCA | 49 | | FEM | A gradient free method that binarizes activations by an activation > mean + k * std rule. | 50 | | ShapleyCAM | Weight the activations using the gradient and Hessian-vector product.| 51 | | FinerCAM | Improves fine-grained classification by comparing similar classes, suppressing shared features and highlighting discriminative details. | 52 | ## Visual Examples 53 | 54 | | What makes the network think the image label is 'pug, pug-dog' | What makes the network think the image label is 'tabby, tabby cat' | Combining Grad-CAM with Guided Backpropagation for the 'pug, pug-dog' class | 55 | | ---------------------------------------------------------------|--------------------|-----------------------------------------------------------------------------| 56 | | | | 57 | 58 | ## Object Detection and Semantic Segmentation 59 | | Object Detection | Semantic Segmentation | 60 | | -----------------|-----------------------| 61 | | | | 62 | 63 | | 3D Medical Semantic Segmentation | 64 | | -------------------------- | 65 | | | 66 | 67 | ## Explaining similarity to other images / embeddings 68 | 69 | 70 | ## Deep Feature Factorization 71 | 72 | 73 | 74 | ## CLIP 75 | | Explaining the text prompt "a dog" | Explaining the text prompt "a cat" | 76 | | -----------------------------------|------------------------------------| 77 | | | 78 | 79 | ## Classification 80 | 81 | #### Resnet50: 82 | | Category | Image | GradCAM | AblationCAM | ScoreCAM | 83 | | ---------|-------|----------|------------|------------| 84 | | Dog | ![](./examples/dog_cat.jfif) | ![](./examples/resnet50_dog_gradcam_cam.jpg) | ![](./examples/resnet50_dog_ablationcam_cam.jpg) |![](./examples/resnet50_dog_scorecam_cam.jpg) | 85 | | Cat | ![](./examples/dog_cat.jfif?raw=true) | ![](./examples/resnet50_cat_gradcam_cam.jpg?raw=true) | ![](./examples/resnet50_cat_ablationcam_cam.jpg?raw=true) |![](./examples/resnet50_cat_scorecam_cam.jpg) | 86 | 87 | #### Vision Transfomer (Deit Tiny): 88 | | Category | Image | GradCAM | AblationCAM | ScoreCAM | 89 | | ---------|-------|----------|------------|------------| 90 | | Dog | ![](./examples/dog_cat.jfif) | ![](./examples/vit_dog_gradcam_cam.jpg) | ![](./examples/vit_dog_ablationcam_cam.jpg) |![](./examples/vit_dog_scorecam_cam.jpg) | 91 | | Cat | ![](./examples/dog_cat.jfif) | ![](./examples/vit_cat_gradcam_cam.jpg) | ![](./examples/vit_cat_ablationcam_cam.jpg) |![](./examples/vit_cat_scorecam_cam.jpg) | 92 | 93 | #### Swin Transfomer (Tiny window:7 patch:4 input-size:224): 94 | | Category | Image | GradCAM | AblationCAM | ScoreCAM | 95 | | ---------|-------|----------|------------|------------| 96 | | Dog | ![](./examples/dog_cat.jfif) | ![](./examples/swinT_dog_gradcam_cam.jpg) | ![](./examples/swinT_dog_ablationcam_cam.jpg) |![](./examples/swinT_dog_scorecam_cam.jpg) | 97 | | Cat | ![](./examples/dog_cat.jfif) | ![](./examples/swinT_cat_gradcam_cam.jpg) | ![](./examples/swinT_cat_ablationcam_cam.jpg) |![](./examples/swinT_cat_scorecam_cam.jpg) | 98 | 99 | 100 | # Metrics and Evaluation for XAI 101 | 102 | 103 | 104 | 105 | ---------- 106 | 107 | # Usage examples 108 | 109 | ```python 110 | from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad 111 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget 112 | from pytorch_grad_cam.utils.image import show_cam_on_image 113 | from torchvision.models import resnet50 114 | 115 | model = resnet50(pretrained=True) 116 | target_layers = [model.layer4[-1]] 117 | input_tensor = # Create an input tensor image for your model.. 118 | # Note: input_tensor can be a batch tensor with several images! 119 | 120 | # We have to specify the target we want to generate the CAM for. 121 | targets = [ClassifierOutputTarget(281)] 122 | 123 | # Construct the CAM object once, and then re-use it on many images. 124 | with GradCAM(model=model, target_layers=target_layers) as cam: 125 | # You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing. 126 | grayscale_cam = cam(input_tensor=input_tensor, targets=targets) 127 | # In this example grayscale_cam has only one image in the batch: 128 | grayscale_cam = grayscale_cam[0, :] 129 | visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True) 130 | # You can also get the model outputs without having to redo inference 131 | model_outputs = cam.outputs 132 | ``` 133 | 134 | [cam.py](https://github.com/jacobgil/pytorch-grad-cam/blob/master/cam.py) has a more detailed usage example. 135 | 136 | ---------- 137 | # Choosing the layer(s) to extract activations from 138 | You need to choose the target layer to compute the CAM for. 139 | Some common choices are: 140 | - FasterRCNN: model.backbone 141 | - Resnet18 and 50: model.layer4[-1] 142 | - VGG, densenet161 and mobilenet: model.features[-1] 143 | - mnasnet1_0: model.layers[-1] 144 | - ViT: model.blocks[-1].norm1 145 | - SwinT: model.layers[-1].blocks[-1].norm1 146 | 147 | 148 | If you pass a list with several layers, the CAM will be averaged accross them. 149 | This can be useful if you're not sure what layer will perform best. 150 | 151 | ---------- 152 | 153 | # Adapting for new architectures and tasks 154 | 155 | Methods like GradCAM were designed for and were originally mostly applied on classification models, 156 | and specifically CNN classification models. 157 | However you can also use this package on new architectures like Vision Transformers, and on non classification tasks like Object Detection or Semantic Segmentation. 158 | 159 | The be able to adapt to non standard cases, we have two concepts. 160 | - The reshape transform - how do we convert activations to represent spatial images ? 161 | - The model targets - What exactly should the explainability method try to explain ? 162 | 163 | ## The reshape_transform argument 164 | In a CNN the intermediate activations in the model are a mult-channel image that have the dimensions channel x rows x cols, 165 | and the various explainabiltiy methods work with these to produce a new image. 166 | 167 | In case of another architecture, like the Vision Transformer, the shape might be different, like (rows x cols + 1) x channels, or something else. 168 | The reshape transform converts the activations back into a multi-channel image, for example by removing the class token in a vision transformer. 169 | For examples, check [here](https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/reshape_transforms.py) 170 | 171 | ## The model_target argument 172 | The model target is just a callable that is able to get the model output, and filter it out for the specific scalar output we want to explain. 173 | 174 | For classification tasks, the model target will typically be the output from a specific category. 175 | The `targets` parameter passed to the CAM method can then use `ClassifierOutputTarget`: 176 | ```python 177 | targets = [ClassifierOutputTarget(281)] 178 | ``` 179 | 180 | However for more advanced cases, you might want a different behaviour. 181 | Check [here](https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/model_targets.py) for more examples. 182 | 183 | ---------- 184 | 185 | # Tutorials 186 | Here you can find detailed examples of how to use this for various custom use cases like object detection: 187 | 188 | These point to the new documentation jupter-book for fast rendering. 189 | The jupyter notebooks themselves can be found under the tutorials folder in the git repository. 190 | 191 | - [Notebook tutorial: XAI Recipes for the HuggingFace 🤗 Image Classification Models]() 192 | 193 | - [Notebook tutorial: Deep Feature Factorizations for better model explainability]() 194 | 195 | - [Notebook tutorial: Class Activation Maps for Object Detection with Faster-RCNN]() 196 | 197 | - [Notebook tutorial: Class Activation Maps for YOLO5]() 198 | 199 | - [Notebook tutorial: Class Activation Maps for Semantic Segmentation]() 200 | 201 | - [Notebook tutorial: Adapting pixel attribution methods for embedding outputs from models]() 202 | 203 | - [Notebook tutorial: May the best explanation win. CAM Metrics and Tuning]() 204 | 205 | - [How it works with Vision/SwinT transformers](tutorials/vision_transformers.md) 206 | 207 | 208 | ---------- 209 | 210 | # Guided backpropagation 211 | 212 | ```python 213 | from pytorch_grad_cam import GuidedBackpropReLUModel 214 | from pytorch_grad_cam.utils.image import ( 215 | show_cam_on_image, deprocess_image, preprocess_image 216 | ) 217 | gb_model = GuidedBackpropReLUModel(model=model, device=model.device()) 218 | gb = gb_model(input_tensor, target_category=None) 219 | 220 | cam_mask = cv2.merge([grayscale_cam, grayscale_cam, grayscale_cam]) 221 | cam_gb = deprocess_image(cam_mask * gb) 222 | result = deprocess_image(gb) 223 | ``` 224 | 225 | ---------- 226 | 227 | # Metrics and evaluating the explanations 228 | 229 | ```python 230 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputSoftmaxTarget 231 | from pytorch_grad_cam.metrics.cam_mult_image import CamMultImageConfidenceChange 232 | # Create the metric target, often the confidence drop in a score of some category 233 | metric_target = ClassifierOutputSoftmaxTarget(281) 234 | scores, batch_visualizations = CamMultImageConfidenceChange()(input_tensor, 235 | inverse_cams, targets, model, return_visualization=True) 236 | visualization = deprocess_image(batch_visualizations[0, :]) 237 | 238 | # State of the art metric: Remove and Debias 239 | from pytorch_grad_cam.metrics.road import ROADMostRelevantFirst, ROADLeastRelevantFirst 240 | cam_metric = ROADMostRelevantFirst(percentile=75) 241 | scores, perturbation_visualizations = cam_metric(input_tensor, 242 | grayscale_cams, targets, model, return_visualization=True) 243 | 244 | # You can also average across different percentiles, and combine 245 | # (LeastRelevantFirst - MostRelevantFirst) / 2 246 | from pytorch_grad_cam.metrics.road import ROADMostRelevantFirstAverage, 247 | ROADLeastRelevantFirstAverage, 248 | ROADCombined 249 | cam_metric = ROADCombined(percentiles=[20, 40, 60, 80]) 250 | scores = cam_metric(input_tensor, grayscale_cams, targets, model) 251 | ``` 252 | 253 | 254 | # Smoothing to get nice looking CAMs 255 | 256 | To reduce noise in the CAMs, and make it fit better on the objects, 257 | two smoothing methods are supported: 258 | 259 | - `aug_smooth=True` 260 | 261 | Test time augmentation: increases the run time by x6. 262 | 263 | Applies a combination of horizontal flips, and mutiplying the image 264 | by [1.0, 1.1, 0.9]. 265 | 266 | This has the effect of better centering the CAM around the objects. 267 | 268 | 269 | - `eigen_smooth=True` 270 | 271 | First principle component of `activations*weights` 272 | 273 | This has the effect of removing a lot of noise. 274 | 275 | 276 | |AblationCAM | aug smooth | eigen smooth | aug+eigen smooth| 277 | |------------|------------|--------------|--------------------| 278 | ![](./examples/nosmooth.jpg) | ![](./examples/augsmooth.jpg) | ![](./examples/eigensmooth.jpg) | ![](./examples/eigenaug.jpg) | 279 | 280 | ---------- 281 | 282 | # Running the example script: 283 | 284 | Usage: `python cam.py --image-path --method --output-dir ` 285 | 286 | 287 | To use with a specific device, like cpu, cuda, cuda:0, mps or hpu: 288 | `python cam.py --image-path --device cuda --output-dir ` 289 | 290 | ---------- 291 | 292 | You can choose between: 293 | 294 | `GradCAM` , `HiResCAM`, `ScoreCAM`, `GradCAMPlusPlus`, `AblationCAM`, `XGradCAM` , `LayerCAM`, `FullGrad`, `EigenCAM`, `ShapleyCAM`, and `FinerCAM`. 295 | 296 | Some methods like ScoreCAM and AblationCAM require a large number of forward passes, 297 | and have a batched implementation. 298 | 299 | You can control the batch size with 300 | `cam.batch_size = ` 301 | 302 | ---------- 303 | 304 | ## Citation 305 | If you use this for research, please cite. Here is an example BibTeX entry: 306 | 307 | ``` 308 | @misc{jacobgilpytorchcam, 309 | title={PyTorch library for CAM methods}, 310 | author={Jacob Gildenblat and contributors}, 311 | year={2021}, 312 | publisher={GitHub}, 313 | howpublished={\url{https://github.com/jacobgil/pytorch-grad-cam}}, 314 | } 315 | ``` 316 | 317 | ---------- 318 | 319 | # References 320 | https://arxiv.org/abs/1610.02391
321 | `Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization 322 | Ramprasaath R. Selvaraju, Michael Cogswell, Abhishek Das, Ramakrishna Vedantam, Devi Parikh, Dhruv Batra` 323 | 324 | https://arxiv.org/abs/2011.08891
325 | `Use HiResCAM instead of Grad-CAM for faithful explanations of convolutional neural networks 326 | Rachel L. Draelos, Lawrence Carin` 327 | 328 | https://arxiv.org/abs/1710.11063
329 | `Grad-CAM++: Improved Visual Explanations for Deep Convolutional Networks 330 | Aditya Chattopadhyay, Anirban Sarkar, Prantik Howlader, Vineeth N Balasubramanian` 331 | 332 | https://arxiv.org/abs/1910.01279
333 | `Score-CAM: Score-Weighted Visual Explanations for Convolutional Neural Networks 334 | Haofan Wang, Zifan Wang, Mengnan Du, Fan Yang, Zijian Zhang, Sirui Ding, Piotr Mardziel, Xia Hu` 335 | 336 | https://ieeexplore.ieee.org/abstract/document/9093360/
337 | `Ablation-cam: Visual explanations for deep convolutional network via gradient-free localization. 338 | Saurabh Desai and Harish G Ramaswamy. In WACV, pages 972–980, 2020` 339 | 340 | https://arxiv.org/abs/2008.02312
341 | `Axiom-based Grad-CAM: Towards Accurate Visualization and Explanation of CNNs 342 | Ruigang Fu, Qingyong Hu, Xiaohu Dong, Yulan Guo, Yinghui Gao, Biao Li` 343 | 344 | https://arxiv.org/abs/2008.00299
345 | `Eigen-CAM: Class Activation Map using Principal Components 346 | Mohammed Bany Muhammad, Mohammed Yeasin` 347 | 348 | http://mftp.mmcheng.net/Papers/21TIP_LayerCAM.pdf
349 | `LayerCAM: Exploring Hierarchical Class Activation Maps for Localization 350 | Peng-Tao Jiang; Chang-Bin Zhang; Qibin Hou; Ming-Ming Cheng; Yunchao Wei` 351 | 352 | https://arxiv.org/abs/1905.00780
353 | `Full-Gradient Representation for Neural Network Visualization 354 | Suraj Srinivas, Francois Fleuret` 355 | 356 | https://arxiv.org/abs/1806.10206
357 | `Deep Feature Factorization For Concept Discovery 358 | Edo Collins, Radhakrishna Achanta, Sabine Süsstrunk` 359 | 360 | https://arxiv.org/abs/2410.00267
361 | `KPCA-CAM: Visual Explainability of Deep Computer Vision Models using Kernel PCA 362 | Sachin Karmani, Thanushon Sivakaran, Gaurav Prasad, Mehmet Ali, Wenbo Yang, Sheyang Tang` 363 | 364 | https://hal.science/hal-02963298/document
365 | `Features Understanding in 3D CNNs for Actions Recognition in Video 366 | Kazi Ahmed Asif Fuad, Pierre-Etienne Martin, Romain Giot, Romain 367 | Bourqui, Jenny Benois-Pineau, Akka Zemmar` 368 | 369 | https://arxiv.org/abs/2501.06261
370 | `CAMs as Shapley Value-based Explainers 371 | Huaiguang Cai` 372 | 373 | 374 | https://arxiv.org/pdf/2501.11309
375 | `Finer-CAM : Spotting the Difference Reveals Finer Details for Visual Explanation` 376 | `Ziheng Zhang*, Jianyang Gu*, Arpita Chowdhury, Zheda Mai, David Carlyn,Tanya Berger-Wolf, Yu Su, Wei-Lun Chao` 377 | -------------------------------------------------------------------------------- /cam.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from torchvision import models 7 | from pytorch_grad_cam import ( 8 | GradCAM, FEM, HiResCAM, ScoreCAM, GradCAMPlusPlus, 9 | AblationCAM, XGradCAM, EigenCAM, EigenGradCAM, 10 | LayerCAM, FullGrad, GradCAMElementWise, KPCA_CAM, ShapleyCAM, 11 | FinerCAM 12 | ) 13 | from pytorch_grad_cam import GuidedBackpropReLUModel 14 | from pytorch_grad_cam.utils.image import ( 15 | show_cam_on_image, deprocess_image, preprocess_image 16 | ) 17 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget, ClassifierOutputReST 18 | 19 | 20 | def get_args(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--device', type=str, default='cpu', 23 | help='Torch device to use') 24 | parser.add_argument( 25 | '--image-path', 26 | type=str, 27 | default='./examples/both.png', 28 | help='Input image path') 29 | parser.add_argument('--aug-smooth', action='store_true', 30 | help='Apply test time augmentation to smooth the CAM') 31 | parser.add_argument( 32 | '--eigen-smooth', 33 | action='store_true', 34 | help='Reduce noise by taking the first principle component' 35 | 'of cam_weights*activations') 36 | parser.add_argument('--method', type=str, default='gradcam', 37 | choices=[ 38 | 'gradcam', 'fem', 'hirescam', 'gradcam++', 39 | 'scorecam', 'xgradcam', 'ablationcam', 40 | 'eigencam', 'eigengradcam', 'layercam', 41 | 'fullgrad', 'gradcamelementwise', 'kpcacam', 'shapleycam', 42 | 'finercam' 43 | ], 44 | help='CAM method') 45 | 46 | parser.add_argument('--output-dir', type=str, default='output', 47 | help='Output directory to save the images') 48 | args = parser.parse_args() 49 | 50 | if args.device: 51 | print(f'Using device "{args.device}" for acceleration') 52 | else: 53 | print('Using CPU for computation') 54 | 55 | return args 56 | 57 | 58 | if __name__ == '__main__': 59 | """ python cam.py -image-path 60 | Example usage of loading an image and computing: 61 | 1. CAM 62 | 2. Guided Back Propagation 63 | 3. Combining both 64 | """ 65 | 66 | args = get_args() 67 | methods = { 68 | "gradcam": GradCAM, 69 | "hirescam": HiResCAM, 70 | "scorecam": ScoreCAM, 71 | "gradcam++": GradCAMPlusPlus, 72 | "ablationcam": AblationCAM, 73 | "xgradcam": XGradCAM, 74 | "eigencam": EigenCAM, 75 | "eigengradcam": EigenGradCAM, 76 | "layercam": LayerCAM, 77 | "fullgrad": FullGrad, 78 | "fem": FEM, 79 | "gradcamelementwise": GradCAMElementWise, 80 | 'kpcacam': KPCA_CAM, 81 | 'shapleycam': ShapleyCAM, 82 | 'finercam': FinerCAM 83 | } 84 | 85 | if args.device=='hpu': 86 | import habana_frameworks.torch.core as htcore 87 | 88 | model = models.resnet50(pretrained=True).to(torch.device(args.device)).eval() 89 | 90 | # Choose the target layer you want to compute the visualization for. 91 | # Usually this will be the last convolutional layer in the model. 92 | # Some common choices can be: 93 | # Resnet18 and 50: model.layer4 94 | # VGG, densenet161: model.features[-1] 95 | # mnasnet1_0: model.layers[-1] 96 | # You can print the model to help chose the layer 97 | # You can pass a list with several target layers, 98 | # in that case the CAMs will be computed per layer and then aggregated. 99 | # You can also try selecting all layers of a certain type, with e.g: 100 | # from pytorch_grad_cam.utils.find_layers import find_layer_types_recursive 101 | # find_layer_types_recursive(model, [torch.nn.ReLU]) 102 | 103 | target_layers = [model.layer4] 104 | 105 | rgb_img = cv2.imread(args.image_path, 1)[:, :, ::-1] 106 | rgb_img = np.float32(rgb_img) / 255 107 | input_tensor = preprocess_image(rgb_img, 108 | mean=[0.485, 0.456, 0.406], 109 | std=[0.229, 0.224, 0.225]).to(args.device) 110 | 111 | # We have to specify the target we want to generate 112 | # the Class Activation Maps for. 113 | # If targets is None, the highest scoring category (for every member in the batch) will be used. 114 | # You can target specific categories by 115 | # targets = [ClassifierOutputTarget(281)] 116 | # targets = [ClassifierOutputReST(281)] 117 | targets = None 118 | 119 | # Using the with statement ensures the context is freed, and you can 120 | # recreate different CAM objects in a loop. 121 | cam_algorithm = methods[args.method] 122 | with cam_algorithm(model=model, 123 | target_layers=target_layers) as cam: 124 | 125 | # AblationCAM and ScoreCAM have batched implementations. 126 | # You can override the internal batch size for faster computation. 127 | cam.batch_size = 32 128 | grayscale_cam = cam(input_tensor=input_tensor, 129 | targets=targets, 130 | aug_smooth=args.aug_smooth, 131 | eigen_smooth=args.eigen_smooth) 132 | 133 | grayscale_cam = grayscale_cam[0, :] 134 | 135 | cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True) 136 | cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR) 137 | 138 | gb_model = GuidedBackpropReLUModel(model=model, device=args.device) 139 | gb = gb_model(input_tensor, target_category=None) 140 | 141 | cam_mask = cv2.merge([grayscale_cam, grayscale_cam, grayscale_cam]) 142 | cam_gb = deprocess_image(cam_mask * gb) 143 | gb = deprocess_image(gb) 144 | 145 | os.makedirs(args.output_dir, exist_ok=True) 146 | 147 | cam_output_path = os.path.join(args.output_dir, f'{args.method}_cam.jpg') 148 | gb_output_path = os.path.join(args.output_dir, f'{args.method}_gb.jpg') 149 | cam_gb_output_path = os.path.join(args.output_dir, f'{args.method}_cam_gb.jpg') 150 | 151 | cv2.imwrite(cam_output_path, cam_image) 152 | cv2.imwrite(gb_output_path, gb) 153 | cv2.imwrite(cam_gb_output_path, cam_gb) -------------------------------------------------------------------------------- /examples/augsmooth.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/augsmooth.jpg -------------------------------------------------------------------------------- /examples/both.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/both.png -------------------------------------------------------------------------------- /examples/both_detection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/both_detection.png -------------------------------------------------------------------------------- /examples/cam_gb_dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/cam_gb_dog.jpg -------------------------------------------------------------------------------- /examples/cars_segmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/cars_segmentation.png -------------------------------------------------------------------------------- /examples/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/cat.jpg -------------------------------------------------------------------------------- /examples/clip_cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/clip_cat.jpg -------------------------------------------------------------------------------- /examples/clip_dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/clip_dog.jpg -------------------------------------------------------------------------------- /examples/dff1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/dff1.png -------------------------------------------------------------------------------- /examples/dff2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/dff2.png -------------------------------------------------------------------------------- /examples/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/dog.jpg -------------------------------------------------------------------------------- /examples/dog_cat.jfif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/dog_cat.jfif -------------------------------------------------------------------------------- /examples/dogs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/dogs.png -------------------------------------------------------------------------------- /examples/dogs_gradcam++_resnet50.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/dogs_gradcam++_resnet50.jpg -------------------------------------------------------------------------------- /examples/dogs_gradcam++_vgg16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/dogs_gradcam++_vgg16.jpg -------------------------------------------------------------------------------- /examples/dogs_gradcam_resnet50.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/dogs_gradcam_resnet50.jpg -------------------------------------------------------------------------------- /examples/dogs_gradcam_vgg16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/dogs_gradcam_vgg16.jpg -------------------------------------------------------------------------------- /examples/dogs_scorecam_resnet50.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/dogs_scorecam_resnet50.jpg -------------------------------------------------------------------------------- /examples/dogs_scorecam_vgg16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/dogs_scorecam_vgg16.jpg -------------------------------------------------------------------------------- /examples/eigenaug.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/eigenaug.jpg -------------------------------------------------------------------------------- /examples/eigensmooth.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/eigensmooth.jpg -------------------------------------------------------------------------------- /examples/embeddings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/embeddings.png -------------------------------------------------------------------------------- /examples/horses.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/horses.jpg -------------------------------------------------------------------------------- /examples/metrics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/metrics.png -------------------------------------------------------------------------------- /examples/multiorgan_segmentation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/multiorgan_segmentation.gif -------------------------------------------------------------------------------- /examples/nosmooth.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/nosmooth.jpg -------------------------------------------------------------------------------- /examples/resnet50_cat_ablationcam_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/resnet50_cat_ablationcam_cam.jpg -------------------------------------------------------------------------------- /examples/resnet50_cat_gradcam_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/resnet50_cat_gradcam_cam.jpg -------------------------------------------------------------------------------- /examples/resnet50_cat_scorecam_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/resnet50_cat_scorecam_cam.jpg -------------------------------------------------------------------------------- /examples/resnet50_dog_ablationcam_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/resnet50_dog_ablationcam_cam.jpg -------------------------------------------------------------------------------- /examples/resnet50_dog_gradcam_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/resnet50_dog_gradcam_cam.jpg -------------------------------------------------------------------------------- /examples/resnet50_dog_scorecam_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/resnet50_dog_scorecam_cam.jpg -------------------------------------------------------------------------------- /examples/resnet_horses_ablationcam_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/resnet_horses_ablationcam_cam.jpg -------------------------------------------------------------------------------- /examples/resnet_horses_gradcam++_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/resnet_horses_gradcam++_cam.jpg -------------------------------------------------------------------------------- /examples/resnet_horses_gradcam_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/resnet_horses_gradcam_cam.jpg -------------------------------------------------------------------------------- /examples/resnet_horses_horses_eigencam_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/resnet_horses_horses_eigencam_cam.jpg -------------------------------------------------------------------------------- /examples/resnet_horses_scorecam_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/resnet_horses_scorecam_cam.jpg -------------------------------------------------------------------------------- /examples/road.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/road.png -------------------------------------------------------------------------------- /examples/swinT_cat_ablationcam_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/swinT_cat_ablationcam_cam.jpg -------------------------------------------------------------------------------- /examples/swinT_cat_gradcam_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/swinT_cat_gradcam_cam.jpg -------------------------------------------------------------------------------- /examples/swinT_cat_scorecam_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/swinT_cat_scorecam_cam.jpg -------------------------------------------------------------------------------- /examples/swinT_dog_ablationcam_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/swinT_dog_ablationcam_cam.jpg -------------------------------------------------------------------------------- /examples/swinT_dog_gradcam_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/swinT_dog_gradcam_cam.jpg -------------------------------------------------------------------------------- /examples/swinT_dog_scorecam_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/swinT_dog_scorecam_cam.jpg -------------------------------------------------------------------------------- /examples/vgg_horses_ablationcam_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/vgg_horses_ablationcam_cam.jpg -------------------------------------------------------------------------------- /examples/vgg_horses_eigencam_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/vgg_horses_eigencam_cam.jpg -------------------------------------------------------------------------------- /examples/vgg_horses_gradcam++_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/vgg_horses_gradcam++_cam.jpg -------------------------------------------------------------------------------- /examples/vgg_horses_gradcam_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/vgg_horses_gradcam_cam.jpg -------------------------------------------------------------------------------- /examples/vgg_horses_scorecam_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/vgg_horses_scorecam_cam.jpg -------------------------------------------------------------------------------- /examples/vit_cat_ablationcam_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/vit_cat_ablationcam_cam.jpg -------------------------------------------------------------------------------- /examples/vit_cat_gradcam_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/vit_cat_gradcam_cam.jpg -------------------------------------------------------------------------------- /examples/vit_cat_scorecam_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/vit_cat_scorecam_cam.jpg -------------------------------------------------------------------------------- /examples/vit_dog_ablationcam_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/vit_dog_ablationcam_cam.jpg -------------------------------------------------------------------------------- /examples/vit_dog_gradcam_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/vit_dog_gradcam_cam.jpg -------------------------------------------------------------------------------- /examples/vit_dog_scorecam_cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/vit_dog_scorecam_cam.jpg -------------------------------------------------------------------------------- /examples/yolo_eigencam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/examples/yolo_eigencam.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel" 5 | ] 6 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /pytorch_grad_cam/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_grad_cam.grad_cam import GradCAM 2 | from pytorch_grad_cam.finer_cam import FinerCAM 3 | from pytorch_grad_cam.shapley_cam import ShapleyCAM 4 | from pytorch_grad_cam.fem import FEM 5 | from pytorch_grad_cam.hirescam import HiResCAM 6 | from pytorch_grad_cam.grad_cam_elementwise import GradCAMElementWise 7 | from pytorch_grad_cam.ablation_layer import AblationLayer, AblationLayerVit, AblationLayerFasterRCNN 8 | from pytorch_grad_cam.ablation_cam import AblationCAM 9 | from pytorch_grad_cam.xgrad_cam import XGradCAM 10 | from pytorch_grad_cam.grad_cam_plusplus import GradCAMPlusPlus 11 | from pytorch_grad_cam.score_cam import ScoreCAM 12 | from pytorch_grad_cam.layer_cam import LayerCAM 13 | from pytorch_grad_cam.eigen_cam import EigenCAM 14 | from pytorch_grad_cam.eigen_grad_cam import EigenGradCAM 15 | from pytorch_grad_cam.kpca_cam import KPCA_CAM 16 | from pytorch_grad_cam.random_cam import RandomCAM 17 | from pytorch_grad_cam.fullgrad_cam import FullGrad 18 | from pytorch_grad_cam.guided_backprop import GuidedBackpropReLUModel 19 | from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients 20 | from pytorch_grad_cam.feature_factorization.deep_feature_factorization import DeepFeatureFactorization, run_dff_on_image 21 | import pytorch_grad_cam.utils.model_targets 22 | import pytorch_grad_cam.utils.reshape_transforms 23 | import pytorch_grad_cam.metrics.cam_mult_image 24 | import pytorch_grad_cam.metrics.road 25 | -------------------------------------------------------------------------------- /pytorch_grad_cam/ablation_cam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import tqdm 4 | from typing import Callable, List 5 | from pytorch_grad_cam.base_cam import BaseCAM 6 | from pytorch_grad_cam.utils.find_layers import replace_layer_recursive 7 | from pytorch_grad_cam.ablation_layer import AblationLayer 8 | 9 | 10 | """ Implementation of AblationCAM 11 | https://openaccess.thecvf.com/content_WACV_2020/papers/Desai_Ablation-CAM_Visual_Explanations_for_Deep_Convolutional_Network_via_Gradient-free_Localization_WACV_2020_paper.pdf 12 | 13 | Ablate individual activations, and then measure the drop in the target scores. 14 | 15 | In the current implementation, the target layer activations is cached, so it won't be re-computed. 16 | However layers before it, if any, will not be cached. 17 | This means that if the target layer is a large block, for example model.featuers (in vgg), there will 18 | be a large save in run time. 19 | 20 | Since we have to go over many channels and ablate them, and every channel ablation requires a forward pass, 21 | it would be nice if we could avoid doing that for channels that won't contribute anwyay, making it much faster. 22 | The parameter ratio_channels_to_ablate controls how many channels should be ablated, using an experimental method 23 | (to be improved). The default 1.0 value means that all channels will be ablated. 24 | """ 25 | 26 | 27 | class AblationCAM(BaseCAM): 28 | def __init__(self, 29 | model: torch.nn.Module, 30 | target_layers: List[torch.nn.Module], 31 | reshape_transform: Callable = None, 32 | ablation_layer: torch.nn.Module = AblationLayer(), 33 | batch_size: int = 32, 34 | ratio_channels_to_ablate: float = 1.0) -> None: 35 | 36 | super(AblationCAM, self).__init__(model, 37 | target_layers, 38 | reshape_transform, 39 | uses_gradients=False) 40 | self.batch_size = batch_size 41 | self.ablation_layer = ablation_layer 42 | self.ratio_channels_to_ablate = ratio_channels_to_ablate 43 | 44 | def save_activation(self, module, input, output) -> None: 45 | """ Helper function to save the raw activations from the target layer """ 46 | self.activations = output 47 | 48 | def assemble_ablation_scores(self, 49 | new_scores: list, 50 | original_score: float, 51 | ablated_channels: np.ndarray, 52 | number_of_channels: int) -> np.ndarray: 53 | """ Take the value from the channels that were ablated, 54 | and just set the original score for the channels that were skipped """ 55 | 56 | index = 0 57 | result = [] 58 | sorted_indices = np.argsort(ablated_channels) 59 | ablated_channels = ablated_channels[sorted_indices] 60 | new_scores = np.float32(new_scores)[sorted_indices] 61 | 62 | for i in range(number_of_channels): 63 | if index < len(ablated_channels) and ablated_channels[index] == i: 64 | weight = new_scores[index] 65 | index = index + 1 66 | else: 67 | weight = original_score 68 | result.append(weight) 69 | 70 | return result 71 | 72 | def get_cam_weights(self, 73 | input_tensor: torch.Tensor, 74 | target_layer: torch.nn.Module, 75 | targets: List[Callable], 76 | activations: torch.Tensor, 77 | grads: torch.Tensor) -> np.ndarray: 78 | 79 | # Do a forward pass, compute the target scores, and cache the 80 | # activations 81 | handle = target_layer.register_forward_hook(self.save_activation) 82 | with torch.no_grad(): 83 | outputs = self.model(input_tensor) 84 | handle.remove() 85 | original_scores = np.float32( 86 | [target(output).cpu().item() for target, output in zip(targets, outputs)]) 87 | 88 | # Replace the layer with the ablation layer. 89 | # When we finish, we will replace it back, so the 90 | # original model is unchanged. 91 | ablation_layer = self.ablation_layer 92 | replace_layer_recursive(self.model, target_layer, ablation_layer) 93 | 94 | number_of_channels = activations.shape[1] 95 | weights = [] 96 | # This is a "gradient free" method, so we don't need gradients here. 97 | with torch.no_grad(): 98 | # Loop over each of the batch images and ablate activations for it. 99 | for batch_index, (target, tensor) in enumerate( 100 | zip(targets, input_tensor)): 101 | new_scores = [] 102 | batch_tensor = tensor.repeat(self.batch_size, 1, 1, 1) 103 | 104 | # Check which channels should be ablated. Normally this will be all channels, 105 | # But we can also try to speed this up by using a low 106 | # ratio_channels_to_ablate. 107 | channels_to_ablate = ablation_layer.activations_to_be_ablated( 108 | activations[batch_index, :], self.ratio_channels_to_ablate) 109 | number_channels_to_ablate = len(channels_to_ablate) 110 | 111 | for i in tqdm.tqdm( 112 | range( 113 | 0, 114 | number_channels_to_ablate, 115 | self.batch_size)): 116 | if i + self.batch_size > number_channels_to_ablate: 117 | batch_tensor = batch_tensor[:( 118 | number_channels_to_ablate - i)] 119 | 120 | # Change the state of the ablation layer so it ablates the next channels. 121 | # TBD: Move this into the ablation layer forward pass. 122 | ablation_layer.set_next_batch( 123 | input_batch_index = batch_index, 124 | activations = self.activations, 125 | num_channels_to_ablate = batch_tensor.size(0)) 126 | score = [target(o).cpu().item() 127 | for o in self.model(batch_tensor)] 128 | new_scores.extend(score) 129 | ablation_layer.indices = ablation_layer.indices[batch_tensor.size( 130 | 0):] 131 | 132 | new_scores = self.assemble_ablation_scores( 133 | new_scores, 134 | original_scores[batch_index], 135 | channels_to_ablate, 136 | number_of_channels) 137 | weights.extend(new_scores) 138 | 139 | weights = np.float32(weights) 140 | weights = weights.reshape(activations.shape[:2]) 141 | original_scores = original_scores[:, None] 142 | weights = (original_scores - weights) / original_scores 143 | 144 | # Replace the model back to the original state 145 | replace_layer_recursive(self.model, ablation_layer, target_layer) 146 | # Returning the weights from new_scores 147 | return weights 148 | -------------------------------------------------------------------------------- /pytorch_grad_cam/ablation_cam_multilayer.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import tqdm 5 | from pytorch_grad_cam.base_cam import BaseCAM 6 | 7 | 8 | class AblationLayer(torch.nn.Module): 9 | def __init__(self, layer, reshape_transform, indices): 10 | super(AblationLayer, self).__init__() 11 | 12 | self.layer = layer 13 | self.reshape_transform = reshape_transform 14 | # The channels to zero out: 15 | self.indices = indices 16 | 17 | def forward(self, x): 18 | self.__call__(x) 19 | 20 | def __call__(self, x): 21 | output = self.layer(x) 22 | 23 | # Hack to work with ViT, 24 | # Since the activation channels are last and not first like in CNNs 25 | # Probably should remove it? 26 | if self.reshape_transform is not None: 27 | output = output.transpose(1, 2) 28 | 29 | for i in range(output.size(0)): 30 | 31 | # Commonly the minimum activation will be 0, 32 | # And then it makes sense to zero it out. 33 | # However depending on the architecture, 34 | # If the values can be negative, we use very negative values 35 | # to perform the ablation, deviating from the paper. 36 | if torch.min(output) == 0: 37 | output[i, self.indices[i], :] = 0 38 | else: 39 | ABLATION_VALUE = 1e5 40 | output[i, self.indices[i], :] = torch.min( 41 | output) - ABLATION_VALUE 42 | 43 | if self.reshape_transform is not None: 44 | output = output.transpose(2, 1) 45 | 46 | return output 47 | 48 | 49 | def replace_layer_recursive(model, old_layer, new_layer): 50 | for name, layer in model._modules.items(): 51 | if layer == old_layer: 52 | model._modules[name] = new_layer 53 | return True 54 | elif replace_layer_recursive(layer, old_layer, new_layer): 55 | return True 56 | return False 57 | 58 | 59 | class AblationCAM(BaseCAM): 60 | def __init__(self, model, target_layers, 61 | reshape_transform=None): 62 | super(AblationCAM, self).__init__(model, target_layers, 63 | reshape_transform) 64 | 65 | if len(target_layers) > 1: 66 | print( 67 | "Warning. You are usign Ablation CAM with more than 1 layers. " 68 | "This is supported only if all layers have the same output shape") 69 | 70 | def set_ablation_layers(self): 71 | self.ablation_layers = [] 72 | for target_layer in self.target_layers: 73 | ablation_layer = AblationLayer(target_layer, 74 | self.reshape_transform, indices=[]) 75 | self.ablation_layers.append(ablation_layer) 76 | replace_layer_recursive(self.model, target_layer, ablation_layer) 77 | 78 | def unset_ablation_layers(self): 79 | # replace the model back to the original state 80 | for ablation_layer, target_layer in zip( 81 | self.ablation_layers, self.target_layers): 82 | replace_layer_recursive(self.model, ablation_layer, target_layer) 83 | 84 | def set_ablation_layer_batch_indices(self, indices): 85 | for ablation_layer in self.ablation_layers: 86 | ablation_layer.indices = indices 87 | 88 | def trim_ablation_layer_batch_indices(self, keep): 89 | for ablation_layer in self.ablation_layers: 90 | ablation_layer.indices = ablation_layer.indices[:keep] 91 | 92 | def get_cam_weights(self, 93 | input_tensor, 94 | target_category, 95 | activations, 96 | grads): 97 | with torch.no_grad(): 98 | outputs = self.model(input_tensor).cpu().numpy() 99 | original_scores = [] 100 | for i in range(input_tensor.size(0)): 101 | original_scores.append(outputs[i, target_category[i]]) 102 | original_scores = np.float32(original_scores) 103 | 104 | self.set_ablation_layers() 105 | 106 | if hasattr(self, "batch_size"): 107 | BATCH_SIZE = self.batch_size 108 | else: 109 | BATCH_SIZE = 32 110 | 111 | number_of_channels = activations.shape[1] 112 | weights = [] 113 | 114 | with torch.no_grad(): 115 | # Iterate over the input batch 116 | for tensor, category in zip(input_tensor, target_category): 117 | batch_tensor = tensor.repeat(BATCH_SIZE, 1, 1, 1) 118 | for i in tqdm.tqdm(range(0, number_of_channels, BATCH_SIZE)): 119 | self.set_ablation_layer_batch_indices( 120 | list(range(i, i + BATCH_SIZE))) 121 | 122 | if i + BATCH_SIZE > number_of_channels: 123 | keep = number_of_channels - i 124 | batch_tensor = batch_tensor[:keep] 125 | self.trim_ablation_layer_batch_indices(self, keep) 126 | score = self.model(batch_tensor)[:, category].cpu().numpy() 127 | weights.extend(score) 128 | 129 | weights = np.float32(weights) 130 | weights = weights.reshape(activations.shape[:2]) 131 | original_scores = original_scores[:, None] 132 | weights = (original_scores - weights) / original_scores 133 | 134 | # replace the model back to the original state 135 | self.unset_ablation_layers() 136 | return weights 137 | -------------------------------------------------------------------------------- /pytorch_grad_cam/ablation_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | import numpy as np 4 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection 5 | 6 | 7 | class AblationLayer(torch.nn.Module): 8 | def __init__(self): 9 | super(AblationLayer, self).__init__() 10 | 11 | def objectiveness_mask_from_svd(self, activations, threshold=0.01): 12 | """ Experimental method to get a binary mask to compare if the activation is worth ablating. 13 | The idea is to apply the EigenCAM method by doing PCA on the activations. 14 | Then we create a binary mask by comparing to a low threshold. 15 | Areas that are masked out, are probably not interesting anyway. 16 | """ 17 | 18 | projection = get_2d_projection(activations[None, :])[0, :] 19 | projection = np.abs(projection) 20 | projection = projection - projection.min() 21 | projection = projection / projection.max() 22 | projection = projection > threshold 23 | return projection 24 | 25 | def activations_to_be_ablated( 26 | self, 27 | activations, 28 | ratio_channels_to_ablate=1.0): 29 | """ Experimental method to get a binary mask to compare if the activation is worth ablating. 30 | Create a binary CAM mask with objectiveness_mask_from_svd. 31 | Score each Activation channel, by seeing how much of its values are inside the mask. 32 | Then keep the top channels. 33 | 34 | """ 35 | if ratio_channels_to_ablate == 1.0: 36 | self.indices = np.int32(range(activations.shape[0])) 37 | return self.indices 38 | 39 | projection = self.objectiveness_mask_from_svd(activations) 40 | 41 | scores = [] 42 | for channel in activations: 43 | normalized = np.abs(channel) 44 | normalized = normalized - normalized.min() 45 | normalized = normalized / np.max(normalized) 46 | score = (projection * normalized).sum() / normalized.sum() 47 | scores.append(score) 48 | scores = np.float32(scores) 49 | 50 | indices = list(np.argsort(scores)) 51 | high_score_indices = indices[::- 52 | 1][: int(len(indices) * 53 | ratio_channels_to_ablate)] 54 | low_score_indices = indices[: int( 55 | len(indices) * ratio_channels_to_ablate)] 56 | self.indices = np.int32(high_score_indices + low_score_indices) 57 | return self.indices 58 | 59 | def set_next_batch( 60 | self, 61 | input_batch_index, 62 | activations, 63 | num_channels_to_ablate): 64 | """ This creates the next batch of activations from the layer. 65 | Just take corresponding batch member from activations, and repeat it num_channels_to_ablate times. 66 | """ 67 | self.activations = activations[input_batch_index, :, :, :].clone( 68 | ).unsqueeze(0).repeat(num_channels_to_ablate, 1, 1, 1) 69 | 70 | def __call__(self, x): 71 | output = self.activations 72 | for i in range(output.size(0)): 73 | # Commonly the minimum activation will be 0, 74 | # And then it makes sense to zero it out. 75 | # However depending on the architecture, 76 | # If the values can be negative, we use very negative values 77 | # to perform the ablation, deviating from the paper. 78 | if torch.min(output) == 0: 79 | output[i, self.indices[i], :] = 0 80 | else: 81 | ABLATION_VALUE = 1e7 82 | output[i, self.indices[i], :] = torch.min( 83 | output) - ABLATION_VALUE 84 | 85 | return output 86 | 87 | 88 | class AblationLayerVit(AblationLayer): 89 | def __init__(self): 90 | super(AblationLayerVit, self).__init__() 91 | 92 | def __call__(self, x): 93 | output = self.activations 94 | output = output.transpose(1, len(output.shape) - 1) 95 | for i in range(output.size(0)): 96 | 97 | # Commonly the minimum activation will be 0, 98 | # And then it makes sense to zero it out. 99 | # However depending on the architecture, 100 | # If the values can be negative, we use very negative values 101 | # to perform the ablation, deviating from the paper. 102 | if torch.min(output) == 0: 103 | output[i, self.indices[i], :] = 0 104 | else: 105 | ABLATION_VALUE = 1e7 106 | output[i, self.indices[i], :] = torch.min( 107 | output) - ABLATION_VALUE 108 | 109 | output = output.transpose(len(output.shape) - 1, 1) 110 | 111 | return output 112 | 113 | def set_next_batch( 114 | self, 115 | input_batch_index, 116 | activations, 117 | num_channels_to_ablate): 118 | """ This creates the next batch of activations from the layer. 119 | Just take corresponding batch member from activations, and repeat it num_channels_to_ablate times. 120 | """ 121 | repeat_params = [num_channels_to_ablate] + \ 122 | len(activations.shape[:-1]) * [1] 123 | self.activations = activations[input_batch_index, :, :].clone( 124 | ).unsqueeze(0).repeat(*repeat_params) 125 | 126 | 127 | class AblationLayerFasterRCNN(AblationLayer): 128 | def __init__(self): 129 | super(AblationLayerFasterRCNN, self).__init__() 130 | 131 | def set_next_batch( 132 | self, 133 | input_batch_index, 134 | activations, 135 | num_channels_to_ablate): 136 | """ Extract the next batch member from activations, 137 | and repeat it num_channels_to_ablate times. 138 | """ 139 | self.activations = OrderedDict() 140 | for key, value in activations.items(): 141 | fpn_activation = value[input_batch_index, 142 | :, :, :].clone().unsqueeze(0) 143 | self.activations[key] = fpn_activation.repeat( 144 | num_channels_to_ablate, 1, 1, 1) 145 | 146 | def __call__(self, x): 147 | result = self.activations 148 | layers = {0: '0', 1: '1', 2: '2', 3: '3', 4: 'pool'} 149 | num_channels_to_ablate = result['pool'].size(0) 150 | for i in range(num_channels_to_ablate): 151 | pyramid_layer = int(self.indices[i] / 256) 152 | index_in_pyramid_layer = int(self.indices[i] % 256) 153 | result[layers[pyramid_layer]][i, 154 | index_in_pyramid_layer, :, :] = -1000 155 | return result 156 | -------------------------------------------------------------------------------- /pytorch_grad_cam/activations_and_gradients.py: -------------------------------------------------------------------------------- 1 | class ActivationsAndGradients: 2 | """ Class for extracting activations and 3 | registering gradients from targetted intermediate layers """ 4 | 5 | def __init__(self, model, target_layers, reshape_transform, detach=True): 6 | self.model = model 7 | self.gradients = [] 8 | self.activations = [] 9 | self.reshape_transform = reshape_transform 10 | self.detach = detach 11 | self.handles = [] 12 | for target_layer in target_layers: 13 | self.handles.append( 14 | target_layer.register_forward_hook(self.save_activation)) 15 | # Because of https://github.com/pytorch/pytorch/issues/61519, 16 | # we don't use backward hook to record gradients. 17 | self.handles.append( 18 | target_layer.register_forward_hook(self.save_gradient)) 19 | 20 | def save_activation(self, module, input, output): 21 | activation = output 22 | if self.detach: 23 | if self.reshape_transform is not None: 24 | activation = self.reshape_transform(activation) 25 | self.activations.append(activation.cpu().detach()) 26 | else: 27 | self.activations.append(activation) 28 | 29 | def save_gradient(self, module, input, output): 30 | if not hasattr(output, "requires_grad") or not output.requires_grad: 31 | # You can only register hooks on tensor requires grad. 32 | return 33 | 34 | # Gradients are computed in reverse order 35 | def _store_grad(grad): 36 | if self.detach: 37 | if self.reshape_transform is not None: 38 | grad = self.reshape_transform(grad) 39 | self.gradients = [grad.cpu().detach()] + self.gradients 40 | else: 41 | self.gradients = [grad] + self.gradients 42 | 43 | output.register_hook(_store_grad) 44 | 45 | def __call__(self, x): 46 | self.gradients = [] 47 | self.activations = [] 48 | return self.model(x) 49 | 50 | def release(self): 51 | for handle in self.handles: 52 | handle.remove() 53 | -------------------------------------------------------------------------------- /pytorch_grad_cam/base_cam.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Optional, Tuple 2 | 3 | import numpy as np 4 | import torch 5 | import ttach as tta 6 | 7 | from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients 8 | from pytorch_grad_cam.utils.image import scale_cam_image 9 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget 10 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection 11 | 12 | 13 | class BaseCAM: 14 | def __init__( 15 | self, 16 | model: torch.nn.Module, 17 | target_layers: List[torch.nn.Module], 18 | reshape_transform: Callable = None, 19 | compute_input_gradient: bool = False, 20 | uses_gradients: bool = True, 21 | tta_transforms: Optional[tta.Compose] = None, 22 | detach: bool = True, 23 | ) -> None: 24 | self.model = model.eval() 25 | self.target_layers = target_layers 26 | 27 | # Use the same device as the model. 28 | self.device = next(self.model.parameters()).device 29 | if 'hpu' in str(self.device): 30 | try: 31 | import habana_frameworks.torch.core as htcore 32 | except ImportError as error: 33 | error.msg = f"Could not import habana_frameworks.torch.core. {error.msg}." 34 | raise error 35 | self.__htcore = htcore 36 | self.reshape_transform = reshape_transform 37 | self.compute_input_gradient = compute_input_gradient 38 | self.uses_gradients = uses_gradients 39 | if tta_transforms is None: 40 | self.tta_transforms = tta.Compose( 41 | [ 42 | tta.HorizontalFlip(), 43 | tta.Multiply(factors=[0.9, 1, 1.1]), 44 | ] 45 | ) 46 | else: 47 | self.tta_transforms = tta_transforms 48 | 49 | self.detach = detach 50 | self.activations_and_grads = ActivationsAndGradients(self.model, target_layers, reshape_transform, self.detach) 51 | 52 | """ Get a vector of weights for every channel in the target layer. 53 | Methods that return weights channels, 54 | will typically need to only implement this function. """ 55 | 56 | def get_cam_weights( 57 | self, 58 | input_tensor: torch.Tensor, 59 | target_layers: List[torch.nn.Module], 60 | targets: List[torch.nn.Module], 61 | activations: torch.Tensor, 62 | grads: torch.Tensor, 63 | ) -> np.ndarray: 64 | raise Exception("Not Implemented") 65 | 66 | def get_cam_image( 67 | self, 68 | input_tensor: torch.Tensor, 69 | target_layer: torch.nn.Module, 70 | targets: List[torch.nn.Module], 71 | activations: torch.Tensor, 72 | grads: torch.Tensor, 73 | eigen_smooth: bool = False, 74 | ) -> np.ndarray: 75 | weights = self.get_cam_weights(input_tensor, target_layer, targets, activations, grads) 76 | if isinstance(activations, torch.Tensor): 77 | activations = activations.cpu().detach().numpy() 78 | # 2D conv 79 | if len(activations.shape) == 4: 80 | weighted_activations = weights[:, :, None, None] * activations 81 | # 3D conv 82 | elif len(activations.shape) == 5: 83 | weighted_activations = weights[:, :, None, None, None] * activations 84 | else: 85 | raise ValueError(f"Invalid activation shape. Get {len(activations.shape)}.") 86 | 87 | if eigen_smooth: 88 | cam = get_2d_projection(weighted_activations) 89 | else: 90 | cam = weighted_activations.sum(axis=1) 91 | return cam 92 | 93 | def forward( 94 | self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool = False 95 | ) -> np.ndarray: 96 | input_tensor = input_tensor.to(self.device) 97 | 98 | if self.compute_input_gradient: 99 | input_tensor = torch.autograd.Variable(input_tensor, requires_grad=True) 100 | 101 | self.outputs = outputs = self.activations_and_grads(input_tensor) 102 | 103 | if targets is None: 104 | target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1) 105 | targets = [ClassifierOutputTarget(category) for category in target_categories] 106 | 107 | if self.uses_gradients: 108 | self.model.zero_grad() 109 | loss = sum([target(output) for target, output in zip(targets, outputs)]) 110 | if self.detach: 111 | loss.backward(retain_graph=True) 112 | else: 113 | # keep the computational graph, create_graph = True is needed for hvp 114 | torch.autograd.grad(loss, input_tensor, retain_graph = True, create_graph = True) 115 | # When using the following loss.backward() method, a warning is raised: "UserWarning: Using backward() with create_graph=True will create a reference cycle" 116 | # loss.backward(retain_graph=True, create_graph=True) 117 | if 'hpu' in str(self.device): 118 | self.__htcore.mark_step() 119 | 120 | # In most of the saliency attribution papers, the saliency is 121 | # computed with a single target layer. 122 | # Commonly it is the last convolutional layer. 123 | # Here we support passing a list with multiple target layers. 124 | # It will compute the saliency image for every image, 125 | # and then aggregate them (with a default mean aggregation). 126 | # This gives you more flexibility in case you just want to 127 | # use all conv layers for example, all Batchnorm layers, 128 | # or something else. 129 | cam_per_layer = self.compute_cam_per_layer(input_tensor, targets, eigen_smooth) 130 | return self.aggregate_multi_layers(cam_per_layer) 131 | 132 | def get_target_width_height(self, input_tensor: torch.Tensor) -> Tuple[int, int]: 133 | if len(input_tensor.shape) == 4: 134 | width, height = input_tensor.size(-1), input_tensor.size(-2) 135 | return width, height 136 | elif len(input_tensor.shape) == 5: 137 | depth, width, height = input_tensor.size(-1), input_tensor.size(-2), input_tensor.size(-3) 138 | return depth, width, height 139 | else: 140 | raise ValueError("Invalid input_tensor shape. Only 2D or 3D images are supported.") 141 | 142 | def compute_cam_per_layer( 143 | self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool 144 | ) -> np.ndarray: 145 | if self.detach: 146 | activations_list = [a.cpu().data.numpy() for a in self.activations_and_grads.activations] 147 | grads_list = [g.cpu().data.numpy() for g in self.activations_and_grads.gradients] 148 | else: 149 | activations_list = [a for a in self.activations_and_grads.activations] 150 | grads_list = [g for g in self.activations_and_grads.gradients] 151 | target_size = self.get_target_width_height(input_tensor) 152 | 153 | cam_per_target_layer = [] 154 | # Loop over the saliency image from every layer 155 | for i in range(len(self.target_layers)): 156 | target_layer = self.target_layers[i] 157 | layer_activations = None 158 | layer_grads = None 159 | if i < len(activations_list): 160 | layer_activations = activations_list[i] 161 | if i < len(grads_list): 162 | layer_grads = grads_list[i] 163 | 164 | cam = self.get_cam_image(input_tensor, target_layer, targets, layer_activations, layer_grads, eigen_smooth) 165 | cam = np.maximum(cam, 0) 166 | scaled = scale_cam_image(cam, target_size) 167 | cam_per_target_layer.append(scaled[:, None, :]) 168 | 169 | return cam_per_target_layer 170 | 171 | def aggregate_multi_layers(self, cam_per_target_layer: np.ndarray) -> np.ndarray: 172 | cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1) 173 | cam_per_target_layer = np.maximum(cam_per_target_layer, 0) 174 | result = np.mean(cam_per_target_layer, axis=1) 175 | return scale_cam_image(result) 176 | 177 | def forward_augmentation_smoothing( 178 | self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool = False 179 | ) -> np.ndarray: 180 | cams = [] 181 | for transform in self.tta_transforms: 182 | augmented_tensor = transform.augment_image(input_tensor) 183 | cam = self.forward(augmented_tensor, targets, eigen_smooth) 184 | 185 | # The ttach library expects a tensor of size BxCxHxW 186 | cam = cam[:, None, :, :] 187 | cam = torch.from_numpy(cam) 188 | cam = transform.deaugment_mask(cam) 189 | 190 | # Back to numpy float32, HxW 191 | cam = cam.numpy() 192 | cam = cam[:, 0, :, :] 193 | cams.append(cam) 194 | 195 | cam = np.mean(np.float32(cams), axis=0) 196 | return cam 197 | 198 | def __call__( 199 | self, 200 | input_tensor: torch.Tensor, 201 | targets: List[torch.nn.Module] = None, 202 | aug_smooth: bool = False, 203 | eigen_smooth: bool = False, 204 | ) -> np.ndarray: 205 | # Smooth the CAM result with test time augmentation 206 | if aug_smooth is True: 207 | return self.forward_augmentation_smoothing(input_tensor, targets, eigen_smooth) 208 | 209 | return self.forward(input_tensor, targets, eigen_smooth) 210 | 211 | def __del__(self): 212 | self.activations_and_grads.release() 213 | 214 | def __enter__(self): 215 | return self 216 | 217 | def __exit__(self, exc_type, exc_value, exc_tb): 218 | self.activations_and_grads.release() 219 | if isinstance(exc_value, IndexError): 220 | # Handle IndexError here... 221 | print(f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}") 222 | return True 223 | -------------------------------------------------------------------------------- /pytorch_grad_cam/eigen_cam.py: -------------------------------------------------------------------------------- 1 | from pytorch_grad_cam.base_cam import BaseCAM 2 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection 3 | 4 | # https://arxiv.org/abs/2008.00299 5 | 6 | 7 | class EigenCAM(BaseCAM): 8 | def __init__(self, model, target_layers, 9 | reshape_transform=None): 10 | super(EigenCAM, self).__init__(model, 11 | target_layers, 12 | reshape_transform, 13 | uses_gradients=False) 14 | 15 | def get_cam_image(self, 16 | input_tensor, 17 | target_layer, 18 | target_category, 19 | activations, 20 | grads, 21 | eigen_smooth): 22 | return get_2d_projection(activations) 23 | -------------------------------------------------------------------------------- /pytorch_grad_cam/eigen_grad_cam.py: -------------------------------------------------------------------------------- 1 | from pytorch_grad_cam.base_cam import BaseCAM 2 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection 3 | 4 | # Like Eigen CAM: https://arxiv.org/abs/2008.00299 5 | # But multiply the activations x gradients 6 | 7 | 8 | class EigenGradCAM(BaseCAM): 9 | def __init__(self, model, target_layers, 10 | reshape_transform=None): 11 | super(EigenGradCAM, self).__init__(model, target_layers, 12 | reshape_transform) 13 | 14 | def get_cam_image(self, 15 | input_tensor, 16 | target_layer, 17 | target_category, 18 | activations, 19 | grads, 20 | eigen_smooth): 21 | return get_2d_projection(grads * activations) 22 | -------------------------------------------------------------------------------- /pytorch_grad_cam/feature_factorization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/pytorch_grad_cam/feature_factorization/__init__.py -------------------------------------------------------------------------------- /pytorch_grad_cam/feature_factorization/deep_feature_factorization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | from typing import Callable, List, Tuple, Optional 5 | from sklearn.decomposition import NMF 6 | from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients 7 | from pytorch_grad_cam.utils.image import scale_cam_image, create_labels_legend, show_factorization_on_image 8 | 9 | 10 | def dff(activations: np.ndarray, n_components: int = 5): 11 | """ Compute Deep Feature Factorization on a 2d Activations tensor. 12 | 13 | :param activations: A numpy array of shape batch x channels x height x width 14 | :param n_components: The number of components for the non negative matrix factorization 15 | :returns: A tuple of the concepts (a numpy array with shape channels x components), 16 | and the explanation heatmaps (a numpy arary with shape batch x height x width) 17 | """ 18 | 19 | batch_size, channels, h, w = activations.shape 20 | reshaped_activations = activations.transpose((1, 0, 2, 3)) 21 | reshaped_activations[np.isnan(reshaped_activations)] = 0 22 | reshaped_activations = reshaped_activations.reshape( 23 | reshaped_activations.shape[0], -1) 24 | offset = reshaped_activations.min(axis=-1) 25 | reshaped_activations = reshaped_activations - offset[:, None] 26 | 27 | model = NMF(n_components=n_components, init='random', random_state=0) 28 | W = model.fit_transform(reshaped_activations) 29 | H = model.components_ 30 | concepts = W + offset[:, None] 31 | explanations = H.reshape(n_components, batch_size, h, w) 32 | explanations = explanations.transpose((1, 0, 2, 3)) 33 | return concepts, explanations 34 | 35 | 36 | class DeepFeatureFactorization: 37 | """ Deep Feature Factorization: https://arxiv.org/abs/1806.10206 38 | This gets a model andcomputes the 2D activations for a target layer, 39 | and computes Non Negative Matrix Factorization on the activations. 40 | 41 | Optionally it runs a computation on the concept embeddings, 42 | like running a classifier on them. 43 | 44 | The explanation heatmaps are scalled to the range [0, 1] 45 | and to the input tensor width and height. 46 | """ 47 | 48 | def __init__(self, 49 | model: torch.nn.Module, 50 | target_layer: torch.nn.Module, 51 | reshape_transform: Callable = None, 52 | computation_on_concepts=None 53 | ): 54 | self.model = model 55 | self.computation_on_concepts = computation_on_concepts 56 | self.activations_and_grads = ActivationsAndGradients( 57 | self.model, [target_layer], reshape_transform) 58 | 59 | def __call__(self, 60 | input_tensor: torch.Tensor, 61 | n_components: int = 16): 62 | batch_size, channels, h, w = input_tensor.size() 63 | _ = self.activations_and_grads(input_tensor) 64 | 65 | with torch.no_grad(): 66 | activations = self.activations_and_grads.activations[0].cpu( 67 | ).numpy() 68 | 69 | concepts, explanations = dff(activations, n_components=n_components) 70 | 71 | processed_explanations = [] 72 | 73 | for batch in explanations: 74 | processed_explanations.append(scale_cam_image(batch, (w, h))) 75 | 76 | if self.computation_on_concepts: 77 | with torch.no_grad(): 78 | concept_tensors = torch.from_numpy( 79 | np.float32(concepts).transpose((1, 0))) 80 | concept_outputs = self.computation_on_concepts( 81 | concept_tensors).cpu().numpy() 82 | return concepts, processed_explanations, concept_outputs 83 | else: 84 | return concepts, processed_explanations 85 | 86 | def __del__(self): 87 | self.activations_and_grads.release() 88 | 89 | def __exit__(self, exc_type, exc_value, exc_tb): 90 | self.activations_and_grads.release() 91 | if isinstance(exc_value, IndexError): 92 | # Handle IndexError here... 93 | print( 94 | f"An exception occurred in ActivationSummary with block: {exc_type}. Message: {exc_value}") 95 | return True 96 | 97 | 98 | def run_dff_on_image(model: torch.nn.Module, 99 | target_layer: torch.nn.Module, 100 | classifier: torch.nn.Module, 101 | img_pil: Image, 102 | img_tensor: torch.Tensor, 103 | reshape_transform=Optional[Callable], 104 | n_components: int = 5, 105 | top_k: int = 2) -> np.ndarray: 106 | """ Helper function to create a Deep Feature Factorization visualization for a single image. 107 | TBD: Run this on a batch with several images. 108 | """ 109 | rgb_img_float = np.array(img_pil) / 255 110 | dff = DeepFeatureFactorization(model=model, 111 | reshape_transform=reshape_transform, 112 | target_layer=target_layer, 113 | computation_on_concepts=classifier) 114 | 115 | concepts, batch_explanations, concept_outputs = dff( 116 | img_tensor[None, :], n_components) 117 | 118 | concept_outputs = torch.softmax( 119 | torch.from_numpy(concept_outputs), 120 | axis=-1).numpy() 121 | concept_label_strings = create_labels_legend(concept_outputs, 122 | labels=model.config.id2label, 123 | top_k=top_k) 124 | visualization = show_factorization_on_image( 125 | rgb_img_float, 126 | batch_explanations[0], 127 | image_weight=0.3, 128 | concept_labels=concept_label_strings) 129 | 130 | result = np.hstack((np.array(img_pil), visualization)) 131 | return result 132 | -------------------------------------------------------------------------------- /pytorch_grad_cam/fem.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pytorch_grad_cam.base_cam import BaseCAM 3 | 4 | """Feature Explanation Method. 5 | Fuad, K. A. A., Martin, P. E., Giot, R., Bourqui, R., Benois-Pineau, J., & Zemmari, A. (2020, November). Features understanding in 3D CNNS for actions recognition in video. In 2020 Tenth International Conference on Image Processing Theory, Tools and Applications (IPTA) (pp. 1-6). IEEE. 6 | https://hal.science/hal-02963298/document 7 | """ 8 | 9 | class FEM(BaseCAM): 10 | def __init__(self, model, target_layers, 11 | reshape_transform=None, k=2): 12 | super(FEM, self).__init__(model, 13 | target_layers, 14 | reshape_transform, 15 | uses_gradients=False) 16 | self.k = k 17 | 18 | def get_cam_image(self, 19 | input_tensor, 20 | target_layer, 21 | target_category, 22 | activations, 23 | grads, 24 | eigen_smooth): 25 | 26 | 27 | # 2D image 28 | if len(activations.shape) == 4: 29 | axis = (2, 3) 30 | # 3D image 31 | elif len(activations.shape) == 5: 32 | axis = (2, 3, 4) 33 | else: 34 | raise ValueError("Invalid activations shape." 35 | "Shape of activations should be 4 (2D image) or 5 (3D image).") 36 | means = np.mean(activations, axis=axis) 37 | stds = np.std(activations, axis=axis) 38 | # k sigma rule: 39 | # Add extra dimensions to match activations shape 40 | th = means + self.k * stds 41 | weights_shape = list(means.shape) + [1] * len(axis) 42 | th = th.reshape(weights_shape) 43 | binary_mask = activations > th 44 | weights = binary_mask.mean(axis=axis) 45 | return (weights.reshape(weights_shape) * activations).sum(axis=1) -------------------------------------------------------------------------------- /pytorch_grad_cam/finer_cam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from typing import List, Callable 4 | from pytorch_grad_cam.base_cam import BaseCAM 5 | from pytorch_grad_cam import GradCAM 6 | from pytorch_grad_cam.utils.model_targets import FinerWeightedTarget 7 | 8 | class FinerCAM: 9 | def __init__(self, model: torch.nn.Module, target_layers: List[torch.nn.Module], reshape_transform: Callable = None, base_method=GradCAM): 10 | self.base_cam = base_method(model, target_layers, reshape_transform) 11 | self.compute_input_gradient = self.base_cam.compute_input_gradient 12 | self.uses_gradients = self.base_cam.uses_gradients 13 | 14 | def __call__(self, *args, **kwargs): 15 | return self.forward(*args, **kwargs) 16 | 17 | def forward(self, input_tensor: torch.Tensor, targets: List[torch.nn.Module] = None, eigen_smooth: bool = False, 18 | alpha: float = 1, comparison_categories: List[int] = [1, 2, 3], target_idx: int = None 19 | ) -> np.ndarray: 20 | input_tensor = input_tensor.to(self.base_cam.device) 21 | 22 | if self.compute_input_gradient: 23 | input_tensor = torch.autograd.Variable(input_tensor, requires_grad=True) 24 | 25 | outputs = self.base_cam.activations_and_grads(input_tensor) 26 | 27 | if targets is None: 28 | output_data = outputs.detach().cpu().numpy() 29 | target_logits = np.max(output_data, axis=-1) if target_idx is None else output_data[:, target_idx] 30 | # Sort class indices for each sample based on the absolute difference 31 | # between the class scores and the target logit, in ascending order. 32 | # The most similar classes (smallest difference) appear first. 33 | sorted_indices = np.argsort(np.abs(output_data - target_logits[:, None]), axis=-1) 34 | targets = [FinerWeightedTarget(int(sorted_indices[i, 0]), 35 | [int(sorted_indices[i, idx]) for idx in comparison_categories], 36 | alpha) 37 | for i in range(output_data.shape[0])] 38 | 39 | if self.uses_gradients: 40 | self.base_cam.model.zero_grad() 41 | loss = sum([target(output) for target, output in zip(targets, outputs)]) 42 | if self.base_cam.detach: 43 | loss.backward(retain_graph=True) 44 | else: 45 | # keep the computational graph, create_graph = True is needed for hvp 46 | torch.autograd.grad(loss, input_tensor, retain_graph = True, create_graph = True) 47 | # When using the following loss.backward() method, a warning is raised: "UserWarning: Using backward() with create_graph=True will create a reference cycle" 48 | # loss.backward(retain_graph=True, create_graph=True) 49 | if 'hpu' in str(self.base_cam.device): 50 | self.base_cam.__htcore.mark_step() 51 | 52 | cam_per_layer = self.base_cam.compute_cam_per_layer(input_tensor, targets, eigen_smooth) 53 | return self.base_cam.aggregate_multi_layers(cam_per_layer) -------------------------------------------------------------------------------- /pytorch_grad_cam/fullgrad_cam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from pytorch_grad_cam.base_cam import BaseCAM 4 | from pytorch_grad_cam.utils.find_layers import find_layer_predicate_recursive 5 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection 6 | from pytorch_grad_cam.utils.image import scale_accross_batch_and_channels, scale_cam_image 7 | 8 | # https://arxiv.org/abs/1905.00780 9 | 10 | 11 | class FullGrad(BaseCAM): 12 | def __init__(self, model, target_layers, 13 | reshape_transform=None): 14 | if len(target_layers) > 0: 15 | print( 16 | "Warning: target_layers is ignored in FullGrad. All bias layers will be used instead") 17 | 18 | def layer_with_2D_bias(layer): 19 | bias_target_layers = [torch.nn.Conv2d, torch.nn.BatchNorm2d] 20 | if type(layer) in bias_target_layers and layer.bias is not None: 21 | return True 22 | return False 23 | target_layers = find_layer_predicate_recursive( 24 | model, layer_with_2D_bias) 25 | super( 26 | FullGrad, 27 | self).__init__( 28 | model, 29 | target_layers, 30 | reshape_transform, 31 | compute_input_gradient=True) 32 | self.bias_data = [self.get_bias_data( 33 | layer).cpu().numpy() for layer in target_layers] 34 | 35 | def get_bias_data(self, layer): 36 | # Borrowed from official paper impl: 37 | # https://github.com/idiap/fullgrad-saliency/blob/master/saliency/tensor_extractor.py#L47 38 | if isinstance(layer, torch.nn.BatchNorm2d): 39 | bias = - (layer.running_mean * layer.weight 40 | / torch.sqrt(layer.running_var + layer.eps)) + layer.bias 41 | return bias.data 42 | else: 43 | return layer.bias.data 44 | 45 | def compute_cam_per_layer( 46 | self, 47 | input_tensor, 48 | target_category, 49 | eigen_smooth): 50 | input_grad = input_tensor.grad.data.cpu().numpy() 51 | grads_list = [g.cpu().data.numpy() for g in 52 | self.activations_and_grads.gradients] 53 | cam_per_target_layer = [] 54 | target_size = self.get_target_width_height(input_tensor) 55 | 56 | gradient_multiplied_input = input_grad * input_tensor.data.cpu().numpy() 57 | gradient_multiplied_input = np.abs(gradient_multiplied_input) 58 | gradient_multiplied_input = scale_accross_batch_and_channels( 59 | gradient_multiplied_input, 60 | target_size) 61 | cam_per_target_layer.append(gradient_multiplied_input) 62 | 63 | # Loop over the saliency image from every layer 64 | assert(len(self.bias_data) == len(grads_list)) 65 | for bias, grads in zip(self.bias_data, grads_list): 66 | bias = bias[None, :, None, None] 67 | # In the paper they take the absolute value, 68 | # but possibily taking only the positive gradients will work 69 | # better. 70 | bias_grad = np.abs(bias * grads) 71 | result = scale_accross_batch_and_channels( 72 | bias_grad, target_size) 73 | result = np.sum(result, axis=1) 74 | cam_per_target_layer.append(result[:, None, :]) 75 | cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1) 76 | if eigen_smooth: 77 | # Resize to a smaller image, since this method typically has a very large number of channels, 78 | # and then consumes a lot of memory 79 | cam_per_target_layer = scale_accross_batch_and_channels( 80 | cam_per_target_layer, (target_size[0] // 8, target_size[1] // 8)) 81 | cam_per_target_layer = get_2d_projection(cam_per_target_layer) 82 | cam_per_target_layer = cam_per_target_layer[:, None, :, :] 83 | cam_per_target_layer = scale_accross_batch_and_channels( 84 | cam_per_target_layer, 85 | target_size) 86 | else: 87 | cam_per_target_layer = np.sum( 88 | cam_per_target_layer, axis=1)[:, None, :] 89 | 90 | return cam_per_target_layer 91 | 92 | def aggregate_multi_layers(self, cam_per_target_layer): 93 | result = np.sum(cam_per_target_layer, axis=1) 94 | return scale_cam_image(result) 95 | -------------------------------------------------------------------------------- /pytorch_grad_cam/grad_cam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pytorch_grad_cam.base_cam import BaseCAM 4 | 5 | 6 | class GradCAM(BaseCAM): 7 | def __init__(self, model, target_layers, 8 | reshape_transform=None): 9 | super( 10 | GradCAM, 11 | self).__init__( 12 | model, 13 | target_layers, 14 | reshape_transform) 15 | 16 | def get_cam_weights(self, 17 | input_tensor, 18 | target_layer, 19 | target_category, 20 | activations, 21 | grads): 22 | # 2D image 23 | if len(grads.shape) == 4: 24 | return np.mean(grads, axis=(2, 3)) 25 | 26 | # 3D image 27 | elif len(grads.shape) == 5: 28 | return np.mean(grads, axis=(2, 3, 4)) 29 | 30 | else: 31 | raise ValueError("Invalid grads shape." 32 | "Shape of grads should be 4 (2D image) or 5 (3D image).") -------------------------------------------------------------------------------- /pytorch_grad_cam/grad_cam_elementwise.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pytorch_grad_cam.base_cam import BaseCAM 3 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection 4 | 5 | 6 | class GradCAMElementWise(BaseCAM): 7 | def __init__(self, model, target_layers, 8 | reshape_transform=None): 9 | super( 10 | GradCAMElementWise, 11 | self).__init__( 12 | model, 13 | target_layers, 14 | reshape_transform) 15 | 16 | def get_cam_image(self, 17 | input_tensor, 18 | target_layer, 19 | target_category, 20 | activations, 21 | grads, 22 | eigen_smooth): 23 | elementwise_activations = np.maximum(grads * activations, 0) 24 | 25 | if eigen_smooth: 26 | cam = get_2d_projection(elementwise_activations) 27 | else: 28 | cam = elementwise_activations.sum(axis=1) 29 | return cam 30 | -------------------------------------------------------------------------------- /pytorch_grad_cam/grad_cam_plusplus.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pytorch_grad_cam.base_cam import BaseCAM 3 | 4 | # https://arxiv.org/abs/1710.11063 5 | 6 | 7 | class GradCAMPlusPlus(BaseCAM): 8 | def __init__(self, model, target_layers, 9 | reshape_transform=None): 10 | super(GradCAMPlusPlus, self).__init__(model, target_layers, 11 | reshape_transform) 12 | 13 | def get_cam_weights(self, 14 | input_tensor, 15 | target_layers, 16 | target_category, 17 | activations, 18 | grads): 19 | grads_power_2 = grads**2 20 | grads_power_3 = grads_power_2 * grads 21 | # Equation 19 in https://arxiv.org/abs/1710.11063 22 | sum_activations = np.sum(activations, axis=(2, 3)) 23 | eps = 0.000001 24 | aij = grads_power_2 / (2 * grads_power_2 + 25 | sum_activations[:, :, None, None] * grads_power_3 + eps) 26 | # Now bring back the ReLU from eq.7 in the paper, 27 | # And zero out aijs where the activations are 0 28 | aij = np.where(grads != 0, aij, 0) 29 | 30 | weights = np.maximum(grads, 0) * aij 31 | weights = np.sum(weights, axis=(2, 3)) 32 | return weights 33 | -------------------------------------------------------------------------------- /pytorch_grad_cam/guided_backprop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Function 4 | from pytorch_grad_cam.utils.find_layers import replace_all_layer_type_recursive 5 | 6 | 7 | class GuidedBackpropReLU(Function): 8 | @staticmethod 9 | def forward(self, input_img): 10 | positive_mask = (input_img > 0).type_as(input_img) 11 | output = torch.addcmul( 12 | torch.zeros( 13 | input_img.size()).type_as(input_img), 14 | input_img, 15 | positive_mask) 16 | self.save_for_backward(input_img, output) 17 | return output 18 | 19 | @staticmethod 20 | def backward(self, grad_output): 21 | input_img, output = self.saved_tensors 22 | grad_input = None 23 | 24 | positive_mask_1 = (input_img > 0).type_as(grad_output) 25 | positive_mask_2 = (grad_output > 0).type_as(grad_output) 26 | grad_input = torch.addcmul( 27 | torch.zeros( 28 | input_img.size()).type_as(input_img), 29 | torch.addcmul( 30 | torch.zeros( 31 | input_img.size()).type_as(input_img), 32 | grad_output, 33 | positive_mask_1), 34 | positive_mask_2) 35 | return grad_input 36 | 37 | 38 | class GuidedBackpropReLUasModule(torch.nn.Module): 39 | def __init__(self): 40 | super(GuidedBackpropReLUasModule, self).__init__() 41 | 42 | def forward(self, input_img): 43 | return GuidedBackpropReLU.apply(input_img) 44 | 45 | 46 | class GuidedBackpropReLUModel: 47 | def __init__(self, model, device): 48 | self.model = model 49 | self.model.eval() 50 | self.device = next(self.model.parameters()).device 51 | 52 | def forward(self, input_img): 53 | return self.model(input_img) 54 | 55 | def recursive_replace_relu_with_guidedrelu(self, module_top): 56 | 57 | for idx, module in module_top._modules.items(): 58 | self.recursive_replace_relu_with_guidedrelu(module) 59 | if module.__class__.__name__ == 'ReLU': 60 | module_top._modules[idx] = GuidedBackpropReLU.apply 61 | print("b") 62 | 63 | def recursive_replace_guidedrelu_with_relu(self, module_top): 64 | try: 65 | for idx, module in module_top._modules.items(): 66 | self.recursive_replace_guidedrelu_with_relu(module) 67 | if module == GuidedBackpropReLU.apply: 68 | module_top._modules[idx] = torch.nn.ReLU() 69 | except BaseException: 70 | pass 71 | 72 | def __call__(self, input_img, target_category=None): 73 | replace_all_layer_type_recursive(self.model, 74 | torch.nn.ReLU, 75 | GuidedBackpropReLUasModule()) 76 | 77 | 78 | input_img = input_img.to(self.device) 79 | 80 | input_img = input_img.requires_grad_(True) 81 | 82 | output = self.forward(input_img) 83 | 84 | if target_category is None: 85 | target_category = np.argmax(output.cpu().data.numpy()) 86 | 87 | loss = output[0, target_category] 88 | loss.backward(retain_graph=True) 89 | 90 | output = input_img.grad.cpu().data.numpy() 91 | output = output[0, :, :, :] 92 | output = output.transpose((1, 2, 0)) 93 | 94 | replace_all_layer_type_recursive(self.model, 95 | GuidedBackpropReLUasModule, 96 | torch.nn.ReLU()) 97 | 98 | return output 99 | -------------------------------------------------------------------------------- /pytorch_grad_cam/hirescam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pytorch_grad_cam.base_cam import BaseCAM 3 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection 4 | 5 | 6 | class HiResCAM(BaseCAM): 7 | def __init__(self, model, target_layers, 8 | reshape_transform=None): 9 | super( 10 | HiResCAM, 11 | self).__init__( 12 | model, 13 | target_layers, 14 | reshape_transform) 15 | 16 | def get_cam_image(self, 17 | input_tensor, 18 | target_layer, 19 | target_category, 20 | activations, 21 | grads, 22 | eigen_smooth): 23 | elementwise_activations = grads * activations 24 | 25 | if eigen_smooth: 26 | print( 27 | "Warning: HiResCAM's faithfulness guarantees do not hold if smoothing is applied") 28 | cam = get_2d_projection(elementwise_activations) 29 | else: 30 | cam = elementwise_activations.sum(axis=1) 31 | return cam 32 | -------------------------------------------------------------------------------- /pytorch_grad_cam/kpca_cam.py: -------------------------------------------------------------------------------- 1 | from pytorch_grad_cam.base_cam import BaseCAM 2 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection_kernel 3 | 4 | class KPCA_CAM(BaseCAM): 5 | def __init__(self, model, target_layers, 6 | reshape_transform=None, kernel='sigmoid', gamma=None): 7 | super(KPCA_CAM, self).__init__(model, 8 | target_layers, 9 | reshape_transform, 10 | uses_gradients=False) 11 | self.kernel=kernel 12 | self.gamma=gamma 13 | 14 | def get_cam_image(self, 15 | input_tensor, 16 | target_layer, 17 | target_category, 18 | activations, 19 | grads, 20 | eigen_smooth): 21 | return get_2d_projection_kernel(activations, self.kernel, self.gamma) 22 | -------------------------------------------------------------------------------- /pytorch_grad_cam/layer_cam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pytorch_grad_cam.base_cam import BaseCAM 3 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection 4 | 5 | # https://ieeexplore.ieee.org/document/9462463 6 | 7 | 8 | class LayerCAM(BaseCAM): 9 | def __init__( 10 | self, 11 | model, 12 | target_layers, 13 | reshape_transform=None): 14 | super( 15 | LayerCAM, 16 | self).__init__( 17 | model, 18 | target_layers, 19 | reshape_transform) 20 | 21 | def get_cam_image(self, 22 | input_tensor, 23 | target_layer, 24 | target_category, 25 | activations, 26 | grads, 27 | eigen_smooth): 28 | spatial_weighted_activations = np.maximum(grads, 0) * activations 29 | 30 | if eigen_smooth: 31 | cam = get_2d_projection(spatial_weighted_activations) 32 | else: 33 | cam = spatial_weighted_activations.sum(axis=1) 34 | return cam 35 | -------------------------------------------------------------------------------- /pytorch_grad_cam/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/pytorch_grad_cam/metrics/__init__.py -------------------------------------------------------------------------------- /pytorch_grad_cam/metrics/cam_mult_image.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import List, Callable 4 | from pytorch_grad_cam.metrics.perturbation_confidence import PerturbationConfidenceMetric 5 | 6 | 7 | def multiply_tensor_with_cam(input_tensor: torch.Tensor, 8 | cam: torch.Tensor): 9 | """ Multiply an input tensor (after normalization) 10 | with a pixel attribution map 11 | """ 12 | return input_tensor * cam 13 | 14 | 15 | class CamMultImageConfidenceChange(PerturbationConfidenceMetric): 16 | def __init__(self): 17 | super(CamMultImageConfidenceChange, 18 | self).__init__(multiply_tensor_with_cam) 19 | 20 | 21 | class DropInConfidence(CamMultImageConfidenceChange): 22 | def __init__(self): 23 | super(DropInConfidence, self).__init__() 24 | 25 | def __call__(self, *args, **kwargs): 26 | scores = super(DropInConfidence, self).__call__(*args, **kwargs) 27 | scores = -scores 28 | return np.maximum(scores, 0) 29 | 30 | 31 | class IncreaseInConfidence(CamMultImageConfidenceChange): 32 | def __init__(self): 33 | super(IncreaseInConfidence, self).__init__() 34 | 35 | def __call__(self, *args, **kwargs): 36 | scores = super(IncreaseInConfidence, self).__call__(*args, **kwargs) 37 | return np.float32(scores > 0) 38 | -------------------------------------------------------------------------------- /pytorch_grad_cam/metrics/perturbation_confidence.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import List, Callable 4 | 5 | import numpy as np 6 | import cv2 7 | 8 | 9 | class PerturbationConfidenceMetric: 10 | def __init__(self, perturbation): 11 | self.perturbation = perturbation 12 | 13 | def __call__(self, input_tensor: torch.Tensor, 14 | cams: np.ndarray, 15 | targets: List[Callable], 16 | model: torch.nn.Module, 17 | return_visualization=False, 18 | return_diff=True): 19 | 20 | if return_diff: 21 | with torch.no_grad(): 22 | outputs = model(input_tensor) 23 | scores = [target(output).cpu().numpy() 24 | for target, output in zip(targets, outputs)] 25 | scores = np.float32(scores) 26 | 27 | batch_size = input_tensor.size(0) 28 | perturbated_tensors = [] 29 | for i in range(batch_size): 30 | cam = cams[i] 31 | tensor = self.perturbation(input_tensor[i, ...].cpu(), 32 | torch.from_numpy(cam)) 33 | tensor = tensor.to(input_tensor.device) 34 | perturbated_tensors.append(tensor.unsqueeze(0)) 35 | perturbated_tensors = torch.cat(perturbated_tensors) 36 | 37 | with torch.no_grad(): 38 | outputs_after_imputation = model(perturbated_tensors) 39 | scores_after_imputation = [ 40 | target(output).cpu().numpy() for target, output in zip( 41 | targets, outputs_after_imputation)] 42 | scores_after_imputation = np.float32(scores_after_imputation) 43 | 44 | if return_diff: 45 | result = scores_after_imputation - scores 46 | else: 47 | result = scores_after_imputation 48 | 49 | if return_visualization: 50 | return result, perturbated_tensors 51 | else: 52 | return result 53 | 54 | 55 | class RemoveMostRelevantFirst: 56 | def __init__(self, percentile, imputer): 57 | self.percentile = percentile 58 | self.imputer = imputer 59 | 60 | def __call__(self, input_tensor, mask): 61 | imputer = self.imputer 62 | if self.percentile != 'auto': 63 | threshold = np.percentile(mask.cpu().numpy(), self.percentile) 64 | binary_mask = np.float32(mask < threshold) 65 | else: 66 | _, binary_mask = cv2.threshold( 67 | np.uint8(mask * 255), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) 68 | 69 | binary_mask = torch.from_numpy(binary_mask) 70 | binary_mask = binary_mask.to(mask.device) 71 | return imputer(input_tensor, binary_mask) 72 | 73 | 74 | class RemoveLeastRelevantFirst(RemoveMostRelevantFirst): 75 | def __init__(self, percentile, imputer): 76 | super(RemoveLeastRelevantFirst, self).__init__(percentile, imputer) 77 | 78 | def __call__(self, input_tensor, mask): 79 | return super(RemoveLeastRelevantFirst, self).__call__( 80 | input_tensor, 1 - mask) 81 | 82 | 83 | class AveragerAcrossThresholds: 84 | def __init__( 85 | self, 86 | imputer, 87 | percentiles=[ 88 | 10, 89 | 20, 90 | 30, 91 | 40, 92 | 50, 93 | 60, 94 | 70, 95 | 80, 96 | 90]): 97 | self.imputer = imputer 98 | self.percentiles = percentiles 99 | 100 | def __call__(self, 101 | input_tensor: torch.Tensor, 102 | cams: np.ndarray, 103 | targets: List[Callable], 104 | model: torch.nn.Module): 105 | scores = [] 106 | for percentile in self.percentiles: 107 | imputer = self.imputer(percentile) 108 | scores.append(imputer(input_tensor, cams, targets, model)) 109 | return np.mean(np.float32(scores), axis=0) 110 | -------------------------------------------------------------------------------- /pytorch_grad_cam/metrics/road.py: -------------------------------------------------------------------------------- 1 | # A Consistent and Efficient Evaluation Strategy for Attribution Methods 2 | # https://arxiv.org/abs/2202.00449 3 | # Taken from https://raw.githubusercontent.com/tleemann/road_evaluation/main/imputations.py 4 | # MIT License 5 | 6 | # Copyright (c) 2022 Tobias Leemann 7 | 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | 27 | # Implementations of our imputation models. 28 | import torch 29 | import numpy as np 30 | from scipy.sparse import lil_matrix, csc_matrix 31 | from scipy.sparse.linalg import spsolve 32 | from typing import List, Callable 33 | from pytorch_grad_cam.metrics.perturbation_confidence import PerturbationConfidenceMetric, \ 34 | AveragerAcrossThresholds, \ 35 | RemoveMostRelevantFirst, \ 36 | RemoveLeastRelevantFirst 37 | 38 | # The weights of the surrounding pixels 39 | neighbors_weights = [((1, 1), 1 / 12), 40 | ((0, 1), 1 / 6), 41 | ((-1, 1), 1 / 12), 42 | ((1, -1), 1 / 12), 43 | ((0, -1), 1 / 6), 44 | ((-1, -1), 1 / 12), 45 | ((1, 0), 1 / 6), 46 | ((-1, 0), 1 / 6)] 47 | 48 | 49 | class NoisyLinearImputer: 50 | def __init__(self, 51 | noise: float = 0.01, 52 | weighting: List[float] = neighbors_weights): 53 | """ 54 | Noisy linear imputation. 55 | noise: magnitude of noise to add (absolute, set to 0 for no noise) 56 | weighting: Weights of the neighboring pixels in the computation. 57 | List of tuples of (offset, weight) 58 | """ 59 | self.noise = noise 60 | self.weighting = neighbors_weights 61 | 62 | @staticmethod 63 | def add_offset_to_indices(indices, offset, mask_shape): 64 | """ Add the corresponding offset to the indices. 65 | Return new indices plus a valid bit-vector. """ 66 | cord1 = indices % mask_shape[1] 67 | cord0 = indices // mask_shape[1] 68 | cord0 += offset[0] 69 | cord1 += offset[1] 70 | valid = ((cord0 < 0) | (cord1 < 0) | 71 | (cord0 >= mask_shape[0]) | 72 | (cord1 >= mask_shape[1])) 73 | return ~valid, indices + offset[0] * mask_shape[1] + offset[1] 74 | 75 | @staticmethod 76 | def setup_sparse_system(mask, img, neighbors_weights): 77 | """ Vectorized version to set up the equation system. 78 | mask: (H, W)-tensor of missing pixels. 79 | Image: (H, W, C)-tensor of all values. 80 | Return (N,N)-System matrix, (N,C)-Right hand side for each of the C channels. 81 | """ 82 | maskflt = mask.flatten() 83 | imgflat = img.reshape((img.shape[0], -1)) 84 | # Indices that are imputed in the flattened mask: 85 | indices = np.argwhere(maskflt == 0).flatten() 86 | coords_to_vidx = np.zeros(len(maskflt), dtype=int) 87 | coords_to_vidx[indices] = np.arange(len(indices)) 88 | numEquations = len(indices) 89 | # System matrix: 90 | A = lil_matrix((numEquations, numEquations)) 91 | b = np.zeros((numEquations, img.shape[0])) 92 | # Sum of weights assigned: 93 | sum_neighbors = np.ones(numEquations) 94 | for n in neighbors_weights: 95 | offset, weight = n[0], n[1] 96 | # Take out outliers 97 | valid, new_coords = NoisyLinearImputer.add_offset_to_indices( 98 | indices, offset, mask.shape) 99 | valid_coords = new_coords[valid] 100 | valid_ids = np.argwhere(valid == 1).flatten() 101 | # Add values to the right hand-side 102 | has_values_coords = valid_coords[maskflt[valid_coords] > 0.5] 103 | has_values_ids = valid_ids[maskflt[valid_coords] > 0.5] 104 | b[has_values_ids, :] -= weight * imgflat[:, has_values_coords].T 105 | # Add weights to the system (left hand side) 106 | # Find coordinates in the system. 107 | has_no_values = valid_coords[maskflt[valid_coords] < 0.5] 108 | variable_ids = coords_to_vidx[has_no_values] 109 | has_no_values_ids = valid_ids[maskflt[valid_coords] < 0.5] 110 | A[has_no_values_ids, variable_ids] = weight 111 | # Reduce weight for invalid 112 | sum_neighbors[np.argwhere(valid == 0).flatten()] = \ 113 | sum_neighbors[np.argwhere(valid == 0).flatten()] - weight 114 | 115 | A[np.arange(numEquations), np.arange(numEquations)] = -sum_neighbors 116 | return A, b 117 | 118 | def __call__(self, img: torch.Tensor, mask: torch.Tensor): 119 | """ Our linear inputation scheme. """ 120 | """ 121 | This is the function to do the linear infilling 122 | img: original image (C,H,W)-tensor; 123 | mask: mask; (H,W)-tensor 124 | 125 | """ 126 | imgflt = img.reshape(img.shape[0], -1) 127 | maskflt = mask.reshape(-1) 128 | # Indices that need to be imputed. 129 | indices_linear = np.argwhere(maskflt == 0).flatten() 130 | # Set up sparse equation system, solve system. 131 | A, b = NoisyLinearImputer.setup_sparse_system( 132 | mask.numpy(), img.numpy(), neighbors_weights) 133 | res = torch.tensor(spsolve(csc_matrix(A), b), dtype=torch.float) 134 | 135 | # Fill the values with the solution of the system. 136 | img_infill = imgflt.clone() 137 | img_infill[:, indices_linear] = res.t() + self.noise * \ 138 | torch.randn_like(res.t()) 139 | 140 | return img_infill.reshape_as(img) 141 | 142 | 143 | class ROADMostRelevantFirst(PerturbationConfidenceMetric): 144 | def __init__(self, percentile=80): 145 | super(ROADMostRelevantFirst, self).__init__( 146 | RemoveMostRelevantFirst(percentile, NoisyLinearImputer())) 147 | 148 | 149 | class ROADLeastRelevantFirst(PerturbationConfidenceMetric): 150 | def __init__(self, percentile=20): 151 | super(ROADLeastRelevantFirst, self).__init__( 152 | RemoveLeastRelevantFirst(percentile, NoisyLinearImputer())) 153 | 154 | 155 | class ROADMostRelevantFirstAverage(AveragerAcrossThresholds): 156 | def __init__(self, percentiles=[10, 20, 30, 40, 50, 60, 70, 80, 90]): 157 | super(ROADMostRelevantFirstAverage, self).__init__( 158 | ROADMostRelevantFirst, percentiles) 159 | 160 | 161 | class ROADLeastRelevantFirstAverage(AveragerAcrossThresholds): 162 | def __init__(self, percentiles=[10, 20, 30, 40, 50, 60, 70, 80, 90]): 163 | super(ROADLeastRelevantFirstAverage, self).__init__( 164 | ROADLeastRelevantFirst, percentiles) 165 | 166 | 167 | class ROADCombined: 168 | def __init__(self, percentiles=[10, 20, 30, 40, 50, 60, 70, 80, 90]): 169 | self.percentiles = percentiles 170 | self.morf_averager = ROADMostRelevantFirstAverage(percentiles) 171 | self.lerf_averager = ROADLeastRelevantFirstAverage(percentiles) 172 | 173 | def __call__(self, 174 | input_tensor: torch.Tensor, 175 | cams: np.ndarray, 176 | targets: List[Callable], 177 | model: torch.nn.Module): 178 | 179 | scores_lerf = self.lerf_averager(input_tensor, cams, targets, model) 180 | scores_morf = self.morf_averager(input_tensor, cams, targets, model) 181 | return (scores_lerf - scores_morf) / 2 182 | -------------------------------------------------------------------------------- /pytorch_grad_cam/random_cam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pytorch_grad_cam.base_cam import BaseCAM 3 | 4 | 5 | class RandomCAM(BaseCAM): 6 | def __init__(self, model, target_layers, 7 | reshape_transform=None): 8 | super( 9 | RandomCAM, 10 | self).__init__( 11 | model, 12 | target_layers, 13 | reshape_transform) 14 | 15 | def get_cam_weights(self, 16 | input_tensor, 17 | target_layer, 18 | target_category, 19 | activations, 20 | grads): 21 | return np.random.uniform(-1, 1, size=(grads.shape[0], grads.shape[1])) 22 | -------------------------------------------------------------------------------- /pytorch_grad_cam/score_cam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | from pytorch_grad_cam.base_cam import BaseCAM 4 | 5 | 6 | class ScoreCAM(BaseCAM): 7 | def __init__( 8 | self, 9 | model, 10 | target_layers, 11 | reshape_transform=None): 12 | super(ScoreCAM, self).__init__(model, 13 | target_layers, 14 | reshape_transform=reshape_transform, 15 | uses_gradients=False) 16 | 17 | def get_cam_weights(self, 18 | input_tensor, 19 | target_layer, 20 | targets, 21 | activations, 22 | grads): 23 | with torch.no_grad(): 24 | upsample = torch.nn.UpsamplingBilinear2d( 25 | size=input_tensor.shape[-2:]) 26 | activation_tensor = torch.from_numpy(activations) 27 | activation_tensor = activation_tensor.to(self.device) 28 | 29 | upsampled = upsample(activation_tensor) 30 | 31 | maxs = upsampled.view(upsampled.size(0), 32 | upsampled.size(1), -1).max(dim=-1)[0] 33 | mins = upsampled.view(upsampled.size(0), 34 | upsampled.size(1), -1).min(dim=-1)[0] 35 | 36 | maxs, mins = maxs[:, :, None, None], mins[:, :, None, None] 37 | upsampled = (upsampled - mins) / (maxs - mins + 1e-8) 38 | 39 | input_tensors = input_tensor[:, None, 40 | :, :] * upsampled[:, :, None, :, :] 41 | 42 | if hasattr(self, "batch_size"): 43 | BATCH_SIZE = self.batch_size 44 | else: 45 | BATCH_SIZE = 16 46 | 47 | scores = [] 48 | for target, tensor in zip(targets, input_tensors): 49 | for i in tqdm.tqdm(range(0, tensor.size(0), BATCH_SIZE)): 50 | batch = tensor[i: i + BATCH_SIZE, :] 51 | outputs = [target(o).cpu().item() 52 | for o in self.model(batch)] 53 | scores.extend(outputs) 54 | scores = torch.Tensor(scores) 55 | scores = scores.view(activations.shape[0], activations.shape[1]) 56 | weights = torch.nn.Softmax(dim=-1)(scores).numpy() 57 | return weights 58 | -------------------------------------------------------------------------------- /pytorch_grad_cam/shapley_cam.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Optional, Tuple 2 | from pytorch_grad_cam.base_cam import BaseCAM 3 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget 4 | import torch 5 | import numpy as np 6 | 7 | """ 8 | Weights the activation maps using the gradient and Hessian-Vector product. 9 | This method (https://arxiv.org/abs/2501.06261) reinterpret CAM methods (include GradCAM, HiResCAM and the original CAM) from a Shapley value perspective. 10 | """ 11 | class ShapleyCAM(BaseCAM): 12 | def __init__(self, model, target_layers, 13 | reshape_transform=None): 14 | super( 15 | ShapleyCAM, 16 | self).__init__( 17 | model = model, 18 | target_layers = target_layers, 19 | reshape_transform = reshape_transform, 20 | compute_input_gradient = True, 21 | uses_gradients = True, 22 | detach = False) 23 | 24 | def get_cam_weights(self, 25 | input_tensor, 26 | target_layer, 27 | target_category, 28 | activations, 29 | grads): 30 | 31 | hvp = torch.autograd.grad( 32 | outputs=grads, 33 | inputs=activations, 34 | grad_outputs=activations, 35 | retain_graph=False, 36 | allow_unused=True 37 | )[0] 38 | # print(torch.max(hvp[0]).item()) # check if hvp is not all zeros 39 | if hvp is None: 40 | hvp = torch.tensor(0).to(self.device) 41 | else: 42 | if self.activations_and_grads.reshape_transform is not None: 43 | hvp = self.activations_and_grads.reshape_transform(hvp) 44 | 45 | if self.activations_and_grads.reshape_transform is not None: 46 | activations = self.activations_and_grads.reshape_transform(activations) 47 | grads = self.activations_and_grads.reshape_transform(grads) 48 | 49 | weight = (grads - 0.5 * hvp).detach().cpu().numpy() 50 | # 2D image 51 | if len(activations.shape) == 4: 52 | weight = np.mean(weight, axis=(2, 3)) 53 | return weight 54 | # 3D image 55 | elif len(activations.shape) == 5: 56 | weight = np.mean(weight, axis=(2, 3, 4)) 57 | return weight 58 | else: 59 | raise ValueError("Invalid grads shape." 60 | "Shape of grads should be 4 (2D image) or 5 (3D image).") 61 | -------------------------------------------------------------------------------- /pytorch_grad_cam/sobel_cam.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | 4 | def sobel_cam(img): 5 | gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) 6 | grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) 7 | grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) 8 | abs_grad_x = cv2.convertScaleAbs(grad_x) 9 | abs_grad_y = cv2.convertScaleAbs(grad_y) 10 | grad = cv2.addWeighted(abs_grad_x, 0.5, abs_grad_y, 0.5, 0) 11 | return grad 12 | -------------------------------------------------------------------------------- /pytorch_grad_cam/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_grad_cam.utils.image import deprocess_image 2 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection 3 | from pytorch_grad_cam.utils import model_targets 4 | from pytorch_grad_cam.utils import reshape_transforms 5 | -------------------------------------------------------------------------------- /pytorch_grad_cam/utils/find_layers.py: -------------------------------------------------------------------------------- 1 | def replace_layer_recursive(model, old_layer, new_layer): 2 | for name, layer in model._modules.items(): 3 | if layer == old_layer: 4 | model._modules[name] = new_layer 5 | return True 6 | elif replace_layer_recursive(layer, old_layer, new_layer): 7 | return True 8 | return False 9 | 10 | 11 | def replace_all_layer_type_recursive(model, old_layer_type, new_layer): 12 | for name, layer in model._modules.items(): 13 | if isinstance(layer, old_layer_type): 14 | model._modules[name] = new_layer 15 | replace_all_layer_type_recursive(layer, old_layer_type, new_layer) 16 | 17 | 18 | def find_layer_types_recursive(model, layer_types): 19 | def predicate(layer): 20 | return type(layer) in layer_types 21 | return find_layer_predicate_recursive(model, predicate) 22 | 23 | 24 | def find_layer_predicate_recursive(model, predicate): 25 | result = [] 26 | for name, layer in model._modules.items(): 27 | if predicate(layer): 28 | result.append(layer) 29 | result.extend(find_layer_predicate_recursive(layer, predicate)) 30 | return result 31 | -------------------------------------------------------------------------------- /pytorch_grad_cam/utils/image.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Dict, List 3 | 4 | import cv2 5 | import matplotlib 6 | import numpy as np 7 | import torch 8 | from matplotlib import pyplot as plt 9 | from matplotlib.lines import Line2D 10 | from scipy.ndimage import zoom 11 | from torchvision.transforms import Compose, Normalize, ToTensor 12 | 13 | 14 | def preprocess_image( 15 | img: np.ndarray, mean=[ 16 | 0.5, 0.5, 0.5], std=[ 17 | 0.5, 0.5, 0.5]) -> torch.Tensor: 18 | preprocessing = Compose([ 19 | ToTensor(), 20 | Normalize(mean=mean, std=std) 21 | ]) 22 | return preprocessing(img.copy()).unsqueeze(0) 23 | 24 | 25 | def deprocess_image(img): 26 | """ see https://github.com/jacobgil/keras-grad-cam/blob/master/grad-cam.py#L65 """ 27 | img = img - np.mean(img) 28 | img = img / (np.std(img) + 1e-5) 29 | img = img * 0.1 30 | img = img + 0.5 31 | img = np.clip(img, 0, 1) 32 | return np.uint8(img * 255) 33 | 34 | 35 | def show_cam_on_image(img: np.ndarray, 36 | mask: np.ndarray, 37 | use_rgb: bool = False, 38 | colormap: int = cv2.COLORMAP_JET, 39 | image_weight: float = 0.5) -> np.ndarray: 40 | """ This function overlays the cam mask on the image as an heatmap. 41 | By default the heatmap is in BGR format. 42 | 43 | :param img: The base image in RGB or BGR format. 44 | :param mask: The cam mask. 45 | :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. 46 | :param colormap: The OpenCV colormap to be used. 47 | :param image_weight: The final result is image_weight * img + (1-image_weight) * mask. 48 | :returns: The default image with the cam overlay. 49 | """ 50 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) 51 | if use_rgb: 52 | heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) 53 | heatmap = np.float32(heatmap) / 255 54 | 55 | if np.max(img) > 1: 56 | raise Exception( 57 | "The input image should np.float32 in the range [0, 1]") 58 | 59 | if image_weight < 0 or image_weight > 1: 60 | raise Exception( 61 | f"image_weight should be in the range [0, 1].\ 62 | Got: {image_weight}") 63 | 64 | cam = (1 - image_weight) * heatmap + image_weight * img 65 | cam = cam / np.max(cam) 66 | return np.uint8(255 * cam) 67 | 68 | 69 | def create_labels_legend(concept_scores: np.ndarray, 70 | labels: Dict[int, str], 71 | top_k=2): 72 | concept_categories = np.argsort(concept_scores, axis=1)[:, ::-1][:, :top_k] 73 | concept_labels_topk = [] 74 | for concept_index in range(concept_categories.shape[0]): 75 | categories = concept_categories[concept_index, :] 76 | concept_labels = [] 77 | for category in categories: 78 | score = concept_scores[concept_index, category] 79 | label = f"{','.join(labels[category].split(',')[:3])}:{score:.2f}" 80 | concept_labels.append(label) 81 | concept_labels_topk.append("\n".join(concept_labels)) 82 | return concept_labels_topk 83 | 84 | 85 | def show_factorization_on_image(img: np.ndarray, 86 | explanations: np.ndarray, 87 | colors: List[np.ndarray] = None, 88 | image_weight: float = 0.5, 89 | concept_labels: List = None) -> np.ndarray: 90 | """ Color code the different component heatmaps on top of the image. 91 | Every component color code will be magnified according to the heatmap itensity 92 | (by modifying the V channel in the HSV color space), 93 | and optionally create a lagend that shows the labels. 94 | 95 | Since different factorization component heatmaps can overlap in principle, 96 | we need a strategy to decide how to deal with the overlaps. 97 | This keeps the component that has a higher value in it's heatmap. 98 | 99 | :param img: The base image RGB format. 100 | :param explanations: A tensor of shape num_componetns x height x width, with the component visualizations. 101 | :param colors: List of R, G, B colors to be used for the components. 102 | If None, will use the gist_rainbow cmap as a default. 103 | :param image_weight: The final result is image_weight * img + (1-image_weight) * visualization. 104 | :concept_labels: A list of strings for every component. If this is paseed, a legend that shows 105 | the labels and their colors will be added to the image. 106 | :returns: The visualized image. 107 | """ 108 | n_components = explanations.shape[0] 109 | if colors is None: 110 | # taken from https://github.com/edocollins/DFF/blob/master/utils.py 111 | _cmap = plt.cm.get_cmap('gist_rainbow') 112 | colors = [ 113 | np.array( 114 | _cmap(i)) for i in np.arange( 115 | 0, 116 | 1, 117 | 1.0 / 118 | n_components)] 119 | concept_per_pixel = explanations.argmax(axis=0) 120 | masks = [] 121 | for i in range(n_components): 122 | mask = np.zeros(shape=(img.shape[0], img.shape[1], 3)) 123 | mask[:, :, :] = colors[i][:3] 124 | explanation = explanations[i] 125 | explanation[concept_per_pixel != i] = 0 126 | mask = np.uint8(mask * 255) 127 | mask = cv2.cvtColor(mask, cv2.COLOR_RGB2HSV) 128 | mask[:, :, 2] = np.uint8(255 * explanation) 129 | mask = cv2.cvtColor(mask, cv2.COLOR_HSV2RGB) 130 | mask = np.float32(mask) / 255 131 | masks.append(mask) 132 | 133 | mask = np.sum(np.float32(masks), axis=0) 134 | result = img * image_weight + mask * (1 - image_weight) 135 | result = np.uint8(result * 255) 136 | 137 | if concept_labels is not None: 138 | px = 1 / plt.rcParams['figure.dpi'] # pixel in inches 139 | fig = plt.figure(figsize=(result.shape[1] * px, result.shape[0] * px)) 140 | plt.rcParams['legend.fontsize'] = int( 141 | 14 * result.shape[0] / 256 / max(1, n_components / 6)) 142 | lw = 5 * result.shape[0] / 256 143 | lines = [Line2D([0], [0], color=colors[i], lw=lw) 144 | for i in range(n_components)] 145 | plt.legend(lines, 146 | concept_labels, 147 | mode="expand", 148 | fancybox=True, 149 | shadow=True) 150 | 151 | plt.tight_layout(pad=0, w_pad=0, h_pad=0) 152 | plt.axis('off') 153 | fig.canvas.draw() 154 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 155 | plt.close(fig=fig) 156 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 157 | data = cv2.resize(data, (result.shape[1], result.shape[0])) 158 | result = np.hstack((result, data)) 159 | return result 160 | 161 | 162 | def scale_cam_image(cam, target_size=None): 163 | result = [] 164 | for img in cam: 165 | img = img - np.min(img) 166 | img = img / (1e-7 + np.max(img)) 167 | if target_size is not None: 168 | if len(img.shape) > 2: 169 | img = zoom(np.float32(img), [ 170 | (t_s / i_s) for i_s, t_s in zip(img.shape, target_size[::-1])]) 171 | else: 172 | img = cv2.resize(np.float32(img), target_size) 173 | 174 | result.append(img) 175 | result = np.float32(result) 176 | 177 | return result 178 | 179 | 180 | def scale_accross_batch_and_channels(tensor, target_size): 181 | batch_size, channel_size = tensor.shape[:2] 182 | reshaped_tensor = tensor.reshape( 183 | batch_size * channel_size, *tensor.shape[2:]) 184 | result = scale_cam_image(reshaped_tensor, target_size) 185 | result = result.reshape( 186 | batch_size, 187 | channel_size, 188 | target_size[1], 189 | target_size[0]) 190 | return result 191 | -------------------------------------------------------------------------------- /pytorch_grad_cam/utils/model_targets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | 5 | 6 | class ClassifierOutputTarget: 7 | def __init__(self, category): 8 | self.category = category 9 | 10 | def __call__(self, model_output): 11 | if len(model_output.shape) == 1: 12 | return model_output[self.category] 13 | return model_output[:, self.category] 14 | 15 | 16 | class ClassifierOutputSoftmaxTarget: 17 | def __init__(self, category): 18 | self.category = category 19 | 20 | def __call__(self, model_output): 21 | if len(model_output.shape) == 1: 22 | return torch.softmax(model_output, dim=-1)[self.category] 23 | return torch.softmax(model_output, dim=-1)[:, self.category] 24 | 25 | 26 | class ClassifierOutputReST: 27 | """ 28 | Using both pre-softmax and post-softmax, proposed in https://arxiv.org/abs/2501.06261 29 | """ 30 | def __init__(self, category): 31 | self.category = category 32 | def __call__(self, model_output): 33 | if len(model_output.shape) == 1: 34 | target = torch.tensor([self.category], device=model_output.device) 35 | model_output = model_output.unsqueeze(0) 36 | return model_output[0][self.category] - torch.nn.functional.cross_entropy(model_output, target) 37 | else: 38 | target = torch.tensor([self.category] * model_output.shape[0], device=model_output.device) 39 | return model_output[:,self.category] - torch.nn.functional.cross_entropy(model_output, target) 40 | 41 | 42 | class BinaryClassifierOutputTarget: 43 | def __init__(self, category): 44 | self.category = category 45 | 46 | def __call__(self, model_output): 47 | if self.category == 1: 48 | sign = 1 49 | else: 50 | sign = -1 51 | return model_output * sign 52 | 53 | 54 | class SoftmaxOutputTarget: 55 | def __init__(self): 56 | pass 57 | 58 | def __call__(self, model_output): 59 | return torch.softmax(model_output, dim=-1) 60 | 61 | 62 | class RawScoresOutputTarget: 63 | def __init__(self): 64 | pass 65 | 66 | def __call__(self, model_output): 67 | return model_output 68 | 69 | 70 | class SemanticSegmentationTarget: 71 | """ Gets a binary spatial mask and a category, 72 | And return the sum of the category scores, 73 | of the pixels in the mask. """ 74 | 75 | def __init__(self, category, mask): 76 | self.category = category 77 | self.mask = torch.from_numpy(mask) 78 | 79 | def __call__(self, model_output): 80 | return (model_output[self.category, :, :] * self.mask.to(model_output.device)).sum() 81 | 82 | 83 | class FasterRCNNBoxScoreTarget: 84 | """ For every original detected bounding box specified in "bounding boxes", 85 | assign a score on how the current bounding boxes match it, 86 | 1. In IOU 87 | 2. In the classification score. 88 | If there is not a large enough overlap, or the category changed, 89 | assign a score of 0. 90 | 91 | The total score is the sum of all the box scores. 92 | """ 93 | 94 | def __init__(self, labels, bounding_boxes, iou_threshold=0.5): 95 | self.labels = labels 96 | self.bounding_boxes = bounding_boxes 97 | self.iou_threshold = iou_threshold 98 | 99 | def __call__(self, model_outputs): 100 | output = torch.Tensor([0]) 101 | if torch.cuda.is_available(): 102 | output = output.cuda() 103 | elif torch.backends.mps.is_available(): 104 | output = output.to("mps") 105 | 106 | if len(model_outputs["boxes"]) == 0: 107 | return output 108 | 109 | for box, label in zip(self.bounding_boxes, self.labels): 110 | box = torch.Tensor(box[None, :]) 111 | if torch.cuda.is_available(): 112 | box = box.cuda() 113 | elif torch.backends.mps.is_available(): 114 | box = box.to("mps") 115 | 116 | ious = torchvision.ops.box_iou(box, model_outputs["boxes"]) 117 | index = ious.argmax() 118 | if ious[0, index] > self.iou_threshold and model_outputs["labels"][index] == label: 119 | score = ious[0, index] + model_outputs["scores"][index] 120 | output = output + score 121 | return output 122 | 123 | class FinerWeightedTarget: 124 | """ 125 | Computes a weighted difference between a primary category and a set of comparison categories. 126 | 127 | This target calculates the difference between the score for the main category and each of the comparison categories. 128 | It obtains a weight for each comparison category from the softmax probabilities of the model output and computes a 129 | weighted difference scaled by a comparison strength factor alpha. 130 | """ 131 | def __init__(self, main_category, comparison_categories, alpha): 132 | self.main_category = main_category 133 | self.comparison_categories = comparison_categories 134 | self.alpha = alpha 135 | 136 | def __call__(self, model_output): 137 | select = lambda idx: model_output[idx] if model_output.ndim == 1 else model_output[..., idx] 138 | 139 | wn = select(self.main_category) 140 | 141 | prob = torch.softmax(model_output, dim=-1) 142 | 143 | weights = [prob[idx] if model_output.ndim == 1 else prob[..., idx] for idx in self.comparison_categories] 144 | numerator = sum(w * (wn - self.alpha * select(idx)) for w, idx in zip(weights, self.comparison_categories)) 145 | denominator = sum(weights) 146 | 147 | return numerator / (denominator + 1e-9) -------------------------------------------------------------------------------- /pytorch_grad_cam/utils/reshape_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def fasterrcnn_reshape_transform(x): 5 | target_size = x['pool'].size()[-2:] 6 | activations = [] 7 | for key, value in x.items(): 8 | activations.append( 9 | torch.nn.functional.interpolate( 10 | torch.abs(value), 11 | target_size, 12 | mode='bilinear')) 13 | activations = torch.cat(activations, axis=1) 14 | return activations 15 | 16 | 17 | def swinT_reshape_transform(tensor, height=7, width=7): 18 | result = tensor.reshape(tensor.size(0), 19 | height, width, tensor.size(2)) 20 | 21 | # Bring the channels to the first dimension, 22 | # like in CNNs. 23 | result = result.transpose(2, 3).transpose(1, 2) 24 | return result 25 | 26 | 27 | def vit_reshape_transform(tensor, height=14, width=14): 28 | result = tensor[:, 1:, :].reshape(tensor.size(0), 29 | height, width, tensor.size(2)) 30 | 31 | # Bring the channels to the first dimension, 32 | # like in CNNs. 33 | result = result.transpose(2, 3).transpose(1, 2) 34 | return result 35 | -------------------------------------------------------------------------------- /pytorch_grad_cam/utils/svd_on_activations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.decomposition import KernelPCA 3 | 4 | 5 | def get_2d_projection(activation_batch): 6 | # TBD: use pytorch batch svd implementation 7 | activation_batch[np.isnan(activation_batch)] = 0 8 | projections = [] 9 | for activations in activation_batch: 10 | reshaped_activations = (activations).reshape( 11 | activations.shape[0], -1).transpose() 12 | # Centering before the SVD seems to be important here, 13 | # Otherwise the image returned is negative 14 | reshaped_activations = reshaped_activations - \ 15 | reshaped_activations.mean(axis=0) 16 | U, S, VT = np.linalg.svd(reshaped_activations, full_matrices=True) 17 | projection = reshaped_activations @ VT[0, :] 18 | projection = projection.reshape(activations.shape[1:]) 19 | projections.append(projection) 20 | return np.float32(projections) 21 | 22 | 23 | 24 | def get_2d_projection_kernel(activation_batch, kernel='sigmoid', gamma=None): 25 | activation_batch[np.isnan(activation_batch)] = 0 26 | projections = [] 27 | for activations in activation_batch: 28 | reshaped_activations = activations.reshape(activations.shape[0], -1).transpose() 29 | reshaped_activations = reshaped_activations - reshaped_activations.mean(axis=0) 30 | # Apply Kernel PCA 31 | kpca = KernelPCA(n_components=1, kernel=kernel, gamma=gamma) 32 | projection = kpca.fit_transform(reshaped_activations) 33 | projection = projection.reshape(activations.shape[1:]) 34 | projections.append(projection) 35 | return np.float32(projections) 36 | -------------------------------------------------------------------------------- /pytorch_grad_cam/xgrad_cam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pytorch_grad_cam.base_cam import BaseCAM 3 | 4 | 5 | class XGradCAM(BaseCAM): 6 | def __init__( 7 | self, 8 | model, 9 | target_layers, 10 | reshape_transform=None): 11 | super( 12 | XGradCAM, 13 | self).__init__( 14 | model, 15 | target_layers, 16 | reshape_transform) 17 | 18 | def get_cam_weights(self, 19 | input_tensor, 20 | target_layer, 21 | target_category, 22 | activations, 23 | grads): 24 | sum_activations = np.sum(activations, axis=(2, 3)) 25 | eps = 1e-7 26 | weights = grads * activations / \ 27 | (sum_activations[:, :, None, None] + eps) 28 | weights = weights.sum(axis=(2, 3)) 29 | return weights 30 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | Pillow 3 | torch>=1.7.1 4 | torchvision>=0.8.2 5 | ttach 6 | tqdm 7 | opencv-python 8 | matplotlib 9 | scikit-learn -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = grad-cam 3 | version = 1.1.0 4 | author = Jacob Gildenblat 5 | author_email = jacob.gildenblat@gmail.com 6 | description = Many Class Activation Map methods implemented in Pytorch. Including Grad-CAM, Grad-CAM++, Score-CAM, Ablation-CAM and XGrad-CAM 7 | long_description = file: README.md 8 | long_description_content_type = text/markdown 9 | url = https://github.com/jacobgil/pytorch-grad-cam 10 | project_urls = 11 | Bug Tracker = https://github.com/jacobgil/pytorch-grad-cam/issues 12 | classifiers = 13 | Programming Language :: Python :: 3 14 | License :: OSI Approved :: MIT License 15 | Operating System :: OS Independent 16 | 17 | [options] 18 | packages = find: 19 | python_requires = >=3.6 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open('README.md', mode='r', encoding='utf-8') as fh: 4 | long_description = fh.read() 5 | 6 | with open("requirements.txt", "r") as f: 7 | requirements = f.readlines() 8 | 9 | setuptools.setup( 10 | name='grad-cam', 11 | version='1.5.5', 12 | author='Jacob Gildenblat', 13 | author_email='jacob.gildenblat@gmail.com', 14 | description='Many Class Activation Map methods implemented in Pytorch for classification, segmentation, object detection and more', 15 | long_description=long_description, 16 | long_description_content_type='text/markdown', 17 | url='https://github.com/jacobgil/pytorch-grad-cam', 18 | project_urls={ 19 | 'Bug Tracker': 'https://github.com/jacobgil/pytorch-grad-cam/issues', 20 | }, 21 | classifiers=[ 22 | 'Programming Language :: Python :: 3', 23 | 'License :: OSI Approved :: MIT License', 24 | 'Operating System :: OS Independent', 25 | ], 26 | packages=setuptools.find_packages( 27 | exclude=["*tutorials*"]), 28 | python_requires='>=3.8', 29 | install_requires=requirements) 30 | -------------------------------------------------------------------------------- /tests/test_context_release.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torchvision 3 | import torch 4 | import cv2 5 | import psutil 6 | from pytorch_grad_cam import GradCAM, \ 7 | ScoreCAM, \ 8 | GradCAMPlusPlus, \ 9 | AblationCAM, \ 10 | XGradCAM, \ 11 | EigenCAM, \ 12 | EigenGradCAM, \ 13 | LayerCAM, \ 14 | FullGrad 15 | from pytorch_grad_cam.utils.image import show_cam_on_image, \ 16 | preprocess_image 17 | 18 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget 19 | 20 | torch.manual_seed(0) 21 | 22 | 23 | @pytest.fixture 24 | def numpy_image(): 25 | return cv2.imread("examples/both.png") 26 | 27 | 28 | @pytest.mark.parametrize("cnn_model,target_layer_names", [ 29 | (torchvision.models.resnet18, ["layer4[-1]"]) 30 | ]) 31 | @pytest.mark.parametrize("batch_size,width,height", [ 32 | (1, 224, 224) 33 | ]) 34 | @pytest.mark.parametrize("target_category", [ 35 | 100 36 | ]) 37 | @pytest.mark.parametrize("aug_smooth", [ 38 | False 39 | ]) 40 | @pytest.mark.parametrize("eigen_smooth", [ 41 | False 42 | ]) 43 | @pytest.mark.parametrize("cam_method", 44 | [GradCAM]) 45 | def test_memory_usage_in_loop(numpy_image, batch_size, width, height, 46 | cnn_model, target_layer_names, cam_method, 47 | target_category, aug_smooth, eigen_smooth): 48 | img = cv2.resize(numpy_image, (width, height)) 49 | input_tensor = preprocess_image(img) 50 | input_tensor = input_tensor.repeat(batch_size, 1, 1, 1) 51 | model = cnn_model(pretrained=True) 52 | target_layers = [] 53 | for layer in target_layer_names: 54 | target_layers.append(eval(f"model.{layer}")) 55 | targets = [ClassifierOutputTarget(target_category) 56 | for _ in range(batch_size)] 57 | initial_memory = 0 58 | for i in range(100): 59 | with cam_method(model=model, 60 | target_layers=target_layers) as cam: 61 | grayscale_cam = cam(input_tensor=input_tensor, 62 | targets=targets, 63 | aug_smooth=aug_smooth, 64 | eigen_smooth=eigen_smooth) 65 | 66 | if i == 0: 67 | initial_memory = psutil.virtual_memory()[2] 68 | assert(psutil.virtual_memory()[2] <= initial_memory * 1.5) 69 | -------------------------------------------------------------------------------- /tests/test_one_channel.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torchvision 3 | import torch 4 | import cv2 5 | import numpy as np 6 | from pytorch_grad_cam import GradCAM, \ 7 | ScoreCAM, \ 8 | GradCAMPlusPlus, \ 9 | AblationCAM, \ 10 | XGradCAM, \ 11 | EigenCAM, \ 12 | EigenGradCAM, \ 13 | LayerCAM, \ 14 | FullGrad 15 | from pytorch_grad_cam.utils.image import show_cam_on_image, \ 16 | preprocess_image 17 | 18 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget 19 | 20 | torch.manual_seed(0) 21 | 22 | 23 | @pytest.fixture 24 | def numpy_image(): 25 | return cv2.imread("examples/both.png") 26 | 27 | 28 | @pytest.mark.parametrize("cam_method", 29 | [GradCAM]) 30 | def test_memory_usage_in_loop(numpy_image, cam_method): 31 | model = torchvision.models.resnet18(pretrained=False) 32 | model.conv1 = torch.nn.Conv2d( 33 | 1, 64, kernel_size=( 34 | 7, 7), stride=( 35 | 2, 2), padding=( 36 | 3, 3), bias=False) 37 | target_layers = [model.layer4] 38 | gray_img = numpy_image[:, :, 0] 39 | input_tensor = torch.from_numpy( 40 | np.float32(gray_img)).unsqueeze(0).unsqueeze(0) 41 | input_tensor = input_tensor.repeat(16, 1, 1, 1) 42 | targets = None 43 | with cam_method(model=model, 44 | target_layers=target_layers) as cam: 45 | grayscale_cam = cam(input_tensor=input_tensor, 46 | targets=targets) 47 | print(grayscale_cam.shape) 48 | -------------------------------------------------------------------------------- /tests/test_run_all_models.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torchvision 3 | import torch 4 | import cv2 5 | from pytorch_grad_cam import GradCAM, \ 6 | ScoreCAM, \ 7 | GradCAMPlusPlus, \ 8 | AblationCAM, \ 9 | XGradCAM, \ 10 | EigenCAM, \ 11 | EigenGradCAM, \ 12 | LayerCAM, \ 13 | FullGrad, \ 14 | KPCA_CAM 15 | from pytorch_grad_cam.utils.image import show_cam_on_image, \ 16 | preprocess_image 17 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget 18 | 19 | 20 | @pytest.fixture 21 | def numpy_image(): 22 | return cv2.imread("examples/both.png") 23 | 24 | 25 | @pytest.mark.parametrize("cnn_model,target_layer_names", [ 26 | (torchvision.models.resnet18, ["layer4[-1]", "layer4[-2]"]), 27 | (torchvision.models.vgg11, ["features[-1]"]) 28 | ]) 29 | @pytest.mark.parametrize("batch_size,width,height", [ 30 | (2, 32, 32), 31 | (1, 32, 40) 32 | ]) 33 | @pytest.mark.parametrize("target_category", [ 34 | None, 35 | 100 36 | ]) 37 | @pytest.mark.parametrize("aug_smooth", [ 38 | False 39 | ]) 40 | @pytest.mark.parametrize("eigen_smooth", [ 41 | True, 42 | False 43 | ]) 44 | @pytest.mark.parametrize("cam_method", 45 | [ScoreCAM, 46 | AblationCAM, 47 | GradCAM, 48 | ScoreCAM, 49 | GradCAMPlusPlus, 50 | XGradCAM, 51 | EigenCAM, 52 | EigenGradCAM, 53 | LayerCAM, 54 | FullGrad, 55 | KPCA_CAM]) 56 | 57 | def test_all_cam_models_can_run(numpy_image, batch_size, width, height, 58 | cnn_model, target_layer_names, cam_method, 59 | target_category, aug_smooth, eigen_smooth): 60 | img = cv2.resize(numpy_image, (width, height)) 61 | input_tensor = preprocess_image(img) 62 | input_tensor = input_tensor.repeat(batch_size, 1, 1, 1) 63 | 64 | model = cnn_model(pretrained=True) 65 | target_layers = [] 66 | for layer in target_layer_names: 67 | target_layers.append(eval(f"model.{layer}")) 68 | 69 | cam = cam_method(model=model, 70 | target_layers=target_layers) 71 | cam.batch_size = 4 72 | if target_category is None: 73 | targets = None 74 | else: 75 | targets = [ClassifierOutputTarget(target_category) 76 | for _ in range(batch_size)] 77 | 78 | grayscale_cam = cam(input_tensor=input_tensor, 79 | targets=targets, 80 | aug_smooth=aug_smooth, 81 | eigen_smooth=eigen_smooth) 82 | assert(grayscale_cam.shape[0] == input_tensor.shape[0]) 83 | assert(grayscale_cam.shape[1:] == input_tensor.shape[2:]) 84 | -------------------------------------------------------------------------------- /tutorials/bbox.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/tutorials/bbox.png -------------------------------------------------------------------------------- /tutorials/huggingface_dff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/tutorials/huggingface_dff.png -------------------------------------------------------------------------------- /tutorials/multimage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/tutorials/multimage.png -------------------------------------------------------------------------------- /tutorials/puppies.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/781dbc0d16ffa95b6d18b96b7b829840a82d93d1/tutorials/puppies.jpg -------------------------------------------------------------------------------- /tutorials/vision_transformers.md: -------------------------------------------------------------------------------- 1 | # How does it work with Vision Transformers 2 | 3 | *See [usage_examples/vit_example.py](../usage_examples/vit_example.py)* 4 | 5 | In ViT the output of the layers are typically BATCH x 197 x 192. 6 | In the dimension with 197, the first element represents the class token, and the rest represent the 14x14 patches in the image. 7 | We can treat the last 196 elements as a 14x14 spatial image, with 192 channels. 8 | 9 | To reshape the activations and gradients to 2D spatial images, 10 | we can pass the CAM constructor a reshape_transform function. 11 | 12 | This can also be a starting point for other architectures that will come in the future. 13 | 14 | ```python 15 | 16 | GradCAM(model=model, target_layers=target_layers, reshape_transform=reshape_transform) 17 | 18 | def reshape_transform(tensor, height=14, width=14): 19 | result = tensor[:, 1 : , :].reshape(tensor.size(0), 20 | height, width, tensor.size(2)) 21 | 22 | # Bring the channels to the first dimension, 23 | # like in CNNs. 24 | result = result.transpose(2, 3).transpose(1, 2) 25 | return result 26 | ``` 27 | 28 | ### Which target_layer should we chose for Vision Transformers? 29 | 30 | Since the final classification is done on the class token computed in the last attention block, 31 | the output will not be affected by the 14x14 channels in the last layer. 32 | The gradient of the output with respect to them, will be 0! 33 | 34 | We should chose any layer before the final attention block, for example: 35 | ```python 36 | target_layers = [model.blocks[-1].norm1] 37 | ``` 38 | 39 | ---------- 40 | 41 | # How does it work with Swin Transformers 42 | 43 | *See [usage_examples/swinT_example.py](../usage_examples/swinT_example.py)* 44 | 45 | In Swin transformer base the output of the layers are typically BATCH x 49 x 1024. 46 | We can treat the last 49 elements as a 7x7 spatial image, with 1024 channels. 47 | 48 | To reshape the activations and gradients to 2D spatial images, 49 | we can pass the CAM constructor a reshape_transform function. 50 | 51 | This can also be a starting point for other architectures that will come in the future. 52 | 53 | ```python 54 | 55 | GradCAM(model=model, target_layers=target_layers, reshape_transform=reshape_transform) 56 | 57 | def reshape_transform(tensor, height=7, width=7): 58 | result = tensor.reshape(tensor.size(0), 59 | height, width, tensor.size(2)) 60 | 61 | # Bring the channels to the first dimension, 62 | # like in CNNs. 63 | result = result.transpose(2, 3).transpose(1, 2) 64 | return result 65 | ``` 66 | 67 | ### Which target_layer should we chose for Swin Transformers? 68 | 69 | Since the swin transformer is different from ViT, it does not contains `cls_token` as present in ViT, 70 | therefore we will use all the 7x7 images we get from the last block of the last layer. 71 | 72 | We should chose any layer before the final attention block, for example: 73 | ```python 74 | target_layers = [model.layers[-1].blocks[-1].norm1] 75 | ``` 76 | -------------------------------------------------------------------------------- /usage_examples/clip_example.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget 8 | try: 9 | from transformers import CLIPProcessor, CLIPModel 10 | except ImportError: 11 | print("The transformers package is not installed. Please install it to use CLIP.") 12 | exit(1) 13 | 14 | 15 | from pytorch_grad_cam import GradCAM, \ 16 | ScoreCAM, \ 17 | GradCAMPlusPlus, \ 18 | AblationCAM, \ 19 | XGradCAM, \ 20 | EigenCAM, \ 21 | EigenGradCAM, \ 22 | LayerCAM, \ 23 | FullGrad 24 | 25 | from pytorch_grad_cam.utils.image import show_cam_on_image, \ 26 | preprocess_image 27 | from pytorch_grad_cam.ablation_layer import AblationLayerVit 28 | 29 | 30 | def get_args(): 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--device', type=str, default='cpu', 33 | help='Torch device to use') 34 | parser.add_argument( 35 | '--image-path', 36 | type=str, 37 | default='./examples/both.png', 38 | help='Input image path') 39 | parser.add_argument( 40 | '--labels', 41 | type=str, 42 | nargs='+', 43 | default=["a cat", "a dog", "a car", "a person", "a shoe"], 44 | help='need recognition labels' 45 | ) 46 | 47 | parser.add_argument('--aug_smooth', action='store_true', 48 | help='Apply test time augmentation to smooth the CAM') 49 | parser.add_argument( 50 | '--eigen_smooth', 51 | action='store_true', 52 | help='Reduce noise by taking the first principle componenet' 53 | 'of cam_weights*activations') 54 | 55 | parser.add_argument( 56 | '--method', 57 | type=str, 58 | default='gradcam', 59 | help='Can be gradcam/gradcam++/scorecam/xgradcam/ablationcam') 60 | 61 | args = parser.parse_args() 62 | if args.device: 63 | print(f'Using device "{args.device}" for acceleration') 64 | else: 65 | print('Using CPU for computation') 66 | 67 | return args 68 | 69 | 70 | def reshape_transform(tensor, height=16, width=16): 71 | result = tensor[:, 1:, :].reshape(tensor.size(0), 72 | height, width, tensor.size(2)) 73 | 74 | # Bring the channels to the first dimension, 75 | # like in CNNs. 76 | result = result.transpose(2, 3).transpose(1, 2) 77 | return result 78 | 79 | 80 | class ImageClassifier(nn.Module): 81 | def __init__(self, labels): 82 | super(ImageClassifier, self).__init__() 83 | self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") 84 | self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") 85 | self.labels = labels 86 | 87 | def forward(self, x): 88 | text_inputs = self.processor(text=self.labels, return_tensors="pt", padding=True) 89 | 90 | outputs = self.clip(pixel_values=x, input_ids=text_inputs['input_ids'].to(self.clip.device), 91 | attention_mask=text_inputs['attention_mask'].to(self.clip.device)) 92 | 93 | logits_per_image = outputs.logits_per_image 94 | probs = logits_per_image.softmax(dim=1) 95 | 96 | for label, prob in zip(self.labels, probs[0]): 97 | print(f"{label}: {prob:.4f}") 98 | return probs 99 | 100 | 101 | if __name__ == '__main__': 102 | """ python vit_gradcam.py --image-path 103 | Example usage of using cam-methods on a VIT network. 104 | 105 | """ 106 | 107 | args = get_args() 108 | methods = \ 109 | {"gradcam": GradCAM, 110 | "scorecam": ScoreCAM, 111 | "gradcam++": GradCAMPlusPlus, 112 | "ablationcam": AblationCAM, 113 | "xgradcam": XGradCAM, 114 | "eigencam": EigenCAM, 115 | "eigengradcam": EigenGradCAM, 116 | "layercam": LayerCAM, 117 | "fullgrad": FullGrad} 118 | 119 | if args.method not in list(methods.keys()): 120 | raise Exception(f"method should be one of {list(methods.keys())}") 121 | 122 | labels = args.labels 123 | model = ImageClassifier(labels).to(torch.device(args.device)).eval() 124 | print(model) 125 | 126 | target_layers = [model.clip.vision_model.encoder.layers[-1].layer_norm1] 127 | 128 | if args.method not in methods: 129 | raise Exception(f"Method {args.method} not implemented") 130 | 131 | rgb_img = cv2.imread(args.image_path, 1)[:, :, ::-1] 132 | rgb_img = cv2.resize(rgb_img, (224, 224)) 133 | rgb_img = np.float32(rgb_img) / 255 134 | input_tensor = preprocess_image(rgb_img, mean=[0.5, 0.5, 0.5], 135 | std=[0.5, 0.5, 0.5]).to(args.device) 136 | 137 | if args.method == "ablationcam": 138 | cam = methods[args.method](model=model, 139 | target_layers=target_layers, 140 | reshape_transform=reshape_transform, 141 | ablation_layer=AblationLayerVit()) 142 | else: 143 | cam = methods[args.method](model=model, 144 | target_layers=target_layers, 145 | reshape_transform=reshape_transform) 146 | 147 | 148 | 149 | # If None, returns the map for the highest scoring category. 150 | # Otherwise, targets the requested category. 151 | #targets = [ClassifierOutputTarget(1)] 152 | targets = None 153 | 154 | # AblationCAM and ScoreCAM have batched implementations. 155 | # You can override the internal batch size for faster computation. 156 | cam.batch_size = 32 157 | 158 | grayscale_cam = cam(input_tensor=input_tensor, 159 | targets=targets, 160 | eigen_smooth=args.eigen_smooth, 161 | aug_smooth=args.aug_smooth) 162 | 163 | # Here grayscale_cam has only one image in the batch 164 | grayscale_cam = grayscale_cam[0, :] 165 | 166 | cam_image = show_cam_on_image(rgb_img, grayscale_cam) 167 | cv2.imwrite(f'{args.method}_cam.jpg', cam_image) 168 | -------------------------------------------------------------------------------- /usage_examples/swinT_example.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import numpy as np 4 | import torch 5 | import timm 6 | 7 | from pytorch_grad_cam import GradCAM, \ 8 | ScoreCAM, \ 9 | GradCAMPlusPlus, \ 10 | AblationCAM, \ 11 | XGradCAM, \ 12 | EigenCAM, \ 13 | EigenGradCAM, \ 14 | LayerCAM, \ 15 | FullGrad 16 | 17 | from pytorch_grad_cam.utils.image import show_cam_on_image, \ 18 | preprocess_image 19 | from pytorch_grad_cam.ablation_layer import AblationLayerVit 20 | 21 | 22 | def get_args(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--use-cuda', action='store_true', default=False, 25 | help='Use NVIDIA GPU acceleration') 26 | parser.add_argument( 27 | '--image-path', 28 | type=str, 29 | default='./examples/both.png', 30 | help='Input image path') 31 | parser.add_argument('--aug_smooth', action='store_true', 32 | help='Apply test time augmentation to smooth the CAM') 33 | parser.add_argument( 34 | '--eigen_smooth', 35 | action='store_true', 36 | help='Reduce noise by taking the first principle componenet' 37 | 'of cam_weights*activations') 38 | 39 | parser.add_argument( 40 | '--method', 41 | type=str, 42 | default='scorecam', 43 | help='Can be gradcam/gradcam++/scorecam/xgradcam/ablationcam') 44 | 45 | args = parser.parse_args() 46 | args.use_cuda = args.use_cuda and torch.cuda.is_available() 47 | if args.use_cuda: 48 | print('Using GPU for acceleration') 49 | else: 50 | print('Using CPU for computation') 51 | 52 | return args 53 | 54 | 55 | def reshape_transform(tensor, height=7, width=7): 56 | result = tensor.reshape(tensor.size(0), 57 | height, width, tensor.size(2)) 58 | 59 | # Bring the channels to the first dimension, 60 | # like in CNNs. 61 | result = result.transpose(2, 3).transpose(1, 2) 62 | return result 63 | 64 | 65 | if __name__ == '__main__': 66 | """ python swinT_example.py -image-path 67 | Example usage of using cam-methods on a SwinTransformers network. 68 | 69 | """ 70 | 71 | args = get_args() 72 | methods = \ 73 | {"gradcam": GradCAM, 74 | "scorecam": ScoreCAM, 75 | "gradcam++": GradCAMPlusPlus, 76 | "ablationcam": AblationCAM, 77 | "xgradcam": XGradCAM, 78 | "eigencam": EigenCAM, 79 | "eigengradcam": EigenGradCAM, 80 | "layercam": LayerCAM, 81 | "fullgrad": FullGrad} 82 | 83 | if args.method not in list(methods.keys()): 84 | raise Exception(f"method should be one of {list(methods.keys())}") 85 | 86 | model = timm.create_model('swin_base_patch4_window7_224', pretrained=True) 87 | model.eval() 88 | 89 | if args.use_cuda: 90 | model = model.cuda() 91 | 92 | target_layers = [model.layers[-1].blocks[-1].norm2] 93 | 94 | if args.method not in methods: 95 | raise Exception(f"Method {args.method} not implemented") 96 | 97 | if args.method == "ablationcam": 98 | cam = methods[args.method](model=model, 99 | target_layers=target_layers, 100 | reshape_transform=reshape_transform, 101 | ablation_layer=AblationLayerVit()) 102 | else: 103 | cam = methods[args.method](model=model, 104 | target_layers=target_layers, 105 | reshape_transform=reshape_transform) 106 | 107 | rgb_img = cv2.imread(args.image_path, 1)[:, :, ::-1] 108 | rgb_img = cv2.resize(rgb_img, (224, 224)) 109 | rgb_img = np.float32(rgb_img) / 255 110 | input_tensor = preprocess_image(rgb_img, mean=[0.5, 0.5, 0.5], 111 | std=[0.5, 0.5, 0.5]) 112 | 113 | # AblationCAM and ScoreCAM have batched implementations. 114 | # You can override the internal batch size for faster computation. 115 | cam.batch_size = 32 116 | 117 | grayscale_cam = cam(input_tensor=input_tensor, 118 | targets=None, 119 | eigen_smooth=args.eigen_smooth, 120 | aug_smooth=args.aug_smooth) 121 | 122 | # Here grayscale_cam has only one image in the batch 123 | grayscale_cam = grayscale_cam[0, :] 124 | 125 | cam_image = show_cam_on_image(rgb_img, grayscale_cam) 126 | cv2.imwrite(f'{args.method}_cam.jpg', cam_image) 127 | -------------------------------------------------------------------------------- /usage_examples/vit_example.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import numpy as np 4 | import torch 5 | 6 | from pytorch_grad_cam import GradCAM, \ 7 | ScoreCAM, \ 8 | GradCAMPlusPlus, \ 9 | AblationCAM, \ 10 | XGradCAM, \ 11 | EigenCAM, \ 12 | EigenGradCAM, \ 13 | LayerCAM, \ 14 | FullGrad 15 | 16 | from pytorch_grad_cam import GuidedBackpropReLUModel 17 | from pytorch_grad_cam.utils.image import show_cam_on_image, \ 18 | preprocess_image 19 | from pytorch_grad_cam.ablation_layer import AblationLayerVit 20 | 21 | 22 | def get_args(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--device', type=str, default='cpu', 25 | help='Torch device to use') 26 | 27 | parser.add_argument( 28 | '--image-path', 29 | type=str, 30 | default='./examples/both.png', 31 | help='Input image path') 32 | parser.add_argument('--aug_smooth', action='store_true', 33 | help='Apply test time augmentation to smooth the CAM') 34 | parser.add_argument( 35 | '--eigen_smooth', 36 | action='store_true', 37 | help='Reduce noise by taking the first principle componenet' 38 | 'of cam_weights*activations') 39 | 40 | parser.add_argument( 41 | '--method', 42 | type=str, 43 | default='gradcam', 44 | help='Can be gradcam/gradcam++/scorecam/xgradcam/ablationcam') 45 | 46 | args = parser.parse_args() 47 | if args.device: 48 | print(f'Using device "{args.device}" for acceleration') 49 | else: 50 | print('Using CPU for computation') 51 | 52 | return args 53 | 54 | 55 | def reshape_transform(tensor, height=14, width=14): 56 | result = tensor[:, 1:, :].reshape(tensor.size(0), 57 | height, width, tensor.size(2)) 58 | 59 | # Bring the channels to the first dimension, 60 | # like in CNNs. 61 | result = result.transpose(2, 3).transpose(1, 2) 62 | return result 63 | 64 | 65 | if __name__ == '__main__': 66 | """ python vit_gradcam.py --image-path 67 | Example usage of using cam-methods on a VIT network. 68 | 69 | """ 70 | 71 | args = get_args() 72 | methods = \ 73 | {"gradcam": GradCAM, 74 | "scorecam": ScoreCAM, 75 | "gradcam++": GradCAMPlusPlus, 76 | "ablationcam": AblationCAM, 77 | "xgradcam": XGradCAM, 78 | "eigencam": EigenCAM, 79 | "eigengradcam": EigenGradCAM, 80 | "layercam": LayerCAM, 81 | "fullgrad": FullGrad} 82 | 83 | if args.method not in list(methods.keys()): 84 | raise Exception(f"method should be one of {list(methods.keys())}") 85 | 86 | model = torch.hub.load('facebookresearch/deit:main', 87 | 'deit_tiny_patch16_224', pretrained=True).to(torch.device(args.device)).eval() 88 | 89 | 90 | target_layers = [model.blocks[-1].norm1] 91 | 92 | if args.method not in methods: 93 | raise Exception(f"Method {args.method} not implemented") 94 | 95 | if args.method == "ablationcam": 96 | cam = methods[args.method](model=model, 97 | target_layers=target_layers, 98 | reshape_transform=reshape_transform, 99 | ablation_layer=AblationLayerVit()) 100 | else: 101 | cam = methods[args.method](model=model, 102 | target_layers=target_layers, 103 | reshape_transform=reshape_transform) 104 | 105 | rgb_img = cv2.imread(args.image_path, 1)[:, :, ::-1] 106 | rgb_img = cv2.resize(rgb_img, (224, 224)) 107 | rgb_img = np.float32(rgb_img) / 255 108 | input_tensor = preprocess_image(rgb_img, mean=[0.5, 0.5, 0.5], 109 | std=[0.5, 0.5, 0.5]).to(args.device) 110 | 111 | # If None, returns the map for the highest scoring category. 112 | # Otherwise, targets the requested category. 113 | targets = None 114 | 115 | # AblationCAM and ScoreCAM have batched implementations. 116 | # You can override the internal batch size for faster computation. 117 | cam.batch_size = 32 118 | 119 | grayscale_cam = cam(input_tensor=input_tensor, 120 | targets=targets, 121 | eigen_smooth=args.eigen_smooth, 122 | aug_smooth=args.aug_smooth) 123 | 124 | # Here grayscale_cam has only one image in the batch 125 | grayscale_cam = grayscale_cam[0, :] 126 | 127 | cam_image = show_cam_on_image(rgb_img, grayscale_cam) 128 | cv2.imwrite(f'{args.method}_cam.jpg', cam_image) 129 | --------------------------------------------------------------------------------