├── .gitignore ├── README.md ├── __init__.py ├── attribution ├── __init__.py ├── ablation_cam.py ├── bag_cam.py ├── base.py ├── blur_ig.py ├── eigen_cam.py ├── eigen_gradcam.py ├── fullgrad_cam.py ├── grad_cam.py ├── grad_cam_plusplus.py ├── guided_backprop.py ├── guided_gradcam.py ├── guided_ig.py ├── hires_cam.py ├── ig.py ├── layer_cam.py ├── occlusion.py ├── score_cam.py ├── utils │ ├── __init__.py │ ├── access_layers.py │ ├── reshape_for_transformer.py │ ├── svd_on_feature_maps.py │ └── visualization_maps.py └── xgrad_cam.py ├── cam_visualization_examples.py ├── cam_visualization_for_transformers_examples.py ├── combine_cam_and_gradients_visualization_examples.py ├── examples ├── attribution_methods.png ├── cam_visualization.png ├── cam_visualization_for_transformers.png ├── cat.png ├── combine_cam_and_gradients_visualization.png ├── dog.png ├── dog_and_cat.png ├── gradients_visualization.png ├── gradients_visualization_for_transformers.png ├── perturbation_based_visualization.png └── quick_start.png ├── gradients_visualization_examples.py ├── gradients_visualization_for_transformers_examples.py ├── metrics ├── __init__.py ├── correlation_value.py ├── insert_and_delete.py └── keep_remove_mask.py ├── perturbation_based_attribution_visualization_examples.py └── quick_start.py /.gitignore: -------------------------------------------------------------------------------- 1 | gradients 2 | saliency 3 | pytorch-grad-cam 4 | test_code.ipynb 5 | test_code.py 6 | .vscode 7 | *__pycache__* 8 | temp 9 | dix 10 | six -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Attritbuion Methods for Image Classification Models 2 | 3 | 4 | We only consider plug-and-play methods that **do not have special requirements on the model architecture and do not need to add modules with learnable parameters or additional training**. At the same time, we hope to facilitate weakly-supervised localization and segmentation using attribution results, as well as incorporating them as part of model training (e.g., use the attribution results as additional supervision information). Therefore, all methods use PyTorch tensors for calculations as much as possible, support batch input, and GPU usage. 5 | 6 | Since we mainly aim to use attribution results to assist weakly supervised training, localization, segmentation, model distillation and etc., we did not include explainability methods for black-box models like RISE and HISC here. 7 | 8 | ## Gradients Visualization 9 | **CNN models**: some results of resnet50 from timm, example code at [./gradientss_visualization_examples.py](./gradients_visualization_examples.py). 10 | 11 | 12 | 13 | 14 | **Vision/Swin Transformers**: gradients visualization methods can be directly used for transformers, the model used here is `vit_tiny_patch16_224.augreg_in21k_ft_in1k` from timm, example code at [./gradients_visualization_for_transformers_examples.py](./gradients_visualization_for_transformers_examples.py). 15 | 16 | 17 | 18 | ## Class Activation Map (CAM) Visualization 19 | resnet50, the target layer is `layer3`, example code at [./cam_visualization_examples.py](./cam_visualization_examples.py) 20 | 21 | 22 | 23 | ## CAM Visualization for ViT and Swin Transformer 24 | use `attribution.utils.get_reshape_transform` when creating the attribution model, example code at [./cam_visualization_for_transformers_examples.py](./cam_visualization_for_transformers_examples.py). The target layer used for ViT here is `blocks.11.norm1` and that for Swin Transformer is `norm`. 25 | 26 | 27 | 28 | Currently, some methods are not supported for transformers, such as Ablation-CAM, and the visualization effect is not as good as CNN models since many methods are designed with the concept of feature maps. We will try to add visualization methods that are designed for transformers in the future. 29 | 30 | ## Combine Gradients and CAM Visualization 31 | similar to Guided Grad-CAM, any method in the gradient visualization can be combined with CAM visualization, example code at [./combine_cam_and_gradients_visualization_examples.py](./combine_cam_and_gradients_visualization_examples.py) 32 | 33 | 34 | 35 | ## Block-Box Perturbation-based Attribution Visualization 36 | example code at [./perturbation_based_attribution_visualization_examples.py](./perturbation_based_attribution_visualization_examples.py) 37 | 38 | 39 | 40 | ## Quick Start 41 | ```python 42 | from matplotlib import pyplot as plt 43 | from PIL import Image 44 | import requests 45 | import timm 46 | from timm.data import resolve_model_data_config 47 | from timm.data.transforms_factory import create_transform 48 | import torch 49 | 50 | from attribution import BlurIG, GradCAM, CombinedWrapper 51 | from attribution.utils import normalize_saliency, visualize_single_saliency 52 | 53 | # Load imagenet labels 54 | IMAGENET_1k_URL = 'https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt' 55 | IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n') 56 | 57 | # Load model 58 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 59 | model = timm.create_model('resnet50', pretrained=True) 60 | model = model.to(device) 61 | model.eval() 62 | config = resolve_model_data_config(model, None) 63 | transform = create_transform(**config) 64 | 65 | # Load image 66 | dog = Image.open('examples/dog.png').convert('RGB') 67 | dog_tensor = transform(dog).unsqueeze(0) 68 | H, W = dog_tensor.shape[-2:] 69 | img = transform(dog).unsqueeze(0) 70 | 71 | # We support batch input 72 | img = torch.cat([img, img]) 73 | img = img.to(device) 74 | output = model(img) 75 | target_index = torch.argmax(output, dim=1).cpu() 76 | print('Predicted:', IMAGENET_1k_LABELS[target_index[0].item()]) 77 | 78 | # Gradients visualization 79 | blur_ig_kwargs = {'steps': 100, 80 | 'batch_size': 4, 81 | 'max_sigma': 50, 82 | 'grad_step': 0.01, 83 | 'sqrt': False} 84 | blur_ig_net = BlurIG(model) 85 | blur_ig = normalize_saliency(blur_ig_net.get_mask(img, target_index, **blur_ig_kwargs)) 86 | 87 | # CAM visualization 88 | gradcam_net = GradCAM(model) 89 | gradcam = normalize_saliency(gradcam_net.get_mask(img, target_index, target_layer='layer3')) 90 | 91 | # Combine Gradients and CAM visualization 92 | combined = CombinedWrapper(model, BlurIG, GradCAM) 93 | combined_saliency = normalize_saliency( 94 | combined.get_mask(img, target_index, target_layer='layer3', **blur_ig_kwargs)) 95 | 96 | # Visualize 97 | plt.figure(figsize=(16, 5)) 98 | plt.subplot(1, 4, 1) 99 | plt.imshow(dog) 100 | plt.title('Input Image') 101 | plt.axis('off') 102 | plt.subplot(1, 4, 2) 103 | visualize_single_saliency(blur_ig[0].unsqueeze(0)) 104 | plt.title('Blur IG') 105 | plt.subplot(1, 4, 3) 106 | visualize_single_saliency(gradcam[0].unsqueeze(0)) 107 | plt.title('GradCAM') 108 | plt.subplot(1, 4, 4) 109 | visualize_single_saliency(combined_saliency[0].unsqueeze(0)) 110 | plt.title('Combined') 111 | plt.tight_layout() 112 | plt.savefig('examples/quick_start.png', bbox_inches='tight', pad_inches=0.5) 113 | ``` 114 | 115 | 116 | 117 | 118 | ## TODO: 119 | This is still an ongoing work to implement various attribution methods for image classification models in PyTorch using a unified framework. 120 | - [x] Unify gradient visualization API. 121 | - [x] Implement CAM visualization for CNN models based on known target_layer names. 122 | - [x] Implement CAM for ViT ,Swin Transformer and etc. 123 | - [x] Implement keep positive/negative mask, keep/remove absolute mask metrics. For details, please refer to [Fast Axiomatic Attribution for Neural Networks](https://proceedings.neurips.cc/paper/2021/hash/a284df1155ec3e67286080500df36a9a-Abstract.html). 124 | - [ ] Add LIFT-CAM (ICCV2021), IIA (ICCV2023), Dix (CIKM) and Six (ICDM). 125 | - [ ] Unify all APIs. 126 | - [ ] Documentation. 127 | 128 | 129 | ## Acknowledgements 130 | This project is inspired by [jacobgil/pytorch-grad-cam](https://github.com/jacobgil/pytorch-grad-cam), [PAIR-code/saliency](https://github.com/PAIR-code/saliency) and [hummat/saliency](https://github.com/hummat/saliency). Thanks for their wonderful work. 131 | 132 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .attribution import * 2 | from .attribution.utils import * 3 | 4 | from .metrics import * -------------------------------------------------------------------------------- /attribution/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import VanillaGradient 2 | from .ig import IntegratedGradients 3 | from .blur_ig import BlurIG 4 | from .guided_ig import GuidedIG 5 | from .guided_backprop import GuidedBackProp 6 | 7 | from .occlusion import Occlusion 8 | 9 | from .grad_cam import GradCAM 10 | from .guided_gradcam import GuidedGradCAM 11 | from .hires_cam import HiResCAM 12 | from .grad_cam_plusplus import GradCAMPlusPlus 13 | from .xgrad_cam import XGradCAM 14 | from .bag_cam import BagCAM 15 | from .score_cam import ScoreCAM 16 | from .layer_cam import LayerCAM 17 | from .ablation_cam import AblationCAM 18 | from .fullgrad_cam import FullGrad 19 | from .eigen_cam import EigenCAM 20 | from .eigen_gradcam import EigenGradCAM 21 | 22 | from .base import CombinedWrapper -------------------------------------------------------------------------------- /attribution/ablation_cam.py: -------------------------------------------------------------------------------- 1 | from .base import CAMWrapper 2 | from .utils import replace_layer_recursive, get_2d_projection 3 | 4 | import numpy as np 5 | import torch 6 | from torch.nn import functional as F 7 | 8 | class AblationLayer(torch.nn.Module): 9 | def __init__(self): 10 | super(AblationLayer, self).__init__() 11 | 12 | def objectiveness_mask_from_svd(self, feature_maps, threshold=0.01): 13 | projection = get_2d_projection(feature_maps[None, :])[0, :] 14 | projection = np.abs(projection) 15 | projection = projection - projection.min() 16 | projection = projection / projection.max() 17 | projection = projection > threshold 18 | return projection 19 | 20 | def feature_maps_to_be_ablated( 21 | self, 22 | feature_maps, 23 | ratio_channels_to_ablate=1.0): 24 | if ratio_channels_to_ablate == 1.0: 25 | self.indices = np.int32(range(feature_maps.shape[0])) 26 | return self.indices 27 | 28 | projection = self.objectiveness_mask_from_svd(feature_maps) 29 | 30 | scores = [] 31 | for channel in feature_maps: 32 | normalized = np.abs(channel) 33 | normalized = normalized - normalized.min() 34 | normalized = normalized / np.max(normalized) 35 | score = (projection * normalized).sum() / normalized.sum() 36 | scores.append(score) 37 | scores = np.float32(scores) 38 | 39 | indices = list(np.argsort(scores)) 40 | high_score_indices = indices[::- 41 | 1][: int(len(indices) * 42 | ratio_channels_to_ablate)] 43 | low_score_indices = indices[: int( 44 | len(indices) * ratio_channels_to_ablate)] 45 | self.indices = np.int32(high_score_indices + low_score_indices) 46 | return self.indices 47 | 48 | def set_next_batch( 49 | self, 50 | input_batch_index, 51 | feature_maps: torch.Tensor, 52 | num_channels_to_ablate): 53 | self.feature_maps_ablation_layer = feature_maps[input_batch_index, :, :, :].clone( 54 | ).unsqueeze(0).repeat(num_channels_to_ablate, 1, 1, 1) 55 | 56 | def __call__(self, x): 57 | output = self.feature_maps_ablation_layer 58 | for i in range(output.size(0)): 59 | # Commonly the minimum activation will be 0, 60 | # And then it makes sense to zero it out. 61 | # However depending on the architecture, 62 | # If the values can be negative, we use very negative values 63 | # to perform the ablation, deviating from the paper. 64 | if torch.min(output) == 0: 65 | output[i, self.indices[i], :] = 0 66 | else: 67 | ABLATION_VALUE = 1e7 68 | output[i, self.indices[i], :] = torch.min( 69 | output) - ABLATION_VALUE 70 | 71 | return output 72 | 73 | class AblationCAM(CAMWrapper): 74 | def get_mask(self, img: torch.Tensor, 75 | target_class: torch.Tensor, 76 | target_layer: str, 77 | batch_size: int = 32, 78 | ratio_channels_to_ablate: float = 1.0): 79 | 80 | B, C, H, W = img.size() 81 | self.model.eval() 82 | self.model.zero_grad() 83 | 84 | # class-specific backpropagation 85 | logits = self.model(img) 86 | target = self._encode_one_hot(target_class, logits) 87 | self.model.zero_grad() 88 | logits.backward(gradient=target, retain_graph=True) 89 | 90 | get_targets = lambda o, target: o[target] 91 | if isinstance(target_class, torch.Tensor): 92 | target_class = target_class.cpu().tolist() 93 | if not isinstance(target_class, list): 94 | target_class = [target_class] 95 | 96 | original_scores = [get_targets(o, t).cpu().detach() for o, t in zip(logits, target_class)] 97 | feature_maps = self._find(self.feature_maps, target_layer) 98 | 99 | # save original layer and replace the model back to the original state later 100 | original_target_layer = self.get_target_module(target_layer) 101 | ablation_layer = AblationLayer() 102 | replace_layer_recursive(self.model, original_target_layer, ablation_layer) 103 | 104 | # get weights 105 | number_of_channels = feature_maps.size(1) 106 | weights = [] 107 | with torch.no_grad(): 108 | for batch_idx, (target, I) in enumerate(zip(target_class, img)): 109 | new_scores = [] 110 | batch_tensor = I.repeat(batch_size, 1, 1, 1) 111 | 112 | channels_to_ablate = ablation_layer.feature_maps_to_be_ablated( 113 | feature_maps[batch_idx, :], ratio_channels_to_ablate 114 | ) 115 | number_of_channels_to_ablate = len(channels_to_ablate) 116 | 117 | for i in range(0, number_of_channels_to_ablate, batch_size): 118 | if i + batch_size > number_of_channels_to_ablate: 119 | batch_tensor = batch_tensor[:(number_of_channels_to_ablate - i)] 120 | 121 | ablation_layer.set_next_batch( 122 | batch_idx, feature_maps, batch_tensor.size(0) 123 | ) 124 | 125 | new_scores.extend([get_targets(o, target).cpu().detach() for o in self.model(batch_tensor)]) 126 | ablation_layer.indices = ablation_layer.indices[batch_size:] 127 | 128 | new_scores = self.assemble_ablation_scores( 129 | new_scores, original_scores[batch_idx], channels_to_ablate, number_of_channels 130 | ) 131 | weights.extend(new_scores) 132 | 133 | weights = np.float32(weights) 134 | weights = weights.reshape(feature_maps.shape[:2]) 135 | original_scores = np.array(original_scores)[:, None] 136 | weights = (original_scores - weights) / original_scores 137 | weights = torch.from_numpy(weights).to(self.device)[:, :, None, None] 138 | 139 | cam = torch.mul(feature_maps, weights).sum(dim=1, keepdim=True) 140 | cam = F.relu(cam) 141 | cam = F.interpolate(cam, (H, W), mode='bilinear', align_corners=False) 142 | cam = self.normalize_cam(cam) 143 | 144 | replace_layer_recursive(self.model, ablation_layer, original_target_layer) 145 | 146 | return cam 147 | 148 | def get_target_module(self, target_layer: str): 149 | for name, module in self.model.named_modules(): 150 | if name == target_layer: 151 | return module 152 | raise ValueError(f"Layer {target_layer} not found in model") 153 | 154 | def assemble_ablation_scores(self, 155 | new_scores: list, 156 | original_score: float, 157 | ablated_channels: np.ndarray, 158 | number_of_channels: int) -> np.ndarray: 159 | """ Take the value from the channels that were ablated, 160 | and just set the original score for the channels that were skipped """ 161 | 162 | index = 0 163 | result = [] 164 | sorted_indices = np.argsort(ablated_channels) 165 | ablated_channels = ablated_channels[sorted_indices] 166 | new_scores = np.float32(new_scores)[sorted_indices] 167 | 168 | for i in range(number_of_channels): 169 | if index < len(ablated_channels) and ablated_channels[index] == i: 170 | weight = new_scores[index] 171 | index = index + 1 172 | else: 173 | weight = original_score 174 | result.append(weight) 175 | 176 | return result -------------------------------------------------------------------------------- /attribution/bag_cam.py: -------------------------------------------------------------------------------- 1 | from .base import CAMWrapper 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | class BagCAM(CAMWrapper): 8 | def get_mask(self, img: torch.Tensor, 9 | target_class: torch.Tensor, 10 | target_layer: str): 11 | B, C, H, W = img.size() 12 | self.model.eval() 13 | self.model.zero_grad() 14 | 15 | # class-specific backpropagation 16 | logits = self.model(img) 17 | target = self._encode_one_hot(target_class, logits) 18 | self.model.zero_grad() 19 | logits.backward(gradient=target, retain_graph=True) 20 | 21 | # get feature maps and gradients 22 | feature_maps = self._find(self.feature_maps, target_layer) 23 | gradients = self._find(self.gradients, target_layer) 24 | 25 | # generate CAM 26 | with torch.no_grad(): 27 | term_2 = gradients * feature_maps 28 | term_1 = term_2 + 1 29 | term_1 = F.adaptive_avg_pool2d(term_1, 1) 30 | cam = F.relu(torch.mul(term_1, term_2)).sum(dim=1, keepdim=True) 31 | cam = F.interpolate(cam, (H, W), mode='bilinear', align_corners=False) 32 | cam = self.normalize_cam(cam) 33 | 34 | return cam -------------------------------------------------------------------------------- /attribution/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, List, Union 3 | from .utils import normalize_saliency 4 | 5 | class Core(torch.nn.Module): 6 | def __init__(self, model: torch.nn.Module): 7 | super(Core, self).__init__() 8 | self.model = model 9 | self.hooks = list() # list of hooks 10 | self.device = next(model.parameters()).device 11 | 12 | def _encode_one_hot(self, targets: torch.Tensor, logits: torch.Tensor): 13 | if not isinstance(targets, torch.Tensor): 14 | targets = torch.tensor([targets], device=self.device) 15 | targets = targets.view(logits.size(0)) 16 | one_hot = torch.nn.functional.one_hot(targets, num_classes=logits.size(-1)).to(self.device) 17 | return one_hot 18 | 19 | def get_mask(self, img, target_class=None): 20 | raise NotImplementedError( 21 | 'A derived class should implemented this method') 22 | 23 | def remove_hooks(self): 24 | for hook in self.hooks: 25 | hook.remove() 26 | 27 | def forward(self, img): 28 | logits = self.model(img) 29 | return logits 30 | 31 | class VanillaGradient(Core): 32 | def __init__(self, model): 33 | super(VanillaGradient, self).__init__(model) 34 | 35 | # return gradients 36 | def get_mask(self, img: torch.Tensor, target_class: torch.Tensor) -> torch.Tensor: 37 | self.model.eval() 38 | self.model.zero_grad() 39 | 40 | img = img.clone() 41 | img.requires_grad = True 42 | img.retain_grad() 43 | 44 | logits = self.model(img) 45 | 46 | target = self._encode_one_hot(target_class, logits) 47 | self.model.zero_grad() 48 | logits.backward(gradient=target, retain_graph=True) 49 | return img.grad.detach() 50 | 51 | def get_smoothed_mask(self, img, target_class, samples=25, std=0.15, process=lambda x: x**2): 52 | std = std * (torch.max(img) - torch.min(img)).detach().cpu().numpy() 53 | 54 | B, C, H, W = img.size() 55 | grad_sum = torch.zeros((B, C, H, W), device=self.device) 56 | for sample in range(samples): 57 | noise = torch.empty(img.size()).normal_(0, std).to(self.device) 58 | noise_image = img + noise 59 | grad_sum += process(self.get_mask(noise_image, target_class)) 60 | return grad_sum / samples 61 | 62 | 63 | class CAMWrapper(Core): 64 | def __init__(self, model: torch.nn.Module, reshape_transform=None): 65 | super(CAMWrapper, self).__init__(model) 66 | 67 | self.feature_maps = dict() 68 | self.gradients = dict() 69 | self.reshape_transform = reshape_transform 70 | 71 | def save_feature_maps(name): 72 | def forward_hook(module, input, output): 73 | self.feature_maps[name] = output.detach() 74 | 75 | return forward_hook 76 | 77 | def save_gradients(name): 78 | def _store_grad(grad): 79 | self.gradients[name] = grad.detach() 80 | def forward_hook(module, input, output): 81 | if not hasattr(output, 'requires_grad') or not output.requires_grad: 82 | return 83 | output.register_hook(_store_grad) 84 | 85 | return forward_hook 86 | 87 | for name, module in self.model.named_modules(): 88 | self.hooks.append(module.register_forward_hook(save_feature_maps(name))) 89 | self.hooks.append(module.register_forward_hook(save_gradients(name))) 90 | 91 | def _find(self, saved_dict, name: str): 92 | if name in saved_dict.keys(): 93 | if self.reshape_transform is not None: 94 | return self.reshape_transform(saved_dict[name]) 95 | return saved_dict[name] 96 | 97 | raise ValueError('Invalid layer name') 98 | 99 | def get_mask(self, img: torch.Tensor, 100 | target_class: torch.Tensor, 101 | target_layer: Union[str, List[str]]): 102 | raise NotImplementedError('A derived class should implemented this method') 103 | 104 | @torch.no_grad() 105 | def normalize_cam(self, cam: torch.Tensor): 106 | B, C, H, W = cam.size() 107 | cam = cam.view(cam.size(0), -1) 108 | cam -= cam.min(dim=1, keepdim=True)[0] 109 | cam = cam / (cam.max(dim=1, keepdim=True)[0]+1e-5) 110 | return cam.view(B, C, H, W) 111 | 112 | 113 | class CombinedWrapper(Core): 114 | def __init__(self, model: torch.nn.Module, gradient_net: Core, cam_net: CAMWrapper, reshape_transform=None): 115 | super(CombinedWrapper, self).__init__(model) 116 | self.gradient_net = gradient_net(model) 117 | self.cam_net = cam_net(model, reshape_transform) 118 | 119 | def get_mask(self, img: torch.Tensor, target_class: torch.Tensor, target_layer: str, **kwargs_for_gradient_net): 120 | B, C, H, W = img.size() 121 | self.model.eval() 122 | self.model.zero_grad() 123 | cam = self.cam_net.get_mask(img, target_class, target_layer) 124 | gradients = self.gradient_net.get_mask(img, target_class, **kwargs_for_gradient_net) 125 | gradients = normalize_saliency(gradients, return_device=self.device) 126 | 127 | with torch.no_grad(): 128 | attribution = gradients * cam 129 | attribution = torch.nn.functional.relu(attribution) 130 | attribution = self.cam_net.normalize_cam(attribution) 131 | 132 | return attribution -------------------------------------------------------------------------------- /attribution/blur_ig.py: -------------------------------------------------------------------------------- 1 | from .base import VanillaGradient 2 | 3 | import math 4 | import torch 5 | from torchvision.transforms import GaussianBlur 6 | 7 | def gaussian_blur(img: torch.Tensor, sigma: int): 8 | if sigma == 0: 9 | return img 10 | kernel_size = int(4 * sigma + 0.5) + 1 11 | return GaussianBlur(kernel_size=kernel_size, sigma=sigma)(img) 12 | 13 | 14 | class BlurIG(VanillaGradient): 15 | def get_mask(self, img: torch.Tensor, 16 | target_class: torch.Tensor, 17 | max_sigma: int = 50, 18 | steps: int = 100, 19 | grad_step: float = 0.01, 20 | sqrt: bool = False, 21 | batch_size: int = 4): 22 | self.model.eval() 23 | self.model.zero_grad() 24 | 25 | if sqrt: 26 | sigmas = [math.sqrt(float(i) * max_sigma / float(steps)) for i in range(0, steps+1)] 27 | else: 28 | sigmas = [float(i) * max_sigma / float(steps) for i in range(0, steps+1)] 29 | 30 | step_vector_diff = [sigmas[i+1] - sigmas[i] for i in range(0, steps)] 31 | total_gradients = torch.zeros_like(img) 32 | x_step_batched = [] 33 | gaussian_gradient_batched = [] 34 | 35 | for i in range(steps): 36 | with torch.no_grad(): 37 | x_step = gaussian_blur(img, sigmas[i]) 38 | gaussian_gradients = (gaussian_blur(img, sigmas[i] + grad_step) - x_step) / grad_step 39 | x_step_batched.append(x_step) 40 | gaussian_gradient_batched.append(gaussian_gradients) 41 | if len(x_step_batched) == batch_size or i == steps - 1: 42 | x_step_batched = torch.cat(x_step_batched, dim=0) 43 | x_step_batched.requires_grad = True 44 | outputs = torch.softmax(self.model(x_step_batched), dim=1)[:, target_class] 45 | gradients = torch.autograd.grad(outputs, x_step_batched, torch.ones_like(outputs), create_graph=True)[0] 46 | gradients = gradients.detach() 47 | # gradients = super(BlurIG, self).get_mask(x_step_batched, torch.stack([target_class] * x_step_batched.size(0), dim=0)) 48 | 49 | with torch.no_grad(): 50 | total_gradients += (step_vector_diff[i] * 51 | torch.mul(torch.cat(gaussian_gradient_batched, dim=0), gradients.clone())).sum(dim=0) 52 | x_step_batched = [] 53 | gaussian_gradient_batched = [] 54 | 55 | with torch.no_grad(): 56 | blur_ig = total_gradients * -1.0 57 | 58 | return blur_ig 59 | 60 | def get_smoothed_mask(self, img: torch.Tensor, 61 | target_class: torch.Tensor, 62 | max_sigma: int = 50, 63 | steps: int = 100, 64 | grad_step: float = 0.01, 65 | sqrt: bool = False, 66 | batch_size: int = 4, 67 | samples: int = 25, 68 | std: float = 0.15, 69 | process=lambda x: x**2): 70 | std = std * (torch.max(img) - torch.min(img)).detach().cpu().numpy() 71 | 72 | B, C, H, W = img.size() 73 | grad_sum = torch.zeros((B, C, H, W), device=self.device) 74 | for sample in range(samples): 75 | noise = torch.empty(img.size()).normal_(0, std).to(self.device) 76 | noise_image = img + noise 77 | grad_sum += process(self.get_mask(noise_image, target_class, max_sigma, steps, grad_step, sqrt, batch_size)) 78 | return grad_sum / samples 79 | -------------------------------------------------------------------------------- /attribution/eigen_cam.py: -------------------------------------------------------------------------------- 1 | from .utils import get_2d_projection 2 | from .base import CAMWrapper 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | from typing import List 7 | 8 | class EigenCAM(CAMWrapper): 9 | def get_mask(self, img: torch.Tensor, 10 | target_class: torch.Tensor, 11 | target_layer: str): 12 | B, C, H, W = img.size() 13 | self.model.eval() 14 | self.model.zero_grad() 15 | 16 | # class-specific backpropagation 17 | logits = self.model(img) 18 | target = self._encode_one_hot(target_class, logits) 19 | self.model.zero_grad() 20 | logits.backward(gradient=target, retain_graph=True) 21 | 22 | # get feature maps and gradients 23 | feature_maps = self._find(self.feature_maps, target_layer) 24 | 25 | # generate CAM 26 | with torch.no_grad(): 27 | cam = get_2d_projection(feature_maps.cpu().numpy()) 28 | cam = torch.from_numpy(cam).to(self.device)[:, None, :, :] 29 | cam = F.relu(cam) 30 | cam = F.interpolate(cam, (H, W), mode='bilinear', align_corners=False) 31 | cam = self.normalize_cam(cam) 32 | 33 | return cam -------------------------------------------------------------------------------- /attribution/eigen_gradcam.py: -------------------------------------------------------------------------------- 1 | from .utils import get_2d_projection 2 | from .base import CAMWrapper 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | from typing import List 7 | 8 | class EigenGradCAM(CAMWrapper): 9 | def get_mask(self, img: torch.Tensor, 10 | target_class: torch.Tensor, 11 | target_layer: str): 12 | B, C, H, W = img.size() 13 | self.model.eval() 14 | self.model.zero_grad() 15 | 16 | # class-specific backpropagation 17 | logits = self.model(img) 18 | target = self._encode_one_hot(target_class, logits) 19 | self.model.zero_grad() 20 | logits.backward(gradient=target, retain_graph=True) 21 | 22 | # get feature maps and gradients 23 | feature_maps = self._find(self.feature_maps, target_layer) 24 | gradients = self._find(self.gradients, target_layer) 25 | 26 | # generate CAM 27 | with torch.no_grad(): 28 | cam = get_2d_projection((gradients * feature_maps).cpu().numpy()) 29 | cam = torch.from_numpy(cam).to(self.device)[:, None, :, :] 30 | cam = F.relu(cam) 31 | cam = F.interpolate(cam, (H, W), mode='bilinear', align_corners=False) 32 | cam = self.normalize_cam(cam) 33 | 34 | return cam -------------------------------------------------------------------------------- /attribution/fullgrad_cam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.nn import functional as F 4 | from typing import Optional 5 | 6 | from .base import CAMWrapper 7 | from .utils import find_layer_predicate_recursive 8 | 9 | class FullGrad(CAMWrapper): 10 | def __init__(self, model: torch.nn.Module, reshape_transform=None): 11 | if reshape_transform is not None: 12 | print('Warning: FullGrad may not work properly with ViT and Swin Transformer models.') 13 | super().__init__(model, reshape_transform) 14 | 15 | self.target_layers = find_layer_predicate_recursive(self.model, self.layer_with_2D_bias) 16 | self.bias_data = [self.get_bias_data(layer) for layer in self.target_layers] 17 | 18 | def layer_with_2D_bias(self, layer): 19 | bias_target_layers = [torch.nn.Conv2d, torch.nn.BatchNorm2d] 20 | if type(layer) in bias_target_layers and layer.bias is not None: 21 | return True 22 | return False 23 | 24 | def get_bias_data(self, layer): 25 | if isinstance(layer, torch.nn.BatchNorm2d): 26 | bias = - (layer.running_mean * layer.weight 27 | / torch.sqrt(layer.running_var + layer.eps)) + layer.bias 28 | return bias.data 29 | else: 30 | return layer.bias.data 31 | 32 | def get_mask(self, img: torch.Tensor, 33 | target_class: torch.Tensor, 34 | target_layer: Optional[str] = None): 35 | if target_layer is not None: 36 | print('Warning: target_layer is ignored in FullGrad. All bias layers will be used instead') 37 | 38 | B, C, H, W = img.size() 39 | img = torch.autograd.Variable(img, requires_grad=True) 40 | self.model.eval() 41 | self.model.zero_grad() 42 | 43 | # class-specific backpropagation 44 | logits = self.model(img) 45 | target = self._encode_one_hot(target_class, logits) 46 | self.model.zero_grad() 47 | logits.backward(gradient=target, retain_graph=True) 48 | 49 | target_layer_names = [] 50 | for name, layer in self.model.named_modules(): 51 | if layer in self.target_layers: 52 | target_layer_names.append(name) 53 | gradients_list = [self._find(self.gradients, layer_name) for layer_name in target_layer_names] 54 | 55 | input_gradients = img.grad.detach() 56 | cam_per_target_layer = [] 57 | 58 | with torch.no_grad(): 59 | gradient_multiplied_input = input_gradients * img 60 | gradient_multiplied_input = torch.abs(gradient_multiplied_input) 61 | gradient_multiplied_input = self.scale_input_across_batch_and_channels(gradient_multiplied_input) 62 | 63 | cam_per_target_layer.append(gradient_multiplied_input) 64 | assert len(gradients_list) == len(self.bias_data) 65 | for bias, gradients in zip(self.bias_data, gradients_list): 66 | bias = bias[None, :, None, None] 67 | bias_grad = torch.abs(bias * gradients) 68 | bias_grad = self.scale_input_across_batch_and_channels(bias_grad, (H, W)) 69 | bias_grad = bias_grad.sum(dim=1, keepdim=True) 70 | cam_per_target_layer.append(bias_grad) 71 | cam_per_target_layer = torch.cat(cam_per_target_layer, dim=1) 72 | cam = cam_per_target_layer.sum(dim=1, keepdim=True) 73 | cam = F.relu(cam) 74 | cam = F.interpolate(cam, (H, W), mode='bilinear', align_corners=False) 75 | cam = self.normalize_cam(cam) 76 | 77 | return cam 78 | 79 | @torch.no_grad() 80 | def scale_input_across_batch_and_channels(self, input_tensor: torch.Tensor, target_size=None): 81 | # target_size should be like (H, W) 82 | B, C, H, W = input_tensor.size() 83 | input_tensor = input_tensor.view(B, C, -1) 84 | input_tensor -= input_tensor.min(dim=2, keepdim=True)[0] 85 | input_tensor /= (input_tensor.max(dim=2, keepdim=True)[0] + 1e-7) 86 | input_tensor = input_tensor.view(B, C, H, W) 87 | if target_size is not None: 88 | input_tensor = F.interpolate(input_tensor, target_size, mode='bilinear', align_corners=False) 89 | return input_tensor 90 | -------------------------------------------------------------------------------- /attribution/grad_cam.py: -------------------------------------------------------------------------------- 1 | from .base import CAMWrapper 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | class GradCAM(CAMWrapper): 7 | def get_mask(self, img: torch.Tensor, 8 | target_class: torch.Tensor, 9 | target_layer: str): 10 | 11 | B, C, H, W = img.size() 12 | self.model.eval() 13 | self.model.zero_grad() 14 | 15 | # class-specific backpropagation 16 | logits = self.model(img) 17 | target = self._encode_one_hot(target_class, logits) 18 | self.model.zero_grad() 19 | logits.backward(gradient=target, retain_graph=True) 20 | 21 | # get feature maps and gradients 22 | feature_maps = self._find(self.feature_maps, target_layer) 23 | gradients = self._find(self.gradients, target_layer) 24 | 25 | # generate CAM 26 | with torch.no_grad(): 27 | weights = F.adaptive_avg_pool2d(gradients, 1) 28 | cam = torch.mul(feature_maps, weights).sum(dim=1, keepdim=True) 29 | cam = F.relu(cam) 30 | cam = F.interpolate(cam, (H, W), mode='bilinear', align_corners=False) 31 | cam = self.normalize_cam(cam) 32 | 33 | return cam -------------------------------------------------------------------------------- /attribution/grad_cam_plusplus.py: -------------------------------------------------------------------------------- 1 | from .grad_cam import GradCAM 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | class GradCAMPlusPlus(GradCAM): 8 | def get_mask(self, img: torch.Tensor, 9 | target_class: torch.Tensor, 10 | target_layer: str): 11 | B, C, H, W = img.size() 12 | self.model.eval() 13 | self.model.zero_grad() 14 | 15 | # class-specific backpropagation 16 | logits = self.model(img) 17 | target = self._encode_one_hot(target_class, logits) 18 | self.model.zero_grad() 19 | logits.backward(gradient=target, retain_graph=True) 20 | 21 | # get feature maps and gradients 22 | feature_maps = self._find(self.feature_maps, target_layer) 23 | gradients = self._find(self.gradients, target_layer) 24 | 25 | # generate CAM 26 | with torch.no_grad(): 27 | gradients_power_2 = gradients**2 28 | gradients_power_3 = gradients_power_2 * gradients 29 | sum_feature_maps = torch.sum(feature_maps, dim=(2, 3)) 30 | sum_feature_maps = sum_feature_maps[:, :, None, None] 31 | eps = 1e-6 32 | aij = gradients_power_2 / (2 * gradients_power_2 + sum_feature_maps * gradients_power_3 + eps) 33 | aij = torch.where(gradients != 0, aij, 0) 34 | weights = torch.maximum(gradients, torch.tensor(0, device=self.device)) * aij 35 | weights = torch.sum(weights, dim=(2, 3)) 36 | weights = weights[:, :, None, None] 37 | 38 | cam = torch.mul(feature_maps, weights).sum(dim=1, keepdim=True) 39 | cam = F.relu(cam) 40 | cam = F.interpolate(cam, (H, W), mode='bilinear', align_corners=False) 41 | cam = self.normalize_cam(cam) 42 | 43 | return cam -------------------------------------------------------------------------------- /attribution/guided_backprop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Function 4 | 5 | from .base import VanillaGradient 6 | from .utils import replace_all_layer_type_recursive 7 | 8 | 9 | class GuidedBackpropReLU(Function): 10 | @staticmethod 11 | def forward(self, input_img): 12 | positive_mask = (input_img > 0).type_as(input_img) 13 | output = torch.addcmul( 14 | torch.zeros( 15 | input_img.size()).type_as(input_img), 16 | input_img, 17 | positive_mask) 18 | self.save_for_backward(input_img, output) 19 | return output 20 | 21 | @staticmethod 22 | def backward(self, grad_output): 23 | input_img, output = self.saved_tensors 24 | grad_input = None 25 | 26 | positive_mask_1 = (input_img > 0).type_as(grad_output) 27 | positive_mask_2 = (grad_output > 0).type_as(grad_output) 28 | grad_input = torch.addcmul( 29 | torch.zeros( 30 | input_img.size()).type_as(input_img), 31 | torch.addcmul( 32 | torch.zeros( 33 | input_img.size()).type_as(input_img), 34 | grad_output, 35 | positive_mask_1), 36 | positive_mask_2) 37 | return grad_input 38 | 39 | 40 | class GuidedBackpropReLUasModule(torch.nn.Module): 41 | def __init__(self): 42 | super(GuidedBackpropReLUasModule, self).__init__() 43 | 44 | def forward(self, input_img): 45 | return GuidedBackpropReLU.apply(input_img) 46 | 47 | 48 | class GuidedBackProp(VanillaGradient): 49 | def get_mask(self, img: torch.Tensor, target_class: torch.Tensor) -> torch.Tensor: 50 | 51 | replace_all_layer_type_recursive(self.model, torch.nn.ReLU, GuidedBackpropReLUasModule()) 52 | 53 | grads = super().get_mask(img, target_class) 54 | 55 | replace_all_layer_type_recursive(self.model, GuidedBackpropReLUasModule, torch.nn.ReLU()) 56 | 57 | return grads -------------------------------------------------------------------------------- /attribution/guided_gradcam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | from .base import CombinedWrapper 5 | from .grad_cam import GradCAM 6 | from .guided_backprop import GuidedBackProp 7 | 8 | class GuidedGradCAM(CombinedWrapper): 9 | def __init__(self, model: torch.nn.Module, reshape_transform=None): 10 | gradient_net = GuidedBackProp 11 | cam_net = GradCAM 12 | super(GuidedGradCAM, self).__init__(model, gradient_net, cam_net, reshape_transform) -------------------------------------------------------------------------------- /attribution/guided_ig.py: -------------------------------------------------------------------------------- 1 | from .base import VanillaGradient 2 | 3 | import torch 4 | import numpy as np 5 | import math 6 | 7 | eps = 1e-9 8 | 9 | @torch.no_grad() 10 | def l1_distance(x, y): 11 | return torch.abs(x - y).sum() 12 | 13 | @torch.no_grad() 14 | def translate_x_to_alpha(x: torch.Tensor, x_input: torch.Tensor, x_baseline: torch.Tensor): 15 | '''Alpha shows the relative position of x in the straight line between x_baseline and x_input''' 16 | alpha = torch.where(x_input-x_baseline != 0, 17 | (x - x_baseline) / (x_input - x_baseline + eps), torch.tensor(np.nan).to(x.device)) 18 | return alpha 19 | 20 | @torch.no_grad() 21 | def translate_alpha_to_x(alpha: torch.Tensor, x_input: torch.Tensor, x_baseline: torch.Tensor): 22 | '''Translate alpha to x''' 23 | assert 0.0 <= alpha <= 1.0, 'Alpha should be in the range [0, 1]' 24 | x = x_baseline + alpha * (x_input - x_baseline) 25 | return x 26 | 27 | class GuidedIG(VanillaGradient): 28 | def get_mask(self, img: torch.Tensor, 29 | target_class: torch.Tensor, 30 | baseline='None', 31 | steps=128, 32 | fraction=0.25, 33 | max_dist=0.02): 34 | 35 | self.model.eval() 36 | self.model.zero_grad() 37 | 38 | if baseline is None or baseline == 'None': 39 | baseline = torch.zeros_like(img, device=self.device) 40 | elif baseline == 'black': 41 | baseline = torch.ones_like(img, device=self.device) * torch.min(img).detach() 42 | elif baseline == 'white': 43 | baseline = torch.ones_like(img, device=self.device) * torch.max(img).detach() 44 | else: 45 | raise ValueError(f'Baseline {baseline} is not supported, use "black", "white" or None') 46 | 47 | return self.guided_ig_impl(img, target_class, baseline, steps, fraction, max_dist) 48 | 49 | def get_smoothed_mask(self, img: torch.Tensor, 50 | target_class: torch.Tensor, 51 | baseline = 'None', 52 | steps: int = 128, 53 | fraction: float = 0.25, 54 | max_dist: float = 0.02, 55 | samples: int = 25, 56 | std: float = 0.15, 57 | process=lambda x: x**2): 58 | std = std * (torch.max(img) - torch.min(img)).detach().cpu().numpy() 59 | 60 | B, C, H, W = img.size() 61 | grad_sum = torch.zeros((B, C, H, W), device=self.device) 62 | for sample in range(samples): 63 | noise = torch.empty(img.size()).normal_(0, std).to(self.device) 64 | noise_image = img + noise 65 | grad_sum += process(self.get_mask(noise_image, target_class, baseline, steps, fraction, max_dist)) 66 | return grad_sum / samples 67 | 68 | def guided_ig_impl(self, img: torch.Tensor, 69 | target_class: torch.Tensor, 70 | baseline: torch.Tensor, 71 | steps: int, 72 | fraction: float, 73 | max_dist: float): 74 | guided_ig = torch.zeros_like(img, dtype=torch.float64, device=self.device) 75 | 76 | x = baseline.clone() 77 | 78 | l1_total = l1_distance(baseline, img) 79 | if l1_total.sum() < eps: 80 | return guided_ig 81 | 82 | for step in range(steps): 83 | gradients_actual = super().get_mask(x, target_class) 84 | gradients = gradients_actual.clone().detach() 85 | alpha = (step + 1.0) / steps 86 | alpha_min = max(alpha - max_dist, 0.0) 87 | alpha_max = min(alpha + max_dist, 1.0) 88 | x_min = translate_alpha_to_x(alpha_min, img, baseline).detach() 89 | x_max = translate_alpha_to_x(alpha_max, img, baseline).detach() 90 | 91 | with torch.no_grad(): 92 | l1_target = l1_total * (1 - (step + 1) / steps) 93 | 94 | # Iterate until the desired L1 distance has been reached. 95 | gamma = torch.tensor(np.inf) 96 | while gamma > 1.0: 97 | x_old = x.clone().detach() 98 | x_alpha = translate_x_to_alpha(x, img, baseline).detach() 99 | x_alpha[torch.isnan(x_alpha)] = alpha_max 100 | x.requires_grad = False 101 | x[x_alpha < alpha_min] = x_min[x_alpha < alpha_min] 102 | 103 | l1_current = l1_distance(x, img) 104 | 105 | if math.isclose(l1_target, l1_current, rel_tol=eps, abs_tol=eps): 106 | with torch.no_grad(): 107 | guided_ig += gradients_actual * (x - x_old) 108 | break 109 | 110 | gradients[x == x_max] = torch.tensor(np.inf) 111 | 112 | threshold = torch.quantile(torch.abs(gradients), fraction, dim=None, keepdim=False, interpolation='lower') 113 | s = torch.logical_and(torch.abs(gradients) <= threshold, gradients != torch.tensor(np.inf)) 114 | 115 | with torch.no_grad(): 116 | l1_s = (torch.abs(x - x_max) * s).sum() 117 | 118 | if l1_s > 0: 119 | gamma = (l1_current - l1_target) / l1_s 120 | else: 121 | gamma = torch.tensor(np.inf) 122 | 123 | if gamma > 1.0: 124 | x[s] = x_max[s] 125 | else: 126 | assert gamma >= 0.0, f'Gamma should be non-negative, but got {gamma.min()}' 127 | x[s] = translate_alpha_to_x(gamma, x_max, x)[s] 128 | 129 | with torch.no_grad(): 130 | guided_ig += gradients_actual * (x - x_old) 131 | 132 | return guided_ig 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /attribution/hires_cam.py: -------------------------------------------------------------------------------- 1 | from .base import CAMWrapper 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | from typing import List 6 | 7 | class HiResCAM(CAMWrapper): 8 | def get_mask(self, img: torch.Tensor, 9 | target_class: torch.Tensor, 10 | target_layer: str): 11 | B, C, H, W = img.size() 12 | self.model.eval() 13 | self.model.zero_grad() 14 | 15 | # class-specific backpropagation 16 | logits = self.model(img) 17 | target = self._encode_one_hot(target_class, logits) 18 | self.model.zero_grad() 19 | logits.backward(gradient=target, retain_graph=True) 20 | 21 | # get feature maps and gradients 22 | feature_maps = self._find(self.feature_maps, target_layer) 23 | gradients = self._find(self.gradients, target_layer) 24 | 25 | # generate CAM 26 | with torch.no_grad(): 27 | cam = gradients * feature_maps 28 | cam = cam.sum(dim=1, keepdim=True) 29 | cam = F.relu(cam) 30 | cam = F.interpolate(cam, (H, W), mode='bilinear', align_corners=False) 31 | cam = self.normalize_cam(cam) 32 | 33 | return cam -------------------------------------------------------------------------------- /attribution/ig.py: -------------------------------------------------------------------------------- 1 | from .base import VanillaGradient 2 | 3 | import torch 4 | import numpy as np 5 | 6 | class IntegratedGradients(VanillaGradient): 7 | def get_mask(self, img: torch.Tensor, 8 | target_class: torch.Tensor, 9 | baseline='black', 10 | steps=128, 11 | process=lambda x: x): 12 | if baseline == 'black': 13 | baseline = torch.ones_like(img, device=self.device) * torch.min(img).detach() 14 | elif baseline == 'white': 15 | baseline = torch.ones_like(img, device=self.device) * torch.max(img).detach() 16 | else: 17 | baseline = torch.zeros_like(img, device=self.device) 18 | 19 | B, C, H, W = img.size() 20 | grad_sum = torch.zeros((B, C, H, W), device=self.device) 21 | image_diff = img - baseline 22 | 23 | for step, alpha in enumerate(np.linspace(0, 1, steps)): 24 | image_step = baseline + alpha * image_diff 25 | grad_sum += process(super(IntegratedGradients, 26 | self).get_mask(image_step, target_class)) 27 | return grad_sum * image_diff.detach() / steps 28 | 29 | def get_smoothed_mask(self, img: torch.Tensor, 30 | target_class: torch.Tensor, 31 | baseline='black', 32 | steps=128, 33 | process_ig=lambda x: x, # used in self.get_mask 34 | samples=25, 35 | std=0.15, 36 | process=lambda x: x**2): 37 | std = std * (torch.max(img) - torch.min(img)).detach().cpu().numpy() 38 | 39 | B, C, H, W = img.size() 40 | grad_sum = torch.zeros((B, C, H, W), device=self.device) 41 | for sample in range(samples): 42 | noise = torch.empty(img.size()).normal_(0, std).to(self.device) 43 | noise_image = img + noise 44 | grad_sum += process(self.get_mask(noise_image, target_class, baseline, steps, process_ig)) 45 | return grad_sum / samples -------------------------------------------------------------------------------- /attribution/layer_cam.py: -------------------------------------------------------------------------------- 1 | from .base import CAMWrapper 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | class LayerCAM(CAMWrapper): 8 | def get_mask(self, img: torch.Tensor, 9 | target_class: torch.Tensor, 10 | target_layer: str): 11 | B, C, H, W = img.size() 12 | self.model.eval() 13 | self.model.zero_grad() 14 | 15 | # class-specific backpropagation 16 | logits = self.model(img) 17 | target = self._encode_one_hot(target_class, logits) 18 | self.model.zero_grad() 19 | logits.backward(gradient=target, retain_graph=True) 20 | 21 | # get feature maps and gradients 22 | feature_maps = self._find(self.feature_maps, target_layer) 23 | gradients = self._find(self.gradients, target_layer) 24 | 25 | # generate CAM 26 | with torch.no_grad(): 27 | spatial_weighted_feature_maps = torch.maximum(gradients, torch.tensor(0, device=self.device)) * feature_maps 28 | cam = spatial_weighted_feature_maps.sum(dim=1, keepdim=True) 29 | cam = F.relu(cam) 30 | cam = F.interpolate(cam, (H, W), mode='bilinear', align_corners=False) 31 | cam = self.normalize_cam(cam) 32 | 33 | return cam -------------------------------------------------------------------------------- /attribution/occlusion.py: -------------------------------------------------------------------------------- 1 | from .base import Core 2 | 3 | import torch 4 | 5 | 6 | class Occlusion(Core): 7 | 8 | @torch.no_grad() 9 | def get_mask(self, img, target_class, size=15, value=0.0): 10 | ''' 11 | size: height and width of the occlusion window. 12 | value: value to replace values inside the occlusion window with. 13 | ''' 14 | B, C, H, W = img.size() 15 | occlusion_scores = torch.zeros_like(img, device=self.device) 16 | occlusion_window = torch.fill_(torch.zeros((B, C, size, size), device=self.device), value) 17 | 18 | original_output = self.model(img) 19 | for row in range(1 + H - size): 20 | for col in range(1 + W - size): 21 | img_occluded = img.clone() 22 | img_occluded[:, :, row:row+size, col:col+size] = occlusion_window 23 | output = self.model(img_occluded) 24 | score_diff = original_output - output 25 | # the score_diff for the target class 26 | score_diff = score_diff[torch.arange(B), target_class] 27 | occlusion_scores[:, :, row:row+size, col:col+size] += score_diff[:, None, None, None] 28 | 29 | return occlusion_scores 30 | -------------------------------------------------------------------------------- /attribution/score_cam.py: -------------------------------------------------------------------------------- 1 | from .base import CAMWrapper 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | class ScoreCAM(CAMWrapper): 8 | def get_mask(self, img: torch.Tensor, 9 | target_class: torch.Tensor, 10 | target_layer: str, 11 | batch_size: int = 16,): 12 | B, C, H, W = img.size() 13 | self.model.eval() 14 | self.model.zero_grad() 15 | 16 | # class-specific backpropagation 17 | logits = self.model(img) 18 | target = self._encode_one_hot(target_class, logits) 19 | self.model.zero_grad() 20 | logits.backward(gradient=target, retain_graph=True) 21 | 22 | # get feature maps and gradients 23 | feature_maps = self._find(self.feature_maps, target_layer) 24 | 25 | with torch.no_grad(): 26 | upsample = torch.nn.UpsamplingBilinear2d(size=(H, W)) 27 | upsampled = upsample(feature_maps) 28 | maxs = upsampled.view(upsampled.size(0), upsampled.size(1), -1).max(dim=-1)[0] 29 | mins = upsampled.view(upsampled.size(0), upsampled.size(1), -1).min(dim=-1)[0] 30 | 31 | maxs, mins = maxs[:, :, None, None], mins[:, :, None, None] 32 | upsampled = (upsampled - mins) / (maxs - mins + 1e-8) 33 | 34 | imgs = img[:, None, :, :] * upsampled[:, :, None, :, :] 35 | 36 | get_targets = lambda o, target: o[target] 37 | if isinstance(target_class, torch.Tensor): 38 | target_class = target_class.cpu().tolist() 39 | if not isinstance(target_class, list): 40 | target_class = [target_class] 41 | 42 | scores = [] 43 | with torch.no_grad(): 44 | for i in range(imgs.size(0)): 45 | input_img = imgs[i] 46 | for batch_i in range(0, input_img.size(0), batch_size): 47 | batch = input_img[batch_i: batch_i + batch_size, :] 48 | outputs = [get_targets(o, target_class[i]).detach() for o in self.model(batch)] 49 | scores.extend(outputs) 50 | scores = torch.tensor(scores) 51 | scores = scores.view(feature_maps.shape[0], feature_maps.shape[1]) 52 | weights = F.softmax(scores, dim=-1) 53 | weights = weights.to(self.device) 54 | cam = torch.mul(feature_maps, weights[:, :, None, None]).sum(dim=1, keepdim=True) 55 | cam = F.relu(cam) 56 | cam = F.interpolate(cam, (H, W), mode='bilinear', align_corners=False) 57 | cam = self.normalize_cam(cam) 58 | 59 | return cam 60 | -------------------------------------------------------------------------------- /attribution/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .visualization_maps import normalize_saliency, visualize_single_saliency 2 | from .access_layers import find_layer_predicate_recursive, find_layer_types_recursive, replace_all_layer_type_recursive, replace_layer_recursive 3 | from .svd_on_feature_maps import get_2d_projection 4 | from .reshape_for_transformer import get_reshape_transform -------------------------------------------------------------------------------- /attribution/utils/access_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def replace_layer_recursive(model: torch.nn.Module, old_layer: torch.nn.Module, new_layer: torch.nn.Module): 4 | for name, layer in model._modules.items(): 5 | if layer == old_layer: 6 | model._modules[name] = new_layer 7 | return True 8 | elif replace_layer_recursive(layer, old_layer, new_layer): 9 | return True 10 | return False 11 | 12 | 13 | def replace_all_layer_type_recursive(model: torch.nn.Module, old_layer_type: torch.nn.Module, new_layer: torch.nn.Module): 14 | '''new_layer is a instance of the new layer type, not the type itself.''' 15 | for name, layer in model._modules.items(): 16 | if isinstance(layer, old_layer_type): 17 | model._modules[name] = new_layer 18 | replace_all_layer_type_recursive(layer, old_layer_type, new_layer) 19 | 20 | 21 | def find_layer_types_recursive(model: torch.nn.Module, layer_types: torch.nn.Module): 22 | def predicate(layer): 23 | return type(layer) in layer_types 24 | return find_layer_predicate_recursive(model, predicate) 25 | 26 | 27 | def find_layer_predicate_recursive(model: torch.nn.Module, predicate): 28 | result = [] 29 | for name, layer in model._modules.items(): 30 | if predicate(layer): 31 | result.append(layer) 32 | result.extend(find_layer_predicate_recursive(layer, predicate)) 33 | return result 34 | -------------------------------------------------------------------------------- /attribution/utils/reshape_for_transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | def get_reshape_transform(has_cls_token=False): 5 | def reshape_transform(x: torch.Tensor): 6 | if not has_cls_token: 7 | # typically is the output of the swin transformer 8 | x = x.permute(0, 3, 1, 2) 9 | return x 10 | # typically is the output of the ViT 11 | B, token_numbers, patch_size = x.size() 12 | if has_cls_token: 13 | token_numbers -= 1 14 | x = x[:, 1:, :] 15 | 16 | img_width = int(math.sqrt(token_numbers)) 17 | assert img_width * img_width == token_numbers 18 | x = x.view(B, img_width, img_width, patch_size) 19 | x = x.permute(0, 3, 1, 2) 20 | return x 21 | 22 | return reshape_transform -------------------------------------------------------------------------------- /attribution/utils/svd_on_feature_maps.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def get_2d_projection(activation_batch): 5 | # TODO: use pytorch batch svd implementation 6 | activation_batch[np.isnan(activation_batch)] = 0 7 | projections = [] 8 | for activations in activation_batch: 9 | reshaped_activations = (activations).reshape( 10 | activations.shape[0], -1).transpose() 11 | # Centering before the SVD seems to be important here, 12 | # Otherwise the image returned is negative 13 | reshaped_activations = reshaped_activations - \ 14 | reshaped_activations.mean(axis=0) 15 | U, S, VT = np.linalg.svd(reshaped_activations, full_matrices=True) 16 | projection = reshaped_activations @ VT[0, :] 17 | projection = projection.reshape(activations.shape[1:]) 18 | projections.append(projection) 19 | return np.float32(projections) 20 | -------------------------------------------------------------------------------- /attribution/utils/visualization_maps.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | @torch.no_grad() 8 | def normalize_saliency(saliency_map, return_device=torch.device('cpu')): 9 | B, C, H, W = saliency_map.size() 10 | if C > 1: # the input image is multi-channel 11 | saliency_map = saliency_map.max(dim=1, keepdim=True)[0] 12 | saliency_map = F.relu(saliency_map, [0]) 13 | # the shape is B x 1 x H x W, normalize the saliency map along the channel dimension 14 | saliency_map = saliency_map.view(saliency_map.size(0), -1) 15 | saliency_map -= saliency_map.min(dim=1, keepdim=True)[0] 16 | saliency_map /= saliency_map.max(dim=1, keepdim=True)[0] 17 | saliency_map = saliency_map.view(B, 1, H, W) 18 | return saliency_map.to(return_device) 19 | 20 | def visualize_single_saliency(saliency_map, img_size=None): 21 | B, C, H, W = saliency_map.size() 22 | assert B == 1, 'The input saliency map should not be batch inputs' 23 | if saliency_map.max() > 1 or C > 1: 24 | saliency_map = normalize_saliency(saliency_map) 25 | saliency_map = saliency_map.view(H, W, 1) 26 | saliency_map = saliency_map.cpu().numpy() 27 | if img_size is not None: 28 | saliency_map = cv2.resize(saliency_map, (img_size[1], img_size[0]), interpolation=cv2.INTER_LINEAR) 29 | else: 30 | saliency_map = cv2.resize(saliency_map, (W, H), interpolation=cv2.INTER_LINEAR) 31 | saliency_map = cv2.applyColorMap(np.uint8(saliency_map * 255.0), cv2.COLORMAP_JET) 32 | saliency_map = cv2.cvtColor(saliency_map, cv2.COLOR_BGR2RGB) 33 | plt.axis('off') 34 | plt.imshow(saliency_map) 35 | return saliency_map -------------------------------------------------------------------------------- /attribution/xgrad_cam.py: -------------------------------------------------------------------------------- 1 | from .grad_cam import GradCAM 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | class XGradCAM(GradCAM): 8 | def get_mask(self, img: torch.Tensor, 9 | target_class: torch.Tensor, 10 | target_layer: str): 11 | 12 | B, C, H, W = img.size() 13 | self.model.eval() 14 | self.model.zero_grad() 15 | 16 | # class-specific backpropagation 17 | logits = self.model(img) 18 | target = self._encode_one_hot(target_class, logits) 19 | self.model.zero_grad() 20 | logits.backward(gradient=target, retain_graph=True) 21 | 22 | # get feature maps and gradients 23 | feature_maps = self._find(self.feature_maps, target_layer) 24 | gradients = self._find(self.gradients, target_layer) 25 | 26 | # generate CAM 27 | with torch.no_grad(): 28 | sum_feature_maps = torch.sum(feature_maps, dim=(2, 3)) 29 | sum_feature_maps = sum_feature_maps[:, :, None, None] 30 | eps = 1e-7 31 | weights = gradients * feature_maps / (sum_feature_maps + eps) 32 | weights = torch.sum(weights, dim=(2, 3)) 33 | weights = weights[:, :, None, None] 34 | cam = torch.mul(feature_maps, weights).sum(dim=1, keepdim=True) 35 | cam = F.relu(cam) 36 | cam = F.interpolate(cam, (H, W), mode='bilinear', align_corners=False) 37 | cam = self.normalize_cam(cam) 38 | 39 | return cam -------------------------------------------------------------------------------- /cam_visualization_examples.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import timm 5 | from timm.data import resolve_model_data_config 6 | from timm.data.transforms_factory import create_transform 7 | import requests 8 | 9 | from attribution import GradCAM, GradCAMPlusPlus, XGradCAM, BagCAM, ScoreCAM, LayerCAM, AblationCAM, FullGrad, EigenCAM, EigenGradCAM, HiResCAM 10 | from attribution.utils import normalize_saliency, visualize_single_saliency 11 | 12 | 13 | if __name__ == '__main__': 14 | 15 | # Load imagenet labels 16 | IMAGENET_1k_URL = 'https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt' 17 | IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n') 18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | 20 | # Load a pretrained model 21 | model = timm.create_model('resnet50', pretrained=True) 22 | model = model.to(device) 23 | model.eval() 24 | config = resolve_model_data_config(model, None) 25 | transform = create_transform(**config) 26 | 27 | # Load an image 28 | dog = Image.open('examples/dog.png').convert('RGB') 29 | dog_tensor = transform(dog).unsqueeze(0) 30 | H, W = dog_tensor.shape[-2:] 31 | 32 | # Predict the image 33 | img = transform(dog).unsqueeze(0) 34 | img = torch.cat([img, img]) 35 | img = img.to(device) 36 | output = model(img) 37 | target_index = torch.argmax(output, dim=1).cpu() 38 | print('Predicted:', IMAGENET_1k_LABELS[target_index[0].item()]) 39 | 40 | # Get the target layer 41 | target_layer_candidates = list() 42 | for name, module in model.named_modules(): 43 | print(name) 44 | target_layer_candidates.append(name) 45 | '''e.g., 46 | --ResNet: 47 | conv1 48 | bn1 49 | act1 50 | maxpool 51 | layer1 52 | layer2 53 | layer3 54 | layer4 55 | global_pool 56 | fc 57 | 58 | --Xception41: 59 | stem 60 | blocks 61 | act 62 | head 63 | ''' 64 | target_layer = input('Enter the target layer: ') 65 | while target_layer not in target_layer_candidates: 66 | print('Invalid layer name') 67 | target_layer = input('Enter the target layer: ') 68 | 69 | # GradCAM 70 | gradcam_net = GradCAM(model) 71 | grad_cam = normalize_saliency(gradcam_net.get_mask(img, target_index, target_layer)) 72 | print('GradCAM', grad_cam.shape) 73 | 74 | # GradCAM++ 75 | gradcam_plus_plus_net = GradCAMPlusPlus(model) 76 | grad_cam_plus_plus = normalize_saliency(gradcam_plus_plus_net.get_mask(img, target_index, target_layer)) 77 | print('GradCAM++:', grad_cam_plus_plus.shape) 78 | 79 | # HiResCAM 80 | hirescam_net = HiResCAM(model) 81 | hires_cam = normalize_saliency(hirescam_net.get_mask(img, target_index, target_layer)) 82 | print('HiResCAM:', hires_cam.shape) 83 | 84 | # XGradCAM 85 | xgradcam_net = XGradCAM(model) 86 | xgrad_cam = normalize_saliency(xgradcam_net.get_mask(img, target_index, target_layer)) 87 | print('XGradCAM:', xgrad_cam.shape) 88 | 89 | # BagCAM 90 | bagcam_net = BagCAM(model) 91 | bag_cam = normalize_saliency(bagcam_net.get_mask(img, target_index, target_layer)) 92 | print('BagCAM:', bag_cam.shape) 93 | 94 | # ScoreCAM 95 | scorecam_net = ScoreCAM(model) 96 | score_cam = normalize_saliency(scorecam_net.get_mask(img, target_index, target_layer)) 97 | print('ScoreCAM', score_cam.shape) 98 | 99 | # EigenCAM 100 | eigencam_net = EigenCAM(model) 101 | eigen_cam = normalize_saliency(eigencam_net.get_mask(img, target_index, target_layer)) 102 | print('EigenCAM', eigen_cam.shape) 103 | 104 | # EigenGradCAM 105 | eigengradcam_net = EigenGradCAM(model) 106 | eigen_grad_cam = normalize_saliency(eigengradcam_net.get_mask(img, target_index, target_layer)) 107 | print('EigenGradCAM', eigen_grad_cam.shape) 108 | 109 | # LayerCAM 110 | layercam_net = LayerCAM(model) 111 | layer_cam = normalize_saliency(layercam_net.get_mask(img, target_index, target_layer)) 112 | print('LayerCAM', layer_cam.shape) 113 | 114 | # AblationCAM 115 | ablationcam_net = AblationCAM(model) 116 | ablation_cam = normalize_saliency(ablationcam_net.get_mask(img, target_index, target_layer)) 117 | print('AblationCAM', ablation_cam.shape) 118 | 119 | # FullGrad 120 | fullgrad_net = FullGrad(model) 121 | full_grad = normalize_saliency(fullgrad_net.get_mask(img, target_index, target_layer=None)) 122 | print('FullGrad', full_grad.shape) 123 | 124 | # Visualize the saliency maps 125 | plt.figure(figsize=(16, 15)) 126 | plt.subplot(3,5,1) 127 | plt.title('Input') 128 | plt.axis('off') 129 | plt.imshow(dog) 130 | plt.subplot(3,5,2) 131 | plt.title('GradCAM') 132 | visualize_single_saliency(grad_cam[0].unsqueeze(0)) 133 | plt.subplot(3,5,3) 134 | plt.title('GradCAM++') 135 | visualize_single_saliency(grad_cam_plus_plus[0].unsqueeze(0)) 136 | plt.subplot(3,5,4) 137 | plt.title('HiResCAM') 138 | visualize_single_saliency(hires_cam[0].unsqueeze(0)) 139 | plt.subplot(3,5,5) 140 | plt.title('FullGrad') 141 | visualize_single_saliency(full_grad[0].unsqueeze(0)) 142 | plt.subplot(3,5,6) 143 | plt.title('AblationCAM') 144 | visualize_single_saliency(ablation_cam[0].unsqueeze(0)) 145 | plt.subplot(3,5,7) 146 | plt.title('ScoreCAM') 147 | visualize_single_saliency(score_cam[0].unsqueeze(0)) 148 | plt.subplot(3,5,8) 149 | plt.title('EigenCAM') 150 | visualize_single_saliency(eigen_cam[0].unsqueeze(0)) 151 | plt.subplot(3,5,9) 152 | plt.title('EigenGradCAM') 153 | visualize_single_saliency(eigen_grad_cam[0].unsqueeze(0)) 154 | plt.subplot(3,5,10) 155 | plt.title('XGradCAM') 156 | visualize_single_saliency(xgrad_cam[0].unsqueeze(0)) 157 | plt.subplot(3,5,11) 158 | plt.title('LayerCAM') 159 | visualize_single_saliency(layer_cam[0].unsqueeze(0)) 160 | plt.subplot(3,5,12) 161 | plt.title('BagCAM') 162 | visualize_single_saliency(bag_cam[0].unsqueeze(0)) 163 | 164 | 165 | plt.tight_layout() 166 | plt.savefig('examples/cam_visualization.png', bbox_inches='tight', pad_inches=0.5) 167 | 168 | 169 | 170 | -------------------------------------------------------------------------------- /cam_visualization_for_transformers_examples.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import timm 5 | from timm.data import resolve_model_data_config 6 | from timm.data.transforms_factory import create_transform 7 | import requests 8 | 9 | from attribution import GradCAM, GradCAMPlusPlus, XGradCAM, BagCAM, ScoreCAM, LayerCAM, AblationCAM, FullGrad, EigenCAM, EigenGradCAM, HiResCAM 10 | from attribution.utils import normalize_saliency, visualize_single_saliency, get_reshape_transform 11 | 12 | 13 | if __name__ == '__main__': 14 | 15 | # Load imagenet labels 16 | IMAGENET_1k_URL = 'https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt' 17 | IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n') 18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | 20 | # Load a pretrained model 21 | # model = timm.create_model('vit_tiny_patch16_224.augreg_in21k_ft_in1k', pretrained=True) 22 | model = timm.create_model('swin_base_patch4_window7_224.ms_in22k_ft_in1k', pretrained=True) 23 | model = model.to(device) 24 | model.eval() 25 | config = resolve_model_data_config(model, None) 26 | transform = create_transform(**config) 27 | 28 | # Load an image 29 | dog = Image.open('examples/dog.png').convert('RGB') 30 | dog_tensor = transform(dog).unsqueeze(0) 31 | H, W = dog_tensor.shape[-2:] 32 | 33 | # Predict the image 34 | img = transform(dog).unsqueeze(0) 35 | img = torch.cat([img, img]) 36 | img = img.to(device) 37 | output = model(img) 38 | target_index = torch.argmax(output, dim=1).cpu() 39 | print('Predicted:', IMAGENET_1k_LABELS[target_index[0].item()]) 40 | 41 | # Get the target layer 42 | target_layer_candidates = list() 43 | for name, module in model.named_modules(): 44 | print(name) 45 | target_layer_candidates.append(name) 46 | '''e.g., 47 | --ResNet: 48 | conv1 49 | bn1 50 | act1 51 | maxpool 52 | layer1 53 | layer2 54 | layer3 55 | layer4 56 | global_pool 57 | fc 58 | 59 | --Xception41: 60 | stem 61 | blocks 62 | act 63 | head 64 | ''' 65 | target_layer = input('Enter the target layer for swin_base_patch4_window7_224: ') 66 | while target_layer not in target_layer_candidates: 67 | print('Invalid layer name') 68 | target_layer = input('Enter the target layer: ') 69 | 70 | # GradCAM 71 | gradcam_net = GradCAM(model, get_reshape_transform(has_cls_token=False)) 72 | grad_cam = normalize_saliency(gradcam_net.get_mask(img, target_index, target_layer)) 73 | print('GradCAM', grad_cam.shape) 74 | 75 | # GradCAM++ 76 | gradcam_plus_plus_net = GradCAMPlusPlus(model, get_reshape_transform(has_cls_token=False)) 77 | grad_cam_plus_plus = normalize_saliency(gradcam_plus_plus_net.get_mask(img, target_index, target_layer)) 78 | print('GradCAM++:', grad_cam_plus_plus.shape) 79 | 80 | # HiResCAM 81 | hirescam_net = HiResCAM(model, get_reshape_transform(has_cls_token=False)) 82 | hires_cam = normalize_saliency(hirescam_net.get_mask(img, target_index, target_layer)) 83 | print('HiResCAM:', hires_cam.shape) 84 | 85 | # XGradCAM 86 | xgradcam_net = XGradCAM(model, get_reshape_transform(has_cls_token=False)) 87 | xgrad_cam = normalize_saliency(xgradcam_net.get_mask(img, target_index, target_layer)) 88 | print('XGradCAM:', xgrad_cam.shape) 89 | 90 | # LayerCAM 91 | layercam_net = LayerCAM(model, get_reshape_transform(has_cls_token=False)) 92 | layer_cam = normalize_saliency(layercam_net.get_mask(img, target_index, target_layer)) 93 | print('LayerCAM', layer_cam.shape) 94 | 95 | # Visualize the saliency maps 96 | plt.figure(figsize=(18, 10)) 97 | plt.subplot(2,6,1) 98 | plt.title('Input') 99 | plt.axis('off') 100 | plt.imshow(dog) 101 | plt.subplot(2,6,2) 102 | plt.title('Swin GradCAM') 103 | visualize_single_saliency(grad_cam[0].unsqueeze(0)) 104 | plt.subplot(2,6,3) 105 | plt.title('Swin GradCAM++') 106 | visualize_single_saliency(grad_cam_plus_plus[0].unsqueeze(0)) 107 | plt.subplot(2,6,4) 108 | plt.title('Swin HiResCAM') 109 | visualize_single_saliency(hires_cam[0].unsqueeze(0)) 110 | plt.subplot(2,6,5) 111 | plt.title('Swin XGradCAM') 112 | visualize_single_saliency(xgrad_cam[0].unsqueeze(0)) 113 | plt.subplot(2,6,6) 114 | plt.title('Swin LayerCAM') 115 | visualize_single_saliency(layer_cam[0].unsqueeze(0)) 116 | 117 | model = timm.create_model('vit_tiny_patch16_224.augreg_in21k_ft_in1k', pretrained=True) 118 | model = model.to(device) 119 | model.eval() 120 | config = resolve_model_data_config(model, None) 121 | transform = create_transform(**config) 122 | 123 | # Load an image 124 | dog = Image.open('examples/dog.png').convert('RGB') 125 | dog_tensor = transform(dog).unsqueeze(0) 126 | H, W = dog_tensor.shape[-2:] 127 | 128 | # Predict the image 129 | img = transform(dog).unsqueeze(0) 130 | img = torch.cat([img, img]) 131 | img = img.to(device) 132 | output = model(img) 133 | target_index = torch.argmax(output, dim=1).cpu() 134 | print('Predicted:', IMAGENET_1k_LABELS[target_index[0].item()]) 135 | 136 | # Get the target layer 137 | target_layer_candidates = list() 138 | for name, module in model.named_modules(): 139 | print(name) 140 | target_layer_candidates.append(name) 141 | '''e.g., 142 | --ResNet: 143 | conv1 144 | bn1 145 | act1 146 | maxpool 147 | layer1 148 | layer2 149 | layer3 150 | layer4 151 | global_pool 152 | fc 153 | 154 | --Xception41: 155 | stem 156 | blocks 157 | act 158 | head 159 | ''' 160 | target_layer = input('Enter the target layer for swin_base_patch4_window7_224: ') 161 | while target_layer not in target_layer_candidates: 162 | print('Invalid layer name') 163 | target_layer = input('Enter the target layer: ') 164 | 165 | # GradCAM 166 | gradcam_net = GradCAM(model, get_reshape_transform(has_cls_token=True)) 167 | grad_cam = normalize_saliency(gradcam_net.get_mask(img, target_index, target_layer)) 168 | print('GradCAM', grad_cam.shape) 169 | 170 | # GradCAM++ 171 | gradcam_plus_plus_net = GradCAMPlusPlus(model, get_reshape_transform(has_cls_token=True)) 172 | grad_cam_plus_plus = normalize_saliency(gradcam_plus_plus_net.get_mask(img, target_index, target_layer)) 173 | print('GradCAM++:', grad_cam_plus_plus.shape) 174 | 175 | # HiResCAM 176 | hirescam_net = HiResCAM(model, get_reshape_transform(has_cls_token=True)) 177 | hires_cam = normalize_saliency(hirescam_net.get_mask(img, target_index, target_layer)) 178 | print('HiResCAM:', hires_cam.shape) 179 | 180 | # XGradCAM 181 | xgradcam_net = XGradCAM(model, get_reshape_transform(has_cls_token=True)) 182 | xgrad_cam = normalize_saliency(xgradcam_net.get_mask(img, target_index, target_layer)) 183 | print('XGradCAM:', xgrad_cam.shape) 184 | 185 | # LayerCAM 186 | layercam_net = LayerCAM(model, get_reshape_transform(has_cls_token=True)) 187 | layer_cam = normalize_saliency(layercam_net.get_mask(img, target_index, target_layer)) 188 | print('LayerCAM', layer_cam.shape) 189 | 190 | plt.subplot(2,6,8) 191 | plt.title('ViT GradCAM') 192 | visualize_single_saliency(grad_cam[0].unsqueeze(0)) 193 | plt.subplot(2,6,9) 194 | plt.title('ViT GradCAM++') 195 | visualize_single_saliency(grad_cam_plus_plus[0].unsqueeze(0)) 196 | plt.subplot(2,6,10) 197 | plt.title('ViT HiResCAM') 198 | visualize_single_saliency(hires_cam[0].unsqueeze(0)) 199 | plt.subplot(2,6,11) 200 | plt.title('ViT XGradCAM') 201 | visualize_single_saliency(xgrad_cam[0].unsqueeze(0)) 202 | plt.subplot(2,6,12) 203 | plt.title('ViT LayerCAM') 204 | visualize_single_saliency(layer_cam[0].unsqueeze(0)) 205 | 206 | plt.tight_layout() 207 | plt.savefig('examples/cam_visualization_for_transformers.png', bbox_inches='tight', pad_inches=0.5) 208 | 209 | 210 | 211 | -------------------------------------------------------------------------------- /combine_cam_and_gradients_visualization_examples.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import timm 5 | from timm.data import resolve_model_data_config 6 | from timm.data.transforms_factory import create_transform 7 | import requests 8 | 9 | from attribution import GradCAM, GradCAMPlusPlus, XGradCAM, BagCAM, ScoreCAM, LayerCAM, AblationCAM, FullGrad, EigenCAM, EigenGradCAM, HiResCAM 10 | from attribution import VanillaGradient, GuidedBackProp, IntegratedGradients, BlurIG, GuidedIG 11 | from attribution import CombinedWrapper, GuidedGradCAM 12 | from attribution.utils import normalize_saliency, visualize_single_saliency 13 | 14 | 15 | if __name__ == '__main__': 16 | 17 | # Load imagenet labels 18 | IMAGENET_1k_URL = 'https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt' 19 | IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n') 20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | 22 | # Load a pretrained model 23 | model = timm.create_model('resnet50', pretrained=True) 24 | model = model.to(device) 25 | model.eval() 26 | config = resolve_model_data_config(model, None) 27 | transform = create_transform(**config) 28 | 29 | # Load an image 30 | dog = Image.open('examples/dog.png').convert('RGB') 31 | dog_tensor = transform(dog).unsqueeze(0) 32 | H, W = dog_tensor.shape[-2:] 33 | 34 | # Predict the image 35 | img = transform(dog).unsqueeze(0) 36 | img = torch.cat([img, img]) 37 | img = img.to(device) 38 | output = model(img) 39 | target_index = torch.argmax(output, dim=1).cpu() 40 | print('Predicted:', IMAGENET_1k_LABELS[target_index[0].item()]) 41 | 42 | # Get the target layer 43 | target_layer_candidates = list() 44 | for name, module in model.named_modules(): 45 | print(name) 46 | target_layer_candidates.append(name) 47 | '''e.g., 48 | --ResNet: 49 | conv1 50 | bn1 51 | act1 52 | maxpool 53 | layer1 54 | layer2 55 | layer3 56 | layer4 57 | global_pool 58 | fc 59 | 60 | --Xception41: 61 | stem 62 | blocks 63 | act 64 | head 65 | ''' 66 | target_layer = input('Enter the target layer: ') 67 | while target_layer not in target_layer_candidates: 68 | print('Invalid layer name') 69 | target_layer = input('Enter the target layer: ') 70 | 71 | # Guided Grad-CAM 72 | net = GuidedGradCAM(model) 73 | guided_gradcam = normalize_saliency(net.get_mask(img, target_index, target_layer)) 74 | print('Guided Grad-CAM', guided_gradcam.shape) 75 | 76 | # GuidedBackprop + FullGrad 77 | net = CombinedWrapper(model, GuidedBackProp, FullGrad) 78 | guided_fullgrad = normalize_saliency(net.get_mask(img, target_index, target_layer)) 79 | print('GuidedBackProp + FullGrad', guided_fullgrad.shape) 80 | 81 | # BlurIG + EigenCAM 82 | net = CombinedWrapper(model, BlurIG, EigenCAM) 83 | kwargs = {'steps': 20} 84 | blurig_eigencam = normalize_saliency(net.get_mask(img, target_index, target_layer, **kwargs)) 85 | print('BlurIG + EigenCAM', blurig_eigencam.shape) 86 | 87 | # Visualize the results 88 | plt.figure(figsize=(16, 5)) 89 | plt.subplot(1, 4, 1) 90 | plt.title('Input') 91 | plt.axis('off') 92 | plt.imshow(dog) 93 | plt.subplot(1, 4, 2) 94 | plt.title('Guided Grad-CAM') 95 | visualize_single_saliency(guided_gradcam[0].unsqueeze(0)) 96 | plt.subplot(1, 4, 3) 97 | plt.title('GuidedBackProp + FullGrad') 98 | visualize_single_saliency(guided_fullgrad[0].unsqueeze(0)) 99 | plt.subplot(1, 4, 4) 100 | plt.title('BlurIG + EigenCAM') 101 | visualize_single_saliency(blurig_eigencam[0].unsqueeze(0)) 102 | 103 | plt.tight_layout() 104 | plt.savefig('examples/combine_cam_and_gradients_visualization.png', bbox_inches='tight', pad_inches=0.5) -------------------------------------------------------------------------------- /examples/attribution_methods.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riverback/pytorch_attribution/a27a05b308b736e38fb6c96b2b3ddce67d187ac3/examples/attribution_methods.png -------------------------------------------------------------------------------- /examples/cam_visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riverback/pytorch_attribution/a27a05b308b736e38fb6c96b2b3ddce67d187ac3/examples/cam_visualization.png -------------------------------------------------------------------------------- /examples/cam_visualization_for_transformers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riverback/pytorch_attribution/a27a05b308b736e38fb6c96b2b3ddce67d187ac3/examples/cam_visualization_for_transformers.png -------------------------------------------------------------------------------- /examples/cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riverback/pytorch_attribution/a27a05b308b736e38fb6c96b2b3ddce67d187ac3/examples/cat.png -------------------------------------------------------------------------------- /examples/combine_cam_and_gradients_visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riverback/pytorch_attribution/a27a05b308b736e38fb6c96b2b3ddce67d187ac3/examples/combine_cam_and_gradients_visualization.png -------------------------------------------------------------------------------- /examples/dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riverback/pytorch_attribution/a27a05b308b736e38fb6c96b2b3ddce67d187ac3/examples/dog.png -------------------------------------------------------------------------------- /examples/dog_and_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riverback/pytorch_attribution/a27a05b308b736e38fb6c96b2b3ddce67d187ac3/examples/dog_and_cat.png -------------------------------------------------------------------------------- /examples/gradients_visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riverback/pytorch_attribution/a27a05b308b736e38fb6c96b2b3ddce67d187ac3/examples/gradients_visualization.png -------------------------------------------------------------------------------- /examples/gradients_visualization_for_transformers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riverback/pytorch_attribution/a27a05b308b736e38fb6c96b2b3ddce67d187ac3/examples/gradients_visualization_for_transformers.png -------------------------------------------------------------------------------- /examples/perturbation_based_visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riverback/pytorch_attribution/a27a05b308b736e38fb6c96b2b3ddce67d187ac3/examples/perturbation_based_visualization.png -------------------------------------------------------------------------------- /examples/quick_start.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riverback/pytorch_attribution/a27a05b308b736e38fb6c96b2b3ddce67d187ac3/examples/quick_start.png -------------------------------------------------------------------------------- /gradients_visualization_examples.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import timm 5 | from timm.data import resolve_model_data_config 6 | from timm.data.transforms_factory import create_transform 7 | import requests 8 | 9 | 10 | from attribution import VanillaGradient, IntegratedGradients, BlurIG, GuidedIG, GuidedBackProp 11 | from attribution.utils import normalize_saliency, visualize_single_saliency 12 | 13 | 14 | if __name__ == '__main__': 15 | 16 | # Load imagenet labels 17 | IMAGENET_1k_URL = 'https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt' 18 | IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n') 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | 21 | # Load a pretrained model 22 | model = timm.create_model('resnet50', pretrained=True) 23 | model = model.to(device) 24 | model.eval() 25 | config = resolve_model_data_config(model, None) 26 | transform = create_transform(**config) 27 | 28 | # Load an image 29 | dog = Image.open('examples/dog.png').convert('RGB') 30 | dog_tensor = transform(dog).unsqueeze(0) 31 | H, W = dog_tensor.shape[-2:] 32 | 33 | # Predict the image 34 | img = transform(dog).unsqueeze(0) 35 | img = torch.cat([img, img]) 36 | img = img.to(device) 37 | output = model(img) 38 | target_index = torch.argmax(output, dim=1).cpu() 39 | print('Predicted:', IMAGENET_1k_LABELS[target_index[0].item()]) 40 | 41 | # Vanilla Gradient 42 | gradient_net = VanillaGradient(model) 43 | attribution_gradients = normalize_saliency(gradient_net.get_mask(img, target_index)) 44 | attribution_smooth_gradients = normalize_saliency(gradient_net.get_smoothed_mask(img, target_index, samples=10, std=0.1)) 45 | 46 | # Guided Backpropagation 47 | guided_bp_net = GuidedBackProp(model) 48 | attribution_guided_bp = normalize_saliency(guided_bp_net.get_mask(img, target_index)) 49 | 50 | # Integrated Gradients 51 | ig_net = IntegratedGradients(model) 52 | attribution_ig = normalize_saliency(ig_net.get_mask(img, target_index, steps=100)) 53 | attribution_smooth_ig = normalize_saliency(ig_net.get_smoothed_mask(img, target_index, steps=100, std=0.15, samples=10)) 54 | # Blur Integrated Gradients 55 | blur_ig_net = BlurIG(model) 56 | attribution_blur_ig = normalize_saliency(blur_ig_net.get_mask(img, target_index, steps=100)) 57 | attribution_smooth_blur_ig = normalize_saliency(blur_ig_net.get_smoothed_mask(img, target_index, steps=100, std=0.15, samples=10)) 58 | # Guided Integrated Gradients 59 | guided_ig_net = GuidedIG(model) 60 | attribution_guided_ig = normalize_saliency(guided_ig_net.get_mask(img, target_index, steps=100)) 61 | attribution_smooth_guided_ig = normalize_saliency(guided_ig_net.get_smoothed_mask(img, target_index, steps=100, std=0.15, samples=10)) 62 | 63 | # Visualize the results 64 | plt.figure(figsize=(16, 10)) 65 | plt.subplot(2, 5, 1) 66 | plt.title('Input') 67 | plt.axis('off') 68 | plt.imshow(dog) 69 | plt.subplot(2, 5, 6) 70 | plt.title('Guided Backprop') 71 | visualize_single_saliency(attribution_guided_bp[0].unsqueeze(0)) 72 | plt.subplot(2, 5, 2) 73 | plt.title('Vanilla Gradient') 74 | visualize_single_saliency(attribution_gradients[0].unsqueeze(0)) 75 | plt.subplot(2, 5, 7) 76 | plt.title('Smoothed Vanilla Gradient') 77 | visualize_single_saliency(attribution_smooth_gradients[0].unsqueeze(0)) 78 | plt.subplot(2, 5, 3) 79 | plt.title('Integrated Gradients') 80 | visualize_single_saliency(attribution_ig[0].unsqueeze(0)) 81 | plt.subplot(2, 5, 8) 82 | plt.title('Smoothed Integrated Gradients') 83 | visualize_single_saliency(attribution_smooth_ig[0].unsqueeze(0)) 84 | plt.subplot(2, 5, 4) 85 | plt.title('Blur IG') 86 | visualize_single_saliency(attribution_blur_ig[0].unsqueeze(0)) 87 | plt.subplot(2, 5, 9) 88 | plt.title('Smoothed Blur IG') 89 | visualize_single_saliency(attribution_smooth_blur_ig[0].unsqueeze(0)) 90 | plt.subplot(2, 5, 5) 91 | plt.title('Guided IG') 92 | visualize_single_saliency(attribution_guided_ig[0].unsqueeze(0)) 93 | plt.subplot(2, 5, 10) 94 | plt.title('Smoothed Guided IG') 95 | visualize_single_saliency(attribution_smooth_guided_ig[0].unsqueeze(0)) 96 | plt.tight_layout() 97 | plt.savefig('examples/gradients_visualization.png', bbox_inches='tight', pad_inches=0.5) -------------------------------------------------------------------------------- /gradients_visualization_for_transformers_examples.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import timm 5 | from timm.data import resolve_model_data_config 6 | from timm.data.transforms_factory import create_transform 7 | import requests 8 | 9 | 10 | from attribution import VanillaGradient, IntegratedGradients, BlurIG, GuidedIG, GuidedBackProp 11 | from attribution.utils import normalize_saliency, visualize_single_saliency 12 | 13 | 14 | if __name__ == '__main__': 15 | 16 | # Load imagenet labels 17 | IMAGENET_1k_URL = 'https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt' 18 | IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n') 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | 21 | # Load a pretrained model 22 | model = timm.create_model('vit_tiny_patch16_224.augreg_in21k_ft_in1k', pretrained=True) 23 | model = model.to(device) 24 | model.eval() 25 | config = resolve_model_data_config(model, None) 26 | transform = create_transform(**config) 27 | 28 | # Load an image 29 | dog = Image.open('examples/dog.png').convert('RGB') 30 | dog_tensor = transform(dog).unsqueeze(0) 31 | H, W = dog_tensor.shape[-2:] 32 | 33 | # Predict the image 34 | img = transform(dog).unsqueeze(0) 35 | img = torch.cat([img, img]) 36 | img = img.to(device) 37 | output = model(img) 38 | target_index = torch.argmax(output, dim=1).cpu() 39 | print('Predicted:', IMAGENET_1k_LABELS[target_index[0].item()]) 40 | 41 | # Vanilla Gradient 42 | gradient_net = VanillaGradient(model) 43 | attribution_gradients = normalize_saliency(gradient_net.get_mask(img, target_index)) 44 | attribution_smooth_gradients = normalize_saliency(gradient_net.get_smoothed_mask(img, target_index, samples=10, std=0.1)) 45 | 46 | # Guided Backpropagation 47 | guided_bp_net = GuidedBackProp(model) 48 | attribution_guided_bp = normalize_saliency(guided_bp_net.get_mask(img, target_index)) 49 | 50 | # Integrated Gradients 51 | ig_net = IntegratedGradients(model) 52 | attribution_ig = normalize_saliency(ig_net.get_mask(img, target_index, steps=100)) 53 | attribution_smooth_ig = normalize_saliency(ig_net.get_smoothed_mask(img, target_index, steps=100, std=0.15, samples=10)) 54 | # Blur Integrated Gradients 55 | blur_ig_net = BlurIG(model) 56 | attribution_blur_ig = normalize_saliency(blur_ig_net.get_mask(img, target_index, steps=100)) 57 | attribution_smooth_blur_ig = normalize_saliency(blur_ig_net.get_smoothed_mask(img, target_index, steps=100, std=0.15, samples=10)) 58 | # Guided Integrated Gradients 59 | guided_ig_net = GuidedIG(model) 60 | attribution_guided_ig = normalize_saliency(guided_ig_net.get_mask(img, target_index, steps=100)) 61 | attribution_smooth_guided_ig = normalize_saliency(guided_ig_net.get_smoothed_mask(img, target_index, steps=100, std=0.15, samples=10)) 62 | 63 | # Visualize the results 64 | plt.figure(figsize=(16, 10)) 65 | plt.subplot(2, 5, 1) 66 | plt.title('Input') 67 | plt.axis('off') 68 | plt.imshow(dog) 69 | plt.subplot(2, 5, 6) 70 | plt.title('Guided Backprop') 71 | visualize_single_saliency(attribution_guided_bp[0].unsqueeze(0)) 72 | plt.subplot(2, 5, 2) 73 | plt.title('Vanilla Gradient') 74 | visualize_single_saliency(attribution_gradients[0].unsqueeze(0)) 75 | plt.subplot(2, 5, 7) 76 | plt.title('Smoothed Vanilla Gradient') 77 | visualize_single_saliency(attribution_smooth_gradients[0].unsqueeze(0)) 78 | plt.subplot(2, 5, 3) 79 | plt.title('Integrated Gradients') 80 | visualize_single_saliency(attribution_ig[0].unsqueeze(0)) 81 | plt.subplot(2, 5, 8) 82 | plt.title('Smoothed Integrated Gradients') 83 | visualize_single_saliency(attribution_smooth_ig[0].unsqueeze(0)) 84 | plt.subplot(2, 5, 4) 85 | plt.title('Blur IG') 86 | visualize_single_saliency(attribution_blur_ig[0].unsqueeze(0)) 87 | plt.subplot(2, 5, 9) 88 | plt.title('Smoothed Blur IG') 89 | visualize_single_saliency(attribution_smooth_blur_ig[0].unsqueeze(0)) 90 | plt.subplot(2, 5, 5) 91 | plt.title('Guided IG') 92 | visualize_single_saliency(attribution_guided_ig[0].unsqueeze(0)) 93 | plt.subplot(2, 5, 10) 94 | plt.title('Smoothed Guided IG') 95 | visualize_single_saliency(attribution_smooth_guided_ig[0].unsqueeze(0)) 96 | plt.tight_layout() 97 | plt.savefig('examples/gradients_visualization_for_transformers.png', bbox_inches='tight', pad_inches=0.5) -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .keep_remove_mask import kpm, kam, ram, knm, keep_remove_metrics 2 | from .correlation_value import class_specific_spearman_rank_corr 3 | from .insert_and_delete import Insert_Delete_Metric, insert_score, delete_score -------------------------------------------------------------------------------- /metrics/correlation_value.py: -------------------------------------------------------------------------------- 1 | from torchmetrics.functional.regression import spearman_corrcoef 2 | import torch 3 | 4 | @torch.no_grad() 5 | def spearman_rank_correlation_value(att1:torch.Tensor, att2: torch.Tensor): 6 | assert att1.size() == att2.size(), 'Attribution maps must have the same size' 7 | att1 = att1.view(att1.size(0), -1) 8 | att2 = att2.view(att2.size(0), -1) 9 | corr = spearman_corrcoef(att1.T, att2.T) 10 | return corr.cpu() 11 | 12 | def class_specific_spearman_rank_corr(att_target, att_other_class_list): 13 | ''' 14 | Return: 15 | corr: torch.Tensor, shape (Batch_size,) 16 | ''' 17 | corr = torch.zeros(att_target.size(0)) 18 | for att_other_class in att_other_class_list: 19 | class_corr = spearman_rank_correlation_value(att_target, att_other_class) 20 | corr += class_corr 21 | return corr / len(att_other_class_list) 22 | 23 | if __name__ == '__main__': 24 | device = torch.device('cuda') 25 | att1 = torch.rand(32, 1, 224, 224).to(device) 26 | att2 = -att1.clone().to(device) 27 | att3 = att1.clone().to(device) 28 | att4 = torch.rand(32, 1, 224, 224).to(device) 29 | att4[1] = -att1[1] 30 | 31 | att_other_class_list = [att2, att3, att4] 32 | corr = class_specific_spearman_rank_corr(att1, att_other_class_list) 33 | print('\n', corr) -------------------------------------------------------------------------------- /metrics/insert_and_delete.py: -------------------------------------------------------------------------------- 1 | ''' 2 | adapted from https://github.com/LMBTough/MFABA 3 | @article{zhu2023mfaba, 4 | title={MFABA: A More Faithful and Accelerated Boundary-based Attribution Method for Deep Neural Networks}, 5 | author={Zhu, Zhiyu and Chen, Huaming and Zhang, Jiayu and Wang, Xinyi and Jin, Zhibo and Xue, Minhui and Zhu, Dongxiao and Choo, Kim-Kwang Raymond}, 6 | journal={arXiv preprint arXiv:2312.13630}, 7 | year={2023} 8 | } 9 | ''' 10 | 11 | import torch 12 | import numpy as np 13 | 14 | @torch.no_grad() 15 | def insert_score(model, step, images, explanations): 16 | model.eval() 17 | B, C, H, W = images.size() 18 | if explanations.size(1) == 1: 19 | explanations = explanations.expand(-1, C, -1, -1) 20 | predictions = model(images) 21 | top, c = torch.max(predictions, -1) 22 | n_steps = (H * W + step - 1) // step 23 | scores = np.empty((B, n_steps + 1)) 24 | salient_order = explanations.view(B, C, H * W).argsort(descending=True) 25 | 26 | start = torch.zeros_like(images, device=images.device) 27 | finish = images.clone() 28 | finish = finish.view(B, C, H * W) 29 | 30 | for i in range(n_steps + 1): 31 | pred = model(start) 32 | pred = torch.nn.functional.softmax(pred, dim=-1) 33 | scores[:, i] = pred[torch.arange(B), c].cpu().numpy() 34 | if i < n_steps: 35 | coords = salient_order[:, :, i * step:(i + 1) * step] 36 | # change the value of the pixels according to the coords 37 | start = start.view(B, C, H * W) 38 | start.scatter_(dim=2, index=coords, src=torch.gather(finish, dim=2, index=coords)) 39 | start = start.view(B, C, H, W) 40 | 41 | scores = np.sum(scores, axis=0) 42 | xs = np.linspace(0, 1, scores.shape[0]) 43 | auc = np.trapz(scores, dx=xs[1] - xs[0]) / B 44 | 45 | return auc, scores 46 | 47 | @torch.no_grad() 48 | def delete_score(model, step, images, explanations): 49 | model.eval() 50 | B, C, H, W = images.size() 51 | if explanations.size(1) == 1: 52 | explanations = explanations.expand(-1, C, -1, -1) 53 | predictions = model(images) 54 | top, c = torch.max(predictions, -1) 55 | n_steps = (H * W + step - 1) // step 56 | scores = np.empty((B, n_steps + 1)) 57 | salient_order = explanations.view(B, C, H * W).argsort(descending=True) 58 | 59 | start = images.clone() 60 | finish = torch.zeros_like(images, device=images.device) 61 | finish = finish.view(B, C, H * W) 62 | 63 | for i in range(n_steps + 1): 64 | pred = model(start) 65 | pred = torch.nn.functional.softmax(pred, dim=-1) 66 | scores[:, i] = pred[torch.arange(B), c].cpu().numpy() 67 | if i < n_steps: 68 | coords = salient_order[:, :, i * step:(i + 1) * step] 69 | # change the value of the pixels according to the coords 70 | start = start.view(B, C, H * W) 71 | start.scatter_(dim=2, index=coords, src=torch.gather(finish, dim=2, index=coords)) 72 | start = start.view(B, C, H, W) 73 | 74 | scores = np.sum(scores, axis=0) 75 | xs = np.linspace(0, 1, scores.shape[0]) 76 | auc = np.trapz(scores, dx=xs[1] - xs[0]) / B 77 | return auc, scores 78 | 79 | class Insert_Delete_Metric(): 80 | 81 | def __init__(self, model, mode, step): 82 | r"""Create deletion/insertion metric instance. 83 | Args: 84 | model (nn.Module): Black-box model being explained. 85 | mode (str): 'del' or 'ins'. 86 | step (int): number of pixels modified per one iteration. 87 | baseline (str): 'black' or 'white'. 88 | """ 89 | assert mode in ['del', 'ins'] 90 | self.model = model 91 | self.mode = mode 92 | self.step = step 93 | self.substrate_fn = lambda x: torch.zeros_like(x, device=x.device) 94 | 95 | def single_run(self, image, explanation: torch.Tensor): 96 | ''' only one image 97 | image: torch.Tensor, 1, C, H, W 98 | explanation: torch.Tensor, 1, 1, H, W or 1, C, H, W 99 | ''' 100 | _, C, H, W = image.size() 101 | if explanation.size(1) == 1: 102 | explanation = explanation.expand(-1, C, -1, -1) 103 | pred = self.model(image) 104 | top, c = torch.max(pred, 1) 105 | n_steps = (H * W + self.step - 1) // self.step 106 | 107 | if self.mode == 'del': 108 | start = image.clone() 109 | finish = self.substrate_fn(image) 110 | else: 111 | start = self.substrate_fn(image) 112 | finish = image.clone() 113 | finish = finish.view(-1, C, H * W) 114 | 115 | scores = np.empty(n_steps + 1) 116 | salient_order = explanation.view(-1, H * W).argsort(descending=True) 117 | for i in range(n_steps+1): 118 | pred = self.model(start) 119 | pred = torch.nn.functional.softmax(pred, dim=-1) 120 | scores[i] = pred[0, c] 121 | if i < n_steps: 122 | coords = salient_order[:, i * self.step:(i + 1) * self.step] 123 | start = start.view(-1, C, H * W) 124 | start[:, :, coords] = finish[:, :, coords] 125 | start = start.view(-1, C, H, W) 126 | auc = np.trapz(scores, dx=self.step) 127 | return auc, scores 128 | 129 | def batch_run(self, images, explanations): 130 | B, C, H, W = images.size() 131 | if explanations.size(1) == 1: 132 | explanations = explanations.expand(-1, C, -1, -1) 133 | predictions = self.model(images) 134 | top, c = torch.max(predictions, -1) 135 | n_steps = (H * W + self.step - 1) // self.step 136 | scores = np.empty((B, n_steps + 1)) 137 | salient_order = explanations.view(B, C, H * W).argsort(descending=True) 138 | if self.mode == 'del': 139 | start = images.clone() 140 | finish = self.substrate_fn(images) 141 | else: 142 | start = self.substrate_fn(images) 143 | finish = images.clone() 144 | 145 | for i in range(n_steps + 1): 146 | pred = self.model(start) 147 | pred = torch.nn.functional.softmax(pred, dim=-1) 148 | scores[:, i] = pred[torch.arange(B), c] 149 | if i < n_steps: 150 | coords = salient_order[:, :, i * self.step:(i + 1) * self.step] 151 | start = start.view(B, C, H * W) 152 | start[torch.arange(B).unsqueeze(1), :, coords] = finish[torch.arange(B).unsqueeze(1), :, coords] 153 | start = start.view(B, C, H, W) 154 | 155 | auc = np.trapz(scores, dx=self.step, axis=1) 156 | 157 | return auc, scores -------------------------------------------------------------------------------- /metrics/keep_remove_mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.transforms import GaussianBlur 3 | import numpy as np 4 | 5 | import json 6 | import numpy as np 7 | #import shap 8 | #import shap.benchmark as benchmark 9 | import scipy as sp 10 | import math 11 | from collections import OrderedDict 12 | import numbers 13 | import torch 14 | from torch import nn 15 | from torch.nn import functional as F 16 | import torchvision.transforms as transforms 17 | import torchvision.datasets as datasets 18 | import matplotlib.pyplot as plt 19 | import matplotlib.image as mpimg 20 | import itertools 21 | from tqdm import tqdm 22 | 23 | 24 | 25 | class GaussianSmoothing(nn.Module): 26 | """ 27 | Apply gaussian smoothing on a 28 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel 29 | in the input using a depthwise convolution. 30 | Arguments: 31 | channels (int, sequence): Number of channels of the input tensors. Output will 32 | have this number of channels as well. 33 | kernel_size (int, sequence): Size of the gaussian kernel. 34 | sigma (float, sequence): Standard deviation of the gaussian kernel. 35 | dim (int, optional): The number of dimensions of the data. 36 | Default value is 2 (spatial). 37 | """ 38 | def __init__(self, device, channels, kernel_size, sigma, dim=2): 39 | super(GaussianSmoothing, self).__init__() 40 | if isinstance(kernel_size, numbers.Number): 41 | kernel_size = [kernel_size] * dim 42 | if isinstance(sigma, numbers.Number): 43 | sigma = [sigma] * dim 44 | 45 | # The gaussian kernel is the product of the 46 | # gaussian function of each dimension. 47 | kernel = 1 48 | meshgrids = torch.meshgrid( 49 | [ 50 | torch.arange(size, dtype=torch.float32) 51 | for size in kernel_size 52 | ] 53 | ) 54 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 55 | mean = (size - 1) / 2 56 | kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ 57 | torch.exp(-((mgrid - mean) / (2 * std)) ** 2) 58 | 59 | # Make sure sum of values in gaussian kernel equals 1. 60 | kernel = kernel / torch.sum(kernel) 61 | 62 | # Reshape to depthwise convolutional weight 63 | kernel = kernel.view(1, 1, *kernel.size()) 64 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 65 | 66 | self.register_buffer('weight', kernel) 67 | self.weight = self.weight.to(device) 68 | self.groups = channels 69 | 70 | if dim == 1: 71 | self.conv = F.conv1d 72 | elif dim == 2: 73 | self.conv = F.conv2d 74 | elif dim == 3: 75 | self.conv = F.conv3d 76 | else: 77 | raise RuntimeError( 78 | 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) 79 | ) 80 | 81 | def forward(self, input): 82 | """ 83 | Apply gaussian filter to input. 84 | Arguments: 85 | input (torch.Tensor): Input to apply gaussian filter on. 86 | Returns: 87 | filtered (torch.Tensor): Filtered output. 88 | """ 89 | return self.conv(input, weight=self.weight, groups=self.groups) 90 | 91 | @torch.no_grad() 92 | def attribution_to_mask(attribution, percent_unmasked, sort_order, perturbation, device): 93 | 94 | attribution = attribution.clone().detach() 95 | attribution = attribution.mean(dim=1) 96 | attribution = attribution.unsqueeze(1) 97 | 98 | zeros = torch.zeros(attribution.shape).to(device) 99 | 100 | #invert attribution for negative case 101 | if sort_order == 'negative' or sort_order == 'negative_target' or sort_order == 'negative_topk': 102 | attribution = -attribution 103 | 104 | if sort_order == 'absolute': 105 | attribution = torch.abs(attribution) 106 | 107 | #for positive and negative the negative and positive values are always masked ond not considered for the topk 108 | positives = torch.maximum(attribution, zeros) 109 | nb_positives = torch.count_nonzero(positives) 110 | 111 | orig_shape = positives.size() 112 | positives = positives.view(positives.size(0), 1, -1) 113 | nb_pixels = positives.size(2) 114 | 115 | if perturbation == 'keep': 116 | # find features to keep 117 | ret = torch.topk(positives, k=int(torch.minimum(torch.tensor(percent_unmasked*nb_pixels).to(device), nb_positives)), dim=2) 118 | 119 | if perturbation == 'remove': 120 | #set zeros to large value 121 | positives_wo_zero = positives.clone() 122 | positives_wo_zero[positives_wo_zero == 0.] = float("Inf") 123 | # find features to keep 124 | ret = torch.topk(positives_wo_zero, k=int(torch.minimum(torch.tensor(percent_unmasked*nb_pixels).to(device), nb_positives)), dim=2, largest=False) 125 | ret.indices.shape 126 | # Scatter to zero'd tensor 127 | res = torch.zeros_like(positives) 128 | res.scatter_(2, ret.indices, ret.values) 129 | res = res.view(*orig_shape) 130 | 131 | res = (res == 0).float() # set topk values to zero and all zeros to one 132 | res = res.repeat(1,3,1,1) 133 | return res 134 | 135 | @torch.no_grad() 136 | def keep_remove_metrics(sort_order, perturbation, model, images, labels, attribution, device, step_num=16): 137 | options = [(sort_order, perturbation)] 138 | vals = {} 139 | 140 | smoothing = GaussianSmoothing(device, 3, 51, 41) 141 | percent_unmasked_range = np.geomspace(0.01, 1.0, num=step_num) 142 | 143 | for percent_unmasked in percent_unmasked_range: 144 | vals[perturbation + "_" + sort_order + "_" + str(percent_unmasked)] = 0.0 145 | 146 | masks = [] 147 | for sort_order, perturbation in options: 148 | for percent_unmasked in percent_unmasked_range: 149 | #create masked images 150 | for sample in range(attribution.shape[0]): 151 | mask = attribution_to_mask(attribution[sample].unsqueeze(0), percent_unmasked, sort_order, perturbation, device) 152 | masks.append(mask) 153 | 154 | mask = torch.cat(masks, dim=0) 155 | 156 | images_masked_pt = images.clone().repeat(int(mask.shape[0]/images.size(0)), 1, 1, 1) 157 | images_smoothed_pt = images.clone().repeat(int(mask.shape[0]/images.size(0)), 1, 1, 1) 158 | images_smoothed_pt = F.pad(images_smoothed_pt, (25,25,25,25), mode='reflect') 159 | images_smoothed_pt = smoothing(images_smoothed_pt) 160 | images_masked_pt[mask.bool()] = images_smoothed_pt[mask.bool()] 161 | 162 | #images_masked = normalize(torch.tensor(images_masked_np / 255.).unsqueeze(0).permute(0,3,1,2)) 163 | images_masked = images_masked_pt 164 | images_masked = images_masked.to(device) 165 | out_masked = model(images_masked) 166 | out_masked = out_masked.softmax(dim=-1) 167 | 168 | #split out_masked in the chunks that correspond to the individual run 169 | option_runs = torch.split(out_masked, int(out_masked.shape[0]/len(options))) 170 | for o, (sort_order, perturbation) in enumerate(options): 171 | option_run = option_runs[o] # N, 1000 172 | percent_unmasked_runs = torch.split(option_run, int(option_run.shape[0]/len(percent_unmasked_range))) # N, 1000 173 | for p, percent_unmasked in enumerate(percent_unmasked_range): 174 | percent_unmasked_run = percent_unmasked_runs[p] # N, 1000 175 | #if len(percent_unmasked_run.shape) == 1: 176 | # percent_unmasked_run = percent_unmasked_run.unsqueeze(0) 177 | 178 | if sort_order == 'positive': 179 | vals[perturbation + "_" + sort_order + "_" + str(percent_unmasked)] += torch.gather(percent_unmasked_run, 1, labels.unsqueeze(-1)).sum().cpu().item() 180 | if sort_order == 'negative': 181 | vals[perturbation + "_" + sort_order + "_" + str(percent_unmasked)] += torch.gather(percent_unmasked_run, 1, labels.unsqueeze(-1)).sum().cpu().item() 182 | if sort_order == 'absolute': 183 | correct = (torch.max(percent_unmasked_run, 1)[1] == labels).float().sum().cpu().item() 184 | vals[perturbation + "_" + sort_order + "_" + str(percent_unmasked)] += correct 185 | 186 | for sort_order, perturbation in options: 187 | for percent_unmasked in percent_unmasked_range: 188 | vals[perturbation + "_" + sort_order + "_" + str(percent_unmasked)] /= images.shape[0] 189 | 190 | for sort_order, perturbation in options: 191 | xs = [] 192 | ys = [] 193 | for percent_unmasked in percent_unmasked_range: 194 | xs.append(percent_unmasked) 195 | ys.append(vals[perturbation + "_" + sort_order + "_" + str(percent_unmasked)]) 196 | auc = np.trapz(ys, xs) 197 | xs = np.array(xs) 198 | ys = np.array(ys) 199 | return auc, vals, xs, ys 200 | 201 | 202 | def kpm(original_input, attribution, target_label, forward, step=0.05): 203 | """ kpm: keep positive mask, 204 | increase the numbers of input pixels that attribution is positive from high attribution to low attribution, other pixels are masked using the average value of masked pixels 205 | record acc score of the target label of each step, finally calculate the area under the curve 206 | the output acc is expected to increase from 0 to the acc of the original input 207 | original_input: batch original input, B*C*H*W 208 | attribution: batch attribution, B*1*H*W 209 | mask: batch target mask, B*1*H*W 210 | target_label: batch target label, B 211 | forward: forward function, forward(masked_input, target_label), return mean target score of the batch and accuracy of the batch 212 | step: step of masking attributed pixels, percentage of the input features of original input 213 | """ 214 | B, C, H, W = original_input.size() 215 | # get sampling interval 216 | max_attribution = attribution.max() 217 | min_attribution = 0. # attribution.min() 218 | interval = (max_attribution - min_attribution) * step 219 | N = int(1 / step) + 1 220 | 221 | y_axis = [] 222 | 223 | # get blurred input 224 | # blurred_original_input = GaussianBlur(51, 41)(original_input) 225 | blurred_original_input = torch.zeros_like(original_input, device=original_input.device) 226 | 227 | for i in range(N): 228 | # get mask regions 229 | mask_region = attribution < max_attribution - i * interval 230 | masked_input = original_input.clone() 231 | # use blurred_original_input to mask the input 232 | masked_input = torch.where(~mask_region, masked_input, blurred_original_input) 233 | 234 | # mean target score of the batch 235 | score, acc = forward(masked_input, target_label) 236 | y_axis.append(score) 237 | 238 | # plot and save the curve 239 | x_axis = np.arange(0, 1+step, step) 240 | y_axis = np.array(y_axis) 241 | auc = np.trapz(y_axis, x_axis) 242 | 243 | # plot 244 | # import matplotlib.pyplot as plt 245 | # plt.plot(x_axis, y_axis) 246 | 247 | return auc, x_axis, y_axis 248 | 249 | 250 | def knm(original_input, attribution, target_label, forward, step=0.05): 251 | """ knm: keep negative mask 252 | works like kpm, but we keep the most low attribution pixels and mask other pixels, 253 | and increase the numbers of input pixels that attribution is negative from low attribution to high attribution, other pixels are masked using the average value of masked pixels 254 | """ 255 | 256 | B, C, H, W = original_input.size() 257 | # get sampling interval 258 | max_attribution = 0. # attribution.max() 259 | min_attribution = attribution.min() 260 | interval = (max_attribution - min_attribution) * step 261 | N = int(1 / step) + 1 262 | 263 | y_axis = [] 264 | # get blurred input 265 | blurred_original_input = GaussianBlur(51, 41)(original_input) 266 | for i in range(N): 267 | # get mask regions 268 | mask_region = attribution > min_attribution + i * interval 269 | masked_input = original_input.clone() 270 | # use blurred_original_input to mask the input 271 | masked_input = torch.where(~mask_region, masked_input, blurred_original_input) 272 | 273 | # mean target score of the batch 274 | score, acc = forward(masked_input, target_label) 275 | y_axis.append(score) 276 | 277 | x_axis = np.arange(0, 1+step, step) 278 | y_axis = np.array(y_axis) 279 | auc = np.trapz(y_axis, x_axis) 280 | 281 | return auc, x_axis, y_axis 282 | 283 | 284 | def kam(original_input, attribution, target_label, forward, step=0.05): 285 | """ kam: keep absolute mask 286 | works like kpm, but we now observe the accuracy of the predictions instead of the prediction score of target label 287 | also use absolute attritbiton map instead of positive attribution map 288 | """ 289 | 290 | B, C, H, W = original_input.size() 291 | attribution = attribution.abs() 292 | # get sampling interval 293 | max_attribution = attribution.max() 294 | min_attribution = attribution.min() 295 | interval = (max_attribution - min_attribution) * step 296 | N = int(1 / step) + 1 297 | 298 | y_axis = [] 299 | blurred_original_input = GaussianBlur(51, 41)(original_input) 300 | for i in range(N): 301 | # get mask regions 302 | mask_region = attribution < max_attribution - i * interval 303 | masked_input = original_input.clone() 304 | # use blurred_original_input to mask the input 305 | masked_input = torch.where(~mask_region, masked_input, blurred_original_input) 306 | 307 | # mean acc of the batch 308 | score, acc = forward(masked_input, target_label) 309 | y_axis.append(acc) 310 | 311 | x_axis = np.arange(0, 1+step, step) 312 | y_axis = np.array(y_axis) 313 | auc = np.trapz(y_axis, x_axis) 314 | 315 | return auc, x_axis, y_axis 316 | 317 | 318 | def ram(original_input, attribution, target_label, forward, step=0.05): 319 | """ ram: remove absolute mask 320 | works like knm, but we now observe the accuracy of the predictions instead of the prediction score of target label 321 | also use absolute attritbiton map instead of positive attribution map 322 | """ 323 | 324 | B, C, H, W = original_input.size() 325 | attribution = attribution.abs() 326 | # get sampling interval 327 | max_attribution = attribution.max() 328 | min_attribution = attribution.min() 329 | interval = (max_attribution - min_attribution) * step 330 | N = int(1 / step) + 1 331 | 332 | y_axis = [] 333 | blurred_original_input = GaussianBlur(51, 41)(original_input) 334 | for i in range(N): 335 | # get mask regions 336 | mask_region = attribution > max_attribution - i * interval 337 | masked_input = original_input.clone() 338 | # use blurred_original_input to mask the input 339 | masked_input = torch.where(~mask_region, masked_input, blurred_original_input) 340 | 341 | # mean acc of the batch 342 | score, acc = forward(masked_input, target_label) 343 | y_axis.append(acc) 344 | 345 | x_axis = np.arange(0, 1+step, step) 346 | y_axis = np.array(y_axis) 347 | auc = np.trapz(y_axis, x_axis) 348 | 349 | return auc, x_axis, y_axis 350 | -------------------------------------------------------------------------------- /perturbation_based_attribution_visualization_examples.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import timm 5 | from timm.data import resolve_model_data_config 6 | from timm.data.transforms_factory import create_transform 7 | import requests 8 | 9 | 10 | from attribution import Occlusion 11 | from attribution.utils import normalize_saliency, visualize_single_saliency 12 | 13 | 14 | if __name__ == '__main__': 15 | 16 | # Load imagenet labels 17 | IMAGENET_1k_URL = 'https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt' 18 | IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n') 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | 21 | # Load a pretrained model 22 | model = timm.create_model('resnet50', pretrained=True) 23 | model = model.to(device) 24 | model.eval() 25 | config = resolve_model_data_config(model, None) 26 | transform = create_transform(**config) 27 | 28 | # Load an image 29 | dog = Image.open('examples/dog.png').convert('RGB') 30 | dog_tensor = transform(dog).unsqueeze(0) 31 | H, W = dog_tensor.shape[-2:] 32 | 33 | # Predict the image 34 | img = transform(dog).unsqueeze(0) 35 | img = torch.cat([img, img]) 36 | img = img.to(device) 37 | output = model(img) 38 | target_index = torch.argmax(output, dim=1).cpu() 39 | print('Predicted:', IMAGENET_1k_LABELS[target_index[0].item()]) 40 | 41 | # Occlusion 42 | occlusion_net = Occlusion(model) 43 | occlusion = normalize_saliency(occlusion_net.get_mask(img, target_index)) 44 | occlusion_2 = normalize_saliency(occlusion_net.get_mask(img, target_index, size=30)) 45 | 46 | # Visualize the results 47 | plt.figure(figsize=(16, 5)) 48 | plt.subplot(1, 3, 1) 49 | plt.title('Input') 50 | plt.axis('off') 51 | plt.imshow(dog) 52 | plt.subplot(1, 3, 2) 53 | plt.title('Occlusion window-15') 54 | visualize_single_saliency(occlusion[0].unsqueeze(0)) 55 | plt.subplot(1, 3, 3) 56 | plt.title('Occlusion window-30') 57 | visualize_single_saliency(occlusion_2[0].unsqueeze(0)) 58 | plt.tight_layout() 59 | plt.savefig('examples/perturbation_based_visualization.png', bbox_inches='tight', pad_inches=0.5) -------------------------------------------------------------------------------- /quick_start.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | from PIL import Image 3 | import requests 4 | import timm 5 | from timm.data import resolve_model_data_config 6 | from timm.data.transforms_factory import create_transform 7 | import torch 8 | 9 | from attribution import BlurIG, GradCAM, CombinedWrapper 10 | from attribution.utils import normalize_saliency, visualize_single_saliency 11 | 12 | # Load imagenet labels 13 | IMAGENET_1k_URL = 'https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt' 14 | IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n') 15 | 16 | # Load model 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | model = timm.create_model('resnet50', pretrained=True) 19 | model = model.to(device) 20 | model.eval() 21 | config = resolve_model_data_config(model, None) 22 | transform = create_transform(**config) 23 | 24 | # Load image 25 | dog = Image.open('examples/dog.png').convert('RGB') 26 | dog_tensor = transform(dog).unsqueeze(0) 27 | H, W = dog_tensor.shape[-2:] 28 | img = transform(dog).unsqueeze(0) 29 | 30 | # We support batch input 31 | img = torch.cat([img, img]) 32 | img = img.to(device) 33 | output = model(img) 34 | target_index = torch.argmax(output, dim=1).cpu() 35 | print('Predicted:', IMAGENET_1k_LABELS[target_index[0].item()]) 36 | 37 | # Gradients visualization 38 | blur_ig_kwargs = {'steps': 100, 39 | 'batch_size': 4, 40 | 'max_sigma': 50, 41 | 'grad_step': 0.01, 42 | 'sqrt': False} 43 | blur_ig_net = BlurIG(model) 44 | blur_ig = normalize_saliency(blur_ig_net.get_mask(img, target_index, **blur_ig_kwargs)) 45 | 46 | # CAM visualization 47 | gradcam_net = GradCAM(model) 48 | gradcam = normalize_saliency( 49 | gradcam_net.get_mask(img, target_index, target_layer='layer3')) 50 | 51 | # Combine Gradients and CAM visualization 52 | combined = CombinedWrapper(model, BlurIG, GradCAM) 53 | combined_saliency = normalize_saliency( 54 | combined.get_mask(img, target_index, target_layer='layer3', **blur_ig_kwargs)) 55 | 56 | # Visualize 57 | plt.figure(figsize=(16, 5)) 58 | plt.subplot(1, 4, 1) 59 | plt.imshow(dog) 60 | plt.title('Input Image') 61 | plt.axis('off') 62 | plt.subplot(1, 4, 2) 63 | visualize_single_saliency(blur_ig[0].unsqueeze(0)) 64 | plt.title('Blur IG') 65 | plt.subplot(1, 4, 3) 66 | visualize_single_saliency(gradcam[0].unsqueeze(0)) 67 | plt.title('GradCAM') 68 | plt.subplot(1, 4, 4) 69 | visualize_single_saliency(combined_saliency[0].unsqueeze(0)) 70 | plt.title('Combined') 71 | plt.tight_layout() 72 | plt.savefig('examples/quick_start.png', bbox_inches='tight', pad_inches=0.5) --------------------------------------------------------------------------------