├── README.md
├── assets
└── readme.png
├── gradcam.py
├── images
├── collies.JPG
├── multiple_dogs.jpg
├── snake.JPEG
└── water-bird.JPEG
├── outputs
├── collies.JPG
├── multiple_dogs.jpg
├── snake.JPEG
└── water-bird.JPEG
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | ### A Pytorch implementation of GradCAM[1], GradCAM++[2], and Smooth-GradCAM++ [3]
2 |
3 |
4 |
5 |
6 |
7 | ### Supported torchvision models
8 | - alexnet
9 | - vgg
10 | - resnet
11 | - densenet
12 | - squeezenet
13 |
14 | ### Usage
15 | please refer to `example.ipynb` for general usage and refer to documentations of each layer-finding functions in `utils.py` if you want to know how to set `target_layer_name` properly.
16 |
17 | ### References:
18 | [1] Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization, Selvaraju et al, ICCV, 2017
19 | [2] Grad-CAM++: Generalized Gradient-based Visual Explanations for Deep Convolutional Networks, Chattopadhyay et al, WACV, 2018
20 | [3] Smooth Grad-CAM++: An Enhanced Inference Level Visualization Technique for Deep Convolutional Neural Network Models, Omeiza et al, 2019
21 |
--------------------------------------------------------------------------------
/assets/readme.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefannc/GradCAM-Pytorch/cdd4e4b93de975220e68cacfec40cf8704a294ed/assets/readme.png
--------------------------------------------------------------------------------
/gradcam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 |
5 | from utils import find_alexnet_layer, find_vgg_layer, find_resnet_layer, find_densenet_layer, find_squeezenet_layer
6 |
7 |
8 | class GradCAM(object):
9 | """Calculate GradCAM saliency map.
10 |
11 | A simple example:
12 |
13 | # initialize a model, model_dict and gradcam
14 | resnet = torchvision.models.resnet101(pretrained=True)
15 | resnet.eval()
16 | model_dict = dict(model_type='resnet', arch=resnet, layer_name='layer4', input_size=(224, 224))
17 | gradcam = GradCAM(model_dict)
18 |
19 | # get an image and normalize with mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
20 | img = load_img()
21 | normed_img = normalizer(img)
22 |
23 | # get a GradCAM saliency map on the class index 10.
24 | mask, logit = gradcam(normed_img, class_idx=10)
25 |
26 | # make heatmap from mask and synthesize saliency map using heatmap and img
27 | heatmap, cam_result = visualize_cam(mask, img)
28 |
29 |
30 | Args:
31 | model_dict (dict): a dictionary that contains 'model_type', 'arch', layer_name', 'input_size'(optional) as keys.
32 | verbose (bool): whether to print output size of the saliency map givien 'layer_name' and 'input_size' in model_dict.
33 | """
34 | def __init__(self, model_dict, verbose=False):
35 | model_type = model_dict['type']
36 | layer_name = model_dict['layer_name']
37 | self.model_arch = model_dict['arch']
38 |
39 | self.gradients = dict()
40 | self.activations = dict()
41 | def backward_hook(module, grad_input, grad_output):
42 | self.gradients['value'] = grad_output[0]
43 | return None
44 | def forward_hook(module, input, output):
45 | self.activations['value'] = output
46 | return None
47 |
48 | if 'vgg' in model_type.lower():
49 | target_layer = find_vgg_layer(self.model_arch, layer_name)
50 | elif 'resnet' in model_type.lower():
51 | target_layer = find_resnet_layer(self.model_arch, layer_name)
52 | elif 'densenet' in model_type.lower():
53 | target_layer = find_densenet_layer(self.model_arch, layer_name)
54 | elif 'alexnet' in model_type.lower():
55 | target_layer = find_alexnet_layer(self.model_arch, layer_name)
56 | elif 'squeezenet' in model_type.lower():
57 | target_layer = find_squeezenet_layer(self.model_arch, layer_name)
58 |
59 | target_layer.register_forward_hook(forward_hook)
60 | target_layer.register_backward_hook(backward_hook)
61 |
62 | if verbose:
63 | try:
64 | input_size = model_dict['input_size']
65 | except KeyError:
66 | print("please specify size of input image in model_dict. e.g. {'input_size':(224, 224)}")
67 | pass
68 | else:
69 | device = 'cuda' if next(self.model_arch.parameters()).is_cuda else 'cpu'
70 | self.model_arch(torch.zeros(1, 3, *(input_size), device=device))
71 | print('saliency_map size :', self.activations['value'].shape[2:])
72 |
73 |
74 | def forward(self, input, class_idx=None, retain_graph=False):
75 | """
76 | Args:
77 | input: input image with shape of (1, 3, H, W)
78 | class_idx (int): class index for calculating GradCAM.
79 | If not specified, the class index that makes the highest model prediction score will be used.
80 | Return:
81 | mask: saliency map of the same spatial dimension with input
82 | logit: model output
83 | """
84 | b, c, h, w = input.size()
85 |
86 | logit = self.model_arch(input)
87 | if class_idx is None:
88 | score = logit[:, logit.max(1)[-1]].squeeze()
89 | else:
90 | score = logit[:, class_idx].squeeze()
91 |
92 | self.model_arch.zero_grad()
93 | score.backward(retain_graph=retain_graph)
94 | gradients = self.gradients['value']
95 | activations = self.activations['value']
96 | b, k, u, v = gradients.size()
97 |
98 | alpha = gradients.view(b, k, -1).mean(2)
99 | #alpha = F.relu(gradients.view(b, k, -1)).mean(2)
100 | weights = alpha.view(b, k, 1, 1)
101 |
102 | saliency_map = (weights*activations).sum(1, keepdim=True)
103 | saliency_map = F.relu(saliency_map)
104 | saliency_map = F.upsample(saliency_map, size=(h, w), mode='bilinear', align_corners=False)
105 | saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
106 | saliency_map = (saliency_map - saliency_map_min).div(saliency_map_max - saliency_map_min).data
107 |
108 | return saliency_map, logit
109 |
110 | def __call__(self, input, class_idx=None, retain_graph=False):
111 | return self.forward(input, class_idx, retain_graph)
112 |
113 |
114 | class GradCAMpp(GradCAM):
115 | """Calculate GradCAM++ saliency map.
116 |
117 | A simple example:
118 |
119 | # initialize a model, model_dict and gradcampp
120 | resnet = torchvision.models.resnet101(pretrained=True)
121 | resnet.eval()
122 | model_dict = dict(model_type='resnet', arch=resnet, layer_name='layer4', input_size=(224, 224))
123 | gradcampp = GradCAMpp(model_dict)
124 |
125 | # get an image and normalize with mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
126 | img = load_img()
127 | normed_img = normalizer(img)
128 |
129 | # get a GradCAM saliency map on the class index 10.
130 | mask, logit = gradcampp(normed_img, class_idx=10)
131 |
132 | # make heatmap from mask and synthesize saliency map using heatmap and img
133 | heatmap, cam_result = visualize_cam(mask, img)
134 |
135 |
136 | Args:
137 | model_dict (dict): a dictionary that contains 'model_type', 'arch', layer_name', 'input_size'(optional) as keys.
138 | verbose (bool): whether to print output size of the saliency map givien 'layer_name' and 'input_size' in model_dict.
139 | """
140 | def __init__(self, model_dict, verbose=False):
141 | super(GradCAMpp, self).__init__(model_dict, verbose)
142 |
143 | def forward(self, input, class_idx=None, retain_graph=False):
144 | """
145 | Args:
146 | input: input image with shape of (1, 3, H, W)
147 | class_idx (int): class index for calculating GradCAM.
148 | If not specified, the class index that makes the highest model prediction score will be used.
149 | Return:
150 | mask: saliency map of the same spatial dimension with input
151 | logit: model output
152 | """
153 | b, c, h, w = input.size()
154 |
155 | logit = self.model_arch(input)
156 | if class_idx is None:
157 | score = logit[:, logit.max(1)[-1]].squeeze()
158 | else:
159 | score = logit[:, class_idx].squeeze()
160 |
161 | self.model_arch.zero_grad()
162 | score.backward(retain_graph=retain_graph)
163 | gradients = self.gradients['value'] # dS/dA
164 | activations = self.activations['value'] # A
165 | b, k, u, v = gradients.size()
166 |
167 | alpha_num = gradients.pow(2)
168 | alpha_denom = gradients.pow(2).mul(2) + \
169 | activations.mul(gradients.pow(3)).view(b, k, u*v).sum(-1, keepdim=True).view(b, k, 1, 1)
170 | alpha_denom = torch.where(alpha_denom != 0.0, alpha_denom, torch.ones_like(alpha_denom))
171 |
172 | alpha = alpha_num.div(alpha_denom+1e-7)
173 | positive_gradients = F.relu(score.exp()*gradients) # ReLU(dY/dA) == ReLU(exp(S)*dS/dA))
174 | weights = (alpha*positive_gradients).view(b, k, u*v).sum(-1).view(b, k, 1, 1)
175 |
176 | saliency_map = (weights*activations).sum(1, keepdim=True)
177 | saliency_map = F.relu(saliency_map)
178 | saliency_map = F.upsample(saliency_map, size=(224, 224), mode='bilinear', align_corners=False)
179 | saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
180 | saliency_map = (saliency_map-saliency_map_min).div(saliency_map_max-saliency_map_min).data
181 |
182 | return saliency_map, logit
183 |
184 | class SmoothGradCAMpp(GradCAM):
185 | """Calculate Smooth-GradCAM++ saliency map.
186 |
187 | A simple example:
188 |
189 | # initialize a model, model_dict and gradcampp
190 | resnet = torchvision.models.resnet101(pretrained=True)
191 | resnet.eval()
192 | model_dict = dict(model_type='resnet', arch=resnet, layer_name='layer4', input_size=(224, 224))
193 | smgradcampp = SmoothGradCAMpp(model_dict)
194 |
195 | # get an image and normalize with mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
196 | img = load_img()
197 | normed_img = normalizer(img)
198 |
199 | # get a GradCAM saliency map on the class index 10.
200 | mask, logit = smgradcampp(normed_img, class_idx=10)
201 |
202 | # make heatmap from mask and synthesize saliency map using heatmap and img
203 | heatmap, cam_result = visualize_cam(mask, img)
204 |
205 |
206 | Args:
207 | model_dict (dict): a dictionary that contains 'model_type', 'arch', layer_name', 'input_size'(optional) as keys.
208 | verbose (bool): whether to print output size of the saliency map givien 'layer_name' and 'input_size' in model_dict.
209 | """
210 | def __init__(self, model_dict, verbose=False):
211 | super(SmoothGradCAMpp, self).__init__(model_dict, verbose)
212 |
213 | def addNoise(self, image, noiselevel):
214 | """
215 | Args:
216 | image: input image with shape of (1, 3, H, W)
217 | noiselevel: noise percentage divided by 100
218 | Return:
219 | img: image with added gaussian noise
220 | """
221 | bitrange = image.max() - image.min()
222 | noise = np.random.normal(scale = (noiselevel*bitrange), size = (3, 224, 224))
223 | img = image.add(torch.tensor(noise.astype('float32')))
224 |
225 | return img
226 |
227 | def forward(self, input, class_idx=None, retain_graph=False, n=50, noiselevel=0.1):
228 | """
229 | Args:
230 | input: input image with shape of (1, 3, H, W)
231 | class_idx (int): class index for calculating GradCAM.
232 | If not specified, the class index that makes the highest model prediction score will be used.
233 | Return:
234 | mask: saliency map of the same spatial dimension with input
235 | logit: model output
236 | """
237 |
238 | b, c, h, w = input.size()
239 | #Create lists to store calculated gradients
240 | alpha1 = []
241 | alpha2 = []
242 | alpha3 = []
243 | relu_input = []
244 |
245 | if class_idx is None:
246 | class_idx = (self.model_arch(input)).max(1)[-1]
247 |
248 | for i in range(0, n):
249 | input_ = self.addNoise(input, noiselevel)
250 | logit = self.model_arch(input_)
251 | score = logit[:, class_idx].squeeze()
252 |
253 | self.model_arch.zero_grad()
254 | score.backward(retain_graph=retain_graph)
255 | gradients = self.gradients['value'] # dS/dA
256 | activations = self.activations['value'] # A
257 | b, k, u, v = gradients.size()
258 |
259 | alpha1.append(gradients.pow(2))
260 | alpha2.append(gradients.pow(2))
261 | alpha3.append(gradients.pow(3))
262 | relu_input.append((score.exp()).mul(gradients))
263 |
264 | relu_input = (sum(relu_input)).div(n)
265 |
266 | alpha_num = (sum(alpha1)).div(n)
267 | alpha_denom = (sum(alpha2)).div(n).mul(2) + \
268 | activations.mul((sum(alpha3)).div(n)).view(b, k, u*v).sum(-1, keepdim=True).view(b, k, 1, 1)
269 | alpha_denom = torch.where(alpha_denom != 0.0, alpha_denom, torch.ones_like(alpha_denom))
270 |
271 | alpha = alpha_num.div(alpha_denom+1e-7)
272 | positive_gradients = F.relu(relu_input) # ReLU(dY/dA) == ReLU(exp(S)*dS/dA))
273 | weights = (alpha*positive_gradients).view(b, k, u*v).sum(-1).view(b, k, 1, 1)
274 |
275 | saliency_map = (weights*activations).sum(1, keepdim=True)
276 | saliency_map = F.relu(saliency_map)
277 | saliency_map = F.upsample(saliency_map, size=(224, 224), mode='bilinear', align_corners=False)
278 | saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
279 | saliency_map = (saliency_map-saliency_map_min).div(saliency_map_max-saliency_map_min).data
280 |
281 | return saliency_map, logit
282 |
283 |
--------------------------------------------------------------------------------
/images/collies.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefannc/GradCAM-Pytorch/cdd4e4b93de975220e68cacfec40cf8704a294ed/images/collies.JPG
--------------------------------------------------------------------------------
/images/multiple_dogs.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefannc/GradCAM-Pytorch/cdd4e4b93de975220e68cacfec40cf8704a294ed/images/multiple_dogs.jpg
--------------------------------------------------------------------------------
/images/snake.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefannc/GradCAM-Pytorch/cdd4e4b93de975220e68cacfec40cf8704a294ed/images/snake.JPEG
--------------------------------------------------------------------------------
/images/water-bird.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefannc/GradCAM-Pytorch/cdd4e4b93de975220e68cacfec40cf8704a294ed/images/water-bird.JPEG
--------------------------------------------------------------------------------
/outputs/collies.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefannc/GradCAM-Pytorch/cdd4e4b93de975220e68cacfec40cf8704a294ed/outputs/collies.JPG
--------------------------------------------------------------------------------
/outputs/multiple_dogs.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefannc/GradCAM-Pytorch/cdd4e4b93de975220e68cacfec40cf8704a294ed/outputs/multiple_dogs.jpg
--------------------------------------------------------------------------------
/outputs/snake.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefannc/GradCAM-Pytorch/cdd4e4b93de975220e68cacfec40cf8704a294ed/outputs/snake.JPEG
--------------------------------------------------------------------------------
/outputs/water-bird.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefannc/GradCAM-Pytorch/cdd4e4b93de975220e68cacfec40cf8704a294ed/outputs/water-bird.JPEG
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 |
5 | def visualize_cam(mask, img):
6 | """Make heatmap from mask and synthesize GradCAM result image using heatmap and img.
7 | Args:
8 | mask (torch.tensor): mask shape of (1, 1, H, W) and each element has value in range [0, 1]
9 | img (torch.tensor): img shape of (1, 3, H, W) and each pixel value is in range [0, 1]
10 |
11 | Return:
12 | heatmap (torch.tensor): heatmap img shape of (3, H, W)
13 | result (torch.tensor): synthesized GradCAM result of same shape with heatmap.
14 | """
15 | heatmap = cv2.applyColorMap(np.uint8(255 * mask.squeeze()), cv2.COLORMAP_JET)
16 | heatmap = torch.from_numpy(heatmap).permute(2, 0, 1).float().div(255)
17 | b, g, r = heatmap.split(1)
18 | heatmap = torch.cat([r, g, b])
19 |
20 | result = heatmap+img.cpu()
21 | result = result.div(result.max()).squeeze()
22 |
23 | return heatmap, result
24 |
25 |
26 | def find_resnet_layer(arch, target_layer_name):
27 | """Find resnet layer to calculate GradCAM and GradCAM++
28 |
29 | Args:
30 | arch: default torchvision densenet models
31 | target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below.
32 | target_layer_name = 'conv1'
33 | target_layer_name = 'layer1'
34 | target_layer_name = 'layer1_basicblock0'
35 | target_layer_name = 'layer1_basicblock0_relu'
36 | target_layer_name = 'layer1_bottleneck0'
37 | target_layer_name = 'layer1_bottleneck0_conv1'
38 | target_layer_name = 'layer1_bottleneck0_downsample'
39 | target_layer_name = 'layer1_bottleneck0_downsample_0'
40 | target_layer_name = 'avgpool'
41 | target_layer_name = 'fc'
42 |
43 | Return:
44 | target_layer: found layer. this layer will be hooked to get forward/backward pass information.
45 | """
46 | if 'layer' in target_layer_name:
47 | hierarchy = target_layer_name.split('_')
48 | layer_num = int(hierarchy[0].lstrip('layer'))
49 | if layer_num == 1:
50 | target_layer = arch.layer1
51 | elif layer_num == 2:
52 | target_layer = arch.layer2
53 | elif layer_num == 3:
54 | target_layer = arch.layer3
55 | elif layer_num == 4:
56 | target_layer = arch.layer4
57 | else:
58 | raise ValueError('unknown layer : {}'.format(target_layer_name))
59 |
60 | if len(hierarchy) >= 2:
61 | bottleneck_num = int(hierarchy[1].lower().lstrip('bottleneck').lstrip('basicblock'))
62 | target_layer = target_layer[bottleneck_num]
63 |
64 | if len(hierarchy) >= 3:
65 | target_layer = target_layer._modules[hierarchy[2]]
66 |
67 | if len(hierarchy) == 4:
68 | target_layer = target_layer._modules[hierarchy[3]]
69 |
70 | else:
71 | target_layer = arch._modules[target_layer_name]
72 |
73 | return target_layer
74 |
75 |
76 | def find_densenet_layer(arch, target_layer_name):
77 | """Find densenet layer to calculate GradCAM and GradCAM++
78 |
79 | Args:
80 | arch: default torchvision densenet models
81 | target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below.
82 | target_layer_name = 'features'
83 | target_layer_name = 'features_transition1'
84 | target_layer_name = 'features_transition1_norm'
85 | target_layer_name = 'features_denseblock2_denselayer12'
86 | target_layer_name = 'features_denseblock2_denselayer12_norm1'
87 | target_layer_name = 'features_denseblock2_denselayer12_norm1'
88 | target_layer_name = 'classifier'
89 |
90 | Return:
91 | target_layer: found layer. this layer will be hooked to get forward/backward pass information.
92 | """
93 |
94 | hierarchy = target_layer_name.split('_')
95 | target_layer = arch._modules[hierarchy[0]]
96 |
97 | if len(hierarchy) >= 2:
98 | target_layer = target_layer._modules[hierarchy[1]]
99 |
100 | if len(hierarchy) >= 3:
101 | target_layer = target_layer._modules[hierarchy[2]]
102 |
103 | if len(hierarchy) == 4:
104 | target_layer = target_layer._modules[hierarchy[3]]
105 |
106 | return target_layer
107 |
108 |
109 | def find_vgg_layer(arch, target_layer_name):
110 | """Find vgg layer to calculate GradCAM and GradCAM++
111 |
112 | Args:
113 | arch: default torchvision densenet models
114 | target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below.
115 | target_layer_name = 'features'
116 | target_layer_name = 'features_42'
117 | target_layer_name = 'classifier'
118 | target_layer_name = 'classifier_0'
119 |
120 | Return:
121 | target_layer: found layer. this layer will be hooked to get forward/backward pass information.
122 | """
123 | hierarchy = target_layer_name.split('_')
124 |
125 | if len(hierarchy) >= 1:
126 | target_layer = arch.features
127 |
128 | if len(hierarchy) == 2:
129 | target_layer = target_layer[int(hierarchy[1])]
130 |
131 | return target_layer
132 |
133 |
134 | def find_alexnet_layer(arch, target_layer_name):
135 | """Find alexnet layer to calculate GradCAM and GradCAM++
136 |
137 | Args:
138 | arch: default torchvision densenet models
139 | target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below.
140 | target_layer_name = 'features'
141 | target_layer_name = 'features_0'
142 | target_layer_name = 'classifier'
143 | target_layer_name = 'classifier_0'
144 |
145 | Return:
146 | target_layer: found layer. this layer will be hooked to get forward/backward pass information.
147 | """
148 | hierarchy = target_layer_name.split('_')
149 |
150 | if len(hierarchy) >= 1:
151 | target_layer = arch.features
152 |
153 | if len(hierarchy) == 2:
154 | target_layer = target_layer[int(hierarchy[1])]
155 |
156 | return target_layer
157 |
158 |
159 | def find_squeezenet_layer(arch, target_layer_name):
160 | """Find squeezenet layer to calculate GradCAM and GradCAM++
161 |
162 | Args:
163 | arch: default torchvision densenet models
164 | target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below.
165 | target_layer_name = 'features_12'
166 | target_layer_name = 'features_12_expand3x3'
167 | target_layer_name = 'features_12_expand3x3_activation'
168 |
169 | Return:
170 | target_layer: found layer. this layer will be hooked to get forward/backward pass information.
171 | """
172 | hierarchy = target_layer_name.split('_')
173 | target_layer = arch._modules[hierarchy[0]]
174 |
175 | if len(hierarchy) >= 2:
176 | target_layer = target_layer._modules[hierarchy[1]]
177 |
178 | if len(hierarchy) == 3:
179 | target_layer = target_layer._modules[hierarchy[2]]
180 |
181 | elif len(hierarchy) == 4:
182 | target_layer = target_layer._modules[hierarchy[2]+'_'+hierarchy[3]]
183 |
184 | return target_layer
185 |
186 |
187 | def denormalize(tensor, mean, std):
188 | if not tensor.ndimension() == 4:
189 | raise TypeError('tensor should be 4D')
190 |
191 | mean = torch.FloatTensor(mean).view(1, 3, 1, 1).expand_as(tensor).to(tensor.device)
192 | std = torch.FloatTensor(std).view(1, 3, 1, 1).expand_as(tensor).to(tensor.device)
193 |
194 | return tensor.mul(std).add(mean)
195 |
196 |
197 | def normalize(tensor, mean, std):
198 | if not tensor.ndimension() == 4:
199 | raise TypeError('tensor should be 4D')
200 |
201 | mean = torch.FloatTensor(mean).view(1, 3, 1, 1).expand_as(tensor).to(tensor.device)
202 | std = torch.FloatTensor(std).view(1, 3, 1, 1).expand_as(tensor).to(tensor.device)
203 |
204 | return tensor.sub(mean).div(std)
205 |
206 |
207 | class Normalize(object):
208 | def __init__(self, mean, std):
209 | self.mean = mean
210 | self.std = std
211 |
212 | def __call__(self, tensor):
213 | return self.do(tensor)
214 |
215 | def do(self, tensor):
216 | return normalize(tensor, self.mean, self.std)
217 |
218 | def undo(self, tensor):
219 | return denormalize(tensor, self.mean, self.std)
220 |
221 | def __repr__(self):
222 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
--------------------------------------------------------------------------------