├── .gitignore
├── README.md
├── __init__.py
├── attribution
├── __init__.py
├── ablation_cam.py
├── bag_cam.py
├── base.py
├── blur_ig.py
├── eigen_cam.py
├── eigen_gradcam.py
├── fullgrad_cam.py
├── grad_cam.py
├── grad_cam_plusplus.py
├── guided_backprop.py
├── guided_gradcam.py
├── guided_ig.py
├── hires_cam.py
├── ig.py
├── layer_cam.py
├── occlusion.py
├── score_cam.py
├── utils
│ ├── __init__.py
│ ├── access_layers.py
│ ├── reshape_for_transformer.py
│ ├── svd_on_feature_maps.py
│ └── visualization_maps.py
└── xgrad_cam.py
├── cam_visualization_examples.py
├── cam_visualization_for_transformers_examples.py
├── combine_cam_and_gradients_visualization_examples.py
├── examples
├── attribution_methods.png
├── cam_visualization.png
├── cam_visualization_for_transformers.png
├── cat.png
├── combine_cam_and_gradients_visualization.png
├── dog.png
├── dog_and_cat.png
├── gradients_visualization.png
├── gradients_visualization_for_transformers.png
├── perturbation_based_visualization.png
└── quick_start.png
├── gradients_visualization_examples.py
├── gradients_visualization_for_transformers_examples.py
├── metrics
├── __init__.py
├── correlation_value.py
├── insert_and_delete.py
└── keep_remove_mask.py
├── perturbation_based_attribution_visualization_examples.py
└── quick_start.py
/.gitignore:
--------------------------------------------------------------------------------
1 | gradients
2 | saliency
3 | pytorch-grad-cam
4 | test_code.ipynb
5 | test_code.py
6 | .vscode
7 | *__pycache__*
8 | temp
9 | dix
10 | six
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Attritbuion Methods for Image Classification Models
2 |
3 |
4 | We only consider plug-and-play methods that **do not have special requirements on the model architecture and do not need to add modules with learnable parameters or additional training**. At the same time, we hope to facilitate weakly-supervised localization and segmentation using attribution results, as well as incorporating them as part of model training (e.g., use the attribution results as additional supervision information). Therefore, all methods use PyTorch tensors for calculations as much as possible, support batch input, and GPU usage.
5 |
6 | Since we mainly aim to use attribution results to assist weakly supervised training, localization, segmentation, model distillation and etc., we did not include explainability methods for black-box models like RISE and HISC here.
7 |
8 | ## Gradients Visualization
9 | **CNN models**: some results of resnet50 from timm, example code at [./gradientss_visualization_examples.py](./gradients_visualization_examples.py).
10 |
11 |
12 |
13 |
14 | **Vision/Swin Transformers**: gradients visualization methods can be directly used for transformers, the model used here is `vit_tiny_patch16_224.augreg_in21k_ft_in1k` from timm, example code at [./gradients_visualization_for_transformers_examples.py](./gradients_visualization_for_transformers_examples.py).
15 |
16 |
17 |
18 | ## Class Activation Map (CAM) Visualization
19 | resnet50, the target layer is `layer3`, example code at [./cam_visualization_examples.py](./cam_visualization_examples.py)
20 |
21 |
22 |
23 | ## CAM Visualization for ViT and Swin Transformer
24 | use `attribution.utils.get_reshape_transform` when creating the attribution model, example code at [./cam_visualization_for_transformers_examples.py](./cam_visualization_for_transformers_examples.py). The target layer used for ViT here is `blocks.11.norm1` and that for Swin Transformer is `norm`.
25 |
26 |
27 |
28 | Currently, some methods are not supported for transformers, such as Ablation-CAM, and the visualization effect is not as good as CNN models since many methods are designed with the concept of feature maps. We will try to add visualization methods that are designed for transformers in the future.
29 |
30 | ## Combine Gradients and CAM Visualization
31 | similar to Guided Grad-CAM, any method in the gradient visualization can be combined with CAM visualization, example code at [./combine_cam_and_gradients_visualization_examples.py](./combine_cam_and_gradients_visualization_examples.py)
32 |
33 |
34 |
35 | ## Block-Box Perturbation-based Attribution Visualization
36 | example code at [./perturbation_based_attribution_visualization_examples.py](./perturbation_based_attribution_visualization_examples.py)
37 |
38 |
39 |
40 | ## Quick Start
41 | ```python
42 | from matplotlib import pyplot as plt
43 | from PIL import Image
44 | import requests
45 | import timm
46 | from timm.data import resolve_model_data_config
47 | from timm.data.transforms_factory import create_transform
48 | import torch
49 |
50 | from attribution import BlurIG, GradCAM, CombinedWrapper
51 | from attribution.utils import normalize_saliency, visualize_single_saliency
52 |
53 | # Load imagenet labels
54 | IMAGENET_1k_URL = 'https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt'
55 | IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n')
56 |
57 | # Load model
58 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59 | model = timm.create_model('resnet50', pretrained=True)
60 | model = model.to(device)
61 | model.eval()
62 | config = resolve_model_data_config(model, None)
63 | transform = create_transform(**config)
64 |
65 | # Load image
66 | dog = Image.open('examples/dog.png').convert('RGB')
67 | dog_tensor = transform(dog).unsqueeze(0)
68 | H, W = dog_tensor.shape[-2:]
69 | img = transform(dog).unsqueeze(0)
70 |
71 | # We support batch input
72 | img = torch.cat([img, img])
73 | img = img.to(device)
74 | output = model(img)
75 | target_index = torch.argmax(output, dim=1).cpu()
76 | print('Predicted:', IMAGENET_1k_LABELS[target_index[0].item()])
77 |
78 | # Gradients visualization
79 | blur_ig_kwargs = {'steps': 100,
80 | 'batch_size': 4,
81 | 'max_sigma': 50,
82 | 'grad_step': 0.01,
83 | 'sqrt': False}
84 | blur_ig_net = BlurIG(model)
85 | blur_ig = normalize_saliency(blur_ig_net.get_mask(img, target_index, **blur_ig_kwargs))
86 |
87 | # CAM visualization
88 | gradcam_net = GradCAM(model)
89 | gradcam = normalize_saliency(gradcam_net.get_mask(img, target_index, target_layer='layer3'))
90 |
91 | # Combine Gradients and CAM visualization
92 | combined = CombinedWrapper(model, BlurIG, GradCAM)
93 | combined_saliency = normalize_saliency(
94 | combined.get_mask(img, target_index, target_layer='layer3', **blur_ig_kwargs))
95 |
96 | # Visualize
97 | plt.figure(figsize=(16, 5))
98 | plt.subplot(1, 4, 1)
99 | plt.imshow(dog)
100 | plt.title('Input Image')
101 | plt.axis('off')
102 | plt.subplot(1, 4, 2)
103 | visualize_single_saliency(blur_ig[0].unsqueeze(0))
104 | plt.title('Blur IG')
105 | plt.subplot(1, 4, 3)
106 | visualize_single_saliency(gradcam[0].unsqueeze(0))
107 | plt.title('GradCAM')
108 | plt.subplot(1, 4, 4)
109 | visualize_single_saliency(combined_saliency[0].unsqueeze(0))
110 | plt.title('Combined')
111 | plt.tight_layout()
112 | plt.savefig('examples/quick_start.png', bbox_inches='tight', pad_inches=0.5)
113 | ```
114 |
115 |
116 |
117 |
118 | ## TODO:
119 | This is still an ongoing work to implement various attribution methods for image classification models in PyTorch using a unified framework.
120 | - [x] Unify gradient visualization API.
121 | - [x] Implement CAM visualization for CNN models based on known target_layer names.
122 | - [x] Implement CAM for ViT ,Swin Transformer and etc.
123 | - [x] Implement keep positive/negative mask, keep/remove absolute mask metrics. For details, please refer to [Fast Axiomatic Attribution for Neural Networks](https://proceedings.neurips.cc/paper/2021/hash/a284df1155ec3e67286080500df36a9a-Abstract.html).
124 | - [ ] Add LIFT-CAM (ICCV2021), IIA (ICCV2023), Dix (CIKM) and Six (ICDM).
125 | - [ ] Unify all APIs.
126 | - [ ] Documentation.
127 |
128 |
129 | ## Acknowledgements
130 | This project is inspired by [jacobgil/pytorch-grad-cam](https://github.com/jacobgil/pytorch-grad-cam), [PAIR-code/saliency](https://github.com/PAIR-code/saliency) and [hummat/saliency](https://github.com/hummat/saliency). Thanks for their wonderful work.
131 |
132 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from .attribution import *
2 | from .attribution.utils import *
3 |
4 | from .metrics import *
--------------------------------------------------------------------------------
/attribution/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import VanillaGradient
2 | from .ig import IntegratedGradients
3 | from .blur_ig import BlurIG
4 | from .guided_ig import GuidedIG
5 | from .guided_backprop import GuidedBackProp
6 |
7 | from .occlusion import Occlusion
8 |
9 | from .grad_cam import GradCAM
10 | from .guided_gradcam import GuidedGradCAM
11 | from .hires_cam import HiResCAM
12 | from .grad_cam_plusplus import GradCAMPlusPlus
13 | from .xgrad_cam import XGradCAM
14 | from .bag_cam import BagCAM
15 | from .score_cam import ScoreCAM
16 | from .layer_cam import LayerCAM
17 | from .ablation_cam import AblationCAM
18 | from .fullgrad_cam import FullGrad
19 | from .eigen_cam import EigenCAM
20 | from .eigen_gradcam import EigenGradCAM
21 |
22 | from .base import CombinedWrapper
--------------------------------------------------------------------------------
/attribution/ablation_cam.py:
--------------------------------------------------------------------------------
1 | from .base import CAMWrapper
2 | from .utils import replace_layer_recursive, get_2d_projection
3 |
4 | import numpy as np
5 | import torch
6 | from torch.nn import functional as F
7 |
8 | class AblationLayer(torch.nn.Module):
9 | def __init__(self):
10 | super(AblationLayer, self).__init__()
11 |
12 | def objectiveness_mask_from_svd(self, feature_maps, threshold=0.01):
13 | projection = get_2d_projection(feature_maps[None, :])[0, :]
14 | projection = np.abs(projection)
15 | projection = projection - projection.min()
16 | projection = projection / projection.max()
17 | projection = projection > threshold
18 | return projection
19 |
20 | def feature_maps_to_be_ablated(
21 | self,
22 | feature_maps,
23 | ratio_channels_to_ablate=1.0):
24 | if ratio_channels_to_ablate == 1.0:
25 | self.indices = np.int32(range(feature_maps.shape[0]))
26 | return self.indices
27 |
28 | projection = self.objectiveness_mask_from_svd(feature_maps)
29 |
30 | scores = []
31 | for channel in feature_maps:
32 | normalized = np.abs(channel)
33 | normalized = normalized - normalized.min()
34 | normalized = normalized / np.max(normalized)
35 | score = (projection * normalized).sum() / normalized.sum()
36 | scores.append(score)
37 | scores = np.float32(scores)
38 |
39 | indices = list(np.argsort(scores))
40 | high_score_indices = indices[::-
41 | 1][: int(len(indices) *
42 | ratio_channels_to_ablate)]
43 | low_score_indices = indices[: int(
44 | len(indices) * ratio_channels_to_ablate)]
45 | self.indices = np.int32(high_score_indices + low_score_indices)
46 | return self.indices
47 |
48 | def set_next_batch(
49 | self,
50 | input_batch_index,
51 | feature_maps: torch.Tensor,
52 | num_channels_to_ablate):
53 | self.feature_maps_ablation_layer = feature_maps[input_batch_index, :, :, :].clone(
54 | ).unsqueeze(0).repeat(num_channels_to_ablate, 1, 1, 1)
55 |
56 | def __call__(self, x):
57 | output = self.feature_maps_ablation_layer
58 | for i in range(output.size(0)):
59 | # Commonly the minimum activation will be 0,
60 | # And then it makes sense to zero it out.
61 | # However depending on the architecture,
62 | # If the values can be negative, we use very negative values
63 | # to perform the ablation, deviating from the paper.
64 | if torch.min(output) == 0:
65 | output[i, self.indices[i], :] = 0
66 | else:
67 | ABLATION_VALUE = 1e7
68 | output[i, self.indices[i], :] = torch.min(
69 | output) - ABLATION_VALUE
70 |
71 | return output
72 |
73 | class AblationCAM(CAMWrapper):
74 | def get_mask(self, img: torch.Tensor,
75 | target_class: torch.Tensor,
76 | target_layer: str,
77 | batch_size: int = 32,
78 | ratio_channels_to_ablate: float = 1.0):
79 |
80 | B, C, H, W = img.size()
81 | self.model.eval()
82 | self.model.zero_grad()
83 |
84 | # class-specific backpropagation
85 | logits = self.model(img)
86 | target = self._encode_one_hot(target_class, logits)
87 | self.model.zero_grad()
88 | logits.backward(gradient=target, retain_graph=True)
89 |
90 | get_targets = lambda o, target: o[target]
91 | if isinstance(target_class, torch.Tensor):
92 | target_class = target_class.cpu().tolist()
93 | if not isinstance(target_class, list):
94 | target_class = [target_class]
95 |
96 | original_scores = [get_targets(o, t).cpu().detach() for o, t in zip(logits, target_class)]
97 | feature_maps = self._find(self.feature_maps, target_layer)
98 |
99 | # save original layer and replace the model back to the original state later
100 | original_target_layer = self.get_target_module(target_layer)
101 | ablation_layer = AblationLayer()
102 | replace_layer_recursive(self.model, original_target_layer, ablation_layer)
103 |
104 | # get weights
105 | number_of_channels = feature_maps.size(1)
106 | weights = []
107 | with torch.no_grad():
108 | for batch_idx, (target, I) in enumerate(zip(target_class, img)):
109 | new_scores = []
110 | batch_tensor = I.repeat(batch_size, 1, 1, 1)
111 |
112 | channels_to_ablate = ablation_layer.feature_maps_to_be_ablated(
113 | feature_maps[batch_idx, :], ratio_channels_to_ablate
114 | )
115 | number_of_channels_to_ablate = len(channels_to_ablate)
116 |
117 | for i in range(0, number_of_channels_to_ablate, batch_size):
118 | if i + batch_size > number_of_channels_to_ablate:
119 | batch_tensor = batch_tensor[:(number_of_channels_to_ablate - i)]
120 |
121 | ablation_layer.set_next_batch(
122 | batch_idx, feature_maps, batch_tensor.size(0)
123 | )
124 |
125 | new_scores.extend([get_targets(o, target).cpu().detach() for o in self.model(batch_tensor)])
126 | ablation_layer.indices = ablation_layer.indices[batch_size:]
127 |
128 | new_scores = self.assemble_ablation_scores(
129 | new_scores, original_scores[batch_idx], channels_to_ablate, number_of_channels
130 | )
131 | weights.extend(new_scores)
132 |
133 | weights = np.float32(weights)
134 | weights = weights.reshape(feature_maps.shape[:2])
135 | original_scores = np.array(original_scores)[:, None]
136 | weights = (original_scores - weights) / original_scores
137 | weights = torch.from_numpy(weights).to(self.device)[:, :, None, None]
138 |
139 | cam = torch.mul(feature_maps, weights).sum(dim=1, keepdim=True)
140 | cam = F.relu(cam)
141 | cam = F.interpolate(cam, (H, W), mode='bilinear', align_corners=False)
142 | cam = self.normalize_cam(cam)
143 |
144 | replace_layer_recursive(self.model, ablation_layer, original_target_layer)
145 |
146 | return cam
147 |
148 | def get_target_module(self, target_layer: str):
149 | for name, module in self.model.named_modules():
150 | if name == target_layer:
151 | return module
152 | raise ValueError(f"Layer {target_layer} not found in model")
153 |
154 | def assemble_ablation_scores(self,
155 | new_scores: list,
156 | original_score: float,
157 | ablated_channels: np.ndarray,
158 | number_of_channels: int) -> np.ndarray:
159 | """ Take the value from the channels that were ablated,
160 | and just set the original score for the channels that were skipped """
161 |
162 | index = 0
163 | result = []
164 | sorted_indices = np.argsort(ablated_channels)
165 | ablated_channels = ablated_channels[sorted_indices]
166 | new_scores = np.float32(new_scores)[sorted_indices]
167 |
168 | for i in range(number_of_channels):
169 | if index < len(ablated_channels) and ablated_channels[index] == i:
170 | weight = new_scores[index]
171 | index = index + 1
172 | else:
173 | weight = original_score
174 | result.append(weight)
175 |
176 | return result
--------------------------------------------------------------------------------
/attribution/bag_cam.py:
--------------------------------------------------------------------------------
1 | from .base import CAMWrapper
2 |
3 | import torch
4 | from torch.nn import functional as F
5 |
6 |
7 | class BagCAM(CAMWrapper):
8 | def get_mask(self, img: torch.Tensor,
9 | target_class: torch.Tensor,
10 | target_layer: str):
11 | B, C, H, W = img.size()
12 | self.model.eval()
13 | self.model.zero_grad()
14 |
15 | # class-specific backpropagation
16 | logits = self.model(img)
17 | target = self._encode_one_hot(target_class, logits)
18 | self.model.zero_grad()
19 | logits.backward(gradient=target, retain_graph=True)
20 |
21 | # get feature maps and gradients
22 | feature_maps = self._find(self.feature_maps, target_layer)
23 | gradients = self._find(self.gradients, target_layer)
24 |
25 | # generate CAM
26 | with torch.no_grad():
27 | term_2 = gradients * feature_maps
28 | term_1 = term_2 + 1
29 | term_1 = F.adaptive_avg_pool2d(term_1, 1)
30 | cam = F.relu(torch.mul(term_1, term_2)).sum(dim=1, keepdim=True)
31 | cam = F.interpolate(cam, (H, W), mode='bilinear', align_corners=False)
32 | cam = self.normalize_cam(cam)
33 |
34 | return cam
--------------------------------------------------------------------------------
/attribution/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Optional, List, Union
3 | from .utils import normalize_saliency
4 |
5 | class Core(torch.nn.Module):
6 | def __init__(self, model: torch.nn.Module):
7 | super(Core, self).__init__()
8 | self.model = model
9 | self.hooks = list() # list of hooks
10 | self.device = next(model.parameters()).device
11 |
12 | def _encode_one_hot(self, targets: torch.Tensor, logits: torch.Tensor):
13 | if not isinstance(targets, torch.Tensor):
14 | targets = torch.tensor([targets], device=self.device)
15 | targets = targets.view(logits.size(0))
16 | one_hot = torch.nn.functional.one_hot(targets, num_classes=logits.size(-1)).to(self.device)
17 | return one_hot
18 |
19 | def get_mask(self, img, target_class=None):
20 | raise NotImplementedError(
21 | 'A derived class should implemented this method')
22 |
23 | def remove_hooks(self):
24 | for hook in self.hooks:
25 | hook.remove()
26 |
27 | def forward(self, img):
28 | logits = self.model(img)
29 | return logits
30 |
31 | class VanillaGradient(Core):
32 | def __init__(self, model):
33 | super(VanillaGradient, self).__init__(model)
34 |
35 | # return gradients
36 | def get_mask(self, img: torch.Tensor, target_class: torch.Tensor) -> torch.Tensor:
37 | self.model.eval()
38 | self.model.zero_grad()
39 |
40 | img = img.clone()
41 | img.requires_grad = True
42 | img.retain_grad()
43 |
44 | logits = self.model(img)
45 |
46 | target = self._encode_one_hot(target_class, logits)
47 | self.model.zero_grad()
48 | logits.backward(gradient=target, retain_graph=True)
49 | return img.grad.detach()
50 |
51 | def get_smoothed_mask(self, img, target_class, samples=25, std=0.15, process=lambda x: x**2):
52 | std = std * (torch.max(img) - torch.min(img)).detach().cpu().numpy()
53 |
54 | B, C, H, W = img.size()
55 | grad_sum = torch.zeros((B, C, H, W), device=self.device)
56 | for sample in range(samples):
57 | noise = torch.empty(img.size()).normal_(0, std).to(self.device)
58 | noise_image = img + noise
59 | grad_sum += process(self.get_mask(noise_image, target_class))
60 | return grad_sum / samples
61 |
62 |
63 | class CAMWrapper(Core):
64 | def __init__(self, model: torch.nn.Module, reshape_transform=None):
65 | super(CAMWrapper, self).__init__(model)
66 |
67 | self.feature_maps = dict()
68 | self.gradients = dict()
69 | self.reshape_transform = reshape_transform
70 |
71 | def save_feature_maps(name):
72 | def forward_hook(module, input, output):
73 | self.feature_maps[name] = output.detach()
74 |
75 | return forward_hook
76 |
77 | def save_gradients(name):
78 | def _store_grad(grad):
79 | self.gradients[name] = grad.detach()
80 | def forward_hook(module, input, output):
81 | if not hasattr(output, 'requires_grad') or not output.requires_grad:
82 | return
83 | output.register_hook(_store_grad)
84 |
85 | return forward_hook
86 |
87 | for name, module in self.model.named_modules():
88 | self.hooks.append(module.register_forward_hook(save_feature_maps(name)))
89 | self.hooks.append(module.register_forward_hook(save_gradients(name)))
90 |
91 | def _find(self, saved_dict, name: str):
92 | if name in saved_dict.keys():
93 | if self.reshape_transform is not None:
94 | return self.reshape_transform(saved_dict[name])
95 | return saved_dict[name]
96 |
97 | raise ValueError('Invalid layer name')
98 |
99 | def get_mask(self, img: torch.Tensor,
100 | target_class: torch.Tensor,
101 | target_layer: Union[str, List[str]]):
102 | raise NotImplementedError('A derived class should implemented this method')
103 |
104 | @torch.no_grad()
105 | def normalize_cam(self, cam: torch.Tensor):
106 | B, C, H, W = cam.size()
107 | cam = cam.view(cam.size(0), -1)
108 | cam -= cam.min(dim=1, keepdim=True)[0]
109 | cam = cam / (cam.max(dim=1, keepdim=True)[0]+1e-5)
110 | return cam.view(B, C, H, W)
111 |
112 |
113 | class CombinedWrapper(Core):
114 | def __init__(self, model: torch.nn.Module, gradient_net: Core, cam_net: CAMWrapper, reshape_transform=None):
115 | super(CombinedWrapper, self).__init__(model)
116 | self.gradient_net = gradient_net(model)
117 | self.cam_net = cam_net(model, reshape_transform)
118 |
119 | def get_mask(self, img: torch.Tensor, target_class: torch.Tensor, target_layer: str, **kwargs_for_gradient_net):
120 | B, C, H, W = img.size()
121 | self.model.eval()
122 | self.model.zero_grad()
123 | cam = self.cam_net.get_mask(img, target_class, target_layer)
124 | gradients = self.gradient_net.get_mask(img, target_class, **kwargs_for_gradient_net)
125 | gradients = normalize_saliency(gradients, return_device=self.device)
126 |
127 | with torch.no_grad():
128 | attribution = gradients * cam
129 | attribution = torch.nn.functional.relu(attribution)
130 | attribution = self.cam_net.normalize_cam(attribution)
131 |
132 | return attribution
--------------------------------------------------------------------------------
/attribution/blur_ig.py:
--------------------------------------------------------------------------------
1 | from .base import VanillaGradient
2 |
3 | import math
4 | import torch
5 | from torchvision.transforms import GaussianBlur
6 |
7 | def gaussian_blur(img: torch.Tensor, sigma: int):
8 | if sigma == 0:
9 | return img
10 | kernel_size = int(4 * sigma + 0.5) + 1
11 | return GaussianBlur(kernel_size=kernel_size, sigma=sigma)(img)
12 |
13 |
14 | class BlurIG(VanillaGradient):
15 | def get_mask(self, img: torch.Tensor,
16 | target_class: torch.Tensor,
17 | max_sigma: int = 50,
18 | steps: int = 100,
19 | grad_step: float = 0.01,
20 | sqrt: bool = False,
21 | batch_size: int = 4):
22 | self.model.eval()
23 | self.model.zero_grad()
24 |
25 | if sqrt:
26 | sigmas = [math.sqrt(float(i) * max_sigma / float(steps)) for i in range(0, steps+1)]
27 | else:
28 | sigmas = [float(i) * max_sigma / float(steps) for i in range(0, steps+1)]
29 |
30 | step_vector_diff = [sigmas[i+1] - sigmas[i] for i in range(0, steps)]
31 | total_gradients = torch.zeros_like(img)
32 | x_step_batched = []
33 | gaussian_gradient_batched = []
34 |
35 | for i in range(steps):
36 | with torch.no_grad():
37 | x_step = gaussian_blur(img, sigmas[i])
38 | gaussian_gradients = (gaussian_blur(img, sigmas[i] + grad_step) - x_step) / grad_step
39 | x_step_batched.append(x_step)
40 | gaussian_gradient_batched.append(gaussian_gradients)
41 | if len(x_step_batched) == batch_size or i == steps - 1:
42 | x_step_batched = torch.cat(x_step_batched, dim=0)
43 | x_step_batched.requires_grad = True
44 | outputs = torch.softmax(self.model(x_step_batched), dim=1)[:, target_class]
45 | gradients = torch.autograd.grad(outputs, x_step_batched, torch.ones_like(outputs), create_graph=True)[0]
46 | gradients = gradients.detach()
47 | # gradients = super(BlurIG, self).get_mask(x_step_batched, torch.stack([target_class] * x_step_batched.size(0), dim=0))
48 |
49 | with torch.no_grad():
50 | total_gradients += (step_vector_diff[i] *
51 | torch.mul(torch.cat(gaussian_gradient_batched, dim=0), gradients.clone())).sum(dim=0)
52 | x_step_batched = []
53 | gaussian_gradient_batched = []
54 |
55 | with torch.no_grad():
56 | blur_ig = total_gradients * -1.0
57 |
58 | return blur_ig
59 |
60 | def get_smoothed_mask(self, img: torch.Tensor,
61 | target_class: torch.Tensor,
62 | max_sigma: int = 50,
63 | steps: int = 100,
64 | grad_step: float = 0.01,
65 | sqrt: bool = False,
66 | batch_size: int = 4,
67 | samples: int = 25,
68 | std: float = 0.15,
69 | process=lambda x: x**2):
70 | std = std * (torch.max(img) - torch.min(img)).detach().cpu().numpy()
71 |
72 | B, C, H, W = img.size()
73 | grad_sum = torch.zeros((B, C, H, W), device=self.device)
74 | for sample in range(samples):
75 | noise = torch.empty(img.size()).normal_(0, std).to(self.device)
76 | noise_image = img + noise
77 | grad_sum += process(self.get_mask(noise_image, target_class, max_sigma, steps, grad_step, sqrt, batch_size))
78 | return grad_sum / samples
79 |
--------------------------------------------------------------------------------
/attribution/eigen_cam.py:
--------------------------------------------------------------------------------
1 | from .utils import get_2d_projection
2 | from .base import CAMWrapper
3 |
4 | import torch
5 | from torch.nn import functional as F
6 | from typing import List
7 |
8 | class EigenCAM(CAMWrapper):
9 | def get_mask(self, img: torch.Tensor,
10 | target_class: torch.Tensor,
11 | target_layer: str):
12 | B, C, H, W = img.size()
13 | self.model.eval()
14 | self.model.zero_grad()
15 |
16 | # class-specific backpropagation
17 | logits = self.model(img)
18 | target = self._encode_one_hot(target_class, logits)
19 | self.model.zero_grad()
20 | logits.backward(gradient=target, retain_graph=True)
21 |
22 | # get feature maps and gradients
23 | feature_maps = self._find(self.feature_maps, target_layer)
24 |
25 | # generate CAM
26 | with torch.no_grad():
27 | cam = get_2d_projection(feature_maps.cpu().numpy())
28 | cam = torch.from_numpy(cam).to(self.device)[:, None, :, :]
29 | cam = F.relu(cam)
30 | cam = F.interpolate(cam, (H, W), mode='bilinear', align_corners=False)
31 | cam = self.normalize_cam(cam)
32 |
33 | return cam
--------------------------------------------------------------------------------
/attribution/eigen_gradcam.py:
--------------------------------------------------------------------------------
1 | from .utils import get_2d_projection
2 | from .base import CAMWrapper
3 |
4 | import torch
5 | from torch.nn import functional as F
6 | from typing import List
7 |
8 | class EigenGradCAM(CAMWrapper):
9 | def get_mask(self, img: torch.Tensor,
10 | target_class: torch.Tensor,
11 | target_layer: str):
12 | B, C, H, W = img.size()
13 | self.model.eval()
14 | self.model.zero_grad()
15 |
16 | # class-specific backpropagation
17 | logits = self.model(img)
18 | target = self._encode_one_hot(target_class, logits)
19 | self.model.zero_grad()
20 | logits.backward(gradient=target, retain_graph=True)
21 |
22 | # get feature maps and gradients
23 | feature_maps = self._find(self.feature_maps, target_layer)
24 | gradients = self._find(self.gradients, target_layer)
25 |
26 | # generate CAM
27 | with torch.no_grad():
28 | cam = get_2d_projection((gradients * feature_maps).cpu().numpy())
29 | cam = torch.from_numpy(cam).to(self.device)[:, None, :, :]
30 | cam = F.relu(cam)
31 | cam = F.interpolate(cam, (H, W), mode='bilinear', align_corners=False)
32 | cam = self.normalize_cam(cam)
33 |
34 | return cam
--------------------------------------------------------------------------------
/attribution/fullgrad_cam.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.nn import functional as F
4 | from typing import Optional
5 |
6 | from .base import CAMWrapper
7 | from .utils import find_layer_predicate_recursive
8 |
9 | class FullGrad(CAMWrapper):
10 | def __init__(self, model: torch.nn.Module, reshape_transform=None):
11 | if reshape_transform is not None:
12 | print('Warning: FullGrad may not work properly with ViT and Swin Transformer models.')
13 | super().__init__(model, reshape_transform)
14 |
15 | self.target_layers = find_layer_predicate_recursive(self.model, self.layer_with_2D_bias)
16 | self.bias_data = [self.get_bias_data(layer) for layer in self.target_layers]
17 |
18 | def layer_with_2D_bias(self, layer):
19 | bias_target_layers = [torch.nn.Conv2d, torch.nn.BatchNorm2d]
20 | if type(layer) in bias_target_layers and layer.bias is not None:
21 | return True
22 | return False
23 |
24 | def get_bias_data(self, layer):
25 | if isinstance(layer, torch.nn.BatchNorm2d):
26 | bias = - (layer.running_mean * layer.weight
27 | / torch.sqrt(layer.running_var + layer.eps)) + layer.bias
28 | return bias.data
29 | else:
30 | return layer.bias.data
31 |
32 | def get_mask(self, img: torch.Tensor,
33 | target_class: torch.Tensor,
34 | target_layer: Optional[str] = None):
35 | if target_layer is not None:
36 | print('Warning: target_layer is ignored in FullGrad. All bias layers will be used instead')
37 |
38 | B, C, H, W = img.size()
39 | img = torch.autograd.Variable(img, requires_grad=True)
40 | self.model.eval()
41 | self.model.zero_grad()
42 |
43 | # class-specific backpropagation
44 | logits = self.model(img)
45 | target = self._encode_one_hot(target_class, logits)
46 | self.model.zero_grad()
47 | logits.backward(gradient=target, retain_graph=True)
48 |
49 | target_layer_names = []
50 | for name, layer in self.model.named_modules():
51 | if layer in self.target_layers:
52 | target_layer_names.append(name)
53 | gradients_list = [self._find(self.gradients, layer_name) for layer_name in target_layer_names]
54 |
55 | input_gradients = img.grad.detach()
56 | cam_per_target_layer = []
57 |
58 | with torch.no_grad():
59 | gradient_multiplied_input = input_gradients * img
60 | gradient_multiplied_input = torch.abs(gradient_multiplied_input)
61 | gradient_multiplied_input = self.scale_input_across_batch_and_channels(gradient_multiplied_input)
62 |
63 | cam_per_target_layer.append(gradient_multiplied_input)
64 | assert len(gradients_list) == len(self.bias_data)
65 | for bias, gradients in zip(self.bias_data, gradients_list):
66 | bias = bias[None, :, None, None]
67 | bias_grad = torch.abs(bias * gradients)
68 | bias_grad = self.scale_input_across_batch_and_channels(bias_grad, (H, W))
69 | bias_grad = bias_grad.sum(dim=1, keepdim=True)
70 | cam_per_target_layer.append(bias_grad)
71 | cam_per_target_layer = torch.cat(cam_per_target_layer, dim=1)
72 | cam = cam_per_target_layer.sum(dim=1, keepdim=True)
73 | cam = F.relu(cam)
74 | cam = F.interpolate(cam, (H, W), mode='bilinear', align_corners=False)
75 | cam = self.normalize_cam(cam)
76 |
77 | return cam
78 |
79 | @torch.no_grad()
80 | def scale_input_across_batch_and_channels(self, input_tensor: torch.Tensor, target_size=None):
81 | # target_size should be like (H, W)
82 | B, C, H, W = input_tensor.size()
83 | input_tensor = input_tensor.view(B, C, -1)
84 | input_tensor -= input_tensor.min(dim=2, keepdim=True)[0]
85 | input_tensor /= (input_tensor.max(dim=2, keepdim=True)[0] + 1e-7)
86 | input_tensor = input_tensor.view(B, C, H, W)
87 | if target_size is not None:
88 | input_tensor = F.interpolate(input_tensor, target_size, mode='bilinear', align_corners=False)
89 | return input_tensor
90 |
--------------------------------------------------------------------------------
/attribution/grad_cam.py:
--------------------------------------------------------------------------------
1 | from .base import CAMWrapper
2 |
3 | import torch
4 | from torch.nn import functional as F
5 |
6 | class GradCAM(CAMWrapper):
7 | def get_mask(self, img: torch.Tensor,
8 | target_class: torch.Tensor,
9 | target_layer: str):
10 |
11 | B, C, H, W = img.size()
12 | self.model.eval()
13 | self.model.zero_grad()
14 |
15 | # class-specific backpropagation
16 | logits = self.model(img)
17 | target = self._encode_one_hot(target_class, logits)
18 | self.model.zero_grad()
19 | logits.backward(gradient=target, retain_graph=True)
20 |
21 | # get feature maps and gradients
22 | feature_maps = self._find(self.feature_maps, target_layer)
23 | gradients = self._find(self.gradients, target_layer)
24 |
25 | # generate CAM
26 | with torch.no_grad():
27 | weights = F.adaptive_avg_pool2d(gradients, 1)
28 | cam = torch.mul(feature_maps, weights).sum(dim=1, keepdim=True)
29 | cam = F.relu(cam)
30 | cam = F.interpolate(cam, (H, W), mode='bilinear', align_corners=False)
31 | cam = self.normalize_cam(cam)
32 |
33 | return cam
--------------------------------------------------------------------------------
/attribution/grad_cam_plusplus.py:
--------------------------------------------------------------------------------
1 | from .grad_cam import GradCAM
2 |
3 | import torch
4 | from torch.nn import functional as F
5 |
6 |
7 | class GradCAMPlusPlus(GradCAM):
8 | def get_mask(self, img: torch.Tensor,
9 | target_class: torch.Tensor,
10 | target_layer: str):
11 | B, C, H, W = img.size()
12 | self.model.eval()
13 | self.model.zero_grad()
14 |
15 | # class-specific backpropagation
16 | logits = self.model(img)
17 | target = self._encode_one_hot(target_class, logits)
18 | self.model.zero_grad()
19 | logits.backward(gradient=target, retain_graph=True)
20 |
21 | # get feature maps and gradients
22 | feature_maps = self._find(self.feature_maps, target_layer)
23 | gradients = self._find(self.gradients, target_layer)
24 |
25 | # generate CAM
26 | with torch.no_grad():
27 | gradients_power_2 = gradients**2
28 | gradients_power_3 = gradients_power_2 * gradients
29 | sum_feature_maps = torch.sum(feature_maps, dim=(2, 3))
30 | sum_feature_maps = sum_feature_maps[:, :, None, None]
31 | eps = 1e-6
32 | aij = gradients_power_2 / (2 * gradients_power_2 + sum_feature_maps * gradients_power_3 + eps)
33 | aij = torch.where(gradients != 0, aij, 0)
34 | weights = torch.maximum(gradients, torch.tensor(0, device=self.device)) * aij
35 | weights = torch.sum(weights, dim=(2, 3))
36 | weights = weights[:, :, None, None]
37 |
38 | cam = torch.mul(feature_maps, weights).sum(dim=1, keepdim=True)
39 | cam = F.relu(cam)
40 | cam = F.interpolate(cam, (H, W), mode='bilinear', align_corners=False)
41 | cam = self.normalize_cam(cam)
42 |
43 | return cam
--------------------------------------------------------------------------------
/attribution/guided_backprop.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.autograd import Function
4 |
5 | from .base import VanillaGradient
6 | from .utils import replace_all_layer_type_recursive
7 |
8 |
9 | class GuidedBackpropReLU(Function):
10 | @staticmethod
11 | def forward(self, input_img):
12 | positive_mask = (input_img > 0).type_as(input_img)
13 | output = torch.addcmul(
14 | torch.zeros(
15 | input_img.size()).type_as(input_img),
16 | input_img,
17 | positive_mask)
18 | self.save_for_backward(input_img, output)
19 | return output
20 |
21 | @staticmethod
22 | def backward(self, grad_output):
23 | input_img, output = self.saved_tensors
24 | grad_input = None
25 |
26 | positive_mask_1 = (input_img > 0).type_as(grad_output)
27 | positive_mask_2 = (grad_output > 0).type_as(grad_output)
28 | grad_input = torch.addcmul(
29 | torch.zeros(
30 | input_img.size()).type_as(input_img),
31 | torch.addcmul(
32 | torch.zeros(
33 | input_img.size()).type_as(input_img),
34 | grad_output,
35 | positive_mask_1),
36 | positive_mask_2)
37 | return grad_input
38 |
39 |
40 | class GuidedBackpropReLUasModule(torch.nn.Module):
41 | def __init__(self):
42 | super(GuidedBackpropReLUasModule, self).__init__()
43 |
44 | def forward(self, input_img):
45 | return GuidedBackpropReLU.apply(input_img)
46 |
47 |
48 | class GuidedBackProp(VanillaGradient):
49 | def get_mask(self, img: torch.Tensor, target_class: torch.Tensor) -> torch.Tensor:
50 |
51 | replace_all_layer_type_recursive(self.model, torch.nn.ReLU, GuidedBackpropReLUasModule())
52 |
53 | grads = super().get_mask(img, target_class)
54 |
55 | replace_all_layer_type_recursive(self.model, GuidedBackpropReLUasModule, torch.nn.ReLU())
56 |
57 | return grads
--------------------------------------------------------------------------------
/attribution/guided_gradcam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 |
4 | from .base import CombinedWrapper
5 | from .grad_cam import GradCAM
6 | from .guided_backprop import GuidedBackProp
7 |
8 | class GuidedGradCAM(CombinedWrapper):
9 | def __init__(self, model: torch.nn.Module, reshape_transform=None):
10 | gradient_net = GuidedBackProp
11 | cam_net = GradCAM
12 | super(GuidedGradCAM, self).__init__(model, gradient_net, cam_net, reshape_transform)
--------------------------------------------------------------------------------
/attribution/guided_ig.py:
--------------------------------------------------------------------------------
1 | from .base import VanillaGradient
2 |
3 | import torch
4 | import numpy as np
5 | import math
6 |
7 | eps = 1e-9
8 |
9 | @torch.no_grad()
10 | def l1_distance(x, y):
11 | return torch.abs(x - y).sum()
12 |
13 | @torch.no_grad()
14 | def translate_x_to_alpha(x: torch.Tensor, x_input: torch.Tensor, x_baseline: torch.Tensor):
15 | '''Alpha shows the relative position of x in the straight line between x_baseline and x_input'''
16 | alpha = torch.where(x_input-x_baseline != 0,
17 | (x - x_baseline) / (x_input - x_baseline + eps), torch.tensor(np.nan).to(x.device))
18 | return alpha
19 |
20 | @torch.no_grad()
21 | def translate_alpha_to_x(alpha: torch.Tensor, x_input: torch.Tensor, x_baseline: torch.Tensor):
22 | '''Translate alpha to x'''
23 | assert 0.0 <= alpha <= 1.0, 'Alpha should be in the range [0, 1]'
24 | x = x_baseline + alpha * (x_input - x_baseline)
25 | return x
26 |
27 | class GuidedIG(VanillaGradient):
28 | def get_mask(self, img: torch.Tensor,
29 | target_class: torch.Tensor,
30 | baseline='None',
31 | steps=128,
32 | fraction=0.25,
33 | max_dist=0.02):
34 |
35 | self.model.eval()
36 | self.model.zero_grad()
37 |
38 | if baseline is None or baseline == 'None':
39 | baseline = torch.zeros_like(img, device=self.device)
40 | elif baseline == 'black':
41 | baseline = torch.ones_like(img, device=self.device) * torch.min(img).detach()
42 | elif baseline == 'white':
43 | baseline = torch.ones_like(img, device=self.device) * torch.max(img).detach()
44 | else:
45 | raise ValueError(f'Baseline {baseline} is not supported, use "black", "white" or None')
46 |
47 | return self.guided_ig_impl(img, target_class, baseline, steps, fraction, max_dist)
48 |
49 | def get_smoothed_mask(self, img: torch.Tensor,
50 | target_class: torch.Tensor,
51 | baseline = 'None',
52 | steps: int = 128,
53 | fraction: float = 0.25,
54 | max_dist: float = 0.02,
55 | samples: int = 25,
56 | std: float = 0.15,
57 | process=lambda x: x**2):
58 | std = std * (torch.max(img) - torch.min(img)).detach().cpu().numpy()
59 |
60 | B, C, H, W = img.size()
61 | grad_sum = torch.zeros((B, C, H, W), device=self.device)
62 | for sample in range(samples):
63 | noise = torch.empty(img.size()).normal_(0, std).to(self.device)
64 | noise_image = img + noise
65 | grad_sum += process(self.get_mask(noise_image, target_class, baseline, steps, fraction, max_dist))
66 | return grad_sum / samples
67 |
68 | def guided_ig_impl(self, img: torch.Tensor,
69 | target_class: torch.Tensor,
70 | baseline: torch.Tensor,
71 | steps: int,
72 | fraction: float,
73 | max_dist: float):
74 | guided_ig = torch.zeros_like(img, dtype=torch.float64, device=self.device)
75 |
76 | x = baseline.clone()
77 |
78 | l1_total = l1_distance(baseline, img)
79 | if l1_total.sum() < eps:
80 | return guided_ig
81 |
82 | for step in range(steps):
83 | gradients_actual = super().get_mask(x, target_class)
84 | gradients = gradients_actual.clone().detach()
85 | alpha = (step + 1.0) / steps
86 | alpha_min = max(alpha - max_dist, 0.0)
87 | alpha_max = min(alpha + max_dist, 1.0)
88 | x_min = translate_alpha_to_x(alpha_min, img, baseline).detach()
89 | x_max = translate_alpha_to_x(alpha_max, img, baseline).detach()
90 |
91 | with torch.no_grad():
92 | l1_target = l1_total * (1 - (step + 1) / steps)
93 |
94 | # Iterate until the desired L1 distance has been reached.
95 | gamma = torch.tensor(np.inf)
96 | while gamma > 1.0:
97 | x_old = x.clone().detach()
98 | x_alpha = translate_x_to_alpha(x, img, baseline).detach()
99 | x_alpha[torch.isnan(x_alpha)] = alpha_max
100 | x.requires_grad = False
101 | x[x_alpha < alpha_min] = x_min[x_alpha < alpha_min]
102 |
103 | l1_current = l1_distance(x, img)
104 |
105 | if math.isclose(l1_target, l1_current, rel_tol=eps, abs_tol=eps):
106 | with torch.no_grad():
107 | guided_ig += gradients_actual * (x - x_old)
108 | break
109 |
110 | gradients[x == x_max] = torch.tensor(np.inf)
111 |
112 | threshold = torch.quantile(torch.abs(gradients), fraction, dim=None, keepdim=False, interpolation='lower')
113 | s = torch.logical_and(torch.abs(gradients) <= threshold, gradients != torch.tensor(np.inf))
114 |
115 | with torch.no_grad():
116 | l1_s = (torch.abs(x - x_max) * s).sum()
117 |
118 | if l1_s > 0:
119 | gamma = (l1_current - l1_target) / l1_s
120 | else:
121 | gamma = torch.tensor(np.inf)
122 |
123 | if gamma > 1.0:
124 | x[s] = x_max[s]
125 | else:
126 | assert gamma >= 0.0, f'Gamma should be non-negative, but got {gamma.min()}'
127 | x[s] = translate_alpha_to_x(gamma, x_max, x)[s]
128 |
129 | with torch.no_grad():
130 | guided_ig += gradients_actual * (x - x_old)
131 |
132 | return guided_ig
133 |
134 |
135 |
136 |
--------------------------------------------------------------------------------
/attribution/hires_cam.py:
--------------------------------------------------------------------------------
1 | from .base import CAMWrapper
2 |
3 | import torch
4 | from torch.nn import functional as F
5 | from typing import List
6 |
7 | class HiResCAM(CAMWrapper):
8 | def get_mask(self, img: torch.Tensor,
9 | target_class: torch.Tensor,
10 | target_layer: str):
11 | B, C, H, W = img.size()
12 | self.model.eval()
13 | self.model.zero_grad()
14 |
15 | # class-specific backpropagation
16 | logits = self.model(img)
17 | target = self._encode_one_hot(target_class, logits)
18 | self.model.zero_grad()
19 | logits.backward(gradient=target, retain_graph=True)
20 |
21 | # get feature maps and gradients
22 | feature_maps = self._find(self.feature_maps, target_layer)
23 | gradients = self._find(self.gradients, target_layer)
24 |
25 | # generate CAM
26 | with torch.no_grad():
27 | cam = gradients * feature_maps
28 | cam = cam.sum(dim=1, keepdim=True)
29 | cam = F.relu(cam)
30 | cam = F.interpolate(cam, (H, W), mode='bilinear', align_corners=False)
31 | cam = self.normalize_cam(cam)
32 |
33 | return cam
--------------------------------------------------------------------------------
/attribution/ig.py:
--------------------------------------------------------------------------------
1 | from .base import VanillaGradient
2 |
3 | import torch
4 | import numpy as np
5 |
6 | class IntegratedGradients(VanillaGradient):
7 | def get_mask(self, img: torch.Tensor,
8 | target_class: torch.Tensor,
9 | baseline='black',
10 | steps=128,
11 | process=lambda x: x):
12 | if baseline == 'black':
13 | baseline = torch.ones_like(img, device=self.device) * torch.min(img).detach()
14 | elif baseline == 'white':
15 | baseline = torch.ones_like(img, device=self.device) * torch.max(img).detach()
16 | else:
17 | baseline = torch.zeros_like(img, device=self.device)
18 |
19 | B, C, H, W = img.size()
20 | grad_sum = torch.zeros((B, C, H, W), device=self.device)
21 | image_diff = img - baseline
22 |
23 | for step, alpha in enumerate(np.linspace(0, 1, steps)):
24 | image_step = baseline + alpha * image_diff
25 | grad_sum += process(super(IntegratedGradients,
26 | self).get_mask(image_step, target_class))
27 | return grad_sum * image_diff.detach() / steps
28 |
29 | def get_smoothed_mask(self, img: torch.Tensor,
30 | target_class: torch.Tensor,
31 | baseline='black',
32 | steps=128,
33 | process_ig=lambda x: x, # used in self.get_mask
34 | samples=25,
35 | std=0.15,
36 | process=lambda x: x**2):
37 | std = std * (torch.max(img) - torch.min(img)).detach().cpu().numpy()
38 |
39 | B, C, H, W = img.size()
40 | grad_sum = torch.zeros((B, C, H, W), device=self.device)
41 | for sample in range(samples):
42 | noise = torch.empty(img.size()).normal_(0, std).to(self.device)
43 | noise_image = img + noise
44 | grad_sum += process(self.get_mask(noise_image, target_class, baseline, steps, process_ig))
45 | return grad_sum / samples
--------------------------------------------------------------------------------
/attribution/layer_cam.py:
--------------------------------------------------------------------------------
1 | from .base import CAMWrapper
2 |
3 | import torch
4 | from torch.nn import functional as F
5 |
6 |
7 | class LayerCAM(CAMWrapper):
8 | def get_mask(self, img: torch.Tensor,
9 | target_class: torch.Tensor,
10 | target_layer: str):
11 | B, C, H, W = img.size()
12 | self.model.eval()
13 | self.model.zero_grad()
14 |
15 | # class-specific backpropagation
16 | logits = self.model(img)
17 | target = self._encode_one_hot(target_class, logits)
18 | self.model.zero_grad()
19 | logits.backward(gradient=target, retain_graph=True)
20 |
21 | # get feature maps and gradients
22 | feature_maps = self._find(self.feature_maps, target_layer)
23 | gradients = self._find(self.gradients, target_layer)
24 |
25 | # generate CAM
26 | with torch.no_grad():
27 | spatial_weighted_feature_maps = torch.maximum(gradients, torch.tensor(0, device=self.device)) * feature_maps
28 | cam = spatial_weighted_feature_maps.sum(dim=1, keepdim=True)
29 | cam = F.relu(cam)
30 | cam = F.interpolate(cam, (H, W), mode='bilinear', align_corners=False)
31 | cam = self.normalize_cam(cam)
32 |
33 | return cam
--------------------------------------------------------------------------------
/attribution/occlusion.py:
--------------------------------------------------------------------------------
1 | from .base import Core
2 |
3 | import torch
4 |
5 |
6 | class Occlusion(Core):
7 |
8 | @torch.no_grad()
9 | def get_mask(self, img, target_class, size=15, value=0.0):
10 | '''
11 | size: height and width of the occlusion window.
12 | value: value to replace values inside the occlusion window with.
13 | '''
14 | B, C, H, W = img.size()
15 | occlusion_scores = torch.zeros_like(img, device=self.device)
16 | occlusion_window = torch.fill_(torch.zeros((B, C, size, size), device=self.device), value)
17 |
18 | original_output = self.model(img)
19 | for row in range(1 + H - size):
20 | for col in range(1 + W - size):
21 | img_occluded = img.clone()
22 | img_occluded[:, :, row:row+size, col:col+size] = occlusion_window
23 | output = self.model(img_occluded)
24 | score_diff = original_output - output
25 | # the score_diff for the target class
26 | score_diff = score_diff[torch.arange(B), target_class]
27 | occlusion_scores[:, :, row:row+size, col:col+size] += score_diff[:, None, None, None]
28 |
29 | return occlusion_scores
30 |
--------------------------------------------------------------------------------
/attribution/score_cam.py:
--------------------------------------------------------------------------------
1 | from .base import CAMWrapper
2 |
3 | import torch
4 | from torch.nn import functional as F
5 |
6 |
7 | class ScoreCAM(CAMWrapper):
8 | def get_mask(self, img: torch.Tensor,
9 | target_class: torch.Tensor,
10 | target_layer: str,
11 | batch_size: int = 16,):
12 | B, C, H, W = img.size()
13 | self.model.eval()
14 | self.model.zero_grad()
15 |
16 | # class-specific backpropagation
17 | logits = self.model(img)
18 | target = self._encode_one_hot(target_class, logits)
19 | self.model.zero_grad()
20 | logits.backward(gradient=target, retain_graph=True)
21 |
22 | # get feature maps and gradients
23 | feature_maps = self._find(self.feature_maps, target_layer)
24 |
25 | with torch.no_grad():
26 | upsample = torch.nn.UpsamplingBilinear2d(size=(H, W))
27 | upsampled = upsample(feature_maps)
28 | maxs = upsampled.view(upsampled.size(0), upsampled.size(1), -1).max(dim=-1)[0]
29 | mins = upsampled.view(upsampled.size(0), upsampled.size(1), -1).min(dim=-1)[0]
30 |
31 | maxs, mins = maxs[:, :, None, None], mins[:, :, None, None]
32 | upsampled = (upsampled - mins) / (maxs - mins + 1e-8)
33 |
34 | imgs = img[:, None, :, :] * upsampled[:, :, None, :, :]
35 |
36 | get_targets = lambda o, target: o[target]
37 | if isinstance(target_class, torch.Tensor):
38 | target_class = target_class.cpu().tolist()
39 | if not isinstance(target_class, list):
40 | target_class = [target_class]
41 |
42 | scores = []
43 | with torch.no_grad():
44 | for i in range(imgs.size(0)):
45 | input_img = imgs[i]
46 | for batch_i in range(0, input_img.size(0), batch_size):
47 | batch = input_img[batch_i: batch_i + batch_size, :]
48 | outputs = [get_targets(o, target_class[i]).detach() for o in self.model(batch)]
49 | scores.extend(outputs)
50 | scores = torch.tensor(scores)
51 | scores = scores.view(feature_maps.shape[0], feature_maps.shape[1])
52 | weights = F.softmax(scores, dim=-1)
53 | weights = weights.to(self.device)
54 | cam = torch.mul(feature_maps, weights[:, :, None, None]).sum(dim=1, keepdim=True)
55 | cam = F.relu(cam)
56 | cam = F.interpolate(cam, (H, W), mode='bilinear', align_corners=False)
57 | cam = self.normalize_cam(cam)
58 |
59 | return cam
60 |
--------------------------------------------------------------------------------
/attribution/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .visualization_maps import normalize_saliency, visualize_single_saliency
2 | from .access_layers import find_layer_predicate_recursive, find_layer_types_recursive, replace_all_layer_type_recursive, replace_layer_recursive
3 | from .svd_on_feature_maps import get_2d_projection
4 | from .reshape_for_transformer import get_reshape_transform
--------------------------------------------------------------------------------
/attribution/utils/access_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def replace_layer_recursive(model: torch.nn.Module, old_layer: torch.nn.Module, new_layer: torch.nn.Module):
4 | for name, layer in model._modules.items():
5 | if layer == old_layer:
6 | model._modules[name] = new_layer
7 | return True
8 | elif replace_layer_recursive(layer, old_layer, new_layer):
9 | return True
10 | return False
11 |
12 |
13 | def replace_all_layer_type_recursive(model: torch.nn.Module, old_layer_type: torch.nn.Module, new_layer: torch.nn.Module):
14 | '''new_layer is a instance of the new layer type, not the type itself.'''
15 | for name, layer in model._modules.items():
16 | if isinstance(layer, old_layer_type):
17 | model._modules[name] = new_layer
18 | replace_all_layer_type_recursive(layer, old_layer_type, new_layer)
19 |
20 |
21 | def find_layer_types_recursive(model: torch.nn.Module, layer_types: torch.nn.Module):
22 | def predicate(layer):
23 | return type(layer) in layer_types
24 | return find_layer_predicate_recursive(model, predicate)
25 |
26 |
27 | def find_layer_predicate_recursive(model: torch.nn.Module, predicate):
28 | result = []
29 | for name, layer in model._modules.items():
30 | if predicate(layer):
31 | result.append(layer)
32 | result.extend(find_layer_predicate_recursive(layer, predicate))
33 | return result
34 |
--------------------------------------------------------------------------------
/attribution/utils/reshape_for_transformer.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 | def get_reshape_transform(has_cls_token=False):
5 | def reshape_transform(x: torch.Tensor):
6 | if not has_cls_token:
7 | # typically is the output of the swin transformer
8 | x = x.permute(0, 3, 1, 2)
9 | return x
10 | # typically is the output of the ViT
11 | B, token_numbers, patch_size = x.size()
12 | if has_cls_token:
13 | token_numbers -= 1
14 | x = x[:, 1:, :]
15 |
16 | img_width = int(math.sqrt(token_numbers))
17 | assert img_width * img_width == token_numbers
18 | x = x.view(B, img_width, img_width, patch_size)
19 | x = x.permute(0, 3, 1, 2)
20 | return x
21 |
22 | return reshape_transform
--------------------------------------------------------------------------------
/attribution/utils/svd_on_feature_maps.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | def get_2d_projection(activation_batch):
5 | # TODO: use pytorch batch svd implementation
6 | activation_batch[np.isnan(activation_batch)] = 0
7 | projections = []
8 | for activations in activation_batch:
9 | reshaped_activations = (activations).reshape(
10 | activations.shape[0], -1).transpose()
11 | # Centering before the SVD seems to be important here,
12 | # Otherwise the image returned is negative
13 | reshaped_activations = reshaped_activations - \
14 | reshaped_activations.mean(axis=0)
15 | U, S, VT = np.linalg.svd(reshaped_activations, full_matrices=True)
16 | projection = reshaped_activations @ VT[0, :]
17 | projection = projection.reshape(activations.shape[1:])
18 | projections.append(projection)
19 | return np.float32(projections)
20 |
--------------------------------------------------------------------------------
/attribution/utils/visualization_maps.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | @torch.no_grad()
8 | def normalize_saliency(saliency_map, return_device=torch.device('cpu')):
9 | B, C, H, W = saliency_map.size()
10 | if C > 1: # the input image is multi-channel
11 | saliency_map = saliency_map.max(dim=1, keepdim=True)[0]
12 | saliency_map = F.relu(saliency_map, [0])
13 | # the shape is B x 1 x H x W, normalize the saliency map along the channel dimension
14 | saliency_map = saliency_map.view(saliency_map.size(0), -1)
15 | saliency_map -= saliency_map.min(dim=1, keepdim=True)[0]
16 | saliency_map /= saliency_map.max(dim=1, keepdim=True)[0]
17 | saliency_map = saliency_map.view(B, 1, H, W)
18 | return saliency_map.to(return_device)
19 |
20 | def visualize_single_saliency(saliency_map, img_size=None):
21 | B, C, H, W = saliency_map.size()
22 | assert B == 1, 'The input saliency map should not be batch inputs'
23 | if saliency_map.max() > 1 or C > 1:
24 | saliency_map = normalize_saliency(saliency_map)
25 | saliency_map = saliency_map.view(H, W, 1)
26 | saliency_map = saliency_map.cpu().numpy()
27 | if img_size is not None:
28 | saliency_map = cv2.resize(saliency_map, (img_size[1], img_size[0]), interpolation=cv2.INTER_LINEAR)
29 | else:
30 | saliency_map = cv2.resize(saliency_map, (W, H), interpolation=cv2.INTER_LINEAR)
31 | saliency_map = cv2.applyColorMap(np.uint8(saliency_map * 255.0), cv2.COLORMAP_JET)
32 | saliency_map = cv2.cvtColor(saliency_map, cv2.COLOR_BGR2RGB)
33 | plt.axis('off')
34 | plt.imshow(saliency_map)
35 | return saliency_map
--------------------------------------------------------------------------------
/attribution/xgrad_cam.py:
--------------------------------------------------------------------------------
1 | from .grad_cam import GradCAM
2 |
3 | import torch
4 | from torch.nn import functional as F
5 |
6 |
7 | class XGradCAM(GradCAM):
8 | def get_mask(self, img: torch.Tensor,
9 | target_class: torch.Tensor,
10 | target_layer: str):
11 |
12 | B, C, H, W = img.size()
13 | self.model.eval()
14 | self.model.zero_grad()
15 |
16 | # class-specific backpropagation
17 | logits = self.model(img)
18 | target = self._encode_one_hot(target_class, logits)
19 | self.model.zero_grad()
20 | logits.backward(gradient=target, retain_graph=True)
21 |
22 | # get feature maps and gradients
23 | feature_maps = self._find(self.feature_maps, target_layer)
24 | gradients = self._find(self.gradients, target_layer)
25 |
26 | # generate CAM
27 | with torch.no_grad():
28 | sum_feature_maps = torch.sum(feature_maps, dim=(2, 3))
29 | sum_feature_maps = sum_feature_maps[:, :, None, None]
30 | eps = 1e-7
31 | weights = gradients * feature_maps / (sum_feature_maps + eps)
32 | weights = torch.sum(weights, dim=(2, 3))
33 | weights = weights[:, :, None, None]
34 | cam = torch.mul(feature_maps, weights).sum(dim=1, keepdim=True)
35 | cam = F.relu(cam)
36 | cam = F.interpolate(cam, (H, W), mode='bilinear', align_corners=False)
37 | cam = self.normalize_cam(cam)
38 |
39 | return cam
--------------------------------------------------------------------------------
/cam_visualization_examples.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import matplotlib.pyplot as plt
3 | import torch
4 | import timm
5 | from timm.data import resolve_model_data_config
6 | from timm.data.transforms_factory import create_transform
7 | import requests
8 |
9 | from attribution import GradCAM, GradCAMPlusPlus, XGradCAM, BagCAM, ScoreCAM, LayerCAM, AblationCAM, FullGrad, EigenCAM, EigenGradCAM, HiResCAM
10 | from attribution.utils import normalize_saliency, visualize_single_saliency
11 |
12 |
13 | if __name__ == '__main__':
14 |
15 | # Load imagenet labels
16 | IMAGENET_1k_URL = 'https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt'
17 | IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n')
18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19 |
20 | # Load a pretrained model
21 | model = timm.create_model('resnet50', pretrained=True)
22 | model = model.to(device)
23 | model.eval()
24 | config = resolve_model_data_config(model, None)
25 | transform = create_transform(**config)
26 |
27 | # Load an image
28 | dog = Image.open('examples/dog.png').convert('RGB')
29 | dog_tensor = transform(dog).unsqueeze(0)
30 | H, W = dog_tensor.shape[-2:]
31 |
32 | # Predict the image
33 | img = transform(dog).unsqueeze(0)
34 | img = torch.cat([img, img])
35 | img = img.to(device)
36 | output = model(img)
37 | target_index = torch.argmax(output, dim=1).cpu()
38 | print('Predicted:', IMAGENET_1k_LABELS[target_index[0].item()])
39 |
40 | # Get the target layer
41 | target_layer_candidates = list()
42 | for name, module in model.named_modules():
43 | print(name)
44 | target_layer_candidates.append(name)
45 | '''e.g.,
46 | --ResNet:
47 | conv1
48 | bn1
49 | act1
50 | maxpool
51 | layer1
52 | layer2
53 | layer3
54 | layer4
55 | global_pool
56 | fc
57 |
58 | --Xception41:
59 | stem
60 | blocks
61 | act
62 | head
63 | '''
64 | target_layer = input('Enter the target layer: ')
65 | while target_layer not in target_layer_candidates:
66 | print('Invalid layer name')
67 | target_layer = input('Enter the target layer: ')
68 |
69 | # GradCAM
70 | gradcam_net = GradCAM(model)
71 | grad_cam = normalize_saliency(gradcam_net.get_mask(img, target_index, target_layer))
72 | print('GradCAM', grad_cam.shape)
73 |
74 | # GradCAM++
75 | gradcam_plus_plus_net = GradCAMPlusPlus(model)
76 | grad_cam_plus_plus = normalize_saliency(gradcam_plus_plus_net.get_mask(img, target_index, target_layer))
77 | print('GradCAM++:', grad_cam_plus_plus.shape)
78 |
79 | # HiResCAM
80 | hirescam_net = HiResCAM(model)
81 | hires_cam = normalize_saliency(hirescam_net.get_mask(img, target_index, target_layer))
82 | print('HiResCAM:', hires_cam.shape)
83 |
84 | # XGradCAM
85 | xgradcam_net = XGradCAM(model)
86 | xgrad_cam = normalize_saliency(xgradcam_net.get_mask(img, target_index, target_layer))
87 | print('XGradCAM:', xgrad_cam.shape)
88 |
89 | # BagCAM
90 | bagcam_net = BagCAM(model)
91 | bag_cam = normalize_saliency(bagcam_net.get_mask(img, target_index, target_layer))
92 | print('BagCAM:', bag_cam.shape)
93 |
94 | # ScoreCAM
95 | scorecam_net = ScoreCAM(model)
96 | score_cam = normalize_saliency(scorecam_net.get_mask(img, target_index, target_layer))
97 | print('ScoreCAM', score_cam.shape)
98 |
99 | # EigenCAM
100 | eigencam_net = EigenCAM(model)
101 | eigen_cam = normalize_saliency(eigencam_net.get_mask(img, target_index, target_layer))
102 | print('EigenCAM', eigen_cam.shape)
103 |
104 | # EigenGradCAM
105 | eigengradcam_net = EigenGradCAM(model)
106 | eigen_grad_cam = normalize_saliency(eigengradcam_net.get_mask(img, target_index, target_layer))
107 | print('EigenGradCAM', eigen_grad_cam.shape)
108 |
109 | # LayerCAM
110 | layercam_net = LayerCAM(model)
111 | layer_cam = normalize_saliency(layercam_net.get_mask(img, target_index, target_layer))
112 | print('LayerCAM', layer_cam.shape)
113 |
114 | # AblationCAM
115 | ablationcam_net = AblationCAM(model)
116 | ablation_cam = normalize_saliency(ablationcam_net.get_mask(img, target_index, target_layer))
117 | print('AblationCAM', ablation_cam.shape)
118 |
119 | # FullGrad
120 | fullgrad_net = FullGrad(model)
121 | full_grad = normalize_saliency(fullgrad_net.get_mask(img, target_index, target_layer=None))
122 | print('FullGrad', full_grad.shape)
123 |
124 | # Visualize the saliency maps
125 | plt.figure(figsize=(16, 15))
126 | plt.subplot(3,5,1)
127 | plt.title('Input')
128 | plt.axis('off')
129 | plt.imshow(dog)
130 | plt.subplot(3,5,2)
131 | plt.title('GradCAM')
132 | visualize_single_saliency(grad_cam[0].unsqueeze(0))
133 | plt.subplot(3,5,3)
134 | plt.title('GradCAM++')
135 | visualize_single_saliency(grad_cam_plus_plus[0].unsqueeze(0))
136 | plt.subplot(3,5,4)
137 | plt.title('HiResCAM')
138 | visualize_single_saliency(hires_cam[0].unsqueeze(0))
139 | plt.subplot(3,5,5)
140 | plt.title('FullGrad')
141 | visualize_single_saliency(full_grad[0].unsqueeze(0))
142 | plt.subplot(3,5,6)
143 | plt.title('AblationCAM')
144 | visualize_single_saliency(ablation_cam[0].unsqueeze(0))
145 | plt.subplot(3,5,7)
146 | plt.title('ScoreCAM')
147 | visualize_single_saliency(score_cam[0].unsqueeze(0))
148 | plt.subplot(3,5,8)
149 | plt.title('EigenCAM')
150 | visualize_single_saliency(eigen_cam[0].unsqueeze(0))
151 | plt.subplot(3,5,9)
152 | plt.title('EigenGradCAM')
153 | visualize_single_saliency(eigen_grad_cam[0].unsqueeze(0))
154 | plt.subplot(3,5,10)
155 | plt.title('XGradCAM')
156 | visualize_single_saliency(xgrad_cam[0].unsqueeze(0))
157 | plt.subplot(3,5,11)
158 | plt.title('LayerCAM')
159 | visualize_single_saliency(layer_cam[0].unsqueeze(0))
160 | plt.subplot(3,5,12)
161 | plt.title('BagCAM')
162 | visualize_single_saliency(bag_cam[0].unsqueeze(0))
163 |
164 |
165 | plt.tight_layout()
166 | plt.savefig('examples/cam_visualization.png', bbox_inches='tight', pad_inches=0.5)
167 |
168 |
169 |
170 |
--------------------------------------------------------------------------------
/cam_visualization_for_transformers_examples.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import matplotlib.pyplot as plt
3 | import torch
4 | import timm
5 | from timm.data import resolve_model_data_config
6 | from timm.data.transforms_factory import create_transform
7 | import requests
8 |
9 | from attribution import GradCAM, GradCAMPlusPlus, XGradCAM, BagCAM, ScoreCAM, LayerCAM, AblationCAM, FullGrad, EigenCAM, EigenGradCAM, HiResCAM
10 | from attribution.utils import normalize_saliency, visualize_single_saliency, get_reshape_transform
11 |
12 |
13 | if __name__ == '__main__':
14 |
15 | # Load imagenet labels
16 | IMAGENET_1k_URL = 'https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt'
17 | IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n')
18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19 |
20 | # Load a pretrained model
21 | # model = timm.create_model('vit_tiny_patch16_224.augreg_in21k_ft_in1k', pretrained=True)
22 | model = timm.create_model('swin_base_patch4_window7_224.ms_in22k_ft_in1k', pretrained=True)
23 | model = model.to(device)
24 | model.eval()
25 | config = resolve_model_data_config(model, None)
26 | transform = create_transform(**config)
27 |
28 | # Load an image
29 | dog = Image.open('examples/dog.png').convert('RGB')
30 | dog_tensor = transform(dog).unsqueeze(0)
31 | H, W = dog_tensor.shape[-2:]
32 |
33 | # Predict the image
34 | img = transform(dog).unsqueeze(0)
35 | img = torch.cat([img, img])
36 | img = img.to(device)
37 | output = model(img)
38 | target_index = torch.argmax(output, dim=1).cpu()
39 | print('Predicted:', IMAGENET_1k_LABELS[target_index[0].item()])
40 |
41 | # Get the target layer
42 | target_layer_candidates = list()
43 | for name, module in model.named_modules():
44 | print(name)
45 | target_layer_candidates.append(name)
46 | '''e.g.,
47 | --ResNet:
48 | conv1
49 | bn1
50 | act1
51 | maxpool
52 | layer1
53 | layer2
54 | layer3
55 | layer4
56 | global_pool
57 | fc
58 |
59 | --Xception41:
60 | stem
61 | blocks
62 | act
63 | head
64 | '''
65 | target_layer = input('Enter the target layer for swin_base_patch4_window7_224: ')
66 | while target_layer not in target_layer_candidates:
67 | print('Invalid layer name')
68 | target_layer = input('Enter the target layer: ')
69 |
70 | # GradCAM
71 | gradcam_net = GradCAM(model, get_reshape_transform(has_cls_token=False))
72 | grad_cam = normalize_saliency(gradcam_net.get_mask(img, target_index, target_layer))
73 | print('GradCAM', grad_cam.shape)
74 |
75 | # GradCAM++
76 | gradcam_plus_plus_net = GradCAMPlusPlus(model, get_reshape_transform(has_cls_token=False))
77 | grad_cam_plus_plus = normalize_saliency(gradcam_plus_plus_net.get_mask(img, target_index, target_layer))
78 | print('GradCAM++:', grad_cam_plus_plus.shape)
79 |
80 | # HiResCAM
81 | hirescam_net = HiResCAM(model, get_reshape_transform(has_cls_token=False))
82 | hires_cam = normalize_saliency(hirescam_net.get_mask(img, target_index, target_layer))
83 | print('HiResCAM:', hires_cam.shape)
84 |
85 | # XGradCAM
86 | xgradcam_net = XGradCAM(model, get_reshape_transform(has_cls_token=False))
87 | xgrad_cam = normalize_saliency(xgradcam_net.get_mask(img, target_index, target_layer))
88 | print('XGradCAM:', xgrad_cam.shape)
89 |
90 | # LayerCAM
91 | layercam_net = LayerCAM(model, get_reshape_transform(has_cls_token=False))
92 | layer_cam = normalize_saliency(layercam_net.get_mask(img, target_index, target_layer))
93 | print('LayerCAM', layer_cam.shape)
94 |
95 | # Visualize the saliency maps
96 | plt.figure(figsize=(18, 10))
97 | plt.subplot(2,6,1)
98 | plt.title('Input')
99 | plt.axis('off')
100 | plt.imshow(dog)
101 | plt.subplot(2,6,2)
102 | plt.title('Swin GradCAM')
103 | visualize_single_saliency(grad_cam[0].unsqueeze(0))
104 | plt.subplot(2,6,3)
105 | plt.title('Swin GradCAM++')
106 | visualize_single_saliency(grad_cam_plus_plus[0].unsqueeze(0))
107 | plt.subplot(2,6,4)
108 | plt.title('Swin HiResCAM')
109 | visualize_single_saliency(hires_cam[0].unsqueeze(0))
110 | plt.subplot(2,6,5)
111 | plt.title('Swin XGradCAM')
112 | visualize_single_saliency(xgrad_cam[0].unsqueeze(0))
113 | plt.subplot(2,6,6)
114 | plt.title('Swin LayerCAM')
115 | visualize_single_saliency(layer_cam[0].unsqueeze(0))
116 |
117 | model = timm.create_model('vit_tiny_patch16_224.augreg_in21k_ft_in1k', pretrained=True)
118 | model = model.to(device)
119 | model.eval()
120 | config = resolve_model_data_config(model, None)
121 | transform = create_transform(**config)
122 |
123 | # Load an image
124 | dog = Image.open('examples/dog.png').convert('RGB')
125 | dog_tensor = transform(dog).unsqueeze(0)
126 | H, W = dog_tensor.shape[-2:]
127 |
128 | # Predict the image
129 | img = transform(dog).unsqueeze(0)
130 | img = torch.cat([img, img])
131 | img = img.to(device)
132 | output = model(img)
133 | target_index = torch.argmax(output, dim=1).cpu()
134 | print('Predicted:', IMAGENET_1k_LABELS[target_index[0].item()])
135 |
136 | # Get the target layer
137 | target_layer_candidates = list()
138 | for name, module in model.named_modules():
139 | print(name)
140 | target_layer_candidates.append(name)
141 | '''e.g.,
142 | --ResNet:
143 | conv1
144 | bn1
145 | act1
146 | maxpool
147 | layer1
148 | layer2
149 | layer3
150 | layer4
151 | global_pool
152 | fc
153 |
154 | --Xception41:
155 | stem
156 | blocks
157 | act
158 | head
159 | '''
160 | target_layer = input('Enter the target layer for swin_base_patch4_window7_224: ')
161 | while target_layer not in target_layer_candidates:
162 | print('Invalid layer name')
163 | target_layer = input('Enter the target layer: ')
164 |
165 | # GradCAM
166 | gradcam_net = GradCAM(model, get_reshape_transform(has_cls_token=True))
167 | grad_cam = normalize_saliency(gradcam_net.get_mask(img, target_index, target_layer))
168 | print('GradCAM', grad_cam.shape)
169 |
170 | # GradCAM++
171 | gradcam_plus_plus_net = GradCAMPlusPlus(model, get_reshape_transform(has_cls_token=True))
172 | grad_cam_plus_plus = normalize_saliency(gradcam_plus_plus_net.get_mask(img, target_index, target_layer))
173 | print('GradCAM++:', grad_cam_plus_plus.shape)
174 |
175 | # HiResCAM
176 | hirescam_net = HiResCAM(model, get_reshape_transform(has_cls_token=True))
177 | hires_cam = normalize_saliency(hirescam_net.get_mask(img, target_index, target_layer))
178 | print('HiResCAM:', hires_cam.shape)
179 |
180 | # XGradCAM
181 | xgradcam_net = XGradCAM(model, get_reshape_transform(has_cls_token=True))
182 | xgrad_cam = normalize_saliency(xgradcam_net.get_mask(img, target_index, target_layer))
183 | print('XGradCAM:', xgrad_cam.shape)
184 |
185 | # LayerCAM
186 | layercam_net = LayerCAM(model, get_reshape_transform(has_cls_token=True))
187 | layer_cam = normalize_saliency(layercam_net.get_mask(img, target_index, target_layer))
188 | print('LayerCAM', layer_cam.shape)
189 |
190 | plt.subplot(2,6,8)
191 | plt.title('ViT GradCAM')
192 | visualize_single_saliency(grad_cam[0].unsqueeze(0))
193 | plt.subplot(2,6,9)
194 | plt.title('ViT GradCAM++')
195 | visualize_single_saliency(grad_cam_plus_plus[0].unsqueeze(0))
196 | plt.subplot(2,6,10)
197 | plt.title('ViT HiResCAM')
198 | visualize_single_saliency(hires_cam[0].unsqueeze(0))
199 | plt.subplot(2,6,11)
200 | plt.title('ViT XGradCAM')
201 | visualize_single_saliency(xgrad_cam[0].unsqueeze(0))
202 | plt.subplot(2,6,12)
203 | plt.title('ViT LayerCAM')
204 | visualize_single_saliency(layer_cam[0].unsqueeze(0))
205 |
206 | plt.tight_layout()
207 | plt.savefig('examples/cam_visualization_for_transformers.png', bbox_inches='tight', pad_inches=0.5)
208 |
209 |
210 |
211 |
--------------------------------------------------------------------------------
/combine_cam_and_gradients_visualization_examples.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import matplotlib.pyplot as plt
3 | import torch
4 | import timm
5 | from timm.data import resolve_model_data_config
6 | from timm.data.transforms_factory import create_transform
7 | import requests
8 |
9 | from attribution import GradCAM, GradCAMPlusPlus, XGradCAM, BagCAM, ScoreCAM, LayerCAM, AblationCAM, FullGrad, EigenCAM, EigenGradCAM, HiResCAM
10 | from attribution import VanillaGradient, GuidedBackProp, IntegratedGradients, BlurIG, GuidedIG
11 | from attribution import CombinedWrapper, GuidedGradCAM
12 | from attribution.utils import normalize_saliency, visualize_single_saliency
13 |
14 |
15 | if __name__ == '__main__':
16 |
17 | # Load imagenet labels
18 | IMAGENET_1k_URL = 'https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt'
19 | IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n')
20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21 |
22 | # Load a pretrained model
23 | model = timm.create_model('resnet50', pretrained=True)
24 | model = model.to(device)
25 | model.eval()
26 | config = resolve_model_data_config(model, None)
27 | transform = create_transform(**config)
28 |
29 | # Load an image
30 | dog = Image.open('examples/dog.png').convert('RGB')
31 | dog_tensor = transform(dog).unsqueeze(0)
32 | H, W = dog_tensor.shape[-2:]
33 |
34 | # Predict the image
35 | img = transform(dog).unsqueeze(0)
36 | img = torch.cat([img, img])
37 | img = img.to(device)
38 | output = model(img)
39 | target_index = torch.argmax(output, dim=1).cpu()
40 | print('Predicted:', IMAGENET_1k_LABELS[target_index[0].item()])
41 |
42 | # Get the target layer
43 | target_layer_candidates = list()
44 | for name, module in model.named_modules():
45 | print(name)
46 | target_layer_candidates.append(name)
47 | '''e.g.,
48 | --ResNet:
49 | conv1
50 | bn1
51 | act1
52 | maxpool
53 | layer1
54 | layer2
55 | layer3
56 | layer4
57 | global_pool
58 | fc
59 |
60 | --Xception41:
61 | stem
62 | blocks
63 | act
64 | head
65 | '''
66 | target_layer = input('Enter the target layer: ')
67 | while target_layer not in target_layer_candidates:
68 | print('Invalid layer name')
69 | target_layer = input('Enter the target layer: ')
70 |
71 | # Guided Grad-CAM
72 | net = GuidedGradCAM(model)
73 | guided_gradcam = normalize_saliency(net.get_mask(img, target_index, target_layer))
74 | print('Guided Grad-CAM', guided_gradcam.shape)
75 |
76 | # GuidedBackprop + FullGrad
77 | net = CombinedWrapper(model, GuidedBackProp, FullGrad)
78 | guided_fullgrad = normalize_saliency(net.get_mask(img, target_index, target_layer))
79 | print('GuidedBackProp + FullGrad', guided_fullgrad.shape)
80 |
81 | # BlurIG + EigenCAM
82 | net = CombinedWrapper(model, BlurIG, EigenCAM)
83 | kwargs = {'steps': 20}
84 | blurig_eigencam = normalize_saliency(net.get_mask(img, target_index, target_layer, **kwargs))
85 | print('BlurIG + EigenCAM', blurig_eigencam.shape)
86 |
87 | # Visualize the results
88 | plt.figure(figsize=(16, 5))
89 | plt.subplot(1, 4, 1)
90 | plt.title('Input')
91 | plt.axis('off')
92 | plt.imshow(dog)
93 | plt.subplot(1, 4, 2)
94 | plt.title('Guided Grad-CAM')
95 | visualize_single_saliency(guided_gradcam[0].unsqueeze(0))
96 | plt.subplot(1, 4, 3)
97 | plt.title('GuidedBackProp + FullGrad')
98 | visualize_single_saliency(guided_fullgrad[0].unsqueeze(0))
99 | plt.subplot(1, 4, 4)
100 | plt.title('BlurIG + EigenCAM')
101 | visualize_single_saliency(blurig_eigencam[0].unsqueeze(0))
102 |
103 | plt.tight_layout()
104 | plt.savefig('examples/combine_cam_and_gradients_visualization.png', bbox_inches='tight', pad_inches=0.5)
--------------------------------------------------------------------------------
/examples/attribution_methods.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riverback/pytorch_attribution/a27a05b308b736e38fb6c96b2b3ddce67d187ac3/examples/attribution_methods.png
--------------------------------------------------------------------------------
/examples/cam_visualization.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riverback/pytorch_attribution/a27a05b308b736e38fb6c96b2b3ddce67d187ac3/examples/cam_visualization.png
--------------------------------------------------------------------------------
/examples/cam_visualization_for_transformers.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riverback/pytorch_attribution/a27a05b308b736e38fb6c96b2b3ddce67d187ac3/examples/cam_visualization_for_transformers.png
--------------------------------------------------------------------------------
/examples/cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riverback/pytorch_attribution/a27a05b308b736e38fb6c96b2b3ddce67d187ac3/examples/cat.png
--------------------------------------------------------------------------------
/examples/combine_cam_and_gradients_visualization.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riverback/pytorch_attribution/a27a05b308b736e38fb6c96b2b3ddce67d187ac3/examples/combine_cam_and_gradients_visualization.png
--------------------------------------------------------------------------------
/examples/dog.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riverback/pytorch_attribution/a27a05b308b736e38fb6c96b2b3ddce67d187ac3/examples/dog.png
--------------------------------------------------------------------------------
/examples/dog_and_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riverback/pytorch_attribution/a27a05b308b736e38fb6c96b2b3ddce67d187ac3/examples/dog_and_cat.png
--------------------------------------------------------------------------------
/examples/gradients_visualization.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riverback/pytorch_attribution/a27a05b308b736e38fb6c96b2b3ddce67d187ac3/examples/gradients_visualization.png
--------------------------------------------------------------------------------
/examples/gradients_visualization_for_transformers.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riverback/pytorch_attribution/a27a05b308b736e38fb6c96b2b3ddce67d187ac3/examples/gradients_visualization_for_transformers.png
--------------------------------------------------------------------------------
/examples/perturbation_based_visualization.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riverback/pytorch_attribution/a27a05b308b736e38fb6c96b2b3ddce67d187ac3/examples/perturbation_based_visualization.png
--------------------------------------------------------------------------------
/examples/quick_start.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riverback/pytorch_attribution/a27a05b308b736e38fb6c96b2b3ddce67d187ac3/examples/quick_start.png
--------------------------------------------------------------------------------
/gradients_visualization_examples.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import matplotlib.pyplot as plt
3 | import torch
4 | import timm
5 | from timm.data import resolve_model_data_config
6 | from timm.data.transforms_factory import create_transform
7 | import requests
8 |
9 |
10 | from attribution import VanillaGradient, IntegratedGradients, BlurIG, GuidedIG, GuidedBackProp
11 | from attribution.utils import normalize_saliency, visualize_single_saliency
12 |
13 |
14 | if __name__ == '__main__':
15 |
16 | # Load imagenet labels
17 | IMAGENET_1k_URL = 'https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt'
18 | IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n')
19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20 |
21 | # Load a pretrained model
22 | model = timm.create_model('resnet50', pretrained=True)
23 | model = model.to(device)
24 | model.eval()
25 | config = resolve_model_data_config(model, None)
26 | transform = create_transform(**config)
27 |
28 | # Load an image
29 | dog = Image.open('examples/dog.png').convert('RGB')
30 | dog_tensor = transform(dog).unsqueeze(0)
31 | H, W = dog_tensor.shape[-2:]
32 |
33 | # Predict the image
34 | img = transform(dog).unsqueeze(0)
35 | img = torch.cat([img, img])
36 | img = img.to(device)
37 | output = model(img)
38 | target_index = torch.argmax(output, dim=1).cpu()
39 | print('Predicted:', IMAGENET_1k_LABELS[target_index[0].item()])
40 |
41 | # Vanilla Gradient
42 | gradient_net = VanillaGradient(model)
43 | attribution_gradients = normalize_saliency(gradient_net.get_mask(img, target_index))
44 | attribution_smooth_gradients = normalize_saliency(gradient_net.get_smoothed_mask(img, target_index, samples=10, std=0.1))
45 |
46 | # Guided Backpropagation
47 | guided_bp_net = GuidedBackProp(model)
48 | attribution_guided_bp = normalize_saliency(guided_bp_net.get_mask(img, target_index))
49 |
50 | # Integrated Gradients
51 | ig_net = IntegratedGradients(model)
52 | attribution_ig = normalize_saliency(ig_net.get_mask(img, target_index, steps=100))
53 | attribution_smooth_ig = normalize_saliency(ig_net.get_smoothed_mask(img, target_index, steps=100, std=0.15, samples=10))
54 | # Blur Integrated Gradients
55 | blur_ig_net = BlurIG(model)
56 | attribution_blur_ig = normalize_saliency(blur_ig_net.get_mask(img, target_index, steps=100))
57 | attribution_smooth_blur_ig = normalize_saliency(blur_ig_net.get_smoothed_mask(img, target_index, steps=100, std=0.15, samples=10))
58 | # Guided Integrated Gradients
59 | guided_ig_net = GuidedIG(model)
60 | attribution_guided_ig = normalize_saliency(guided_ig_net.get_mask(img, target_index, steps=100))
61 | attribution_smooth_guided_ig = normalize_saliency(guided_ig_net.get_smoothed_mask(img, target_index, steps=100, std=0.15, samples=10))
62 |
63 | # Visualize the results
64 | plt.figure(figsize=(16, 10))
65 | plt.subplot(2, 5, 1)
66 | plt.title('Input')
67 | plt.axis('off')
68 | plt.imshow(dog)
69 | plt.subplot(2, 5, 6)
70 | plt.title('Guided Backprop')
71 | visualize_single_saliency(attribution_guided_bp[0].unsqueeze(0))
72 | plt.subplot(2, 5, 2)
73 | plt.title('Vanilla Gradient')
74 | visualize_single_saliency(attribution_gradients[0].unsqueeze(0))
75 | plt.subplot(2, 5, 7)
76 | plt.title('Smoothed Vanilla Gradient')
77 | visualize_single_saliency(attribution_smooth_gradients[0].unsqueeze(0))
78 | plt.subplot(2, 5, 3)
79 | plt.title('Integrated Gradients')
80 | visualize_single_saliency(attribution_ig[0].unsqueeze(0))
81 | plt.subplot(2, 5, 8)
82 | plt.title('Smoothed Integrated Gradients')
83 | visualize_single_saliency(attribution_smooth_ig[0].unsqueeze(0))
84 | plt.subplot(2, 5, 4)
85 | plt.title('Blur IG')
86 | visualize_single_saliency(attribution_blur_ig[0].unsqueeze(0))
87 | plt.subplot(2, 5, 9)
88 | plt.title('Smoothed Blur IG')
89 | visualize_single_saliency(attribution_smooth_blur_ig[0].unsqueeze(0))
90 | plt.subplot(2, 5, 5)
91 | plt.title('Guided IG')
92 | visualize_single_saliency(attribution_guided_ig[0].unsqueeze(0))
93 | plt.subplot(2, 5, 10)
94 | plt.title('Smoothed Guided IG')
95 | visualize_single_saliency(attribution_smooth_guided_ig[0].unsqueeze(0))
96 | plt.tight_layout()
97 | plt.savefig('examples/gradients_visualization.png', bbox_inches='tight', pad_inches=0.5)
--------------------------------------------------------------------------------
/gradients_visualization_for_transformers_examples.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import matplotlib.pyplot as plt
3 | import torch
4 | import timm
5 | from timm.data import resolve_model_data_config
6 | from timm.data.transforms_factory import create_transform
7 | import requests
8 |
9 |
10 | from attribution import VanillaGradient, IntegratedGradients, BlurIG, GuidedIG, GuidedBackProp
11 | from attribution.utils import normalize_saliency, visualize_single_saliency
12 |
13 |
14 | if __name__ == '__main__':
15 |
16 | # Load imagenet labels
17 | IMAGENET_1k_URL = 'https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt'
18 | IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n')
19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20 |
21 | # Load a pretrained model
22 | model = timm.create_model('vit_tiny_patch16_224.augreg_in21k_ft_in1k', pretrained=True)
23 | model = model.to(device)
24 | model.eval()
25 | config = resolve_model_data_config(model, None)
26 | transform = create_transform(**config)
27 |
28 | # Load an image
29 | dog = Image.open('examples/dog.png').convert('RGB')
30 | dog_tensor = transform(dog).unsqueeze(0)
31 | H, W = dog_tensor.shape[-2:]
32 |
33 | # Predict the image
34 | img = transform(dog).unsqueeze(0)
35 | img = torch.cat([img, img])
36 | img = img.to(device)
37 | output = model(img)
38 | target_index = torch.argmax(output, dim=1).cpu()
39 | print('Predicted:', IMAGENET_1k_LABELS[target_index[0].item()])
40 |
41 | # Vanilla Gradient
42 | gradient_net = VanillaGradient(model)
43 | attribution_gradients = normalize_saliency(gradient_net.get_mask(img, target_index))
44 | attribution_smooth_gradients = normalize_saliency(gradient_net.get_smoothed_mask(img, target_index, samples=10, std=0.1))
45 |
46 | # Guided Backpropagation
47 | guided_bp_net = GuidedBackProp(model)
48 | attribution_guided_bp = normalize_saliency(guided_bp_net.get_mask(img, target_index))
49 |
50 | # Integrated Gradients
51 | ig_net = IntegratedGradients(model)
52 | attribution_ig = normalize_saliency(ig_net.get_mask(img, target_index, steps=100))
53 | attribution_smooth_ig = normalize_saliency(ig_net.get_smoothed_mask(img, target_index, steps=100, std=0.15, samples=10))
54 | # Blur Integrated Gradients
55 | blur_ig_net = BlurIG(model)
56 | attribution_blur_ig = normalize_saliency(blur_ig_net.get_mask(img, target_index, steps=100))
57 | attribution_smooth_blur_ig = normalize_saliency(blur_ig_net.get_smoothed_mask(img, target_index, steps=100, std=0.15, samples=10))
58 | # Guided Integrated Gradients
59 | guided_ig_net = GuidedIG(model)
60 | attribution_guided_ig = normalize_saliency(guided_ig_net.get_mask(img, target_index, steps=100))
61 | attribution_smooth_guided_ig = normalize_saliency(guided_ig_net.get_smoothed_mask(img, target_index, steps=100, std=0.15, samples=10))
62 |
63 | # Visualize the results
64 | plt.figure(figsize=(16, 10))
65 | plt.subplot(2, 5, 1)
66 | plt.title('Input')
67 | plt.axis('off')
68 | plt.imshow(dog)
69 | plt.subplot(2, 5, 6)
70 | plt.title('Guided Backprop')
71 | visualize_single_saliency(attribution_guided_bp[0].unsqueeze(0))
72 | plt.subplot(2, 5, 2)
73 | plt.title('Vanilla Gradient')
74 | visualize_single_saliency(attribution_gradients[0].unsqueeze(0))
75 | plt.subplot(2, 5, 7)
76 | plt.title('Smoothed Vanilla Gradient')
77 | visualize_single_saliency(attribution_smooth_gradients[0].unsqueeze(0))
78 | plt.subplot(2, 5, 3)
79 | plt.title('Integrated Gradients')
80 | visualize_single_saliency(attribution_ig[0].unsqueeze(0))
81 | plt.subplot(2, 5, 8)
82 | plt.title('Smoothed Integrated Gradients')
83 | visualize_single_saliency(attribution_smooth_ig[0].unsqueeze(0))
84 | plt.subplot(2, 5, 4)
85 | plt.title('Blur IG')
86 | visualize_single_saliency(attribution_blur_ig[0].unsqueeze(0))
87 | plt.subplot(2, 5, 9)
88 | plt.title('Smoothed Blur IG')
89 | visualize_single_saliency(attribution_smooth_blur_ig[0].unsqueeze(0))
90 | plt.subplot(2, 5, 5)
91 | plt.title('Guided IG')
92 | visualize_single_saliency(attribution_guided_ig[0].unsqueeze(0))
93 | plt.subplot(2, 5, 10)
94 | plt.title('Smoothed Guided IG')
95 | visualize_single_saliency(attribution_smooth_guided_ig[0].unsqueeze(0))
96 | plt.tight_layout()
97 | plt.savefig('examples/gradients_visualization_for_transformers.png', bbox_inches='tight', pad_inches=0.5)
--------------------------------------------------------------------------------
/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .keep_remove_mask import kpm, kam, ram, knm, keep_remove_metrics
2 | from .correlation_value import class_specific_spearman_rank_corr
3 | from .insert_and_delete import Insert_Delete_Metric, insert_score, delete_score
--------------------------------------------------------------------------------
/metrics/correlation_value.py:
--------------------------------------------------------------------------------
1 | from torchmetrics.functional.regression import spearman_corrcoef
2 | import torch
3 |
4 | @torch.no_grad()
5 | def spearman_rank_correlation_value(att1:torch.Tensor, att2: torch.Tensor):
6 | assert att1.size() == att2.size(), 'Attribution maps must have the same size'
7 | att1 = att1.view(att1.size(0), -1)
8 | att2 = att2.view(att2.size(0), -1)
9 | corr = spearman_corrcoef(att1.T, att2.T)
10 | return corr.cpu()
11 |
12 | def class_specific_spearman_rank_corr(att_target, att_other_class_list):
13 | '''
14 | Return:
15 | corr: torch.Tensor, shape (Batch_size,)
16 | '''
17 | corr = torch.zeros(att_target.size(0))
18 | for att_other_class in att_other_class_list:
19 | class_corr = spearman_rank_correlation_value(att_target, att_other_class)
20 | corr += class_corr
21 | return corr / len(att_other_class_list)
22 |
23 | if __name__ == '__main__':
24 | device = torch.device('cuda')
25 | att1 = torch.rand(32, 1, 224, 224).to(device)
26 | att2 = -att1.clone().to(device)
27 | att3 = att1.clone().to(device)
28 | att4 = torch.rand(32, 1, 224, 224).to(device)
29 | att4[1] = -att1[1]
30 |
31 | att_other_class_list = [att2, att3, att4]
32 | corr = class_specific_spearman_rank_corr(att1, att_other_class_list)
33 | print('\n', corr)
--------------------------------------------------------------------------------
/metrics/insert_and_delete.py:
--------------------------------------------------------------------------------
1 | '''
2 | adapted from https://github.com/LMBTough/MFABA
3 | @article{zhu2023mfaba,
4 | title={MFABA: A More Faithful and Accelerated Boundary-based Attribution Method for Deep Neural Networks},
5 | author={Zhu, Zhiyu and Chen, Huaming and Zhang, Jiayu and Wang, Xinyi and Jin, Zhibo and Xue, Minhui and Zhu, Dongxiao and Choo, Kim-Kwang Raymond},
6 | journal={arXiv preprint arXiv:2312.13630},
7 | year={2023}
8 | }
9 | '''
10 |
11 | import torch
12 | import numpy as np
13 |
14 | @torch.no_grad()
15 | def insert_score(model, step, images, explanations):
16 | model.eval()
17 | B, C, H, W = images.size()
18 | if explanations.size(1) == 1:
19 | explanations = explanations.expand(-1, C, -1, -1)
20 | predictions = model(images)
21 | top, c = torch.max(predictions, -1)
22 | n_steps = (H * W + step - 1) // step
23 | scores = np.empty((B, n_steps + 1))
24 | salient_order = explanations.view(B, C, H * W).argsort(descending=True)
25 |
26 | start = torch.zeros_like(images, device=images.device)
27 | finish = images.clone()
28 | finish = finish.view(B, C, H * W)
29 |
30 | for i in range(n_steps + 1):
31 | pred = model(start)
32 | pred = torch.nn.functional.softmax(pred, dim=-1)
33 | scores[:, i] = pred[torch.arange(B), c].cpu().numpy()
34 | if i < n_steps:
35 | coords = salient_order[:, :, i * step:(i + 1) * step]
36 | # change the value of the pixels according to the coords
37 | start = start.view(B, C, H * W)
38 | start.scatter_(dim=2, index=coords, src=torch.gather(finish, dim=2, index=coords))
39 | start = start.view(B, C, H, W)
40 |
41 | scores = np.sum(scores, axis=0)
42 | xs = np.linspace(0, 1, scores.shape[0])
43 | auc = np.trapz(scores, dx=xs[1] - xs[0]) / B
44 |
45 | return auc, scores
46 |
47 | @torch.no_grad()
48 | def delete_score(model, step, images, explanations):
49 | model.eval()
50 | B, C, H, W = images.size()
51 | if explanations.size(1) == 1:
52 | explanations = explanations.expand(-1, C, -1, -1)
53 | predictions = model(images)
54 | top, c = torch.max(predictions, -1)
55 | n_steps = (H * W + step - 1) // step
56 | scores = np.empty((B, n_steps + 1))
57 | salient_order = explanations.view(B, C, H * W).argsort(descending=True)
58 |
59 | start = images.clone()
60 | finish = torch.zeros_like(images, device=images.device)
61 | finish = finish.view(B, C, H * W)
62 |
63 | for i in range(n_steps + 1):
64 | pred = model(start)
65 | pred = torch.nn.functional.softmax(pred, dim=-1)
66 | scores[:, i] = pred[torch.arange(B), c].cpu().numpy()
67 | if i < n_steps:
68 | coords = salient_order[:, :, i * step:(i + 1) * step]
69 | # change the value of the pixels according to the coords
70 | start = start.view(B, C, H * W)
71 | start.scatter_(dim=2, index=coords, src=torch.gather(finish, dim=2, index=coords))
72 | start = start.view(B, C, H, W)
73 |
74 | scores = np.sum(scores, axis=0)
75 | xs = np.linspace(0, 1, scores.shape[0])
76 | auc = np.trapz(scores, dx=xs[1] - xs[0]) / B
77 | return auc, scores
78 |
79 | class Insert_Delete_Metric():
80 |
81 | def __init__(self, model, mode, step):
82 | r"""Create deletion/insertion metric instance.
83 | Args:
84 | model (nn.Module): Black-box model being explained.
85 | mode (str): 'del' or 'ins'.
86 | step (int): number of pixels modified per one iteration.
87 | baseline (str): 'black' or 'white'.
88 | """
89 | assert mode in ['del', 'ins']
90 | self.model = model
91 | self.mode = mode
92 | self.step = step
93 | self.substrate_fn = lambda x: torch.zeros_like(x, device=x.device)
94 |
95 | def single_run(self, image, explanation: torch.Tensor):
96 | ''' only one image
97 | image: torch.Tensor, 1, C, H, W
98 | explanation: torch.Tensor, 1, 1, H, W or 1, C, H, W
99 | '''
100 | _, C, H, W = image.size()
101 | if explanation.size(1) == 1:
102 | explanation = explanation.expand(-1, C, -1, -1)
103 | pred = self.model(image)
104 | top, c = torch.max(pred, 1)
105 | n_steps = (H * W + self.step - 1) // self.step
106 |
107 | if self.mode == 'del':
108 | start = image.clone()
109 | finish = self.substrate_fn(image)
110 | else:
111 | start = self.substrate_fn(image)
112 | finish = image.clone()
113 | finish = finish.view(-1, C, H * W)
114 |
115 | scores = np.empty(n_steps + 1)
116 | salient_order = explanation.view(-1, H * W).argsort(descending=True)
117 | for i in range(n_steps+1):
118 | pred = self.model(start)
119 | pred = torch.nn.functional.softmax(pred, dim=-1)
120 | scores[i] = pred[0, c]
121 | if i < n_steps:
122 | coords = salient_order[:, i * self.step:(i + 1) * self.step]
123 | start = start.view(-1, C, H * W)
124 | start[:, :, coords] = finish[:, :, coords]
125 | start = start.view(-1, C, H, W)
126 | auc = np.trapz(scores, dx=self.step)
127 | return auc, scores
128 |
129 | def batch_run(self, images, explanations):
130 | B, C, H, W = images.size()
131 | if explanations.size(1) == 1:
132 | explanations = explanations.expand(-1, C, -1, -1)
133 | predictions = self.model(images)
134 | top, c = torch.max(predictions, -1)
135 | n_steps = (H * W + self.step - 1) // self.step
136 | scores = np.empty((B, n_steps + 1))
137 | salient_order = explanations.view(B, C, H * W).argsort(descending=True)
138 | if self.mode == 'del':
139 | start = images.clone()
140 | finish = self.substrate_fn(images)
141 | else:
142 | start = self.substrate_fn(images)
143 | finish = images.clone()
144 |
145 | for i in range(n_steps + 1):
146 | pred = self.model(start)
147 | pred = torch.nn.functional.softmax(pred, dim=-1)
148 | scores[:, i] = pred[torch.arange(B), c]
149 | if i < n_steps:
150 | coords = salient_order[:, :, i * self.step:(i + 1) * self.step]
151 | start = start.view(B, C, H * W)
152 | start[torch.arange(B).unsqueeze(1), :, coords] = finish[torch.arange(B).unsqueeze(1), :, coords]
153 | start = start.view(B, C, H, W)
154 |
155 | auc = np.trapz(scores, dx=self.step, axis=1)
156 |
157 | return auc, scores
--------------------------------------------------------------------------------
/metrics/keep_remove_mask.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.transforms import GaussianBlur
3 | import numpy as np
4 |
5 | import json
6 | import numpy as np
7 | #import shap
8 | #import shap.benchmark as benchmark
9 | import scipy as sp
10 | import math
11 | from collections import OrderedDict
12 | import numbers
13 | import torch
14 | from torch import nn
15 | from torch.nn import functional as F
16 | import torchvision.transforms as transforms
17 | import torchvision.datasets as datasets
18 | import matplotlib.pyplot as plt
19 | import matplotlib.image as mpimg
20 | import itertools
21 | from tqdm import tqdm
22 |
23 |
24 |
25 | class GaussianSmoothing(nn.Module):
26 | """
27 | Apply gaussian smoothing on a
28 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel
29 | in the input using a depthwise convolution.
30 | Arguments:
31 | channels (int, sequence): Number of channels of the input tensors. Output will
32 | have this number of channels as well.
33 | kernel_size (int, sequence): Size of the gaussian kernel.
34 | sigma (float, sequence): Standard deviation of the gaussian kernel.
35 | dim (int, optional): The number of dimensions of the data.
36 | Default value is 2 (spatial).
37 | """
38 | def __init__(self, device, channels, kernel_size, sigma, dim=2):
39 | super(GaussianSmoothing, self).__init__()
40 | if isinstance(kernel_size, numbers.Number):
41 | kernel_size = [kernel_size] * dim
42 | if isinstance(sigma, numbers.Number):
43 | sigma = [sigma] * dim
44 |
45 | # The gaussian kernel is the product of the
46 | # gaussian function of each dimension.
47 | kernel = 1
48 | meshgrids = torch.meshgrid(
49 | [
50 | torch.arange(size, dtype=torch.float32)
51 | for size in kernel_size
52 | ]
53 | )
54 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
55 | mean = (size - 1) / 2
56 | kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
57 | torch.exp(-((mgrid - mean) / (2 * std)) ** 2)
58 |
59 | # Make sure sum of values in gaussian kernel equals 1.
60 | kernel = kernel / torch.sum(kernel)
61 |
62 | # Reshape to depthwise convolutional weight
63 | kernel = kernel.view(1, 1, *kernel.size())
64 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
65 |
66 | self.register_buffer('weight', kernel)
67 | self.weight = self.weight.to(device)
68 | self.groups = channels
69 |
70 | if dim == 1:
71 | self.conv = F.conv1d
72 | elif dim == 2:
73 | self.conv = F.conv2d
74 | elif dim == 3:
75 | self.conv = F.conv3d
76 | else:
77 | raise RuntimeError(
78 | 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
79 | )
80 |
81 | def forward(self, input):
82 | """
83 | Apply gaussian filter to input.
84 | Arguments:
85 | input (torch.Tensor): Input to apply gaussian filter on.
86 | Returns:
87 | filtered (torch.Tensor): Filtered output.
88 | """
89 | return self.conv(input, weight=self.weight, groups=self.groups)
90 |
91 | @torch.no_grad()
92 | def attribution_to_mask(attribution, percent_unmasked, sort_order, perturbation, device):
93 |
94 | attribution = attribution.clone().detach()
95 | attribution = attribution.mean(dim=1)
96 | attribution = attribution.unsqueeze(1)
97 |
98 | zeros = torch.zeros(attribution.shape).to(device)
99 |
100 | #invert attribution for negative case
101 | if sort_order == 'negative' or sort_order == 'negative_target' or sort_order == 'negative_topk':
102 | attribution = -attribution
103 |
104 | if sort_order == 'absolute':
105 | attribution = torch.abs(attribution)
106 |
107 | #for positive and negative the negative and positive values are always masked ond not considered for the topk
108 | positives = torch.maximum(attribution, zeros)
109 | nb_positives = torch.count_nonzero(positives)
110 |
111 | orig_shape = positives.size()
112 | positives = positives.view(positives.size(0), 1, -1)
113 | nb_pixels = positives.size(2)
114 |
115 | if perturbation == 'keep':
116 | # find features to keep
117 | ret = torch.topk(positives, k=int(torch.minimum(torch.tensor(percent_unmasked*nb_pixels).to(device), nb_positives)), dim=2)
118 |
119 | if perturbation == 'remove':
120 | #set zeros to large value
121 | positives_wo_zero = positives.clone()
122 | positives_wo_zero[positives_wo_zero == 0.] = float("Inf")
123 | # find features to keep
124 | ret = torch.topk(positives_wo_zero, k=int(torch.minimum(torch.tensor(percent_unmasked*nb_pixels).to(device), nb_positives)), dim=2, largest=False)
125 | ret.indices.shape
126 | # Scatter to zero'd tensor
127 | res = torch.zeros_like(positives)
128 | res.scatter_(2, ret.indices, ret.values)
129 | res = res.view(*orig_shape)
130 |
131 | res = (res == 0).float() # set topk values to zero and all zeros to one
132 | res = res.repeat(1,3,1,1)
133 | return res
134 |
135 | @torch.no_grad()
136 | def keep_remove_metrics(sort_order, perturbation, model, images, labels, attribution, device, step_num=16):
137 | options = [(sort_order, perturbation)]
138 | vals = {}
139 |
140 | smoothing = GaussianSmoothing(device, 3, 51, 41)
141 | percent_unmasked_range = np.geomspace(0.01, 1.0, num=step_num)
142 |
143 | for percent_unmasked in percent_unmasked_range:
144 | vals[perturbation + "_" + sort_order + "_" + str(percent_unmasked)] = 0.0
145 |
146 | masks = []
147 | for sort_order, perturbation in options:
148 | for percent_unmasked in percent_unmasked_range:
149 | #create masked images
150 | for sample in range(attribution.shape[0]):
151 | mask = attribution_to_mask(attribution[sample].unsqueeze(0), percent_unmasked, sort_order, perturbation, device)
152 | masks.append(mask)
153 |
154 | mask = torch.cat(masks, dim=0)
155 |
156 | images_masked_pt = images.clone().repeat(int(mask.shape[0]/images.size(0)), 1, 1, 1)
157 | images_smoothed_pt = images.clone().repeat(int(mask.shape[0]/images.size(0)), 1, 1, 1)
158 | images_smoothed_pt = F.pad(images_smoothed_pt, (25,25,25,25), mode='reflect')
159 | images_smoothed_pt = smoothing(images_smoothed_pt)
160 | images_masked_pt[mask.bool()] = images_smoothed_pt[mask.bool()]
161 |
162 | #images_masked = normalize(torch.tensor(images_masked_np / 255.).unsqueeze(0).permute(0,3,1,2))
163 | images_masked = images_masked_pt
164 | images_masked = images_masked.to(device)
165 | out_masked = model(images_masked)
166 | out_masked = out_masked.softmax(dim=-1)
167 |
168 | #split out_masked in the chunks that correspond to the individual run
169 | option_runs = torch.split(out_masked, int(out_masked.shape[0]/len(options)))
170 | for o, (sort_order, perturbation) in enumerate(options):
171 | option_run = option_runs[o] # N, 1000
172 | percent_unmasked_runs = torch.split(option_run, int(option_run.shape[0]/len(percent_unmasked_range))) # N, 1000
173 | for p, percent_unmasked in enumerate(percent_unmasked_range):
174 | percent_unmasked_run = percent_unmasked_runs[p] # N, 1000
175 | #if len(percent_unmasked_run.shape) == 1:
176 | # percent_unmasked_run = percent_unmasked_run.unsqueeze(0)
177 |
178 | if sort_order == 'positive':
179 | vals[perturbation + "_" + sort_order + "_" + str(percent_unmasked)] += torch.gather(percent_unmasked_run, 1, labels.unsqueeze(-1)).sum().cpu().item()
180 | if sort_order == 'negative':
181 | vals[perturbation + "_" + sort_order + "_" + str(percent_unmasked)] += torch.gather(percent_unmasked_run, 1, labels.unsqueeze(-1)).sum().cpu().item()
182 | if sort_order == 'absolute':
183 | correct = (torch.max(percent_unmasked_run, 1)[1] == labels).float().sum().cpu().item()
184 | vals[perturbation + "_" + sort_order + "_" + str(percent_unmasked)] += correct
185 |
186 | for sort_order, perturbation in options:
187 | for percent_unmasked in percent_unmasked_range:
188 | vals[perturbation + "_" + sort_order + "_" + str(percent_unmasked)] /= images.shape[0]
189 |
190 | for sort_order, perturbation in options:
191 | xs = []
192 | ys = []
193 | for percent_unmasked in percent_unmasked_range:
194 | xs.append(percent_unmasked)
195 | ys.append(vals[perturbation + "_" + sort_order + "_" + str(percent_unmasked)])
196 | auc = np.trapz(ys, xs)
197 | xs = np.array(xs)
198 | ys = np.array(ys)
199 | return auc, vals, xs, ys
200 |
201 |
202 | def kpm(original_input, attribution, target_label, forward, step=0.05):
203 | """ kpm: keep positive mask,
204 | increase the numbers of input pixels that attribution is positive from high attribution to low attribution, other pixels are masked using the average value of masked pixels
205 | record acc score of the target label of each step, finally calculate the area under the curve
206 | the output acc is expected to increase from 0 to the acc of the original input
207 | original_input: batch original input, B*C*H*W
208 | attribution: batch attribution, B*1*H*W
209 | mask: batch target mask, B*1*H*W
210 | target_label: batch target label, B
211 | forward: forward function, forward(masked_input, target_label), return mean target score of the batch and accuracy of the batch
212 | step: step of masking attributed pixels, percentage of the input features of original input
213 | """
214 | B, C, H, W = original_input.size()
215 | # get sampling interval
216 | max_attribution = attribution.max()
217 | min_attribution = 0. # attribution.min()
218 | interval = (max_attribution - min_attribution) * step
219 | N = int(1 / step) + 1
220 |
221 | y_axis = []
222 |
223 | # get blurred input
224 | # blurred_original_input = GaussianBlur(51, 41)(original_input)
225 | blurred_original_input = torch.zeros_like(original_input, device=original_input.device)
226 |
227 | for i in range(N):
228 | # get mask regions
229 | mask_region = attribution < max_attribution - i * interval
230 | masked_input = original_input.clone()
231 | # use blurred_original_input to mask the input
232 | masked_input = torch.where(~mask_region, masked_input, blurred_original_input)
233 |
234 | # mean target score of the batch
235 | score, acc = forward(masked_input, target_label)
236 | y_axis.append(score)
237 |
238 | # plot and save the curve
239 | x_axis = np.arange(0, 1+step, step)
240 | y_axis = np.array(y_axis)
241 | auc = np.trapz(y_axis, x_axis)
242 |
243 | # plot
244 | # import matplotlib.pyplot as plt
245 | # plt.plot(x_axis, y_axis)
246 |
247 | return auc, x_axis, y_axis
248 |
249 |
250 | def knm(original_input, attribution, target_label, forward, step=0.05):
251 | """ knm: keep negative mask
252 | works like kpm, but we keep the most low attribution pixels and mask other pixels,
253 | and increase the numbers of input pixels that attribution is negative from low attribution to high attribution, other pixels are masked using the average value of masked pixels
254 | """
255 |
256 | B, C, H, W = original_input.size()
257 | # get sampling interval
258 | max_attribution = 0. # attribution.max()
259 | min_attribution = attribution.min()
260 | interval = (max_attribution - min_attribution) * step
261 | N = int(1 / step) + 1
262 |
263 | y_axis = []
264 | # get blurred input
265 | blurred_original_input = GaussianBlur(51, 41)(original_input)
266 | for i in range(N):
267 | # get mask regions
268 | mask_region = attribution > min_attribution + i * interval
269 | masked_input = original_input.clone()
270 | # use blurred_original_input to mask the input
271 | masked_input = torch.where(~mask_region, masked_input, blurred_original_input)
272 |
273 | # mean target score of the batch
274 | score, acc = forward(masked_input, target_label)
275 | y_axis.append(score)
276 |
277 | x_axis = np.arange(0, 1+step, step)
278 | y_axis = np.array(y_axis)
279 | auc = np.trapz(y_axis, x_axis)
280 |
281 | return auc, x_axis, y_axis
282 |
283 |
284 | def kam(original_input, attribution, target_label, forward, step=0.05):
285 | """ kam: keep absolute mask
286 | works like kpm, but we now observe the accuracy of the predictions instead of the prediction score of target label
287 | also use absolute attritbiton map instead of positive attribution map
288 | """
289 |
290 | B, C, H, W = original_input.size()
291 | attribution = attribution.abs()
292 | # get sampling interval
293 | max_attribution = attribution.max()
294 | min_attribution = attribution.min()
295 | interval = (max_attribution - min_attribution) * step
296 | N = int(1 / step) + 1
297 |
298 | y_axis = []
299 | blurred_original_input = GaussianBlur(51, 41)(original_input)
300 | for i in range(N):
301 | # get mask regions
302 | mask_region = attribution < max_attribution - i * interval
303 | masked_input = original_input.clone()
304 | # use blurred_original_input to mask the input
305 | masked_input = torch.where(~mask_region, masked_input, blurred_original_input)
306 |
307 | # mean acc of the batch
308 | score, acc = forward(masked_input, target_label)
309 | y_axis.append(acc)
310 |
311 | x_axis = np.arange(0, 1+step, step)
312 | y_axis = np.array(y_axis)
313 | auc = np.trapz(y_axis, x_axis)
314 |
315 | return auc, x_axis, y_axis
316 |
317 |
318 | def ram(original_input, attribution, target_label, forward, step=0.05):
319 | """ ram: remove absolute mask
320 | works like knm, but we now observe the accuracy of the predictions instead of the prediction score of target label
321 | also use absolute attritbiton map instead of positive attribution map
322 | """
323 |
324 | B, C, H, W = original_input.size()
325 | attribution = attribution.abs()
326 | # get sampling interval
327 | max_attribution = attribution.max()
328 | min_attribution = attribution.min()
329 | interval = (max_attribution - min_attribution) * step
330 | N = int(1 / step) + 1
331 |
332 | y_axis = []
333 | blurred_original_input = GaussianBlur(51, 41)(original_input)
334 | for i in range(N):
335 | # get mask regions
336 | mask_region = attribution > max_attribution - i * interval
337 | masked_input = original_input.clone()
338 | # use blurred_original_input to mask the input
339 | masked_input = torch.where(~mask_region, masked_input, blurred_original_input)
340 |
341 | # mean acc of the batch
342 | score, acc = forward(masked_input, target_label)
343 | y_axis.append(acc)
344 |
345 | x_axis = np.arange(0, 1+step, step)
346 | y_axis = np.array(y_axis)
347 | auc = np.trapz(y_axis, x_axis)
348 |
349 | return auc, x_axis, y_axis
350 |
--------------------------------------------------------------------------------
/perturbation_based_attribution_visualization_examples.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import matplotlib.pyplot as plt
3 | import torch
4 | import timm
5 | from timm.data import resolve_model_data_config
6 | from timm.data.transforms_factory import create_transform
7 | import requests
8 |
9 |
10 | from attribution import Occlusion
11 | from attribution.utils import normalize_saliency, visualize_single_saliency
12 |
13 |
14 | if __name__ == '__main__':
15 |
16 | # Load imagenet labels
17 | IMAGENET_1k_URL = 'https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt'
18 | IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n')
19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20 |
21 | # Load a pretrained model
22 | model = timm.create_model('resnet50', pretrained=True)
23 | model = model.to(device)
24 | model.eval()
25 | config = resolve_model_data_config(model, None)
26 | transform = create_transform(**config)
27 |
28 | # Load an image
29 | dog = Image.open('examples/dog.png').convert('RGB')
30 | dog_tensor = transform(dog).unsqueeze(0)
31 | H, W = dog_tensor.shape[-2:]
32 |
33 | # Predict the image
34 | img = transform(dog).unsqueeze(0)
35 | img = torch.cat([img, img])
36 | img = img.to(device)
37 | output = model(img)
38 | target_index = torch.argmax(output, dim=1).cpu()
39 | print('Predicted:', IMAGENET_1k_LABELS[target_index[0].item()])
40 |
41 | # Occlusion
42 | occlusion_net = Occlusion(model)
43 | occlusion = normalize_saliency(occlusion_net.get_mask(img, target_index))
44 | occlusion_2 = normalize_saliency(occlusion_net.get_mask(img, target_index, size=30))
45 |
46 | # Visualize the results
47 | plt.figure(figsize=(16, 5))
48 | plt.subplot(1, 3, 1)
49 | plt.title('Input')
50 | plt.axis('off')
51 | plt.imshow(dog)
52 | plt.subplot(1, 3, 2)
53 | plt.title('Occlusion window-15')
54 | visualize_single_saliency(occlusion[0].unsqueeze(0))
55 | plt.subplot(1, 3, 3)
56 | plt.title('Occlusion window-30')
57 | visualize_single_saliency(occlusion_2[0].unsqueeze(0))
58 | plt.tight_layout()
59 | plt.savefig('examples/perturbation_based_visualization.png', bbox_inches='tight', pad_inches=0.5)
--------------------------------------------------------------------------------
/quick_start.py:
--------------------------------------------------------------------------------
1 | from matplotlib import pyplot as plt
2 | from PIL import Image
3 | import requests
4 | import timm
5 | from timm.data import resolve_model_data_config
6 | from timm.data.transforms_factory import create_transform
7 | import torch
8 |
9 | from attribution import BlurIG, GradCAM, CombinedWrapper
10 | from attribution.utils import normalize_saliency, visualize_single_saliency
11 |
12 | # Load imagenet labels
13 | IMAGENET_1k_URL = 'https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt'
14 | IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n')
15 |
16 | # Load model
17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18 | model = timm.create_model('resnet50', pretrained=True)
19 | model = model.to(device)
20 | model.eval()
21 | config = resolve_model_data_config(model, None)
22 | transform = create_transform(**config)
23 |
24 | # Load image
25 | dog = Image.open('examples/dog.png').convert('RGB')
26 | dog_tensor = transform(dog).unsqueeze(0)
27 | H, W = dog_tensor.shape[-2:]
28 | img = transform(dog).unsqueeze(0)
29 |
30 | # We support batch input
31 | img = torch.cat([img, img])
32 | img = img.to(device)
33 | output = model(img)
34 | target_index = torch.argmax(output, dim=1).cpu()
35 | print('Predicted:', IMAGENET_1k_LABELS[target_index[0].item()])
36 |
37 | # Gradients visualization
38 | blur_ig_kwargs = {'steps': 100,
39 | 'batch_size': 4,
40 | 'max_sigma': 50,
41 | 'grad_step': 0.01,
42 | 'sqrt': False}
43 | blur_ig_net = BlurIG(model)
44 | blur_ig = normalize_saliency(blur_ig_net.get_mask(img, target_index, **blur_ig_kwargs))
45 |
46 | # CAM visualization
47 | gradcam_net = GradCAM(model)
48 | gradcam = normalize_saliency(
49 | gradcam_net.get_mask(img, target_index, target_layer='layer3'))
50 |
51 | # Combine Gradients and CAM visualization
52 | combined = CombinedWrapper(model, BlurIG, GradCAM)
53 | combined_saliency = normalize_saliency(
54 | combined.get_mask(img, target_index, target_layer='layer3', **blur_ig_kwargs))
55 |
56 | # Visualize
57 | plt.figure(figsize=(16, 5))
58 | plt.subplot(1, 4, 1)
59 | plt.imshow(dog)
60 | plt.title('Input Image')
61 | plt.axis('off')
62 | plt.subplot(1, 4, 2)
63 | visualize_single_saliency(blur_ig[0].unsqueeze(0))
64 | plt.title('Blur IG')
65 | plt.subplot(1, 4, 3)
66 | visualize_single_saliency(gradcam[0].unsqueeze(0))
67 | plt.title('GradCAM')
68 | plt.subplot(1, 4, 4)
69 | visualize_single_saliency(combined_saliency[0].unsqueeze(0))
70 | plt.title('Combined')
71 | plt.tight_layout()
72 | plt.savefig('examples/quick_start.png', bbox_inches='tight', pad_inches=0.5)
--------------------------------------------------------------------------------