├── README.md
├── assets
└── readme.png
├── example.ipynb
├── gradcam.py
├── images
├── collies.JPG
├── multiple_dogs.jpg
├── snake.JPEG
└── water-bird.JPEG
├── outputs
├── collies.JPG
├── multiple_dogs.jpg
├── snake.JPEG
└── water-bird.JPEG
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | ### A Simple pytorch implementation of GradCAM[1], and GradCAM++[2]
2 |
3 |
4 |
5 |
6 |
7 | ### Supported torchvision models
8 | - alexnet
9 | - vgg
10 | - resnet
11 | - densenet
12 | - squeezenet
13 |
14 | ### Usage
15 | please refer to `example.ipynb` for general usage and refer to documentations of each layer-finding functions in `utils.py` if you want to know how to set `target_layer_name` properly.
16 |
17 | ### References:
18 | [1] Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization, Selvaraju et al, ICCV, 2017
19 | [2] Grad-CAM++: Generalized Gradient-based Visual Explanations for Deep Convolutional Networks, Chattopadhyay et al, WACV, 2018
20 |
--------------------------------------------------------------------------------
/assets/readme.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1Konny/gradcam_plus_plus-pytorch/07fd6ece5010f7c1c9fbcc8155a60023819111d7/assets/readme.png
--------------------------------------------------------------------------------
/gradcam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from utils import find_alexnet_layer, find_vgg_layer, find_resnet_layer, find_densenet_layer, find_squeezenet_layer
5 |
6 |
7 | class GradCAM(object):
8 | """Calculate GradCAM salinecy map.
9 |
10 | A simple example:
11 |
12 | # initialize a model, model_dict and gradcam
13 | resnet = torchvision.models.resnet101(pretrained=True)
14 | resnet.eval()
15 | model_dict = dict(model_type='resnet', arch=resnet, layer_name='layer4', input_size=(224, 224))
16 | gradcam = GradCAM(model_dict)
17 |
18 | # get an image and normalize with mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
19 | img = load_img()
20 | normed_img = normalizer(img)
21 |
22 | # get a GradCAM saliency map on the class index 10.
23 | mask, logit = gradcam(normed_img, class_idx=10)
24 |
25 | # make heatmap from mask and synthesize saliency map using heatmap and img
26 | heatmap, cam_result = visualize_cam(mask, img)
27 |
28 |
29 | Args:
30 | model_dict (dict): a dictionary that contains 'model_type', 'arch', layer_name', 'input_size'(optional) as keys.
31 | verbose (bool): whether to print output size of the saliency map givien 'layer_name' and 'input_size' in model_dict.
32 | """
33 | def __init__(self, model_dict, verbose=False):
34 | model_type = model_dict['type']
35 | layer_name = model_dict['layer_name']
36 | self.model_arch = model_dict['arch']
37 |
38 | self.gradients = dict()
39 | self.activations = dict()
40 | def backward_hook(module, grad_input, grad_output):
41 | self.gradients['value'] = grad_output[0]
42 | return None
43 | def forward_hook(module, input, output):
44 | self.activations['value'] = output
45 | return None
46 |
47 | if 'vgg' in model_type.lower():
48 | target_layer = find_vgg_layer(self.model_arch, layer_name)
49 | elif 'resnet' in model_type.lower():
50 | target_layer = find_resnet_layer(self.model_arch, layer_name)
51 | elif 'densenet' in model_type.lower():
52 | target_layer = find_densenet_layer(self.model_arch, layer_name)
53 | elif 'alexnet' in model_type.lower():
54 | target_layer = find_alexnet_layer(self.model_arch, layer_name)
55 | elif 'squeezenet' in model_type.lower():
56 | target_layer = find_squeezenet_layer(self.model_arch, layer_name)
57 |
58 | target_layer.register_forward_hook(forward_hook)
59 | target_layer.register_backward_hook(backward_hook)
60 |
61 | if verbose:
62 | try:
63 | input_size = model_dict['input_size']
64 | except KeyError:
65 | print("please specify size of input image in model_dict. e.g. {'input_size':(224, 224)}")
66 | pass
67 | else:
68 | device = 'cuda' if next(self.model_arch.parameters()).is_cuda else 'cpu'
69 | self.model_arch(torch.zeros(1, 3, *(input_size), device=device))
70 | print('saliency_map size :', self.activations['value'].shape[2:])
71 |
72 |
73 | def forward(self, input, class_idx=None, retain_graph=False):
74 | """
75 | Args:
76 | input: input image with shape of (1, 3, H, W)
77 | class_idx (int): class index for calculating GradCAM.
78 | If not specified, the class index that makes the highest model prediction score will be used.
79 | Return:
80 | mask: saliency map of the same spatial dimension with input
81 | logit: model output
82 | """
83 | b, c, h, w = input.size()
84 |
85 | logit = self.model_arch(input)
86 | if class_idx is None:
87 | score = logit[:, logit.max(1)[-1]].squeeze()
88 | else:
89 | score = logit[:, class_idx].squeeze()
90 |
91 | self.model_arch.zero_grad()
92 | score.backward(retain_graph=retain_graph)
93 | gradients = self.gradients['value']
94 | activations = self.activations['value']
95 | b, k, u, v = gradients.size()
96 |
97 | alpha = gradients.view(b, k, -1).mean(2)
98 | #alpha = F.relu(gradients.view(b, k, -1)).mean(2)
99 | weights = alpha.view(b, k, 1, 1)
100 |
101 | saliency_map = (weights*activations).sum(1, keepdim=True)
102 | saliency_map = F.relu(saliency_map)
103 | saliency_map = F.upsample(saliency_map, size=(h, w), mode='bilinear', align_corners=False)
104 | saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
105 | saliency_map = (saliency_map - saliency_map_min).div(saliency_map_max - saliency_map_min).data
106 |
107 | return saliency_map, logit
108 |
109 | def __call__(self, input, class_idx=None, retain_graph=False):
110 | return self.forward(input, class_idx, retain_graph)
111 |
112 |
113 | class GradCAMpp(GradCAM):
114 | """Calculate GradCAM++ salinecy map.
115 |
116 | A simple example:
117 |
118 | # initialize a model, model_dict and gradcampp
119 | resnet = torchvision.models.resnet101(pretrained=True)
120 | resnet.eval()
121 | model_dict = dict(model_type='resnet', arch=resnet, layer_name='layer4', input_size=(224, 224))
122 | gradcampp = GradCAMpp(model_dict)
123 |
124 | # get an image and normalize with mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
125 | img = load_img()
126 | normed_img = normalizer(img)
127 |
128 | # get a GradCAM saliency map on the class index 10.
129 | mask, logit = gradcampp(normed_img, class_idx=10)
130 |
131 | # make heatmap from mask and synthesize saliency map using heatmap and img
132 | heatmap, cam_result = visualize_cam(mask, img)
133 |
134 |
135 | Args:
136 | model_dict (dict): a dictionary that contains 'model_type', 'arch', layer_name', 'input_size'(optional) as keys.
137 | verbose (bool): whether to print output size of the saliency map givien 'layer_name' and 'input_size' in model_dict.
138 | """
139 | def __init__(self, model_dict, verbose=False):
140 | super(GradCAMpp, self).__init__(model_dict, verbose)
141 |
142 | def forward(self, input, class_idx=None, retain_graph=False):
143 | """
144 | Args:
145 | input: input image with shape of (1, 3, H, W)
146 | class_idx (int): class index for calculating GradCAM.
147 | If not specified, the class index that makes the highest model prediction score will be used.
148 | Return:
149 | mask: saliency map of the same spatial dimension with input
150 | logit: model output
151 | """
152 | b, c, h, w = input.size()
153 |
154 | logit = self.model_arch(input)
155 | if class_idx is None:
156 | score = logit[:, logit.max(1)[-1]].squeeze()
157 | else:
158 | score = logit[:, class_idx].squeeze()
159 |
160 | self.model_arch.zero_grad()
161 | score.backward(retain_graph=retain_graph)
162 | gradients = self.gradients['value'] # dS/dA
163 | activations = self.activations['value'] # A
164 | b, k, u, v = gradients.size()
165 |
166 | alpha_num = gradients.pow(2)
167 | alpha_denom = gradients.pow(2).mul(2) + \
168 | activations.mul(gradients.pow(3)).view(b, k, u*v).sum(-1, keepdim=True).view(b, k, 1, 1)
169 | alpha_denom = torch.where(alpha_denom != 0.0, alpha_denom, torch.ones_like(alpha_denom))
170 |
171 | alpha = alpha_num.div(alpha_denom+1e-7)
172 | positive_gradients = F.relu(score.exp()*gradients) # ReLU(dY/dA) == ReLU(exp(S)*dS/dA))
173 | weights = (alpha*positive_gradients).view(b, k, u*v).sum(-1).view(b, k, 1, 1)
174 |
175 | saliency_map = (weights*activations).sum(1, keepdim=True)
176 | saliency_map = F.relu(saliency_map)
177 | saliency_map = F.upsample(saliency_map, size=(224, 224), mode='bilinear', align_corners=False)
178 | saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
179 | saliency_map = (saliency_map-saliency_map_min).div(saliency_map_max-saliency_map_min).data
180 |
181 | return saliency_map, logit
182 |
--------------------------------------------------------------------------------
/images/collies.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1Konny/gradcam_plus_plus-pytorch/07fd6ece5010f7c1c9fbcc8155a60023819111d7/images/collies.JPG
--------------------------------------------------------------------------------
/images/multiple_dogs.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1Konny/gradcam_plus_plus-pytorch/07fd6ece5010f7c1c9fbcc8155a60023819111d7/images/multiple_dogs.jpg
--------------------------------------------------------------------------------
/images/snake.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1Konny/gradcam_plus_plus-pytorch/07fd6ece5010f7c1c9fbcc8155a60023819111d7/images/snake.JPEG
--------------------------------------------------------------------------------
/images/water-bird.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1Konny/gradcam_plus_plus-pytorch/07fd6ece5010f7c1c9fbcc8155a60023819111d7/images/water-bird.JPEG
--------------------------------------------------------------------------------
/outputs/collies.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1Konny/gradcam_plus_plus-pytorch/07fd6ece5010f7c1c9fbcc8155a60023819111d7/outputs/collies.JPG
--------------------------------------------------------------------------------
/outputs/multiple_dogs.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1Konny/gradcam_plus_plus-pytorch/07fd6ece5010f7c1c9fbcc8155a60023819111d7/outputs/multiple_dogs.jpg
--------------------------------------------------------------------------------
/outputs/snake.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1Konny/gradcam_plus_plus-pytorch/07fd6ece5010f7c1c9fbcc8155a60023819111d7/outputs/snake.JPEG
--------------------------------------------------------------------------------
/outputs/water-bird.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1Konny/gradcam_plus_plus-pytorch/07fd6ece5010f7c1c9fbcc8155a60023819111d7/outputs/water-bird.JPEG
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 |
5 | def visualize_cam(mask, img):
6 | """Make heatmap from mask and synthesize GradCAM result image using heatmap and img.
7 | Args:
8 | mask (torch.tensor): mask shape of (1, 1, H, W) and each element has value in range [0, 1]
9 | img (torch.tensor): img shape of (1, 3, H, W) and each pixel value is in range [0, 1]
10 |
11 | Return:
12 | heatmap (torch.tensor): heatmap img shape of (3, H, W)
13 | result (torch.tensor): synthesized GradCAM result of same shape with heatmap.
14 | """
15 | heatmap = cv2.applyColorMap(np.uint8(255 * mask.squeeze()), cv2.COLORMAP_JET)
16 | heatmap = torch.from_numpy(heatmap).permute(2, 0, 1).float().div(255)
17 | b, g, r = heatmap.split(1)
18 | heatmap = torch.cat([r, g, b])
19 |
20 | result = heatmap+img.cpu()
21 | result = result.div(result.max()).squeeze()
22 |
23 | return heatmap, result
24 |
25 |
26 | def find_resnet_layer(arch, target_layer_name):
27 | """Find resnet layer to calculate GradCAM and GradCAM++
28 |
29 | Args:
30 | arch: default torchvision densenet models
31 | target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below.
32 | target_layer_name = 'conv1'
33 | target_layer_name = 'layer1'
34 | target_layer_name = 'layer1_basicblock0'
35 | target_layer_name = 'layer1_basicblock0_relu'
36 | target_layer_name = 'layer1_bottleneck0'
37 | target_layer_name = 'layer1_bottleneck0_conv1'
38 | target_layer_name = 'layer1_bottleneck0_downsample'
39 | target_layer_name = 'layer1_bottleneck0_downsample_0'
40 | target_layer_name = 'avgpool'
41 | target_layer_name = 'fc'
42 |
43 | Return:
44 | target_layer: found layer. this layer will be hooked to get forward/backward pass information.
45 | """
46 | if 'layer' in target_layer_name:
47 | hierarchy = target_layer_name.split('_')
48 | layer_num = int(hierarchy[0].lstrip('layer'))
49 | if layer_num == 1:
50 | target_layer = arch.layer1
51 | elif layer_num == 2:
52 | target_layer = arch.layer2
53 | elif layer_num == 3:
54 | target_layer = arch.layer3
55 | elif layer_num == 4:
56 | target_layer = arch.layer4
57 | else:
58 | raise ValueError('unknown layer : {}'.format(target_layer_name))
59 |
60 | if len(hierarchy) >= 2:
61 | bottleneck_num = int(hierarchy[1].lower().lstrip('bottleneck').lstrip('basicblock'))
62 | target_layer = target_layer[bottleneck_num]
63 |
64 | if len(hierarchy) >= 3:
65 | target_layer = target_layer._modules[hierarchy[2]]
66 |
67 | if len(hierarchy) == 4:
68 | target_layer = target_layer._modules[hierarchy[3]]
69 |
70 | else:
71 | target_layer = arch._modules[target_layer_name]
72 |
73 | return target_layer
74 |
75 |
76 | def find_densenet_layer(arch, target_layer_name):
77 | """Find densenet layer to calculate GradCAM and GradCAM++
78 |
79 | Args:
80 | arch: default torchvision densenet models
81 | target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below.
82 | target_layer_name = 'features'
83 | target_layer_name = 'features_transition1'
84 | target_layer_name = 'features_transition1_norm'
85 | target_layer_name = 'features_denseblock2_denselayer12'
86 | target_layer_name = 'features_denseblock2_denselayer12_norm1'
87 | target_layer_name = 'features_denseblock2_denselayer12_norm1'
88 | target_layer_name = 'classifier'
89 |
90 | Return:
91 | target_layer: found layer. this layer will be hooked to get forward/backward pass information.
92 | """
93 |
94 | hierarchy = target_layer_name.split('_')
95 | target_layer = arch._modules[hierarchy[0]]
96 |
97 | if len(hierarchy) >= 2:
98 | target_layer = target_layer._modules[hierarchy[1]]
99 |
100 | if len(hierarchy) >= 3:
101 | target_layer = target_layer._modules[hierarchy[2]]
102 |
103 | if len(hierarchy) == 4:
104 | target_layer = target_layer._modules[hierarchy[3]]
105 |
106 | return target_layer
107 |
108 |
109 | def find_vgg_layer(arch, target_layer_name):
110 | """Find vgg layer to calculate GradCAM and GradCAM++
111 |
112 | Args:
113 | arch: default torchvision densenet models
114 | target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below.
115 | target_layer_name = 'features'
116 | target_layer_name = 'features_42'
117 | target_layer_name = 'classifier'
118 | target_layer_name = 'classifier_0'
119 |
120 | Return:
121 | target_layer: found layer. this layer will be hooked to get forward/backward pass information.
122 | """
123 | hierarchy = target_layer_name.split('_')
124 |
125 | if len(hierarchy) >= 1:
126 | target_layer = arch.features
127 |
128 | if len(hierarchy) == 2:
129 | target_layer = target_layer[int(hierarchy[1])]
130 |
131 | return target_layer
132 |
133 |
134 | def find_alexnet_layer(arch, target_layer_name):
135 | """Find alexnet layer to calculate GradCAM and GradCAM++
136 |
137 | Args:
138 | arch: default torchvision densenet models
139 | target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below.
140 | target_layer_name = 'features'
141 | target_layer_name = 'features_0'
142 | target_layer_name = 'classifier'
143 | target_layer_name = 'classifier_0'
144 |
145 | Return:
146 | target_layer: found layer. this layer will be hooked to get forward/backward pass information.
147 | """
148 | hierarchy = target_layer_name.split('_')
149 |
150 | if len(hierarchy) >= 1:
151 | target_layer = arch.features
152 |
153 | if len(hierarchy) == 2:
154 | target_layer = target_layer[int(hierarchy[1])]
155 |
156 | return target_layer
157 |
158 |
159 | def find_squeezenet_layer(arch, target_layer_name):
160 | """Find squeezenet layer to calculate GradCAM and GradCAM++
161 |
162 | Args:
163 | arch: default torchvision densenet models
164 | target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below.
165 | target_layer_name = 'features_12'
166 | target_layer_name = 'features_12_expand3x3'
167 | target_layer_name = 'features_12_expand3x3_activation'
168 |
169 | Return:
170 | target_layer: found layer. this layer will be hooked to get forward/backward pass information.
171 | """
172 | hierarchy = target_layer_name.split('_')
173 | target_layer = arch._modules[hierarchy[0]]
174 |
175 | if len(hierarchy) >= 2:
176 | target_layer = target_layer._modules[hierarchy[1]]
177 |
178 | if len(hierarchy) == 3:
179 | target_layer = target_layer._modules[hierarchy[2]]
180 |
181 | elif len(hierarchy) == 4:
182 | target_layer = target_layer._modules[hierarchy[2]+'_'+hierarchy[3]]
183 |
184 | return target_layer
185 |
186 |
187 | def denormalize(tensor, mean, std):
188 | if not tensor.ndimension() == 4:
189 | raise TypeError('tensor should be 4D')
190 |
191 | mean = torch.FloatTensor(mean).view(1, 3, 1, 1).expand_as(tensor).to(tensor.device)
192 | std = torch.FloatTensor(std).view(1, 3, 1, 1).expand_as(tensor).to(tensor.device)
193 |
194 | return tensor.mul(std).add(mean)
195 |
196 |
197 | def normalize(tensor, mean, std):
198 | if not tensor.ndimension() == 4:
199 | raise TypeError('tensor should be 4D')
200 |
201 | mean = torch.FloatTensor(mean).view(1, 3, 1, 1).expand_as(tensor).to(tensor.device)
202 | std = torch.FloatTensor(std).view(1, 3, 1, 1).expand_as(tensor).to(tensor.device)
203 |
204 | return tensor.sub(mean).div(std)
205 |
206 |
207 | class Normalize(object):
208 | def __init__(self, mean, std):
209 | self.mean = mean
210 | self.std = std
211 |
212 | def __call__(self, tensor):
213 | return self.do(tensor)
214 |
215 | def do(self, tensor):
216 | return normalize(tensor, self.mean, self.std)
217 |
218 | def undo(self, tensor):
219 | return denormalize(tensor, self.mean, self.std)
220 |
221 | def __repr__(self):
222 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
--------------------------------------------------------------------------------