├── README.md ├── assets └── readme.png ├── gradcam.py ├── images ├── collies.JPG ├── multiple_dogs.jpg ├── snake.JPEG └── water-bird.JPEG ├── outputs ├── collies.JPG ├── multiple_dogs.jpg ├── snake.JPEG └── water-bird.JPEG └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | ### A Pytorch implementation of GradCAM[1], GradCAM++[2], and Smooth-GradCAM++ [3] 2 |
3 |

4 | 5 |

6 | 7 | ### Supported torchvision models 8 | - alexnet 9 | - vgg 10 | - resnet 11 | - densenet 12 | - squeezenet 13 | 14 | ### Usage 15 | please refer to `example.ipynb` for general usage and refer to documentations of each layer-finding functions in `utils.py` if you want to know how to set `target_layer_name` properly. 16 | 17 | ### References: 18 | [1] Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization, Selvaraju et al, ICCV, 2017
19 | [2] Grad-CAM++: Generalized Gradient-based Visual Explanations for Deep Convolutional Networks, Chattopadhyay et al, WACV, 2018 20 | [3] Smooth Grad-CAM++: An Enhanced Inference Level Visualization Technique for Deep Convolutional Neural Network Models, Omeiza et al, 2019 21 | -------------------------------------------------------------------------------- /assets/readme.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefannc/GradCAM-Pytorch/cdd4e4b93de975220e68cacfec40cf8704a294ed/assets/readme.png -------------------------------------------------------------------------------- /gradcam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | from utils import find_alexnet_layer, find_vgg_layer, find_resnet_layer, find_densenet_layer, find_squeezenet_layer 6 | 7 | 8 | class GradCAM(object): 9 | """Calculate GradCAM saliency map. 10 | 11 | A simple example: 12 | 13 | # initialize a model, model_dict and gradcam 14 | resnet = torchvision.models.resnet101(pretrained=True) 15 | resnet.eval() 16 | model_dict = dict(model_type='resnet', arch=resnet, layer_name='layer4', input_size=(224, 224)) 17 | gradcam = GradCAM(model_dict) 18 | 19 | # get an image and normalize with mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) 20 | img = load_img() 21 | normed_img = normalizer(img) 22 | 23 | # get a GradCAM saliency map on the class index 10. 24 | mask, logit = gradcam(normed_img, class_idx=10) 25 | 26 | # make heatmap from mask and synthesize saliency map using heatmap and img 27 | heatmap, cam_result = visualize_cam(mask, img) 28 | 29 | 30 | Args: 31 | model_dict (dict): a dictionary that contains 'model_type', 'arch', layer_name', 'input_size'(optional) as keys. 32 | verbose (bool): whether to print output size of the saliency map givien 'layer_name' and 'input_size' in model_dict. 33 | """ 34 | def __init__(self, model_dict, verbose=False): 35 | model_type = model_dict['type'] 36 | layer_name = model_dict['layer_name'] 37 | self.model_arch = model_dict['arch'] 38 | 39 | self.gradients = dict() 40 | self.activations = dict() 41 | def backward_hook(module, grad_input, grad_output): 42 | self.gradients['value'] = grad_output[0] 43 | return None 44 | def forward_hook(module, input, output): 45 | self.activations['value'] = output 46 | return None 47 | 48 | if 'vgg' in model_type.lower(): 49 | target_layer = find_vgg_layer(self.model_arch, layer_name) 50 | elif 'resnet' in model_type.lower(): 51 | target_layer = find_resnet_layer(self.model_arch, layer_name) 52 | elif 'densenet' in model_type.lower(): 53 | target_layer = find_densenet_layer(self.model_arch, layer_name) 54 | elif 'alexnet' in model_type.lower(): 55 | target_layer = find_alexnet_layer(self.model_arch, layer_name) 56 | elif 'squeezenet' in model_type.lower(): 57 | target_layer = find_squeezenet_layer(self.model_arch, layer_name) 58 | 59 | target_layer.register_forward_hook(forward_hook) 60 | target_layer.register_backward_hook(backward_hook) 61 | 62 | if verbose: 63 | try: 64 | input_size = model_dict['input_size'] 65 | except KeyError: 66 | print("please specify size of input image in model_dict. e.g. {'input_size':(224, 224)}") 67 | pass 68 | else: 69 | device = 'cuda' if next(self.model_arch.parameters()).is_cuda else 'cpu' 70 | self.model_arch(torch.zeros(1, 3, *(input_size), device=device)) 71 | print('saliency_map size :', self.activations['value'].shape[2:]) 72 | 73 | 74 | def forward(self, input, class_idx=None, retain_graph=False): 75 | """ 76 | Args: 77 | input: input image with shape of (1, 3, H, W) 78 | class_idx (int): class index for calculating GradCAM. 79 | If not specified, the class index that makes the highest model prediction score will be used. 80 | Return: 81 | mask: saliency map of the same spatial dimension with input 82 | logit: model output 83 | """ 84 | b, c, h, w = input.size() 85 | 86 | logit = self.model_arch(input) 87 | if class_idx is None: 88 | score = logit[:, logit.max(1)[-1]].squeeze() 89 | else: 90 | score = logit[:, class_idx].squeeze() 91 | 92 | self.model_arch.zero_grad() 93 | score.backward(retain_graph=retain_graph) 94 | gradients = self.gradients['value'] 95 | activations = self.activations['value'] 96 | b, k, u, v = gradients.size() 97 | 98 | alpha = gradients.view(b, k, -1).mean(2) 99 | #alpha = F.relu(gradients.view(b, k, -1)).mean(2) 100 | weights = alpha.view(b, k, 1, 1) 101 | 102 | saliency_map = (weights*activations).sum(1, keepdim=True) 103 | saliency_map = F.relu(saliency_map) 104 | saliency_map = F.upsample(saliency_map, size=(h, w), mode='bilinear', align_corners=False) 105 | saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max() 106 | saliency_map = (saliency_map - saliency_map_min).div(saliency_map_max - saliency_map_min).data 107 | 108 | return saliency_map, logit 109 | 110 | def __call__(self, input, class_idx=None, retain_graph=False): 111 | return self.forward(input, class_idx, retain_graph) 112 | 113 | 114 | class GradCAMpp(GradCAM): 115 | """Calculate GradCAM++ saliency map. 116 | 117 | A simple example: 118 | 119 | # initialize a model, model_dict and gradcampp 120 | resnet = torchvision.models.resnet101(pretrained=True) 121 | resnet.eval() 122 | model_dict = dict(model_type='resnet', arch=resnet, layer_name='layer4', input_size=(224, 224)) 123 | gradcampp = GradCAMpp(model_dict) 124 | 125 | # get an image and normalize with mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) 126 | img = load_img() 127 | normed_img = normalizer(img) 128 | 129 | # get a GradCAM saliency map on the class index 10. 130 | mask, logit = gradcampp(normed_img, class_idx=10) 131 | 132 | # make heatmap from mask and synthesize saliency map using heatmap and img 133 | heatmap, cam_result = visualize_cam(mask, img) 134 | 135 | 136 | Args: 137 | model_dict (dict): a dictionary that contains 'model_type', 'arch', layer_name', 'input_size'(optional) as keys. 138 | verbose (bool): whether to print output size of the saliency map givien 'layer_name' and 'input_size' in model_dict. 139 | """ 140 | def __init__(self, model_dict, verbose=False): 141 | super(GradCAMpp, self).__init__(model_dict, verbose) 142 | 143 | def forward(self, input, class_idx=None, retain_graph=False): 144 | """ 145 | Args: 146 | input: input image with shape of (1, 3, H, W) 147 | class_idx (int): class index for calculating GradCAM. 148 | If not specified, the class index that makes the highest model prediction score will be used. 149 | Return: 150 | mask: saliency map of the same spatial dimension with input 151 | logit: model output 152 | """ 153 | b, c, h, w = input.size() 154 | 155 | logit = self.model_arch(input) 156 | if class_idx is None: 157 | score = logit[:, logit.max(1)[-1]].squeeze() 158 | else: 159 | score = logit[:, class_idx].squeeze() 160 | 161 | self.model_arch.zero_grad() 162 | score.backward(retain_graph=retain_graph) 163 | gradients = self.gradients['value'] # dS/dA 164 | activations = self.activations['value'] # A 165 | b, k, u, v = gradients.size() 166 | 167 | alpha_num = gradients.pow(2) 168 | alpha_denom = gradients.pow(2).mul(2) + \ 169 | activations.mul(gradients.pow(3)).view(b, k, u*v).sum(-1, keepdim=True).view(b, k, 1, 1) 170 | alpha_denom = torch.where(alpha_denom != 0.0, alpha_denom, torch.ones_like(alpha_denom)) 171 | 172 | alpha = alpha_num.div(alpha_denom+1e-7) 173 | positive_gradients = F.relu(score.exp()*gradients) # ReLU(dY/dA) == ReLU(exp(S)*dS/dA)) 174 | weights = (alpha*positive_gradients).view(b, k, u*v).sum(-1).view(b, k, 1, 1) 175 | 176 | saliency_map = (weights*activations).sum(1, keepdim=True) 177 | saliency_map = F.relu(saliency_map) 178 | saliency_map = F.upsample(saliency_map, size=(224, 224), mode='bilinear', align_corners=False) 179 | saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max() 180 | saliency_map = (saliency_map-saliency_map_min).div(saliency_map_max-saliency_map_min).data 181 | 182 | return saliency_map, logit 183 | 184 | class SmoothGradCAMpp(GradCAM): 185 | """Calculate Smooth-GradCAM++ saliency map. 186 | 187 | A simple example: 188 | 189 | # initialize a model, model_dict and gradcampp 190 | resnet = torchvision.models.resnet101(pretrained=True) 191 | resnet.eval() 192 | model_dict = dict(model_type='resnet', arch=resnet, layer_name='layer4', input_size=(224, 224)) 193 | smgradcampp = SmoothGradCAMpp(model_dict) 194 | 195 | # get an image and normalize with mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) 196 | img = load_img() 197 | normed_img = normalizer(img) 198 | 199 | # get a GradCAM saliency map on the class index 10. 200 | mask, logit = smgradcampp(normed_img, class_idx=10) 201 | 202 | # make heatmap from mask and synthesize saliency map using heatmap and img 203 | heatmap, cam_result = visualize_cam(mask, img) 204 | 205 | 206 | Args: 207 | model_dict (dict): a dictionary that contains 'model_type', 'arch', layer_name', 'input_size'(optional) as keys. 208 | verbose (bool): whether to print output size of the saliency map givien 'layer_name' and 'input_size' in model_dict. 209 | """ 210 | def __init__(self, model_dict, verbose=False): 211 | super(SmoothGradCAMpp, self).__init__(model_dict, verbose) 212 | 213 | def addNoise(self, image, noiselevel): 214 | """ 215 | Args: 216 | image: input image with shape of (1, 3, H, W) 217 | noiselevel: noise percentage divided by 100 218 | Return: 219 | img: image with added gaussian noise 220 | """ 221 | bitrange = image.max() - image.min() 222 | noise = np.random.normal(scale = (noiselevel*bitrange), size = (3, 224, 224)) 223 | img = image.add(torch.tensor(noise.astype('float32'))) 224 | 225 | return img 226 | 227 | def forward(self, input, class_idx=None, retain_graph=False, n=50, noiselevel=0.1): 228 | """ 229 | Args: 230 | input: input image with shape of (1, 3, H, W) 231 | class_idx (int): class index for calculating GradCAM. 232 | If not specified, the class index that makes the highest model prediction score will be used. 233 | Return: 234 | mask: saliency map of the same spatial dimension with input 235 | logit: model output 236 | """ 237 | 238 | b, c, h, w = input.size() 239 | #Create lists to store calculated gradients 240 | alpha1 = [] 241 | alpha2 = [] 242 | alpha3 = [] 243 | relu_input = [] 244 | 245 | if class_idx is None: 246 | class_idx = (self.model_arch(input)).max(1)[-1] 247 | 248 | for i in range(0, n): 249 | input_ = self.addNoise(input, noiselevel) 250 | logit = self.model_arch(input_) 251 | score = logit[:, class_idx].squeeze() 252 | 253 | self.model_arch.zero_grad() 254 | score.backward(retain_graph=retain_graph) 255 | gradients = self.gradients['value'] # dS/dA 256 | activations = self.activations['value'] # A 257 | b, k, u, v = gradients.size() 258 | 259 | alpha1.append(gradients.pow(2)) 260 | alpha2.append(gradients.pow(2)) 261 | alpha3.append(gradients.pow(3)) 262 | relu_input.append((score.exp()).mul(gradients)) 263 | 264 | relu_input = (sum(relu_input)).div(n) 265 | 266 | alpha_num = (sum(alpha1)).div(n) 267 | alpha_denom = (sum(alpha2)).div(n).mul(2) + \ 268 | activations.mul((sum(alpha3)).div(n)).view(b, k, u*v).sum(-1, keepdim=True).view(b, k, 1, 1) 269 | alpha_denom = torch.where(alpha_denom != 0.0, alpha_denom, torch.ones_like(alpha_denom)) 270 | 271 | alpha = alpha_num.div(alpha_denom+1e-7) 272 | positive_gradients = F.relu(relu_input) # ReLU(dY/dA) == ReLU(exp(S)*dS/dA)) 273 | weights = (alpha*positive_gradients).view(b, k, u*v).sum(-1).view(b, k, 1, 1) 274 | 275 | saliency_map = (weights*activations).sum(1, keepdim=True) 276 | saliency_map = F.relu(saliency_map) 277 | saliency_map = F.upsample(saliency_map, size=(224, 224), mode='bilinear', align_corners=False) 278 | saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max() 279 | saliency_map = (saliency_map-saliency_map_min).div(saliency_map_max-saliency_map_min).data 280 | 281 | return saliency_map, logit 282 | 283 | -------------------------------------------------------------------------------- /images/collies.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefannc/GradCAM-Pytorch/cdd4e4b93de975220e68cacfec40cf8704a294ed/images/collies.JPG -------------------------------------------------------------------------------- /images/multiple_dogs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefannc/GradCAM-Pytorch/cdd4e4b93de975220e68cacfec40cf8704a294ed/images/multiple_dogs.jpg -------------------------------------------------------------------------------- /images/snake.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefannc/GradCAM-Pytorch/cdd4e4b93de975220e68cacfec40cf8704a294ed/images/snake.JPEG -------------------------------------------------------------------------------- /images/water-bird.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefannc/GradCAM-Pytorch/cdd4e4b93de975220e68cacfec40cf8704a294ed/images/water-bird.JPEG -------------------------------------------------------------------------------- /outputs/collies.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefannc/GradCAM-Pytorch/cdd4e4b93de975220e68cacfec40cf8704a294ed/outputs/collies.JPG -------------------------------------------------------------------------------- /outputs/multiple_dogs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefannc/GradCAM-Pytorch/cdd4e4b93de975220e68cacfec40cf8704a294ed/outputs/multiple_dogs.jpg -------------------------------------------------------------------------------- /outputs/snake.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefannc/GradCAM-Pytorch/cdd4e4b93de975220e68cacfec40cf8704a294ed/outputs/snake.JPEG -------------------------------------------------------------------------------- /outputs/water-bird.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefannc/GradCAM-Pytorch/cdd4e4b93de975220e68cacfec40cf8704a294ed/outputs/water-bird.JPEG -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | 5 | def visualize_cam(mask, img): 6 | """Make heatmap from mask and synthesize GradCAM result image using heatmap and img. 7 | Args: 8 | mask (torch.tensor): mask shape of (1, 1, H, W) and each element has value in range [0, 1] 9 | img (torch.tensor): img shape of (1, 3, H, W) and each pixel value is in range [0, 1] 10 | 11 | Return: 12 | heatmap (torch.tensor): heatmap img shape of (3, H, W) 13 | result (torch.tensor): synthesized GradCAM result of same shape with heatmap. 14 | """ 15 | heatmap = cv2.applyColorMap(np.uint8(255 * mask.squeeze()), cv2.COLORMAP_JET) 16 | heatmap = torch.from_numpy(heatmap).permute(2, 0, 1).float().div(255) 17 | b, g, r = heatmap.split(1) 18 | heatmap = torch.cat([r, g, b]) 19 | 20 | result = heatmap+img.cpu() 21 | result = result.div(result.max()).squeeze() 22 | 23 | return heatmap, result 24 | 25 | 26 | def find_resnet_layer(arch, target_layer_name): 27 | """Find resnet layer to calculate GradCAM and GradCAM++ 28 | 29 | Args: 30 | arch: default torchvision densenet models 31 | target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below. 32 | target_layer_name = 'conv1' 33 | target_layer_name = 'layer1' 34 | target_layer_name = 'layer1_basicblock0' 35 | target_layer_name = 'layer1_basicblock0_relu' 36 | target_layer_name = 'layer1_bottleneck0' 37 | target_layer_name = 'layer1_bottleneck0_conv1' 38 | target_layer_name = 'layer1_bottleneck0_downsample' 39 | target_layer_name = 'layer1_bottleneck0_downsample_0' 40 | target_layer_name = 'avgpool' 41 | target_layer_name = 'fc' 42 | 43 | Return: 44 | target_layer: found layer. this layer will be hooked to get forward/backward pass information. 45 | """ 46 | if 'layer' in target_layer_name: 47 | hierarchy = target_layer_name.split('_') 48 | layer_num = int(hierarchy[0].lstrip('layer')) 49 | if layer_num == 1: 50 | target_layer = arch.layer1 51 | elif layer_num == 2: 52 | target_layer = arch.layer2 53 | elif layer_num == 3: 54 | target_layer = arch.layer3 55 | elif layer_num == 4: 56 | target_layer = arch.layer4 57 | else: 58 | raise ValueError('unknown layer : {}'.format(target_layer_name)) 59 | 60 | if len(hierarchy) >= 2: 61 | bottleneck_num = int(hierarchy[1].lower().lstrip('bottleneck').lstrip('basicblock')) 62 | target_layer = target_layer[bottleneck_num] 63 | 64 | if len(hierarchy) >= 3: 65 | target_layer = target_layer._modules[hierarchy[2]] 66 | 67 | if len(hierarchy) == 4: 68 | target_layer = target_layer._modules[hierarchy[3]] 69 | 70 | else: 71 | target_layer = arch._modules[target_layer_name] 72 | 73 | return target_layer 74 | 75 | 76 | def find_densenet_layer(arch, target_layer_name): 77 | """Find densenet layer to calculate GradCAM and GradCAM++ 78 | 79 | Args: 80 | arch: default torchvision densenet models 81 | target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below. 82 | target_layer_name = 'features' 83 | target_layer_name = 'features_transition1' 84 | target_layer_name = 'features_transition1_norm' 85 | target_layer_name = 'features_denseblock2_denselayer12' 86 | target_layer_name = 'features_denseblock2_denselayer12_norm1' 87 | target_layer_name = 'features_denseblock2_denselayer12_norm1' 88 | target_layer_name = 'classifier' 89 | 90 | Return: 91 | target_layer: found layer. this layer will be hooked to get forward/backward pass information. 92 | """ 93 | 94 | hierarchy = target_layer_name.split('_') 95 | target_layer = arch._modules[hierarchy[0]] 96 | 97 | if len(hierarchy) >= 2: 98 | target_layer = target_layer._modules[hierarchy[1]] 99 | 100 | if len(hierarchy) >= 3: 101 | target_layer = target_layer._modules[hierarchy[2]] 102 | 103 | if len(hierarchy) == 4: 104 | target_layer = target_layer._modules[hierarchy[3]] 105 | 106 | return target_layer 107 | 108 | 109 | def find_vgg_layer(arch, target_layer_name): 110 | """Find vgg layer to calculate GradCAM and GradCAM++ 111 | 112 | Args: 113 | arch: default torchvision densenet models 114 | target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below. 115 | target_layer_name = 'features' 116 | target_layer_name = 'features_42' 117 | target_layer_name = 'classifier' 118 | target_layer_name = 'classifier_0' 119 | 120 | Return: 121 | target_layer: found layer. this layer will be hooked to get forward/backward pass information. 122 | """ 123 | hierarchy = target_layer_name.split('_') 124 | 125 | if len(hierarchy) >= 1: 126 | target_layer = arch.features 127 | 128 | if len(hierarchy) == 2: 129 | target_layer = target_layer[int(hierarchy[1])] 130 | 131 | return target_layer 132 | 133 | 134 | def find_alexnet_layer(arch, target_layer_name): 135 | """Find alexnet layer to calculate GradCAM and GradCAM++ 136 | 137 | Args: 138 | arch: default torchvision densenet models 139 | target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below. 140 | target_layer_name = 'features' 141 | target_layer_name = 'features_0' 142 | target_layer_name = 'classifier' 143 | target_layer_name = 'classifier_0' 144 | 145 | Return: 146 | target_layer: found layer. this layer will be hooked to get forward/backward pass information. 147 | """ 148 | hierarchy = target_layer_name.split('_') 149 | 150 | if len(hierarchy) >= 1: 151 | target_layer = arch.features 152 | 153 | if len(hierarchy) == 2: 154 | target_layer = target_layer[int(hierarchy[1])] 155 | 156 | return target_layer 157 | 158 | 159 | def find_squeezenet_layer(arch, target_layer_name): 160 | """Find squeezenet layer to calculate GradCAM and GradCAM++ 161 | 162 | Args: 163 | arch: default torchvision densenet models 164 | target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below. 165 | target_layer_name = 'features_12' 166 | target_layer_name = 'features_12_expand3x3' 167 | target_layer_name = 'features_12_expand3x3_activation' 168 | 169 | Return: 170 | target_layer: found layer. this layer will be hooked to get forward/backward pass information. 171 | """ 172 | hierarchy = target_layer_name.split('_') 173 | target_layer = arch._modules[hierarchy[0]] 174 | 175 | if len(hierarchy) >= 2: 176 | target_layer = target_layer._modules[hierarchy[1]] 177 | 178 | if len(hierarchy) == 3: 179 | target_layer = target_layer._modules[hierarchy[2]] 180 | 181 | elif len(hierarchy) == 4: 182 | target_layer = target_layer._modules[hierarchy[2]+'_'+hierarchy[3]] 183 | 184 | return target_layer 185 | 186 | 187 | def denormalize(tensor, mean, std): 188 | if not tensor.ndimension() == 4: 189 | raise TypeError('tensor should be 4D') 190 | 191 | mean = torch.FloatTensor(mean).view(1, 3, 1, 1).expand_as(tensor).to(tensor.device) 192 | std = torch.FloatTensor(std).view(1, 3, 1, 1).expand_as(tensor).to(tensor.device) 193 | 194 | return tensor.mul(std).add(mean) 195 | 196 | 197 | def normalize(tensor, mean, std): 198 | if not tensor.ndimension() == 4: 199 | raise TypeError('tensor should be 4D') 200 | 201 | mean = torch.FloatTensor(mean).view(1, 3, 1, 1).expand_as(tensor).to(tensor.device) 202 | std = torch.FloatTensor(std).view(1, 3, 1, 1).expand_as(tensor).to(tensor.device) 203 | 204 | return tensor.sub(mean).div(std) 205 | 206 | 207 | class Normalize(object): 208 | def __init__(self, mean, std): 209 | self.mean = mean 210 | self.std = std 211 | 212 | def __call__(self, tensor): 213 | return self.do(tensor) 214 | 215 | def do(self, tensor): 216 | return normalize(tensor, self.mean, self.std) 217 | 218 | def undo(self, tensor): 219 | return denormalize(tensor, self.mean, self.std) 220 | 221 | def __repr__(self): 222 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) --------------------------------------------------------------------------------