├── README.md ├── assets └── readme.png ├── example.ipynb ├── gradcam.py ├── images ├── collies.JPG ├── multiple_dogs.jpg ├── snake.JPEG └── water-bird.JPEG ├── outputs ├── collies.JPG ├── multiple_dogs.jpg ├── snake.JPEG └── water-bird.JPEG └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | ### A Simple pytorch implementation of GradCAM[1], and GradCAM++[2] 2 |
3 |

4 | 5 |

6 | 7 | ### Supported torchvision models 8 | - alexnet 9 | - vgg 10 | - resnet 11 | - densenet 12 | - squeezenet 13 | 14 | ### Usage 15 | please refer to `example.ipynb` for general usage and refer to documentations of each layer-finding functions in `utils.py` if you want to know how to set `target_layer_name` properly. 16 | 17 | ### References: 18 | [1] Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization, Selvaraju et al, ICCV, 2017
19 | [2] Grad-CAM++: Generalized Gradient-based Visual Explanations for Deep Convolutional Networks, Chattopadhyay et al, WACV, 2018 20 | -------------------------------------------------------------------------------- /assets/readme.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1Konny/gradcam_plus_plus-pytorch/07fd6ece5010f7c1c9fbcc8155a60023819111d7/assets/readme.png -------------------------------------------------------------------------------- /gradcam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from utils import find_alexnet_layer, find_vgg_layer, find_resnet_layer, find_densenet_layer, find_squeezenet_layer 5 | 6 | 7 | class GradCAM(object): 8 | """Calculate GradCAM salinecy map. 9 | 10 | A simple example: 11 | 12 | # initialize a model, model_dict and gradcam 13 | resnet = torchvision.models.resnet101(pretrained=True) 14 | resnet.eval() 15 | model_dict = dict(model_type='resnet', arch=resnet, layer_name='layer4', input_size=(224, 224)) 16 | gradcam = GradCAM(model_dict) 17 | 18 | # get an image and normalize with mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) 19 | img = load_img() 20 | normed_img = normalizer(img) 21 | 22 | # get a GradCAM saliency map on the class index 10. 23 | mask, logit = gradcam(normed_img, class_idx=10) 24 | 25 | # make heatmap from mask and synthesize saliency map using heatmap and img 26 | heatmap, cam_result = visualize_cam(mask, img) 27 | 28 | 29 | Args: 30 | model_dict (dict): a dictionary that contains 'model_type', 'arch', layer_name', 'input_size'(optional) as keys. 31 | verbose (bool): whether to print output size of the saliency map givien 'layer_name' and 'input_size' in model_dict. 32 | """ 33 | def __init__(self, model_dict, verbose=False): 34 | model_type = model_dict['type'] 35 | layer_name = model_dict['layer_name'] 36 | self.model_arch = model_dict['arch'] 37 | 38 | self.gradients = dict() 39 | self.activations = dict() 40 | def backward_hook(module, grad_input, grad_output): 41 | self.gradients['value'] = grad_output[0] 42 | return None 43 | def forward_hook(module, input, output): 44 | self.activations['value'] = output 45 | return None 46 | 47 | if 'vgg' in model_type.lower(): 48 | target_layer = find_vgg_layer(self.model_arch, layer_name) 49 | elif 'resnet' in model_type.lower(): 50 | target_layer = find_resnet_layer(self.model_arch, layer_name) 51 | elif 'densenet' in model_type.lower(): 52 | target_layer = find_densenet_layer(self.model_arch, layer_name) 53 | elif 'alexnet' in model_type.lower(): 54 | target_layer = find_alexnet_layer(self.model_arch, layer_name) 55 | elif 'squeezenet' in model_type.lower(): 56 | target_layer = find_squeezenet_layer(self.model_arch, layer_name) 57 | 58 | target_layer.register_forward_hook(forward_hook) 59 | target_layer.register_backward_hook(backward_hook) 60 | 61 | if verbose: 62 | try: 63 | input_size = model_dict['input_size'] 64 | except KeyError: 65 | print("please specify size of input image in model_dict. e.g. {'input_size':(224, 224)}") 66 | pass 67 | else: 68 | device = 'cuda' if next(self.model_arch.parameters()).is_cuda else 'cpu' 69 | self.model_arch(torch.zeros(1, 3, *(input_size), device=device)) 70 | print('saliency_map size :', self.activations['value'].shape[2:]) 71 | 72 | 73 | def forward(self, input, class_idx=None, retain_graph=False): 74 | """ 75 | Args: 76 | input: input image with shape of (1, 3, H, W) 77 | class_idx (int): class index for calculating GradCAM. 78 | If not specified, the class index that makes the highest model prediction score will be used. 79 | Return: 80 | mask: saliency map of the same spatial dimension with input 81 | logit: model output 82 | """ 83 | b, c, h, w = input.size() 84 | 85 | logit = self.model_arch(input) 86 | if class_idx is None: 87 | score = logit[:, logit.max(1)[-1]].squeeze() 88 | else: 89 | score = logit[:, class_idx].squeeze() 90 | 91 | self.model_arch.zero_grad() 92 | score.backward(retain_graph=retain_graph) 93 | gradients = self.gradients['value'] 94 | activations = self.activations['value'] 95 | b, k, u, v = gradients.size() 96 | 97 | alpha = gradients.view(b, k, -1).mean(2) 98 | #alpha = F.relu(gradients.view(b, k, -1)).mean(2) 99 | weights = alpha.view(b, k, 1, 1) 100 | 101 | saliency_map = (weights*activations).sum(1, keepdim=True) 102 | saliency_map = F.relu(saliency_map) 103 | saliency_map = F.upsample(saliency_map, size=(h, w), mode='bilinear', align_corners=False) 104 | saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max() 105 | saliency_map = (saliency_map - saliency_map_min).div(saliency_map_max - saliency_map_min).data 106 | 107 | return saliency_map, logit 108 | 109 | def __call__(self, input, class_idx=None, retain_graph=False): 110 | return self.forward(input, class_idx, retain_graph) 111 | 112 | 113 | class GradCAMpp(GradCAM): 114 | """Calculate GradCAM++ salinecy map. 115 | 116 | A simple example: 117 | 118 | # initialize a model, model_dict and gradcampp 119 | resnet = torchvision.models.resnet101(pretrained=True) 120 | resnet.eval() 121 | model_dict = dict(model_type='resnet', arch=resnet, layer_name='layer4', input_size=(224, 224)) 122 | gradcampp = GradCAMpp(model_dict) 123 | 124 | # get an image and normalize with mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) 125 | img = load_img() 126 | normed_img = normalizer(img) 127 | 128 | # get a GradCAM saliency map on the class index 10. 129 | mask, logit = gradcampp(normed_img, class_idx=10) 130 | 131 | # make heatmap from mask and synthesize saliency map using heatmap and img 132 | heatmap, cam_result = visualize_cam(mask, img) 133 | 134 | 135 | Args: 136 | model_dict (dict): a dictionary that contains 'model_type', 'arch', layer_name', 'input_size'(optional) as keys. 137 | verbose (bool): whether to print output size of the saliency map givien 'layer_name' and 'input_size' in model_dict. 138 | """ 139 | def __init__(self, model_dict, verbose=False): 140 | super(GradCAMpp, self).__init__(model_dict, verbose) 141 | 142 | def forward(self, input, class_idx=None, retain_graph=False): 143 | """ 144 | Args: 145 | input: input image with shape of (1, 3, H, W) 146 | class_idx (int): class index for calculating GradCAM. 147 | If not specified, the class index that makes the highest model prediction score will be used. 148 | Return: 149 | mask: saliency map of the same spatial dimension with input 150 | logit: model output 151 | """ 152 | b, c, h, w = input.size() 153 | 154 | logit = self.model_arch(input) 155 | if class_idx is None: 156 | score = logit[:, logit.max(1)[-1]].squeeze() 157 | else: 158 | score = logit[:, class_idx].squeeze() 159 | 160 | self.model_arch.zero_grad() 161 | score.backward(retain_graph=retain_graph) 162 | gradients = self.gradients['value'] # dS/dA 163 | activations = self.activations['value'] # A 164 | b, k, u, v = gradients.size() 165 | 166 | alpha_num = gradients.pow(2) 167 | alpha_denom = gradients.pow(2).mul(2) + \ 168 | activations.mul(gradients.pow(3)).view(b, k, u*v).sum(-1, keepdim=True).view(b, k, 1, 1) 169 | alpha_denom = torch.where(alpha_denom != 0.0, alpha_denom, torch.ones_like(alpha_denom)) 170 | 171 | alpha = alpha_num.div(alpha_denom+1e-7) 172 | positive_gradients = F.relu(score.exp()*gradients) # ReLU(dY/dA) == ReLU(exp(S)*dS/dA)) 173 | weights = (alpha*positive_gradients).view(b, k, u*v).sum(-1).view(b, k, 1, 1) 174 | 175 | saliency_map = (weights*activations).sum(1, keepdim=True) 176 | saliency_map = F.relu(saliency_map) 177 | saliency_map = F.upsample(saliency_map, size=(224, 224), mode='bilinear', align_corners=False) 178 | saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max() 179 | saliency_map = (saliency_map-saliency_map_min).div(saliency_map_max-saliency_map_min).data 180 | 181 | return saliency_map, logit 182 | -------------------------------------------------------------------------------- /images/collies.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1Konny/gradcam_plus_plus-pytorch/07fd6ece5010f7c1c9fbcc8155a60023819111d7/images/collies.JPG -------------------------------------------------------------------------------- /images/multiple_dogs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1Konny/gradcam_plus_plus-pytorch/07fd6ece5010f7c1c9fbcc8155a60023819111d7/images/multiple_dogs.jpg -------------------------------------------------------------------------------- /images/snake.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1Konny/gradcam_plus_plus-pytorch/07fd6ece5010f7c1c9fbcc8155a60023819111d7/images/snake.JPEG -------------------------------------------------------------------------------- /images/water-bird.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1Konny/gradcam_plus_plus-pytorch/07fd6ece5010f7c1c9fbcc8155a60023819111d7/images/water-bird.JPEG -------------------------------------------------------------------------------- /outputs/collies.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1Konny/gradcam_plus_plus-pytorch/07fd6ece5010f7c1c9fbcc8155a60023819111d7/outputs/collies.JPG -------------------------------------------------------------------------------- /outputs/multiple_dogs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1Konny/gradcam_plus_plus-pytorch/07fd6ece5010f7c1c9fbcc8155a60023819111d7/outputs/multiple_dogs.jpg -------------------------------------------------------------------------------- /outputs/snake.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1Konny/gradcam_plus_plus-pytorch/07fd6ece5010f7c1c9fbcc8155a60023819111d7/outputs/snake.JPEG -------------------------------------------------------------------------------- /outputs/water-bird.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1Konny/gradcam_plus_plus-pytorch/07fd6ece5010f7c1c9fbcc8155a60023819111d7/outputs/water-bird.JPEG -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | 5 | def visualize_cam(mask, img): 6 | """Make heatmap from mask and synthesize GradCAM result image using heatmap and img. 7 | Args: 8 | mask (torch.tensor): mask shape of (1, 1, H, W) and each element has value in range [0, 1] 9 | img (torch.tensor): img shape of (1, 3, H, W) and each pixel value is in range [0, 1] 10 | 11 | Return: 12 | heatmap (torch.tensor): heatmap img shape of (3, H, W) 13 | result (torch.tensor): synthesized GradCAM result of same shape with heatmap. 14 | """ 15 | heatmap = cv2.applyColorMap(np.uint8(255 * mask.squeeze()), cv2.COLORMAP_JET) 16 | heatmap = torch.from_numpy(heatmap).permute(2, 0, 1).float().div(255) 17 | b, g, r = heatmap.split(1) 18 | heatmap = torch.cat([r, g, b]) 19 | 20 | result = heatmap+img.cpu() 21 | result = result.div(result.max()).squeeze() 22 | 23 | return heatmap, result 24 | 25 | 26 | def find_resnet_layer(arch, target_layer_name): 27 | """Find resnet layer to calculate GradCAM and GradCAM++ 28 | 29 | Args: 30 | arch: default torchvision densenet models 31 | target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below. 32 | target_layer_name = 'conv1' 33 | target_layer_name = 'layer1' 34 | target_layer_name = 'layer1_basicblock0' 35 | target_layer_name = 'layer1_basicblock0_relu' 36 | target_layer_name = 'layer1_bottleneck0' 37 | target_layer_name = 'layer1_bottleneck0_conv1' 38 | target_layer_name = 'layer1_bottleneck0_downsample' 39 | target_layer_name = 'layer1_bottleneck0_downsample_0' 40 | target_layer_name = 'avgpool' 41 | target_layer_name = 'fc' 42 | 43 | Return: 44 | target_layer: found layer. this layer will be hooked to get forward/backward pass information. 45 | """ 46 | if 'layer' in target_layer_name: 47 | hierarchy = target_layer_name.split('_') 48 | layer_num = int(hierarchy[0].lstrip('layer')) 49 | if layer_num == 1: 50 | target_layer = arch.layer1 51 | elif layer_num == 2: 52 | target_layer = arch.layer2 53 | elif layer_num == 3: 54 | target_layer = arch.layer3 55 | elif layer_num == 4: 56 | target_layer = arch.layer4 57 | else: 58 | raise ValueError('unknown layer : {}'.format(target_layer_name)) 59 | 60 | if len(hierarchy) >= 2: 61 | bottleneck_num = int(hierarchy[1].lower().lstrip('bottleneck').lstrip('basicblock')) 62 | target_layer = target_layer[bottleneck_num] 63 | 64 | if len(hierarchy) >= 3: 65 | target_layer = target_layer._modules[hierarchy[2]] 66 | 67 | if len(hierarchy) == 4: 68 | target_layer = target_layer._modules[hierarchy[3]] 69 | 70 | else: 71 | target_layer = arch._modules[target_layer_name] 72 | 73 | return target_layer 74 | 75 | 76 | def find_densenet_layer(arch, target_layer_name): 77 | """Find densenet layer to calculate GradCAM and GradCAM++ 78 | 79 | Args: 80 | arch: default torchvision densenet models 81 | target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below. 82 | target_layer_name = 'features' 83 | target_layer_name = 'features_transition1' 84 | target_layer_name = 'features_transition1_norm' 85 | target_layer_name = 'features_denseblock2_denselayer12' 86 | target_layer_name = 'features_denseblock2_denselayer12_norm1' 87 | target_layer_name = 'features_denseblock2_denselayer12_norm1' 88 | target_layer_name = 'classifier' 89 | 90 | Return: 91 | target_layer: found layer. this layer will be hooked to get forward/backward pass information. 92 | """ 93 | 94 | hierarchy = target_layer_name.split('_') 95 | target_layer = arch._modules[hierarchy[0]] 96 | 97 | if len(hierarchy) >= 2: 98 | target_layer = target_layer._modules[hierarchy[1]] 99 | 100 | if len(hierarchy) >= 3: 101 | target_layer = target_layer._modules[hierarchy[2]] 102 | 103 | if len(hierarchy) == 4: 104 | target_layer = target_layer._modules[hierarchy[3]] 105 | 106 | return target_layer 107 | 108 | 109 | def find_vgg_layer(arch, target_layer_name): 110 | """Find vgg layer to calculate GradCAM and GradCAM++ 111 | 112 | Args: 113 | arch: default torchvision densenet models 114 | target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below. 115 | target_layer_name = 'features' 116 | target_layer_name = 'features_42' 117 | target_layer_name = 'classifier' 118 | target_layer_name = 'classifier_0' 119 | 120 | Return: 121 | target_layer: found layer. this layer will be hooked to get forward/backward pass information. 122 | """ 123 | hierarchy = target_layer_name.split('_') 124 | 125 | if len(hierarchy) >= 1: 126 | target_layer = arch.features 127 | 128 | if len(hierarchy) == 2: 129 | target_layer = target_layer[int(hierarchy[1])] 130 | 131 | return target_layer 132 | 133 | 134 | def find_alexnet_layer(arch, target_layer_name): 135 | """Find alexnet layer to calculate GradCAM and GradCAM++ 136 | 137 | Args: 138 | arch: default torchvision densenet models 139 | target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below. 140 | target_layer_name = 'features' 141 | target_layer_name = 'features_0' 142 | target_layer_name = 'classifier' 143 | target_layer_name = 'classifier_0' 144 | 145 | Return: 146 | target_layer: found layer. this layer will be hooked to get forward/backward pass information. 147 | """ 148 | hierarchy = target_layer_name.split('_') 149 | 150 | if len(hierarchy) >= 1: 151 | target_layer = arch.features 152 | 153 | if len(hierarchy) == 2: 154 | target_layer = target_layer[int(hierarchy[1])] 155 | 156 | return target_layer 157 | 158 | 159 | def find_squeezenet_layer(arch, target_layer_name): 160 | """Find squeezenet layer to calculate GradCAM and GradCAM++ 161 | 162 | Args: 163 | arch: default torchvision densenet models 164 | target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below. 165 | target_layer_name = 'features_12' 166 | target_layer_name = 'features_12_expand3x3' 167 | target_layer_name = 'features_12_expand3x3_activation' 168 | 169 | Return: 170 | target_layer: found layer. this layer will be hooked to get forward/backward pass information. 171 | """ 172 | hierarchy = target_layer_name.split('_') 173 | target_layer = arch._modules[hierarchy[0]] 174 | 175 | if len(hierarchy) >= 2: 176 | target_layer = target_layer._modules[hierarchy[1]] 177 | 178 | if len(hierarchy) == 3: 179 | target_layer = target_layer._modules[hierarchy[2]] 180 | 181 | elif len(hierarchy) == 4: 182 | target_layer = target_layer._modules[hierarchy[2]+'_'+hierarchy[3]] 183 | 184 | return target_layer 185 | 186 | 187 | def denormalize(tensor, mean, std): 188 | if not tensor.ndimension() == 4: 189 | raise TypeError('tensor should be 4D') 190 | 191 | mean = torch.FloatTensor(mean).view(1, 3, 1, 1).expand_as(tensor).to(tensor.device) 192 | std = torch.FloatTensor(std).view(1, 3, 1, 1).expand_as(tensor).to(tensor.device) 193 | 194 | return tensor.mul(std).add(mean) 195 | 196 | 197 | def normalize(tensor, mean, std): 198 | if not tensor.ndimension() == 4: 199 | raise TypeError('tensor should be 4D') 200 | 201 | mean = torch.FloatTensor(mean).view(1, 3, 1, 1).expand_as(tensor).to(tensor.device) 202 | std = torch.FloatTensor(std).view(1, 3, 1, 1).expand_as(tensor).to(tensor.device) 203 | 204 | return tensor.sub(mean).div(std) 205 | 206 | 207 | class Normalize(object): 208 | def __init__(self, mean, std): 209 | self.mean = mean 210 | self.std = std 211 | 212 | def __call__(self, tensor): 213 | return self.do(tensor) 214 | 215 | def do(self, tensor): 216 | return normalize(tensor, self.mean, self.std) 217 | 218 | def undo(self, tensor): 219 | return denormalize(tensor, self.mean, self.std) 220 | 221 | def __repr__(self): 222 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) --------------------------------------------------------------------------------