├── LICENSE
├── README.md
├── RISE
├── evaluate_auc_metrics.py
├── evaluation.py
├── pytorch_grad_cam
│ ├── __init__.py
│ ├── ablation_cam.py
│ ├── activations_and_gradients.py
│ ├── base_cam.py
│ ├── eigen_cam.py
│ ├── eigen_grad_cam.py
│ ├── grad_cam.py
│ ├── grad_cam_plusplus.py
│ ├── guided_backprop.py
│ ├── score_cam.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── image.py
│ │ └── svd_on_activations.py
│ └── xgrad_cam.py
└── utils.py
├── TorchRay
├── CHANGELOG.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── MANIFEST.in
├── Makefile
├── README.md
├── docs
│ ├── attribution.rst
│ ├── benchmark.rst
│ ├── conf.py
│ ├── index.rst
│ ├── make-docs
│ ├── pointing.csv
│ ├── static
│ │ └── css
│ │ │ └── equations.css
│ └── utils.rst
├── examples
│ ├── __init__.py
│ ├── __main__.py
│ ├── attribution_benchmark.py
│ ├── attribution_benchmark_debug.py
│ ├── contrastive_excitation_backprop.py
│ ├── contrastive_excitation_backprop_manual.py
│ ├── deconvnet.py
│ ├── deconvnet_manual.py
│ ├── excitation_backprop.py
│ ├── excitation_backprop_manual.py
│ ├── extremal_perturbation.py
│ ├── grad_cam.py
│ ├── grad_cam_manual.py
│ ├── gradient.py
│ ├── gradient_manual.py
│ ├── guided_backprop.py
│ ├── guided_backprop_manual.py
│ ├── linear_approx.py
│ ├── linear_approx_manual.py
│ └── rise.py
├── packaging
│ └── meta.yaml
├── scripts
│ └── torchrayrc
├── setup.py
└── torchray
│ ├── VERSION
│ ├── __init__.py
│ ├── attribution
│ ├── __init__.py
│ ├── common.py
│ ├── deconvnet.py
│ ├── excitation_backprop.py
│ ├── extremal_perturbation.py
│ ├── grad_cam.py
│ ├── gradient.py
│ ├── guided_backprop.py
│ ├── linear_approx.py
│ ├── resnet_maxpool.py
│ └── rise.py
│ ├── benchmark
│ ├── __init__.py
│ ├── datasets.py
│ ├── evaluate_finegrained_gradcam_energy_inside_bbox.py
│ ├── evaluate_imagenet_excitation_backprop_energy_inside_bbox.py
│ ├── evaluate_imagenet_gradcam_energy_inside_bbox.py
│ ├── evaluate_swav_imagenet_gradcam_energy_inside_bbox.py
│ ├── imagenet_classes.txt
│ ├── logging_mongo.py
│ ├── models.py
│ ├── pointing_game.py
│ ├── resnet_multigpu_cgc.py
│ ├── server.py
│ ├── swav_resnet_cgc.py
│ └── vision.py
│ └── utils.py
├── baseline_train_eval.py
├── datasets
├── imagefolder_cgc_ssl.py
└── vision.py
├── misc
└── teaser_image.png
├── models
├── resnet_multigpu_cgc.py
├── swav_resnet_cgc.py
└── utils.py
├── train_eval_cgc.py
├── train_eval_gc_l2_no_contrast_baseline.py
└── train_imagenet_1pc_swav_cgc_unlabeled.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 UCDvision
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Consistent-Explanations-by-Contrastive-Learning
2 | Official PyTorch code for CVPR 2022 paper - [Consistent Explanations by Contrastive Learning][1]
3 |
4 |
5 | Post-hoc explanation methods, e.g., Grad-CAM, enable humans to inspect the spatial regions responsible for a particular network decision. However, it is shown that such explanations are not always consistent with human priors, such as consistency across image transformations. Given an interpretation algorithm, e.g., Grad-CAM, we introduce a novel training method to train the model to produce more consistent explanations. Since obtaining the ground truth for a desired model interpretation is not a well-defined task, we adopt ideas from contrastive self-supervised learning, and apply them to the interpretations of the model rather than its embeddings. We show that our method, Contrastive Grad-CAM Consistency (CGC), results in Grad-CAM interpretation heatmaps that are more consistent with human annotations while still achieving comparable classification accuracy. Moreover, our method acts as a regularizer and improves the accuracy on limited-data, fine-grained classification settings. In addition, because our method does not rely on annotations, it allows for the incorporation of unlabeled data into training, which enables better generalization of the model.
6 |
7 | ![Teaser image][teaser]
8 |
9 |
10 |
11 | ## Bibtex
12 | ```
13 | @InProceedings{Pillai_2022_CVPR,
14 | author = {Pillai, Vipin and Abbasi Koohpayegani, Soroush and Ouligian, Ashley and Fong, Dennis and Pirsiavash, Hamed},
15 | title = {Consistent Explanations by Contrastive Learning},
16 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
17 | month = {June},
18 | year = {2022}
19 | }
20 | ```
21 |
22 | ## Pre-requisites
23 | - Pytorch 1.3 - Please install [PyTorch](https://pytorch.org/get-started/locally/) and CUDA if you don't have it installed.
24 |
25 | ## Datasets
26 | - [ImageNet - 1K](https://www.image-net.org/download.php)
27 | - [CUB-200](https://vision.cornell.edu/se3/caltech-ucsd-birds-200/)
28 | - [FGVC-Aircraft](https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/)
29 | - [Stanford Cars-196](https://ai.stanford.edu/~jkrause/cars/car_dataset.html)
30 | - [VGG Flowers-102](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/)
31 |
32 | ## Training
33 |
34 | #### Train and evaluate a ResNet50 model on the ImageNet dataset using our CGC loss
35 | ```
36 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train_eval_cgc.py /datasets/imagenet -a resnet50 -p 100 -j 8 -b 256 --lr 0.1 --lambda 0.5 -t 0.5 --save_dir --log_dir
37 | ```
38 |
39 | #### Train and evaluate a ResNet50 model on 1pc labeled subset of ImageNet dataset and the rest as unlabeled dataset. We initialize the model from SwAV
40 | For the below command, can be downloaded from the github directory of SwAV - https://github.com/facebookresearch/swav
41 | We use the model checkpoint listed on the first row (800 epochs, 75.3% ImageNet top-1 acc.) of the Model Zoo of the above repository.
42 |
43 | ```
44 | CUDA_VISIBLE_DEVICES=0,1 python train_imagenet_1pc_swav_cgc_unlabeled.py -a resnet50 -b 128 -j 8 --lambda 0.25 -t 0.5 --epochs 50 --lr 0.02 --lr_last_layer 5 --resume --save_dir --log_dir 2>&1 | tee
45 | ```
46 |
47 |
48 |
49 | ## Checkpoints
50 | * ResNet50 model pre-trained on ImageNet - [link](https://drive.google.com/drive/folders/1n7lFew0CdWuYCpR1kImMt7UC7_vsO5CT?usp=sharing)
51 |
52 | ## Evaluation
53 |
54 | #### Evaluate model checkpoint using Content Heatmap (CH) evaluation metric
55 | We use the evaluation code adapted from the TorchRay framework.
56 | * Change directory to TorchRay and install the library. Please refer to the [TorchRay](https://github.com/facebookresearch/TorchRay) repository for full documentation and instructions.
57 | * cd TorchRay
58 | * python setup.py install
59 |
60 | * Change directory to TorchRay/torchray/benchmark
61 | * cd torchray/benchmark
62 |
63 | * For the ImageNet & CUB-200 datasets, this evaluation requires the following structure for validation images and bounding box xml annotations
64 | * /val/*.JPEG - Flat list of validation images
65 | * /annotation/*.xml - Flat list of annotation xml files
66 |
67 | ##### Evaluate ResNet50 models trained on the full ImageNet dataset
68 | ```
69 | CUDA_VISIBLE_DEVICES=0 python evaluate_imagenet_gradcam_energy_inside_bbox.py -j 0 -b 1 --resume --input_resize 448 -a resnet50
70 | ```
71 |
72 | ##### Evaluate ResNet50 models trained on the CUB-200 fine-grained dataset
73 | ```
74 | CUDA_VISIBLE_DEVICES=0 python evaluate_finegrained_gradcam_energy_inside_bbox.py --dataset cub -j 0 -b 1 --resume --input_resize 448 -a resnet50
75 | ```
76 |
77 | ##### Evaluate ResNet50 models trained from SwAV initialized models with 1pc labeled subset of ImageNet and rest as unlabeled
78 | ```
79 | CUDA_VISIBLE_DEVICES=0 python evaluate_swav_imagenet_gradcam_energy_inside_bbox.py -j 0 -b 1 --resume --input_resize 448 -a resnet50
80 | ```
81 |
82 |
83 |
84 | #### Evaluate model checkpoint using Insertion AUC (IAUC) evaluation metric
85 | Change to directory RISE/ and follow the below commands:
86 |
87 | ##### Evaluate pre-trained ResNet50 model
88 | ```
89 | CUDA_VISIBLE_DEVICES=0 python evaluate_auc_metrics.py --pretrained
90 | ```
91 |
92 | ##### Evaluate ResNet50 model trained using our CGC method
93 | ```
94 | CUDA_VISIBLE_DEVICES=0 python evaluate_auc_metrics.py --ckpt-path
95 | ```
96 |
97 |
98 |
99 | ## License
100 | This project is licensed under the MIT License.
101 |
102 | [1]: https://arxiv.org/pdf/2110.00527.pdf
103 | [teaser]: https://github.com/UMBCvision/Consistent-Explanations-by-Contrastive-Learning/blob/main/misc/teaser_image.png
104 |
--------------------------------------------------------------------------------
/RISE/evaluate_auc_metrics.py:
--------------------------------------------------------------------------------
1 | ## Code to evaluate the pre-trained model and our CGC trained model with Insertion AUC score.
2 |
3 | import numpy as np
4 | from tqdm import tqdm
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.backends.cudnn as cudnn
9 | import torchvision.datasets as datasets
10 | import torchvision.models as models
11 | import torch.nn.functional as F
12 | from utils import *
13 | from evaluation import CausalMetric, auc, gkern
14 | from pytorch_grad_cam import GradCAM
15 | import argparse
16 |
17 | parser = argparse.ArgumentParser(description='PyTorch AUC Metric Evaluation')
18 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
19 | help='use pre-trained model')
20 | parser.add_argument('--ckpt-path', dest='ckpt_path', type=str, help='path to checkpoint file')
21 |
22 | def main():
23 | args = parser.parse_args()
24 |
25 | cudnn.benchmark = True
26 |
27 | scores = {'del': [], 'ins': []}
28 | if args.pretrained:
29 | net = models.resnet50(pretrained=True)
30 | else:
31 | net = models.resnet50()
32 | state_dict = torch.load(args.ckpt_path)['state_dict']
33 |
34 | # remove the module prefix if model was saved with DataParallel
35 | state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
36 | # load params
37 | net.load_state_dict(state_dict)
38 |
39 | target_layer = net.layer4[-1]
40 | cam = GradCAM(model=net, target_layer=target_layer, use_cuda=True)
41 |
42 | # we process the imagenet 50k val images in 10 set of 5k each and compute mean
43 | for i in range(10):
44 | auc_score = get_auc_per_data_subset(i, net, cam)
45 | scores['ins'].append(auc_score)
46 | print('Finished evaluating the insertion metrics...')
47 |
48 | print('----------------------------------------------------------------')
49 | print('Final:\nInsertion - {:.5f}'.format(np.mean(scores['ins'])))
50 |
51 |
52 | def get_auc_per_data_subset(range_index, net, cam):
53 | batch_size = 100
54 | data_loader = torch.utils.data.DataLoader(
55 | dataset=datasets.ImageFolder('/nfs3/datasets/imagenet/val/', preprocess),
56 | batch_size=batch_size, shuffle=False,
57 | num_workers=8, pin_memory=True, sampler=RangeSampler(range(5000 * range_index, 5000 * (range_index + 1))))
58 |
59 | net = net.train()
60 |
61 | images = []
62 | targets = []
63 | gcam_exp = []
64 |
65 | for j, (img, trg) in enumerate(tqdm(data_loader, total=len(data_loader), desc='Loading images')):
66 | grayscale_gradcam = cam(input_tensor=img, target_category=trg)
67 | for k in range(batch_size):
68 | images.append(img[k])
69 | targets.append(trg[k])
70 | gcam_exp.append(grayscale_gradcam[k])
71 |
72 | images = torch.stack(images).cpu().numpy()
73 | gcam_exp = np.stack(gcam_exp)
74 | images = np.asarray(images)
75 | gcam_exp = np.asarray(gcam_exp)
76 |
77 | images = images.reshape((-1, 3, 224, 224))
78 | gcam_exp = gcam_exp.reshape((-1, 224, 224))
79 | print('Finished obtaining CAM')
80 |
81 | model = nn.Sequential(net, nn.Softmax(dim=1))
82 | model = model.eval()
83 | model = model.cuda()
84 |
85 | for p in model.parameters():
86 | p.requires_grad = False
87 |
88 | # To use multiple GPUs
89 | ddp_model = nn.DataParallel(model)
90 |
91 | # we use blur as the substrate function
92 | klen = 11
93 | ksig = 5
94 | kern = gkern(klen, ksig)
95 | # Function that blurs input image
96 | blur = lambda x: F.conv2d(x, kern, padding=klen // 2)
97 |
98 | insertion = CausalMetric(ddp_model, 'ins', 224 * 8, substrate_fn=blur)
99 |
100 | # Evaluate insertion
101 | h = insertion.evaluate(torch.from_numpy(images.astype('float32')), gcam_exp, batch_size)
102 |
103 | model = model.train()
104 | for p in model.parameters():
105 | p.requires_grad = True
106 |
107 | return auc(h.mean(1))
108 |
109 |
110 | if __name__ == '__main__':
111 | main()
112 |
--------------------------------------------------------------------------------
/RISE/evaluation.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from tqdm import tqdm
3 | from scipy.ndimage.filters import gaussian_filter
4 |
5 | from utils import *
6 |
7 | HW = 224 * 224 # image area
8 | n_classes = 1000
9 |
10 | def gkern(klen, nsig):
11 | """Returns a Gaussian kernel array.
12 | Convolution with it results in image blurring."""
13 | # create nxn zeros
14 | inp = np.zeros((klen, klen))
15 | # set element at the middle to one, a dirac delta
16 | inp[klen//2, klen//2] = 1
17 | # gaussian-smooth the dirac, resulting in a gaussian filter mask
18 | k = gaussian_filter(inp, nsig)
19 | kern = np.zeros((3, 3, klen, klen))
20 | kern[0, 0] = k
21 | kern[1, 1] = k
22 | kern[2, 2] = k
23 | return torch.from_numpy(kern.astype('float32'))
24 |
25 | def auc(arr):
26 | """Returns normalized Area Under Curve of the array."""
27 | return (arr.sum() - arr[0] / 2 - arr[-1] / 2) / (arr.shape[0] - 1)
28 |
29 | class CausalMetric():
30 |
31 | def __init__(self, model, mode, step, substrate_fn):
32 | r"""Create deletion/insertion metric instance.
33 |
34 | Args:
35 | model (nn.Module): Black-box model being explained.
36 | mode (str): 'del' or 'ins'.
37 | step (int): number of pixels modified per one iteration.
38 | substrate_fn (func): a mapping from old pixels to new pixels.
39 | """
40 | assert mode in ['del', 'ins']
41 | self.model = model
42 | self.mode = mode
43 | self.step = step
44 | self.substrate_fn = substrate_fn
45 |
46 | def single_run(self, img_tensor, explanation, verbose=0, save_to=None):
47 | r"""Run metric on one image-saliency pair.
48 |
49 | Args:
50 | img_tensor (Tensor): normalized image tensor.
51 | explanation (np.ndarray): saliency map.
52 | verbose (int): in [0, 1, 2].
53 | 0 - return list of scores.
54 | 1 - also plot final step.
55 | 2 - also plot every step and print 2 top classes.
56 | save_to (str): directory to save every step plots to.
57 |
58 | Return:
59 | scores (nd.array): Array containing scores at every step.
60 | """
61 | pred = self.model(img_tensor.cuda())
62 | top, c = torch.max(pred, 1)
63 | c = c.cpu().numpy()[0]
64 | n_steps = (HW + self.step - 1) // self.step
65 |
66 | if self.mode == 'del':
67 | title = 'Deletion game'
68 | ylabel = 'Pixels deleted'
69 | start = img_tensor.clone()
70 | finish = self.substrate_fn(img_tensor)
71 | elif self.mode == 'ins':
72 | title = 'Insertion game'
73 | ylabel = 'Pixels inserted'
74 | start = self.substrate_fn(img_tensor)
75 | finish = img_tensor.clone()
76 |
77 | scores = np.empty(n_steps + 1)
78 | # Coordinates of pixels in order of decreasing saliency
79 | salient_order = np.flip(np.argsort(explanation.reshape(-1, HW), axis=1), axis=-1)
80 | for i in range(n_steps+1):
81 | pred = self.model(start.cuda())
82 | pr, cl = torch.topk(pred, 2)
83 | if verbose == 2:
84 | print('{}: {:.3f}'.format(get_class_name(cl[0][0]), float(pr[0][0])))
85 | print('{}: {:.3f}'.format(get_class_name(cl[0][1]), float(pr[0][1])))
86 | scores[i] = pred[0, c]
87 | # Render image if verbose, if it's the last step or if save is required.
88 | if verbose == 2 or (verbose == 1 and i == n_steps) or save_to:
89 | plt.figure(figsize=(10, 5))
90 | plt.subplot(121)
91 | plt.title('{} {:.1f}%, P={:.4f}'.format(ylabel, 100 * i / n_steps, scores[i]))
92 | plt.axis('off')
93 | tensor_imshow(start[0])
94 |
95 | plt.subplot(122)
96 | plt.plot(np.arange(i+1) / n_steps, scores[:i+1])
97 | plt.xlim(-0.1, 1.1)
98 | plt.ylim(0, 1.05)
99 | plt.fill_between(np.arange(i+1) / n_steps, 0, scores[:i+1], alpha=0.4)
100 | plt.title(title)
101 | plt.xlabel(ylabel)
102 | plt.ylabel(get_class_name(c))
103 | if save_to:
104 | plt.savefig(save_to + '/{:03d}.png'.format(i))
105 | plt.close()
106 | else:
107 | plt.show()
108 | if i < n_steps:
109 | coords = salient_order[:, self.step * i:self.step * (i + 1)]
110 | start.cpu().numpy().reshape(1, 3, HW)[0, :, coords] = finish.cpu().numpy().reshape(1, 3, HW)[0, :, coords]
111 | return scores
112 |
113 | def evaluate(self, img_batch, exp_batch, batch_size):
114 | r"""Efficiently evaluate big batch of images.
115 |
116 | Args:
117 | img_batch (Tensor): batch of images.
118 | exp_batch (np.ndarray): batch of explanations.
119 | batch_size (int): number of images for one small batch.
120 |
121 | Returns:
122 | scores (nd.array): Array containing scores at every step for every image.
123 | """
124 | n_samples = img_batch.shape[0]
125 | predictions = torch.FloatTensor(n_samples, n_classes)
126 | # assert n_samples % batch_size == 0
127 | for i in tqdm(range(n_samples // batch_size), desc='Predicting labels'):
128 | preds = self.model(img_batch[i*batch_size:(i+1)*batch_size].cuda()).cpu()
129 | predictions[i*batch_size:(i+1)*batch_size] = preds
130 | if n_samples % batch_size != 0:
131 | start_index = n_samples // batch_size
132 | preds = self.model(img_batch[start_index*batch_size:].cuda()).cpu()
133 | predictions[start_index*batch_size:] = preds
134 |
135 | top = np.argmax(predictions, -1)
136 | n_steps = (HW + self.step - 1) // self.step
137 | scores = np.empty((n_steps + 1, n_samples))
138 | salient_order = np.flip(np.argsort(exp_batch.reshape(-1, HW), axis=1), axis=-1)
139 | r = np.arange(n_samples).reshape(n_samples, 1)
140 |
141 | substrate = torch.zeros_like(img_batch)
142 | for j in tqdm(range(n_samples // batch_size), desc='Substrate'):
143 | substrate[j*batch_size:(j+1)*batch_size] = self.substrate_fn(img_batch[j*batch_size:(j+1)*batch_size])
144 |
145 | if n_samples % batch_size != 0:
146 | start_index = n_samples // batch_size
147 | substrate[start_index*batch_size:] = self.substrate_fn(img_batch[start_index*batch_size:])
148 |
149 | if self.mode == 'del':
150 | caption = 'Deleting '
151 | start = img_batch.clone()
152 | finish = substrate
153 | elif self.mode == 'ins':
154 | caption = 'Inserting '
155 | start = substrate
156 | finish = img_batch.clone()
157 |
158 | # While not all pixels are changed
159 | for i in tqdm(range(n_steps+1), desc=caption + 'pixels'):
160 | # Iterate over batches
161 | for j in range(n_samples // batch_size):
162 | # Compute new scores
163 | preds = self.model(start[j*batch_size:(j+1)*batch_size].cuda())
164 | preds = preds.cpu().numpy()[range(batch_size), top[j*batch_size:(j+1)*batch_size]]
165 | scores[i, j*batch_size:(j+1)*batch_size] = preds
166 | if n_samples % batch_size != 0:
167 | start_index = n_samples // batch_size
168 | preds = self.model(start[start_index*batch_size:].cuda())
169 | preds = preds.cpu().numpy()[range(n_samples%batch_size), top[start_index*batch_size:]]
170 | scores[i, start_index*batch_size:] = preds
171 |
172 | # Change specified number of most salient pixels to substrate pixels
173 | coords = salient_order[:, self.step * i:self.step * (i + 1)]
174 | start.cpu().numpy().reshape(n_samples, 3, HW)[r, :, coords] = finish.cpu().numpy().reshape(n_samples, 3, HW)[r, :, coords]
175 | print('AUC: {}'.format(auc(scores.mean(1))))
176 | return scores
177 |
--------------------------------------------------------------------------------
/RISE/pytorch_grad_cam/__init__.py:
--------------------------------------------------------------------------------
1 | from pytorch_grad_cam.grad_cam import GradCAM
2 | from pytorch_grad_cam.ablation_cam import AblationCAM
3 | from pytorch_grad_cam.xgrad_cam import XGradCAM
4 | from pytorch_grad_cam.grad_cam_plusplus import GradCAMPlusPlus
5 | from pytorch_grad_cam.score_cam import ScoreCAM
6 | from pytorch_grad_cam.eigen_cam import EigenCAM
7 | from pytorch_grad_cam.eigen_grad_cam import EigenGradCAM
8 | from pytorch_grad_cam.guided_backprop import GuidedBackpropReLUModel
--------------------------------------------------------------------------------
/RISE/pytorch_grad_cam/ablation_cam.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | import tqdm
5 | from pytorch_grad_cam.base_cam import BaseCAM
6 |
7 | class AblationLayer(torch.nn.Module):
8 | def __init__(self, layer, reshape_transform, indices):
9 | super(AblationLayer, self).__init__()
10 |
11 | self.layer = layer
12 | self.reshape_transform = reshape_transform
13 | # The channels to zero out:
14 | self.indices = indices
15 |
16 | def forward(self, x):
17 | self.__call__(x)
18 |
19 | def __call__(self, x):
20 | output = self.layer(x)
21 |
22 | # Hack to work with ViT,
23 | # Since the activation channels are last and not first like in CNNs
24 | # Probably should remove it?
25 | if self.reshape_transform is not None:
26 | output = output.transpose(1, 2)
27 |
28 | for i in range(output.size(0)):
29 |
30 | # Commonly the minimum activation will be 0,
31 | # And then it makes sense to zero it out.
32 | # However depending on the architecture,
33 | # If the values can be negative, we use very negative values
34 | # to perform the ablation, deviating from the paper.
35 | if torch.min(output) == 0:
36 | output[i, self.indices[i], :] = 0
37 | else:
38 | ABLATION_VALUE = 1e5
39 | output[i, self.indices[i], :] = torch.min(output) - ABLATION_VALUE
40 |
41 | if self.reshape_transform is not None:
42 | output = output.transpose(2, 1)
43 |
44 | return output
45 |
46 | def replace_layer_recursive(model, old_layer, new_layer):
47 | for name, layer in model._modules.items():
48 | if layer == old_layer:
49 | model._modules[name] = new_layer
50 | return True
51 | elif replace_layer_recursive(layer, old_layer, new_layer):
52 | return True
53 | return False
54 |
55 | class AblationCAM(BaseCAM):
56 | def __init__(self, model, target_layer, use_cuda=False,
57 | reshape_transform=None):
58 | super(AblationCAM, self).__init__(model, target_layer, use_cuda,
59 | reshape_transform)
60 |
61 | def get_cam_weights(self,
62 | input_tensor,
63 | target_category,
64 | activations,
65 | grads):
66 | with torch.no_grad():
67 | outputs = self.model(input_tensor).cpu().numpy()
68 | original_scores = []
69 | for i in range(input_tensor.size(0)):
70 | original_scores.append(outputs[i, target_category[i]])
71 | original_scores = np.float32(original_scores)
72 |
73 | ablation_layer = AblationLayer(self.target_layer,
74 | self.reshape_transform, indices=[])
75 | replace_layer_recursive(self.model, self.target_layer, ablation_layer)
76 |
77 |
78 | if hasattr(self, "batch_size"):
79 | BATCH_SIZE = self.batch_size
80 | else:
81 | BATCH_SIZE = 32
82 |
83 | number_of_channels = activations.shape[1]
84 | weights = []
85 |
86 | with torch.no_grad():
87 |
88 | # Iterate over the input batch
89 | for tensor, category in zip(input_tensor, target_category):
90 | batch_tensor = tensor.repeat(BATCH_SIZE, 1, 1, 1)
91 | for i in tqdm.tqdm(range(0, number_of_channels, BATCH_SIZE)):
92 | ablation_layer.indices = list(range(i, i + BATCH_SIZE))
93 |
94 | if i + BATCH_SIZE > number_of_channels:
95 | keep = number_of_channels - i
96 | batch_tensor = batch_tensor[:keep]
97 | ablation_layer.indices = ablation_layer.indices[:keep]
98 | score = self.model(batch_tensor)[:, category].cpu().numpy()
99 | weights.extend(score)
100 |
101 | weights = np.float32(weights)
102 | weights = weights.reshape(activations.shape[:2])
103 | original_scores = original_scores[:, None]
104 | weights = (original_scores - weights) / original_scores
105 |
106 | #replace the model back to the original state
107 | replace_layer_recursive(self.model, ablation_layer, self.target_layer)
108 | return weights
--------------------------------------------------------------------------------
/RISE/pytorch_grad_cam/activations_and_gradients.py:
--------------------------------------------------------------------------------
1 | class ActivationsAndGradients:
2 | """ Class for extracting activations and
3 | registering gradients from targetted intermediate layers """
4 |
5 | def __init__(self, model, target_layer, reshape_transform):
6 | self.model = model
7 | self.gradients = []
8 | self.activations = []
9 | self.reshape_transform = reshape_transform
10 |
11 | target_layer.register_forward_hook(self.save_activation)
12 | target_layer.register_backward_hook(self.save_gradient)
13 |
14 | def save_activation(self, module, input, output):
15 | activation = output
16 | if self.reshape_transform is not None:
17 | activation = self.reshape_transform(activation)
18 | self.activations.append(activation.cpu().detach())
19 |
20 | def save_gradient(self, module, grad_input, grad_output):
21 | # Gradients are computed in reverse order
22 | grad = grad_output[0]
23 | if self.reshape_transform is not None:
24 | grad = self.reshape_transform(grad)
25 | self.gradients = [grad.cpu().detach()] + self.gradients
26 |
27 | def __call__(self, x):
28 | self.gradients = []
29 | self.activations = []
30 | return self.model(x)
--------------------------------------------------------------------------------
/RISE/pytorch_grad_cam/base_cam.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | import ttach as tta
5 | from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients
6 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
7 |
8 |
9 | class BaseCAM:
10 | def __init__(self,
11 | model,
12 | target_layer,
13 | use_cuda=False,
14 | reshape_transform=None):
15 | self.model = model.eval()
16 | self.target_layer = target_layer
17 | self.cuda = use_cuda
18 | if self.cuda:
19 | self.model = model.cuda()
20 | self.reshape_transform = reshape_transform
21 | self.activations_and_grads = ActivationsAndGradients(self.model,
22 | target_layer, reshape_transform)
23 |
24 | def forward(self, input_img):
25 | return self.model(input_img)
26 |
27 | def get_cam_weights(self,
28 | input_tensor,
29 | target_category,
30 | activations,
31 | grads):
32 | raise Exception("Not Implemented")
33 |
34 | def get_loss(self, output, target_category):
35 | loss = 0
36 | for i in range(len(target_category)):
37 | loss = loss + output[i, target_category[i]]
38 | return loss
39 |
40 | def get_cam_image(self,
41 | input_tensor,
42 | target_category,
43 | activations,
44 | grads,
45 | eigen_smooth=False):
46 | weights = self.get_cam_weights(input_tensor, target_category, activations, grads)
47 | weighted_activations = weights[:, :, None, None] * activations
48 | if eigen_smooth:
49 | cam = get_2d_projection(weighted_activations)
50 | else:
51 | # import pdb
52 | # pdb.set_trace()
53 | cam = weighted_activations.sum(axis=1)
54 | return cam
55 |
56 | def forward(self, input_tensor, target_category=None, eigen_smooth=False):
57 |
58 | if self.cuda:
59 | input_tensor = input_tensor.cuda()
60 |
61 | output = self.activations_and_grads(input_tensor)
62 |
63 | if type(target_category) is int:
64 | target_category = [target_category] * input_tensor.size(0)
65 |
66 | if target_category is None:
67 | target_category = np.argmax(output.cpu().data.numpy(), axis=-1)
68 | else:
69 | assert(len(target_category) == input_tensor.size(0))
70 |
71 | self.model.zero_grad()
72 | loss = self.get_loss(output, target_category)
73 | loss.backward(retain_graph=True)
74 |
75 | activations = self.activations_and_grads.activations[-1].cpu().data.numpy()
76 | grads = self.activations_and_grads.gradients[-1].cpu().data.numpy()
77 |
78 | cam = self.get_cam_image(input_tensor, target_category,
79 | activations, grads, eigen_smooth)
80 |
81 | cam = np.maximum(cam, 0)
82 |
83 | result = []
84 | for img in cam:
85 | img = cv2.resize(img, input_tensor.shape[-2:][::-1])
86 | img = img - np.min(img)
87 | img = img / (np.max(img) + 1e-8)
88 | result.append(img)
89 | result = np.float32(result)
90 | return result
91 |
92 | def forward_augmentation_smoothing(self,
93 | input_tensor,
94 | target_category=None,
95 | eigen_smooth=False):
96 | transforms = tta.Compose(
97 | [
98 | tta.HorizontalFlip(),
99 | tta.Multiply(factors=[0.9, 1, 1.1]),
100 | ]
101 | )
102 | cams = []
103 | for transform in transforms:
104 | augmented_tensor = transform.augment_image(input_tensor)
105 | cam = self.forward(augmented_tensor,
106 | target_category, eigen_smooth)
107 |
108 | # The ttach library expects a tensor of size BxCxHxW
109 | cam = cam[:, None, :, :]
110 | cam = torch.from_numpy(cam)
111 | cam = transform.deaugment_mask(cam)
112 |
113 | # Back to numpy float32, HxW
114 | cam = cam.numpy()
115 | cam = cam[:, 0, :, :]
116 | cams.append(cam)
117 |
118 | cam = np.mean(np.float32(cams), axis=0)
119 | return cam
120 |
121 | def __call__(self,
122 | input_tensor,
123 | target_category=None,
124 | aug_smooth=False,
125 | eigen_smooth=False):
126 | if aug_smooth is True:
127 | return self.forward_augmentation_smoothing(input_tensor,
128 | target_category, eigen_smooth)
129 |
130 | return self.forward(input_tensor,
131 | target_category, eigen_smooth)
--------------------------------------------------------------------------------
/RISE/pytorch_grad_cam/eigen_cam.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | from pytorch_grad_cam.base_cam import BaseCAM
5 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
6 |
7 | # https://arxiv.org/abs/2008.00299
8 | class EigenCAM(BaseCAM):
9 | def __init__(self, model, target_layer, use_cuda=False,
10 | reshape_transform=None):
11 | super(EigenCAM, self).__init__(model, target_layer, use_cuda,
12 | reshape_transform)
13 |
14 | def get_cam_image(self,
15 | input_tensor,
16 | target_category,
17 | activations,
18 | grads,
19 | eigen_smooth):
20 | return get_2d_projection(activations)
21 |
--------------------------------------------------------------------------------
/RISE/pytorch_grad_cam/eigen_grad_cam.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | from pytorch_grad_cam.base_cam import BaseCAM
5 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
6 |
7 | # Like Eigen CAM: https://arxiv.org/abs/2008.00299
8 | # But multiply the activations x gradients
9 | class EigenGradCAM(BaseCAM):
10 | def __init__(self, model, target_layer, use_cuda=False,
11 | reshape_transform=None):
12 | super(EigenGradCAM, self).__init__(model, target_layer, use_cuda,
13 | reshape_transform)
14 |
15 | def get_cam_image(self,
16 | input_tensor,
17 | target_category,
18 | activations,
19 | grads,
20 | eigen_smooth):
21 | return get_2d_projection(grads*activations)
--------------------------------------------------------------------------------
/RISE/pytorch_grad_cam/grad_cam.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | from pytorch_grad_cam.base_cam import BaseCAM
5 |
6 |
7 | class GradCAM(BaseCAM):
8 | def __init__(self, model, target_layer, use_cuda=False, reshape_transform=None):
9 | super(GradCAM, self).__init__(model, target_layer, use_cuda, reshape_transform)
10 |
11 | def get_cam_weights(self,
12 | input_tensor,
13 | target_category,
14 | activations,
15 | grads):
16 | return np.mean(grads, axis=(2, 3))
17 |
--------------------------------------------------------------------------------
/RISE/pytorch_grad_cam/grad_cam_plusplus.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | from pytorch_grad_cam.base_cam import BaseCAM
5 |
6 | class GradCAMPlusPlus(BaseCAM):
7 | def __init__(self, model, target_layer, use_cuda=False,
8 | reshape_transform=None):
9 | super(GradCAMPlusPlus, self).__init__(model, target_layer, use_cuda,
10 | reshape_transform)
11 |
12 | def get_cam_weights(self, input_tensor,
13 | target_category,
14 | activations,
15 | grads):
16 | grads_power_2 = grads**2
17 | grads_power_3 = grads_power_2*grads
18 | # Equation 19 in https://arxiv.org/abs/1710.11063
19 | sum_activations = np.sum(activations, axis=(2, 3))
20 | eps = 0.000001
21 | aij = grads_power_2 / (2*grads_power_2 +
22 | sum_activations[:, :, None, None]*grads_power_3 + eps)
23 | # Now bring back the ReLU from eq.7 in the paper,
24 | # And zero out aijs where the activations are 0
25 | aij = np.where(grads != 0, aij, 0)
26 |
27 | weights = np.maximum(grads, 0)*aij
28 | weights = np.sum(weights, axis=(2, 3))
29 | return weights
--------------------------------------------------------------------------------
/RISE/pytorch_grad_cam/guided_backprop.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.autograd import Function
4 |
5 |
6 | class GuidedBackpropReLU(Function):
7 | @staticmethod
8 | def forward(self, input_img):
9 | positive_mask = (input_img > 0).type_as(input_img)
10 | output = torch.addcmul(torch.zeros(input_img.size()).type_as(input_img), input_img, positive_mask)
11 | self.save_for_backward(input_img, output)
12 | return output
13 |
14 | @staticmethod
15 | def backward(self, grad_output):
16 | input_img, output = self.saved_tensors
17 | grad_input = None
18 |
19 | positive_mask_1 = (input_img > 0).type_as(grad_output)
20 | positive_mask_2 = (grad_output > 0).type_as(grad_output)
21 | grad_input = torch.addcmul(torch.zeros(input_img.size()).type_as(input_img),
22 | torch.addcmul(torch.zeros(input_img.size()).type_as(input_img), grad_output,
23 | positive_mask_1), positive_mask_2)
24 | return grad_input
25 |
26 |
27 | class GuidedBackpropReLUModel:
28 | def __init__(self, model, use_cuda):
29 | self.model = model
30 | self.model.eval()
31 | self.cuda = use_cuda
32 | if self.cuda:
33 | self.model = self.model.cuda()
34 |
35 | def forward(self, input_img):
36 | return self.model(input_img)
37 |
38 | def recursive_replace_relu_with_guidedrelu(self, module_top):
39 | for idx, module in module_top._modules.items():
40 | self.recursive_replace_relu_with_guidedrelu(module)
41 | if module.__class__.__name__ == 'ReLU':
42 | module_top._modules[idx] = GuidedBackpropReLU.apply
43 |
44 | def recursive_replace_guidedrelu_with_relu(self, module_top):
45 | try:
46 | for idx, module in module_top._modules.items():
47 | self.recursive_replace_guidedrelu_with_relu(module)
48 | if module == GuidedBackpropReLU.apply:
49 | module_top._modules[idx] = torch.nn.ReLU()
50 | except:
51 | pass
52 |
53 |
54 | def __call__(self, input_img, target_category=None):
55 | # replace ReLU with GuidedBackpropReLU
56 | self.recursive_replace_relu_with_guidedrelu(self.model)
57 |
58 | if self.cuda:
59 | input_img = input_img.cuda()
60 |
61 | input_img = input_img.requires_grad_(True)
62 |
63 | output = self.forward(input_img)
64 |
65 | if target_category is None:
66 | target_category = np.argmax(output.cpu().data.numpy())
67 |
68 | loss = output[0, target_category]
69 | loss.backward(retain_graph=True)
70 |
71 | output = input_img.grad.cpu().data.numpy()
72 | output = output[0, :, :, :]
73 | output = output.transpose((1, 2, 0))
74 |
75 | # replace GuidedBackpropReLU back with ReLU
76 | self.recursive_replace_guidedrelu_with_relu(self.model)
77 |
78 | return output
79 |
--------------------------------------------------------------------------------
/RISE/pytorch_grad_cam/score_cam.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | import tqdm
5 | from pytorch_grad_cam.base_cam import BaseCAM
6 |
7 | class ScoreCAM(BaseCAM):
8 | def __init__(self, model, target_layer, use_cuda=False, reshape_transform=None):
9 | super(ScoreCAM, self).__init__(model, target_layer, use_cuda,
10 | reshape_transform=reshape_transform)
11 |
12 | def get_cam_weights(self,
13 | input_tensor,
14 | target_category,
15 | activations,
16 | grads):
17 | with torch.no_grad():
18 | upsample = torch.nn.UpsamplingBilinear2d(size=input_tensor.shape[-2 : ])
19 | activation_tensor = torch.from_numpy(activations)
20 | if self.cuda:
21 | activation_tensor = activation_tensor.cuda()
22 |
23 | upsampled = upsample(activation_tensor)
24 |
25 | maxs = upsampled.view(upsampled.size(0),
26 | upsampled.size(1), -1).max(dim=-1)[0]
27 | mins = upsampled.view(upsampled.size(0),
28 | upsampled.size(1), -1).min(dim=-1)[0]
29 | maxs, mins = maxs[:, :, None, None], mins[:, :, None, None]
30 | upsampled = (upsampled - mins) / (maxs - mins)
31 |
32 | input_tensors = input_tensor[:, None, :, :]*upsampled[:, :, None, :, :]
33 |
34 | if hasattr(self, "batch_size"):
35 | BATCH_SIZE = self.batch_size
36 | else:
37 | BATCH_SIZE = 16
38 |
39 | scores = []
40 | for batch_index, tensor in enumerate(input_tensors):
41 | category = target_category[batch_index]
42 | for i in tqdm.tqdm(range(0, tensor.size(0), BATCH_SIZE)):
43 | batch = tensor[i : i + BATCH_SIZE, :]
44 | outputs = self.model(batch).cpu().numpy()[:, category]
45 | scores.extend(outputs)
46 | scores = torch.Tensor(scores)
47 | scores = scores.view(activations.shape[0], activations.shape[1])
48 |
49 | weights = torch.nn.Softmax(dim=-1)(scores).numpy()
50 | return weights
51 |
--------------------------------------------------------------------------------
/RISE/pytorch_grad_cam/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from pytorch_grad_cam.utils.image import deprocess_image
2 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
3 | from pytorch_grad_cam.utils.image import preprocess_image
4 |
5 |
--------------------------------------------------------------------------------
/RISE/pytorch_grad_cam/utils/image.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | from torchvision.transforms import Compose, Normalize, ToTensor
5 |
6 |
7 | def preprocess_image(img: np.ndarray, mean=None, std=None) -> torch.Tensor:
8 | if std is None:
9 | std = [0.5, 0.5, 0.5]
10 | if mean is None:
11 | mean = [0.5, 0.5, 0.5]
12 |
13 | preprocessing = Compose([
14 | ToTensor(),
15 | Normalize(mean=mean, std=std)
16 | ])
17 |
18 | return preprocessing(img.copy()).unsqueeze(0)
19 |
20 |
21 | def deprocess_image(img):
22 | """ see https://github.com/jacobgil/keras-grad-cam/blob/master/grad-cam.py#L65 """
23 | img = img - np.mean(img)
24 | img = img / (np.std(img) + 1e-5)
25 | img = img * 0.1
26 | img = img + 0.5
27 | img = np.clip(img, 0, 1)
28 | return np.uint8(img * 255)
29 |
30 |
31 | def show_cam_on_image(img: np.ndarray,
32 | mask: np.ndarray,
33 | use_rgb: bool = False,
34 | colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
35 | """ This function overlays the cam mask on the image as an heatmap.
36 | By default the heatmap is in BGR format.
37 |
38 | :param img: The base image in RGB or BGR format.
39 | :param mask: The cam mask.
40 | :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
41 | :param colormap: The OpenCV colormap to be used.
42 | :returns: The default image with the cam overlay.
43 | """
44 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
45 | if use_rgb:
46 | heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
47 | heatmap = np.float32(heatmap) / 255
48 |
49 | if np.max(img) > 1:
50 | raise Exception("The input image should np.float32 in the range [0, 1]")
51 |
52 | cam = heatmap + img
53 | cam = cam / np.max(cam)
54 | return np.uint8(255 * cam)
55 |
--------------------------------------------------------------------------------
/RISE/pytorch_grad_cam/utils/svd_on_activations.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def get_2d_projection(activation_batch):
4 | # TBD: use pytorch batch svd implementation
5 | projections = []
6 | for activations in activation_batch:
7 | reshaped_activations = (activations).reshape(activations.shape[0], -1).transpose()
8 | # Centering before the SVD seems to be important here,
9 | # Otherwise the image returned is negative
10 | reshaped_activations = reshaped_activations - reshaped_activations.mean(axis=0)
11 | U, S, VT = np.linalg.svd(reshaped_activations, full_matrices=True)
12 | projection = reshaped_activations @ VT[0, :]
13 | projection = projection.reshape(activations.shape[1 : ])
14 | projections.append(projection)
15 | return np.float32(projections)
--------------------------------------------------------------------------------
/RISE/pytorch_grad_cam/xgrad_cam.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | from pytorch_grad_cam.base_cam import BaseCAM
5 |
6 | class XGradCAM(BaseCAM):
7 | def __init__(self, model, target_layer, use_cuda=False, reshape_transform=None):
8 | super(XGradCAM, self).__init__(model, target_layer, use_cuda, reshape_transform)
9 |
10 | def get_cam_weights(self,
11 | input_tensor,
12 | target_category,
13 | activations,
14 | grads):
15 | sum_activations = np.sum(activations, axis=(2, 3))
16 | eps = 1e-7
17 | weights = grads * activations / (sum_activations[:, :, None, None] + eps)
18 | weights = weights.sum(axis=(2, 3))
19 | return weights
--------------------------------------------------------------------------------
/RISE/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from matplotlib import pyplot as plt
3 | import torch
4 | from torch.utils.data.sampler import Sampler
5 | from torchvision import transforms, datasets
6 | from PIL import Image
7 |
8 |
9 | # Dummy class to store arguments
10 | class Dummy():
11 | pass
12 |
13 |
14 | # Function that opens image from disk, normalizes it and converts to tensor
15 | read_tensor = transforms.Compose([
16 | lambda x: Image.open(x),
17 | transforms.Resize((224, 224)),
18 | transforms.ToTensor(),
19 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
20 | std=[0.229, 0.224, 0.225]),
21 | lambda x: torch.unsqueeze(x, 0)
22 | ])
23 |
24 |
25 | # Plots image from tensor
26 | def tensor_imshow(inp, title=None, **kwargs):
27 | """Imshow for Tensor."""
28 | inp = inp.numpy().transpose((1, 2, 0))
29 | # Mean and std for ImageNet
30 | mean = np.array([0.485, 0.456, 0.406])
31 | std = np.array([0.229, 0.224, 0.225])
32 | inp = std * inp + mean
33 | inp = np.clip(inp, 0, 1)
34 | plt.imshow(inp, **kwargs)
35 | if title is not None:
36 | plt.title(title)
37 |
38 |
39 | # Given label number returns class name
40 | def get_class_name(c):
41 | labels = np.loadtxt('synset_words.txt', str, delimiter='\t')
42 | return ' '.join(labels[c].split(',')[0].split()[1:])
43 |
44 |
45 | # Image preprocessing function
46 | preprocess = transforms.Compose([
47 | transforms.Resize((224, 224)),
48 | transforms.ToTensor(),
49 | # Normalization for ImageNet
50 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
51 | std=[0.229, 0.224, 0.225]),
52 | ])
53 |
54 |
55 | # Sampler for pytorch loader. Given range r loader will only
56 | # return dataset[r] instead of whole dataset.
57 | class RangeSampler(Sampler):
58 | def __init__(self, r):
59 | self.r = r
60 |
61 | def __iter__(self):
62 | return iter(self.r)
63 |
64 | def __len__(self):
65 | return len(self.r)
66 |
--------------------------------------------------------------------------------
/TorchRay/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # 1.0.0 (October 2019)
2 |
3 | * Initial public release.
4 |
--------------------------------------------------------------------------------
/TorchRay/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to.
4 | Please read the [full text](https://code.fb.com/codeofconduct/)
5 | so that you can understand what actions will and will not be tolerated.
6 |
--------------------------------------------------------------------------------
/TorchRay/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to TorchRay
2 |
3 | We want to make contributing to this project as easy and transparent as possible.
4 |
5 | ## Pull Requests
6 |
7 | We actively welcome your pull requests.
8 |
9 | 1. Fork the repo and create your branch from `master`.
10 | 2. If you've added code that should be tested, add tests.
11 | 3. If you've changed APIs, update the documentation.
12 | 4. Ensure the test passes (for basic testing use `examples.run_all_examples()`)
13 | 5. Make sure your code lints.
14 | 6. If you haven't already, complete the Contributor License Agreement ("CLA").
15 |
16 | ## Contributor License Agreement ("CLA")
17 |
18 | In order to accept your pull request, we need you to submit a CLA. You only need to do this once to work on any of Facebook's open source projects.
19 |
20 | Complete your CLA here:
21 |
22 | ## Issues
23 |
24 | We use GitHub issues to track public bugs. Please ensure your description is clear and has sufficient instructions to be able to reproduce the issue.
25 |
26 | ## Coding Style
27 |
28 | We follow the PEP8 standard using flake8 for linting. For the documentation, we use Sphinx and Napoleon using the Google style for the comments, and LaTeX math for equations.
29 |
30 | ## License
31 |
32 | By contributing to TorchRay, you agree that your contributions will be licensed under the LICENSE file in the root directory of this source tree.
33 |
--------------------------------------------------------------------------------
/TorchRay/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include CHANGELOG.md
2 | include CODE_OF_CONDUCT.md
3 | include CONTRIBUTING.md
4 | include LICENSE
5 | include README.md
6 | include torchray/VERSION
7 | include docs
8 | include docs/html
9 | include scripts
10 | include scripts/torchrayrc
11 | recursive-include docs/html *.html *.png *.gif *.js *.css *.eot *.ttf *.woff *.woff2 *.svg *.inv *.txt
12 | recursive-include torchray *.py *.txt
13 | recursive-include scripts *.sh
14 |
--------------------------------------------------------------------------------
/TorchRay/Makefile:
--------------------------------------------------------------------------------
1 | VER=$(shell head -n 1 torchray/VERSION)
2 |
3 | define get_conda_env
4 | $(shell which python | xargs dirname | xargs dirname)
5 | endef
6 |
7 | .PHONY: all
8 | all:
9 |
10 |
11 | .PHONY: dist
12 | dist:
13 | rm -rf dist
14 | docs/make-docs
15 | python3 setup.py sdist
16 | tar -tzvf dist/torchray-*.tar.gz
17 |
18 |
19 | .PHONY: conda
20 | conda:
21 | mkdir -p dist
22 | VER=$(VER) conda build -c defaults -c pytorch -c conda-forge \
23 | --no-anaconda-upload --python 3.7 packaging/meta.yaml
24 |
25 |
26 | .PHONY: conda-install
27 | conda-install:
28 | conda install --use-local torchray
29 |
30 |
31 | .PHONY: docs
32 | docs:
33 | rm -rf docs/html
34 | docs/make-docs
35 |
36 |
37 | .PHONY: pub
38 | pub: docs
39 | git push pub master
40 | git push -f pub gh-pages
41 | git push -f --tags pub
42 |
43 | touch docs/html/.nojekyll
44 | git -C docs/html init
45 | git -C docs/html remote add pub git@github.com:facebookresearch/TorchRay.git
46 | git -C docs/html add .
47 | git -C docs/html commit -m "add documentation"
48 | git -C docs/html push -f pub master:gh-pages
49 |
50 |
51 | .PHONY: tag
52 | tag:
53 | git tag -f v$(VER)
54 |
55 |
56 | .PHONY: distclean
57 | distclean:
58 | rm -rf dist
59 |
--------------------------------------------------------------------------------
/TorchRay/README.md:
--------------------------------------------------------------------------------
1 | # TorchRay
2 |
3 | The *TorchRay* package implements several visualization methods for deep
4 | convolutional neural networks using PyTorch. In this release, TorchRay focuses
5 | on *attribution*, namely the problem of determining which part of the input,
6 | usually an image, is responsible for the value computed by a neural network.
7 |
8 | *TorchRay* is research oriented: in addition to implementing well known
9 | techniques form the literature, it provides code for reproducing results that
10 | appear in several papers, in order to support *reproducible research*.
11 |
12 | *TorchRay* was initially developed to support the paper:
13 |
14 | * *Understanding deep networks via extremal perturbations and smooth masks.*
15 | Fong, Patrick, Vedaldi.
16 | Proceedings of the International Conference on Computer Vision (ICCV), 2019.
17 |
18 | ## Examples
19 |
20 | The package contains several usage examples in the
21 | [`examples`](https://github.com/facebookresearch/TorchRay/tree/master/examples)
22 | subdirectory.
23 |
24 | Here is a complete example for using GradCAM:
25 |
26 | ```python
27 | from torchray.attribution.grad_cam import grad_cam
28 | from torchray.benchmark import get_example_data, plot_example
29 |
30 | # Obtain example data.
31 | model, x, category_id, _ = get_example_data()
32 |
33 | # Grad-CAM backprop.
34 | saliency = grad_cam(model, x, category_id, saliency_layer='features.29')
35 |
36 | # Plots.
37 | plot_example(x, saliency, 'grad-cam backprop', category_id)
38 | ```
39 |
40 | ## Requirements
41 |
42 | TorchRay requires:
43 |
44 | * Python 3.4 or greater
45 | * pytorch 1.1.0 or greater
46 | * matplotlib
47 |
48 | For benchmarking, it also requires:
49 |
50 | * torchvision 0.3.0 or greater
51 | * pycocotools
52 | * mongodb (suggested)
53 | * pymongod (suggested)
54 |
55 | On Linux/macOS, using conda you can install
56 |
57 | ```bash
58 | while read requirement; do conda install \
59 | -c defaults -c pytorch -c conda-forge --yes $requirement; done <=1.1.0
61 | pycocotools
62 | torchvision>=0.3.0
63 | mongodb
64 | pymongo
65 | EOF
66 | ```
67 |
68 | ## Installing TorchRay
69 |
70 | Using `pip`:
71 |
72 | ```shell
73 | pip install torchray
74 | ```
75 |
76 | From source:
77 |
78 | ```shell
79 | python setup.py install
80 | ```
81 |
82 | or
83 |
84 | ```shell
85 | pip install .
86 | ```
87 |
88 | ## Full documentation
89 |
90 | The full documentation can be found
91 | [here](https://facebookresearch.github.io/TorchRay).
92 |
93 | ## Changes
94 |
95 | See the [CHANGELOG](CHANGELOG.md).
96 |
97 | ## Join the TorchRay community
98 |
99 | * Website: https://github.com/facebookresearch/TorchRay
100 |
101 | See the [CONTRIBUTING](CONTRIBUTING.md) file for how to help out.
102 |
103 | ## The team
104 |
105 | TorchRay has been primarily developed by Ruth C. Fong and Andrea Vedaldi.
106 |
107 | ## License
108 |
109 | TorchRay is CC-BY-NC licensed, as found in the [LICENSE](LICENSE) file.
110 |
--------------------------------------------------------------------------------
/TorchRay/docs/attribution.rst:
--------------------------------------------------------------------------------
1 | .. _backprop:
2 |
3 | Attribution
4 | ===========
5 |
6 | *Attribution* is the problem of determining which part of the input,
7 | e.g. an image, is responsible for the value computed by a predictor
8 | such as a neural network.
9 |
10 | Formally, let :math:`\mathbf{x}` be the input to a convolutional neural
11 | network, e.g., a :math:`N \times C \times H \times W` real tensor. The neural
12 | network is a function :math:`\Phi` mapping :math:`\mathbf{x}` to a scalar
13 | output :math:`z \in \mathbb{R}`. Thus the goal is to find which of the
14 | elements of :math:`\mathbf{x}` are "most responsible" for the outcome
15 | :math:`z`.
16 |
17 | Some attribution methods are "black box" approaches, in the sense that they
18 | ignore the nature of the function :math:`\Phi` (however, most assume that it is
19 | at least possible to compute the gradient of :math:`\Phi` efficiently). Most
20 | attribution methods, however, are "white box" approaches, in the sense that
21 | they exploit the knowledge of the structure of :math:`\Phi`.
22 |
23 | :ref:`Backpropagation methods ` are "white box" visualization
24 | approaches that build on backpropagation, thus leveraging the functionality
25 | already implemented in standard deep learning packages toolboxes such as
26 | PyTorch.
27 |
28 | :ref:`Perturbation methods ` are "black box" visualization
29 | approaches that generate attribution visualizations by perturbing the input
30 | and observing the changes in a model's output.
31 |
32 | TorchRay implements the following methods:
33 |
34 | * Backpropagation methods
35 |
36 | * Deconvolution (:mod:`.deconvnet`)
37 | * Excitation backpropagation (:mod:`.excitation_backprop`)
38 | * Gradient [1]_ (:mod:`.gradient`)
39 | * Grad-CAM (:mod:`.grad_cam`)
40 | * Guided backpropagation (:mod:`.guided_backprop`)
41 | * Linear approximation (:mod:`.linear_approx`)
42 |
43 | * Perturbation methods
44 |
45 | * Extremal perturbation [1]_ (:mod:`.extremal_perturbation`)
46 | * RISE (:mod:`.rise`)
47 |
48 | .. rubric:: Footnotes
49 |
50 | .. [1] The :mod:`.gradient` and :mod:`.extremal_perturbation` methods actually
51 | straddle the boundaries between white and black box methods, as they
52 | only require the ability to compute the gradient of the predictor,
53 | which does not necessarily require to know the predictor internals.
54 | However, in TorchRay both are implemented using backpropagation.
55 |
56 | .. _backpropagation:
57 |
58 | Backpropagation methods
59 | -----------------------
60 |
61 | Backpropagation methods work by tweaking the backpropagation algorithm that, on
62 | its own, computes the gradient of tensor functions. Formally, a neural network
63 | :math:`\Phi` is a collection :math:`\Phi_1,\dots,\Phi_n` of :math:`n` layers.
64 | Each layer is in itself a "smaller" function inputting and outputting tensors,
65 | called *activations* (for simplicity, we call activations the network input and
66 | parameter tensors as well). Layers are interconnected in a *Directed Acyclic
67 | Graph* (DAG). The DAG is bipartite with some nodes representing the activation
68 | tensors and the other nodes representing the layers, with interconnections
69 | between layers and input/output tensors in the obvious way. The DAG sources are
70 | the network's input and parameter tensors and the DAG sinks are the network's
71 | output tensors.
72 |
73 | The main goal of a deep neural network toolbox such as PyTorch is to evaluate
74 | the function :math:`\Phi` implemented by the DAG as well as its gradients with
75 | respect to various tensors (usually the model parameters). The calculation of
76 | the gradients, which uses backpropagation, associates to the forward DAG a
77 | backward DAG, obtained as follows:
78 |
79 | * Activation tensors :math:`\mathbf{x}_j` become gradient tensors
80 | :math:`d\mathbf{x}_j` (preserving their shape).
81 | * Forward layers :math:`\Phi_i` become backward layers :math:`\Phi_i^*`.
82 | * All arrows are reversed.
83 | * Additional arrows connecting the activation tensors :math:`\mathbf{x}_i`
84 | as inputs to the corresponding backward function :math:`\Phi_i^*` are added
85 | as well.
86 |
87 | Backpropagation methods modify the backward graph in order to generate a
88 | visualization of the network forward pass. Additionally, inputs as well as
89 | intermediate activations can be inspected to obtain different visualizations.
90 | These two concepts are explained next.
91 |
92 | Changing the backward propagation rules
93 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
94 |
95 | Changing the backward propagation rules amounts to redefining the functions
96 | :math:`\Phi_i^*`. After doing so, the "gradients" computed by backpropagation
97 | change their meaning into something useful for visualization. We call these
98 | modified gradients *pseudo-gradients*.
99 |
100 | TorchRay provides a number of context managers that enable patching PyTorch
101 | functions on the fly in order to change the backward propagation rules for
102 | a segment of code. For example, let ``x`` be an input tensor and ``model``
103 | a deep classification network. Furthermore, let ``category_id`` be the
104 | index of the class for which we want to attribute input regions. The following
105 | code uses :mod:`.guided_backprop` to compute and store the pseudo gradient in
106 | ``x.grad``.
107 |
108 | .. code-block:: python
109 |
110 | from torchray.attribution.guided_backprop import GuidedBackpropContext
111 |
112 | x.requires_grad_(True)
113 |
114 | with GuidedBackpropContext():
115 | y = model(x)
116 | z = y[0, category_id]
117 | z.backward()
118 |
119 | At this point, ``x.grad`` contains the "guided gradient" computed by this
120 | method. This gradient is usually flattened along the channel dimension to
121 | produce a saliency map for visualization:
122 |
123 | .. code-block:: python
124 |
125 | from torchray.attribution.common import gradient_to_saliency
126 |
127 | saliency = gradient_to_saliency(x)
128 |
129 | TorchRay contains also some wrapper code, such as
130 | :func:`.guided_backprop.guided_backprop`, that combine these steps in a way
131 | that would work for common networks.
132 |
133 | Probing intermediate activations and gradients
134 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
135 |
136 | Most visualization methods are based on inspecting the activations when the
137 | network is evaluated and the pseudo-gradients during backpropagation. This is
138 | generally easy for input tensors. For intermediate tensors, when using PyTorch
139 | functional interface, this is also easy: simply use ``retain_grad_(True)`` in
140 | order to retain the gradient of an intermediate tensor:
141 |
142 | .. code-block:: python
143 |
144 | from torch.nn.functional import relu, conv2d
145 | from torchray.attribution import GuidedBackpropContext
146 |
147 | with GuidedBackpropContext():
148 | y = conv2d(x, weight)
149 | y.requires_grad_(True)
150 | y.retain_grad_(True)
151 | z = relu(y)[0, class_index]
152 | z.backward()
153 |
154 | # Now y and y.grad contain the activation and guided gradient,
155 | # respectively.
156 |
157 | However, in PyTorch most network components are implemented as
158 | :class:`torch.nn.Module` objects. In this case, is not obvious how to access a
159 | specific layer's information. In order to simplify this process, the library
160 | provides the :class:`Probe` class:
161 |
162 | .. code-block:: python
163 |
164 | from torch.nn.functional import relu, conv2d
165 | from torchray.attribution.guided_backprop import GuidedBackpropContext
166 | import torchray.attribution.Probe
167 |
168 | # Attach a probe to the last conv layer.
169 | probe = Probe(alexnet.features[11])
170 |
171 | with GuidedBackpropContext():
172 | y = alexnet(x)
173 | z = y[0, class_index]
174 | z.backward()
175 |
176 | # Now probe.data[0] and probe.data[0].grad contain
177 | # the activations and guided gradients.
178 |
179 | The probe automatically applies :func:`torch.Tensor.requires_grad_` and
180 | :func:`torch.Tensor.retain_grad_` as needed. You can use ``probe.remove()`` to
181 | remove the probe from the network once you are done.
182 |
183 | Limitations
184 | ^^^^^^^^^^^
185 |
186 | Except for the gradient method, backpropagation methods require modifying
187 | the backward function of each layer. TorchRay implements the rules
188 | necessary to do so as originally defined by each authors' method.
189 | However, as new neural network layers are introduced, it is possible
190 | that the default behavior, which is to not change backpropagation, may
191 | be inappropriate or suboptimal for them.
192 |
193 | .. _perturbation:
194 |
195 | Perturbation methods
196 | --------------------
197 |
198 | Perturbation methods work by changing the input to the neural network in a
199 | controlled manner, observing the outcome on the output generated by the
200 | network. Attribution can be achieved by occluding (setting to zero) specific
201 | parts of the image and checking whether this has a strong effect on the output.
202 | This can be thought of as a form of sensitivity analysis which is still
203 | specific to a given input, but is not differential as for the gradient method.
204 |
205 |
206 | DeConvNet
207 | ---------
208 |
209 | .. automodule:: torchray.attribution.deconvnet
210 | :members:
211 | :show-inheritance:
212 |
213 |
214 | Excitation backprop
215 | -------------------
216 |
217 | .. automodule:: torchray.attribution.excitation_backprop
218 | :members:
219 | :show-inheritance:
220 |
221 | Extremal perturbation
222 | ---------------------
223 |
224 | .. automodule:: torchray.attribution.extremal_perturbation
225 | :members:
226 | :show-inheritance:
227 |
228 | Gradient
229 | --------
230 |
231 | .. automodule:: torchray.attribution.gradient
232 | :members:
233 | :show-inheritance:
234 |
235 | Grad-CAM
236 | --------
237 |
238 | .. automodule:: torchray.attribution.grad_cam
239 | :members:
240 | :show-inheritance:
241 |
242 | Guided backprop
243 | ---------------
244 |
245 | .. automodule:: torchray.attribution.guided_backprop
246 | :members:
247 | :show-inheritance:
248 |
249 | Linear approximation
250 | --------------------
251 |
252 | .. automodule:: torchray.attribution.linear_approx
253 | :members:
254 | :show-inheritance:
255 |
256 | RISE
257 | ----
258 |
259 | .. automodule:: torchray.attribution.rise
260 | :members:
261 | :show-inheritance:
262 |
263 | Common code
264 | -----------
265 |
266 | .. automodule:: torchray.attribution.common
267 | :members:
268 | :show-inheritance:
--------------------------------------------------------------------------------
/TorchRay/docs/benchmark.rst:
--------------------------------------------------------------------------------
1 | .. _benchmark:
2 |
3 | Benchmarking
4 | ============
5 |
6 | This module contains code for benchmarking attribution methods, including
7 | reproducing several published results. In addition to implementations of
8 | benchmarking protocols (:mod:`.pointing_game`), the module also provides
9 | implementations of *reference datasets* and *reference models* used in prior
10 | research work, properly converted to PyTorch. Overall, this implementations
11 | closely reproduces prior results, notably the ones in the [EBP]_ paper.
12 |
13 | A standard benchmarking suite is included in this library as
14 | :mod:`examples.standard_suite`. For slow methods, a computer cluster may be
15 | required for evaluation (we do not include explicit support for clusters, but
16 | it is easy to add on top of this example code).
17 |
18 | It is also recommended to turn on logging (see
19 | :mod:`torchray.benchmark.logging`), which allows the driver to
20 | uses MongoDB to store partial benchmarking results as it goes.
21 | Computations can then be cached and reused to resume the calculations
22 | after a crash or other issue. In order to start the logging server, use
23 |
24 | .. code:: shell
25 |
26 | $ python -m torchray.benchmark.server
27 |
28 | The server parameters (address, port, etc) can be configured by writing
29 | a ``.torchrayrc`` file in your current or home directory. The package
30 | contains an example configuration file. The server creates a regular
31 | MongoDB database (by default in ``./data/db``) which can be manually
32 | explored by means of the MongoDB shell.
33 |
34 | By default, the driver writes data in the ``./data/`` subfolder.
35 | You can change that via the configuration file, or, possibly more easily,
36 | add a symbolic link to where you want to store the data.
37 |
38 | The data include the *datasets* (PASCAL VOC, COCO, ImageNet; see
39 | :mod:`torchray.benchmark.datasets`). These must be downloaded manually and
40 | stored in ``./data/datasets/{voc,coco,imagenet}`` unless this is changed via
41 | the configuration file. Note that these datasets can be very large (many GBs).
42 |
43 | The data also include *reference models* (see
44 | :mod:`torchray.benchmark.models`).
45 |
46 | .. automodule:: torchray.benchmark
47 | :members:
48 | :show-inheritance:
49 |
50 | Pointing Game
51 | -------------
52 |
53 | The *Pointing Game* [EBP]_ assesses the quality of an attribution method by
54 | testing how well it can extract from a predictor a response correlated with the
55 | presence of known object categories in the image.
56 |
57 | Given an input image :math:`x` containing an object of category :math:`c`, the
58 | attribution method is applied to the predictor in order to find the part of the
59 | images responsible for predicting :math:`c`. The attribution method usually
60 | returns a saliency heatmap. The latter must then be converted in a single point
61 | :math:`(u,v)` that is "most likely" to be contained by an object of that class.
62 | The specific way the point is obtained is method-dependent.
63 |
64 | The attribution method then scores a hit if the point is within a *tolerance*
65 | :math:`\tau` (set to 15 pixels by default) to the image region :math:`\Omega`
66 | containing that object:
67 |
68 | .. math::
69 | \operatorname{hit}(u,v|\Omega)
70 | = [ \exists (u',v') \in \Omega : \|(u,v) - (u',v')\| \leq \tau].
71 |
72 | The point coordinates :math:`(u,v)` are also indices :math:`x_{ncvu}` in the
73 | input image tensor :math:`x`.
74 |
75 | RISE [RISE]_ and Extremal Perturbation [EP]_ results are averaged over 3 runs.
76 |
77 | .. csv-table:: Pointing game results
78 | :widths: auto
79 | :header-rows: 2
80 | :stub-columns: 1
81 | :file: pointing.csv
82 |
83 |
84 | .. automodule:: torchray.benchmark.pointing_game
85 | :members:
86 | :show-inheritance:
87 |
88 | Datasets
89 | --------
90 |
91 | .. automodule:: torchray.benchmark.datasets
92 | :members:
93 | :show-inheritance:
94 |
95 | .. autodata:: IMAGENET_CLASSES
96 | :annotation:
97 |
98 | .. autodata:: VOC_CLASSES
99 | :annotation:
100 |
101 | .. autodata:: COCO_CLASSES
102 | :annotation:
103 |
104 | Reference models
105 | ----------------
106 |
107 | .. automodule:: torchray.benchmark.models
108 | :members:
109 | :show-inheritance:
110 |
111 | Logging with MongoDB
112 | --------------------
113 |
114 | .. automodule:: torchray.benchmark.logging
115 | :members:
116 | :show-inheritance:
117 |
118 |
--------------------------------------------------------------------------------
/TorchRay/docs/conf.py:
--------------------------------------------------------------------------------
1 | # conf.py - Sphinx configuration
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3 |
4 | import os
5 | import sys
6 |
7 |
8 | def setup(app):
9 | app.add_stylesheet('css/equations.css')
10 |
11 |
12 | author = 'TorchRay Contributors'
13 | copyright = 'TorchRay Contributors'
14 | project = 'TorchRay'
15 | release = 'beta'
16 | version = '1.0'
17 |
18 | extensions = [
19 | 'sphinx.ext.autodoc',
20 | 'sphinx.ext.mathjax',
21 | 'sphinx.ext.ifconfig',
22 | 'sphinx.ext.viewcode',
23 | 'sphinx.ext.napoleon',
24 | ]
25 |
26 | exclude_patterns = ['html']
27 | master_doc = 'index'
28 | pygments_style = None
29 | source_suffix = ['.rst', '.md']
30 |
31 | # HTML documentation.
32 | html_theme = 'sphinx_rtd_theme'
33 | html_theme_options = {
34 | 'analytics_id': '',
35 | 'canonical_url': '',
36 | 'display_version': True,
37 | 'logo_only': False,
38 | 'prev_next_buttons_location': 'bottom',
39 | 'style_external_links': False,
40 |
41 | # Toc options
42 | 'collapse_navigation': True,
43 | 'includehidden': True,
44 | 'navigation_depth': 4,
45 | 'sticky_navigation': True,
46 | 'titles_only': False
47 | }
48 | html_static_path = ['static']
49 |
--------------------------------------------------------------------------------
/TorchRay/docs/index.rst:
--------------------------------------------------------------------------------
1 | Welcome to TorchRay
2 | ===================
3 |
4 | The *TorchRay* package implements several visualization methods for deep
5 | convolutional neural networks using PyTorch. In this release, TorchRay focuses
6 | on *attribution*, namely the problem of determining which part of the input,
7 | e.g., an image, is responsible for the value computed by a neural network.
8 |
9 | *TorchRay* is research-oriented. In addition to implementing well known
10 | techniques form the literature, it provides code for reproducing results that
11 | appear in several papers, and can thus be a tool for *reproducible research*.
12 |
13 | For downloads, installation instructions, and access to the source code
14 | use the
15 | `GitHub repository `_.
16 |
17 | .. toctree::
18 | attribution
19 | benchmark
20 | utils
21 |
22 | Indices
23 | =======
24 |
25 | * :ref:`genindex`
26 | * :ref:`modindex`
27 | * :ref:`search`
--------------------------------------------------------------------------------
/TorchRay/docs/make-docs:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | PYTHON=${PYTHON:-python}
3 | ${PYTHON} -m sphinx -b html docs docs/html
4 |
--------------------------------------------------------------------------------
/TorchRay/docs/pointing.csv:
--------------------------------------------------------------------------------
1 | ,voc_2007,voc_2007,voc_2007,voc_2007,coco,coco,coco,coco
2 | ,vgg16,vgg16,resnet50,resnet50,vgg16,vgg16,resnet50,resnet50
3 | center,69.6,42.4,69.6,42.4,27.8,19.5,27.8,19.5
4 | gradient,76.3,56.9,72.3,56.8,37.7,31.4,35.0,29.4
5 | deconvnet,67.5,44.2,68.6,44.7,30.7,23.0,30.0,21.9
6 | guided_backprop,75.9,53.0,77.2,59.4,39.1,31.4,42.1,35.3
7 | excitation_backprop,77.1,56.6,84.5,70.8,39.8,32.8,49.6,43.9
8 | contrastive_excitation_backprop,79.9,66.5,90.7,82.1,49.7,44.3,58.5,53.6
9 | rise,86.9,75.1,86.4,78.8,50.8,45.3,54.7,50.0
10 | grad_cam,86.6,74.0,90.4,82.3,54.2,49.0,57.3,52.3
11 | extremal_perturbation,88.0,76.1,88.9,78.7,51.5,45.9,56.5,51.5
--------------------------------------------------------------------------------
/TorchRay/docs/static/css/equations.css:
--------------------------------------------------------------------------------
1 | /* Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. */
2 |
3 | .math {
4 | text-align: left;
5 | }
6 |
7 | .eqno {
8 | float: right;
9 | }
--------------------------------------------------------------------------------
/TorchRay/docs/utils.rst:
--------------------------------------------------------------------------------
1 | Utilities
2 | =========
3 |
4 | .. contents:: :local:
5 |
6 | .. automodule:: torchray.utils
7 | :members:
8 | :show-inheritance:
--------------------------------------------------------------------------------
/TorchRay/examples/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Define :func:`run_all_examples` to run all examples of saliency methods
3 | (excluding :mod:`examples.standard_suite`).
4 | """
5 |
6 | __all__ = ['run_all_examples']
7 |
8 | from matplotlib import pyplot as plt
9 |
10 |
11 | def run_all_examples():
12 | """Run all examples."""
13 |
14 | plt.figure()
15 | import examples.extremal_perturbation
16 | plt.draw()
17 | plt.pause(0.001)
18 |
19 | plt.figure()
20 | import examples.deconvnet_manual
21 | plt.draw()
22 | plt.pause(0.001)
23 |
24 | plt.figure()
25 | import examples.deconvnet
26 | plt.draw()
27 | plt.pause(0.001)
28 |
29 | plt.figure()
30 | import examples.grad_cam_manual
31 | plt.draw()
32 | plt.pause(0.001)
33 |
34 | plt.figure()
35 | import examples.grad_cam
36 | plt.draw()
37 | plt.pause(0.001)
38 |
39 | plt.figure()
40 | import examples.contrastive_excitation_backprop_manual
41 | plt.draw()
42 | plt.pause(0.001)
43 |
44 | plt.figure()
45 | import examples.contrastive_excitation_backprop
46 | plt.draw()
47 | plt.pause(0.001)
48 |
49 | plt.figure()
50 | import examples.excitation_backprop_manual
51 | plt.draw()
52 | plt.pause(0.001)
53 |
54 | plt.figure()
55 | import examples.excitation_backprop
56 | plt.draw()
57 | plt.pause(0.001)
58 |
59 | plt.figure()
60 | import examples.guided_backprop_manual
61 | plt.draw()
62 | plt.pause(0.001)
63 |
64 | plt.figure()
65 | import examples.guided_backprop
66 | plt.draw()
67 | plt.pause(0.001)
68 |
69 | plt.figure()
70 | import examples.gradient_manual
71 | plt.draw()
72 | plt.pause(0.001)
73 |
74 | plt.figure()
75 | import examples.gradient
76 | plt.draw()
77 | plt.pause(0.001)
78 |
79 | plt.figure()
80 | import examples.rise
81 | plt.draw()
82 | plt.pause(0.001)
83 |
--------------------------------------------------------------------------------
/TorchRay/examples/__main__.py:
--------------------------------------------------------------------------------
1 | """
2 | Run all examples of saliency methods (excluding
3 | :mod:`examples.standard_suite`).
4 | """
5 | from . import run_all_examples
6 |
7 | if __name__ == "__main__":
8 | run_all_examples()
9 |
--------------------------------------------------------------------------------
/TorchRay/examples/contrastive_excitation_backprop.py:
--------------------------------------------------------------------------------
1 | from torchray.attribution.excitation_backprop import contrastive_excitation_backprop
2 | from torchray.benchmark import get_example_data, plot_example
3 |
4 | # Obtain example data.
5 | model, x, category_id, _ = get_example_data()
6 |
7 | # Contrastive excitation backprop.
8 | saliency = contrastive_excitation_backprop(
9 | model,
10 | x,
11 | category_id,
12 | saliency_layer='features.9',
13 | contrast_layer='features.30',
14 | classifier_layer='classifier.6',
15 | )
16 |
17 | # Plots.
18 | plot_example(x, saliency, 'contrastive excitation backprop', category_id)
19 |
--------------------------------------------------------------------------------
/TorchRay/examples/contrastive_excitation_backprop_manual.py:
--------------------------------------------------------------------------------
1 | from torchray.attribution.common import Probe, get_module
2 | from torchray.attribution.excitation_backprop import ExcitationBackpropContext
3 | from torchray.attribution.excitation_backprop import gradient_to_contrastive_excitation_backprop_saliency
4 | from torchray.benchmark import get_example_data, plot_example
5 |
6 | # Obtain example data.
7 | model, x, category_id, _ = get_example_data()
8 |
9 | # Contrastive excitation backprop.
10 | input_layer = get_module(model, 'features.9')
11 | contrast_layer = get_module(model, 'features.30')
12 | classifier_layer = get_module(model, 'classifier.6')
13 |
14 | input_probe = Probe(input_layer, target='output')
15 | contrast_probe = Probe(contrast_layer, target='output')
16 |
17 | with ExcitationBackpropContext():
18 | y = model(x)
19 | z = y[0, category_id]
20 | classifier_layer.weight.data.neg_()
21 | z.backward()
22 |
23 | classifier_layer.weight.data.neg_()
24 |
25 | contrast_probe.contrast = [contrast_probe.data[0].grad]
26 |
27 | y = model(x)
28 | z = y[0, category_id]
29 | z.backward()
30 |
31 | saliency = gradient_to_contrastive_excitation_backprop_saliency(input_probe.data[0])
32 |
33 | input_probe.remove()
34 | contrast_probe.remove()
35 |
36 | # Plots.
37 | plot_example(x, saliency, 'contrastive excitation backprop', category_id)
38 |
--------------------------------------------------------------------------------
/TorchRay/examples/deconvnet.py:
--------------------------------------------------------------------------------
1 | from torchray.attribution.deconvnet import deconvnet
2 | from torchray.benchmark import get_example_data, plot_example
3 |
4 | # Obtain example data.
5 | model, x, category_id, _ = get_example_data()
6 |
7 | # DeConvNet method.
8 | saliency = deconvnet(model, x, category_id)
9 |
10 | # Plots.
11 | plot_example(x, saliency, 'deconvnet', category_id)
12 |
--------------------------------------------------------------------------------
/TorchRay/examples/deconvnet_manual.py:
--------------------------------------------------------------------------------
1 | from torchray.attribution.common import gradient_to_saliency
2 | from torchray.attribution.deconvnet import DeConvNetContext
3 | from torchray.benchmark import get_example_data, plot_example
4 |
5 | # Obtain example data.
6 | model, x, category_id, _ = get_example_data()
7 |
8 | # DeConvNet method.
9 | x.requires_grad_(True)
10 |
11 | with DeConvNetContext():
12 | y = model(x)
13 | z = y[0, category_id]
14 | z.backward()
15 |
16 | saliency = gradient_to_saliency(x)
17 |
18 | # Plots.
19 | plot_example(x, saliency, 'deconvnet', category_id)
20 |
--------------------------------------------------------------------------------
/TorchRay/examples/excitation_backprop.py:
--------------------------------------------------------------------------------
1 | from torchray.attribution.excitation_backprop import excitation_backprop
2 | from torchray.benchmark import get_example_data, plot_example
3 |
4 | # Obtain example data.
5 | model, x, category_id, _ = get_example_data()
6 |
7 | # Contrastive excitation backprop.
8 | saliency = excitation_backprop(
9 | model,
10 | x,
11 | category_id,
12 | saliency_layer='features.9',
13 | )
14 |
15 | # Plots.
16 | plot_example(x, saliency, 'excitation backprop', category_id)
17 |
--------------------------------------------------------------------------------
/TorchRay/examples/excitation_backprop_manual.py:
--------------------------------------------------------------------------------
1 | from torchray.attribution.common import Probe, get_module
2 | from torchray.attribution.excitation_backprop import ExcitationBackpropContext
3 | from torchray.attribution.excitation_backprop import gradient_to_excitation_backprop_saliency
4 | from torchray.benchmark import get_example_data, plot_example
5 |
6 | # Obtain example data.
7 | model, x, category_id, _ = get_example_data()
8 |
9 | # Contrastive excitation backprop.
10 | saliency_layer = get_module(model, 'features.9')
11 | saliency_probe = Probe(saliency_layer, target='output')
12 |
13 | with ExcitationBackpropContext():
14 | y = model(x)
15 | z = y[0, category_id]
16 | z.backward()
17 |
18 | saliency = gradient_to_excitation_backprop_saliency(saliency_probe.data[0])
19 |
20 | saliency_probe.remove()
21 |
22 | # Plots.
23 | plot_example(x, saliency, 'excitation backprop', category_id)
24 |
--------------------------------------------------------------------------------
/TorchRay/examples/extremal_perturbation.py:
--------------------------------------------------------------------------------
1 | from torchray.attribution.extremal_perturbation import extremal_perturbation, contrastive_reward
2 | from torchray.benchmark import get_example_data, plot_example
3 | from torchray.utils import get_device
4 |
5 | # Obtain example data.
6 | model, x, category_id_1, category_id_2 = get_example_data()
7 |
8 | # Run on GPU if available.
9 | device = get_device()
10 | model.to(device)
11 | x = x.to(device)
12 |
13 | # Extremal perturbation backprop.
14 | masks_1, _ = extremal_perturbation(
15 | model, x, category_id_1,
16 | reward_func=contrastive_reward,
17 | debug=True,
18 | areas=[0.12],
19 | )
20 |
21 | masks_2, _ = extremal_perturbation(
22 | model, x, category_id_2,
23 | reward_func=contrastive_reward,
24 | debug=True,
25 | areas=[0.05],
26 | )
27 |
28 | # Plots.
29 | plot_example(x, masks_1, 'extremal perturbation', category_id_1)
30 | plot_example(x, masks_2, 'extremal perturbation', category_id_2)
31 |
--------------------------------------------------------------------------------
/TorchRay/examples/grad_cam.py:
--------------------------------------------------------------------------------
1 | from torchray.attribution.grad_cam import grad_cam
2 | from torchray.benchmark import get_example_data, plot_example
3 |
4 | # Obtain example data.
5 | model, x, category_id, _ = get_example_data()
6 |
7 | # Grad-CAM backprop.
8 | saliency = grad_cam(model, x, category_id, saliency_layer='features.29')
9 |
10 | # Plots.
11 | plot_example(x, saliency, 'grad-cam backprop', category_id)
12 |
--------------------------------------------------------------------------------
/TorchRay/examples/grad_cam_manual.py:
--------------------------------------------------------------------------------
1 | from torchray.attribution.common import Probe, get_module
2 | from torchray.attribution.grad_cam import gradient_to_grad_cam_saliency
3 | from torchray.benchmark import get_example_data, plot_example
4 |
5 | # Obtain example data.
6 | model, x, category_id, _ = get_example_data()
7 |
8 | # Grad-CAM backprop.
9 | saliency_layer = get_module(model, 'features.29')
10 |
11 | probe = Probe(saliency_layer, target='output')
12 |
13 | y = model(x)
14 | z = y[0, category_id]
15 | z.backward()
16 |
17 | saliency = gradient_to_grad_cam_saliency(probe.data[0])
18 |
19 | # Plots.
20 | plot_example(x, saliency, 'grad-cam backprop', category_id)
21 |
--------------------------------------------------------------------------------
/TorchRay/examples/gradient.py:
--------------------------------------------------------------------------------
1 | from torchray.attribution.gradient import gradient
2 | from torchray.benchmark import get_example_data, plot_example
3 |
4 | # Obtain example data.
5 | model, x, category_id, _ = get_example_data()
6 |
7 | # Gradient method.
8 | saliency = gradient(model, x, category_id)
9 |
10 | # Plots.
11 | plot_example(x, saliency, 'gradient', category_id)
12 |
--------------------------------------------------------------------------------
/TorchRay/examples/gradient_manual.py:
--------------------------------------------------------------------------------
1 | from torchray.attribution.common import gradient_to_saliency
2 | from torchray.benchmark import get_example_data, plot_example
3 |
4 | # Obtain example data.
5 | model, x, category_id, _ = get_example_data()
6 |
7 | # Gradient method.
8 |
9 | x.requires_grad_(True)
10 | y = model(x)
11 | z = y[0, category_id]
12 | z.backward()
13 |
14 | saliency = gradient_to_saliency(x)
15 |
16 | # Plots.
17 | plot_example(x, saliency, 'gradient', category_id)
18 |
--------------------------------------------------------------------------------
/TorchRay/examples/guided_backprop.py:
--------------------------------------------------------------------------------
1 | from torchray.attribution.guided_backprop import guided_backprop
2 | from torchray.benchmark import get_example_data, plot_example
3 |
4 | # Obtain example data.
5 | model, x, category_id, _ = get_example_data()
6 |
7 | # Guided backprop.
8 | saliency = guided_backprop(model, x, category_id)
9 |
10 | # Plots.
11 | plot_example(x, saliency, 'guided backprop', category_id)
12 |
--------------------------------------------------------------------------------
/TorchRay/examples/guided_backprop_manual.py:
--------------------------------------------------------------------------------
1 | from torchray.attribution.common import gradient_to_saliency
2 | from torchray.attribution.guided_backprop import GuidedBackpropContext
3 | from torchray.benchmark import get_example_data, plot_example
4 |
5 | # Obtain example data.
6 | model, x, category_id, _ = get_example_data()
7 |
8 | # Guided backprop.
9 | x.requires_grad_(True)
10 |
11 | with GuidedBackpropContext():
12 | y = model(x)
13 | z = y[0, category_id]
14 | z.backward()
15 |
16 | saliency = gradient_to_saliency(x)
17 |
18 | # Plots.
19 | plot_example(x, saliency, 'guided backprop', category_id)
20 |
--------------------------------------------------------------------------------
/TorchRay/examples/linear_approx.py:
--------------------------------------------------------------------------------
1 | from torchray.attribution.linear_approx import linear_approx
2 | from torchray.benchmark import get_example_data, plot_example
3 |
4 | # Obtain example data.
5 | model, x, category_id, _ = get_example_data()
6 |
7 | # Linear approximation backprop.
8 | saliency = linear_approx(model, x, category_id, saliency_layer='features.29')
9 |
10 | # Plots.
11 | plot_example(x, saliency, 'linear approx', category_id)
12 |
--------------------------------------------------------------------------------
/TorchRay/examples/linear_approx_manual.py:
--------------------------------------------------------------------------------
1 | from torchray.attribution.common import Probe, get_module
2 | from torchray.attribution.linear_approx import gradient_to_linear_approx_saliency
3 | from torchray.benchmark import get_example_data, plot_example
4 |
5 | # Obtain example data.
6 | model, x, category_id, _ = get_example_data()
7 |
8 | # Linear approximation.
9 | saliency_layer = get_module(model, 'features.29')
10 |
11 | probe = Probe(saliency_layer, target='output')
12 |
13 | y = model(x)
14 | z = y[0, category_id]
15 | z.backward()
16 |
17 | saliency = gradient_to_linear_approx_saliency(probe.data[0])
18 |
19 | # Plots.
20 | plot_example(x, saliency, 'linear approx', category_id)
21 |
--------------------------------------------------------------------------------
/TorchRay/examples/rise.py:
--------------------------------------------------------------------------------
1 | from torchray.attribution.rise import rise
2 | from torchray.benchmark import get_example_data, plot_example
3 | from torchray.utils import get_device
4 |
5 | # Obtain example data.
6 | model, x, category_id, _ = get_example_data()
7 |
8 | # Run on GPU if available.
9 | device = get_device()
10 | model.to(device)
11 | x = x.to(device)
12 |
13 | # RISE method.
14 | saliency = rise(model, x)
15 | saliency = saliency[:, category_id].unsqueeze(0)
16 |
17 | # Plots.
18 | plot_example(x, saliency, 'RISE', category_id)
19 |
--------------------------------------------------------------------------------
/TorchRay/packaging/meta.yaml:
--------------------------------------------------------------------------------
1 | package:
2 | name: torchray
3 | version: {{ VER }}
4 |
5 | source:
6 | git_rev: v{{ VER }}
7 | git_url: git@github.com:facebookresearch/TorchRay.git
8 |
9 | requirements:
10 | build:
11 | - python >=3.4
12 |
13 | run:
14 | - python >=3.4
15 | - pytorch >=1.1.0
16 | - torchvision >=0.3.0
17 | - pycocotools >=2.0.0
18 | - matplotlib
19 |
20 |
21 | build:
22 | noarch: python
23 | number: 0
24 | script: python setup.py install --single-version-externally-managed --record=record.txt
25 |
26 | about:
27 | home: https://github.com/facebookresearch/TorchRay
28 | license: Attribution-NonCommercial 4.0 International
29 | summary: 'A PyTorch library for visualzing convnets.'
30 | description: |
31 | The *TorchRay* package implements several visualization methods for deep
32 | convolutional neural networks using PyTorch. In this release, TorchRay focuses
33 | on *attribution*, namely the problem of determining which part of the input,
34 | usually an image, is responsible for the value computed by a neural network.
35 |
36 | *TorchRay* is research oriented: in addition to implementing well known
37 | techniques form the literature, it provides code for reproducing results that
38 | appear in several papers, and can thus be a tool for *reproducible research*.
39 | dev_url: https://github.com/facebookresearch/TorchRay
40 | doc_url: https://facebookresearch.github.io/TorchRay
41 | doc_source_url: https://github.com/facebookresearch/TorchRay
--------------------------------------------------------------------------------
/TorchRay/scripts/torchrayrc:
--------------------------------------------------------------------------------
1 | {
2 | "mongo": {
3 | "server": "mongod",
4 | "hostname": "localhost",
5 | "port": 27017,
6 | "database": "./data/db"
7 | },
8 | "benchmark": {
9 | "imagenet_dir": "./data/datasets/imagenet",
10 | "models_dir": "./data/models",
11 | "experiments_dir": "./data"
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/TorchRay/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 |
3 | from setuptools import setup
4 |
5 | setup(
6 | name='torchray',
7 | version=open('torchray/VERSION').readline(),
8 | packages=[
9 | 'torchray',
10 | 'torchray.attribution',
11 | 'torchray.benchmark'
12 | ],
13 | package_data={
14 | 'torchray': ['VERSION'],
15 | 'torchray.benchmark': ['*.txt']
16 | },
17 | url='http://pypi.python.org/pypi/torchray/',
18 | author='Andrea Vedaldi',
19 | author_email='vedaldi@fb.com',
20 | license='Creative Commons Attribution-Noncommercial 4.0 International',
21 | description='TorchRay is a PyTorch library of visualization methods for convnets.',
22 | long_description=open('README.md').read(),
23 | long_description_content_type="text/markdown",
24 | classifiers=[
25 | 'Development Status :: 5 - Production/Stable',
26 | 'Intended Audience :: Developers',
27 | 'Topic :: Machine Learning :: Neural Networks',
28 | 'License :: OSI Approved :: Creative Commons Attribution-Noncommercial 4.0 International',
29 | 'Programming Language :: Python :: 3',
30 | ],
31 | install_requires=[
32 | 'importlib_resources',
33 | 'matplotlib',
34 | 'packaging',
35 | 'pycocotools >= 2.0.0',
36 | 'pymongo',
37 | 'requests',
38 | 'torch >= 1.1',
39 | 'torchvision >= 0.3.0',
40 | ],
41 | setup_requires=[
42 | 'cython',
43 | 'numpy',
44 | ]
45 | )
46 |
--------------------------------------------------------------------------------
/TorchRay/torchray/VERSION:
--------------------------------------------------------------------------------
1 | 1.0.0
2 |
--------------------------------------------------------------------------------
/TorchRay/torchray/__init__.py:
--------------------------------------------------------------------------------
1 | try:
2 | import importlib.resources as resources
3 | except ImportError:
4 | import importlib_resources as resources
5 |
6 | with resources.open_text('torchray', 'VERSION') as f:
7 | __version__ = f.readlines()[0].rstrip()
8 |
--------------------------------------------------------------------------------
/TorchRay/torchray/attribution/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCDvision/CGC/a66d87240863e19cc43d11c9e715ca447614042d/TorchRay/torchray/attribution/__init__.py
--------------------------------------------------------------------------------
/TorchRay/torchray/attribution/deconvnet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 |
3 | r"""
4 | This module implements the *deconvolution* method of [DECONV]_ for visualizing
5 | deep networks. The simplest interface is given by the :func:`deconvnet`
6 | function:
7 |
8 | .. literalinclude:: ../examples/deconvnet.py
9 | :language: python
10 | :linenos:
11 |
12 | Alternatively, it is possible to run the method "manually". DeConvNet is a
13 | backpropagation method, and thus works by changing the definition of the
14 | backward functions of some layers. The modified ReLU is implemented by class
15 | :class:`DeConvNetReLU`; however, this is rarely used directly; instead, one
16 | uses the :class:`DeConvNetContext` context instead, as follows:
17 |
18 | .. literalinclude:: ../examples/deconvnet_manual.py
19 | :language: python
20 | :linenos:
21 |
22 | See also :ref:`Backprogation methods ` for further examples
23 | and discussion.
24 |
25 | Theory
26 | ~~~~~~
27 |
28 | The only change is a modified definition of the backward ReLU function:
29 |
30 | .. math::
31 | \operatorname{ReLU}^*(x,p) =
32 | \begin{cases}
33 | p, & \mathrm{if}~ p > 0,\\
34 | 0, & \mathrm{otherwise} \\
35 | \end{cases}
36 |
37 | Warning:
38 |
39 | DeConvNets are defined for "standard" networks that use ReLU operations.
40 | Further modifications may be required for more complex or new networks
41 | that use other type of non-linearities.
42 |
43 | References:
44 |
45 | .. [DECONV] Zeiler and Fergus,
46 | *Visualizing and Understanding Convolutional Networks*,
47 | ECCV 2014,
48 | ``__.
49 | """
50 |
51 | __all__ = ["DeConvNetContext", "deconvnet"]
52 |
53 | import torch
54 |
55 | from .common import ReLUContext, saliency
56 |
57 |
58 | class DeConvNetReLU(torch.autograd.Function):
59 | """DeConvNet ReLU autograd function.
60 |
61 | This is an autograd function that redefines the ``relu`` function
62 | to match the DeConvNet ReLU definition.
63 | """
64 |
65 | @staticmethod
66 | def forward(ctx, input):
67 | """DeConvNet ReLU forward function."""
68 | return input.clamp(min=0)
69 |
70 | @staticmethod
71 | def backward(ctx, grad_output):
72 | """DeConvNet ReLU backward function."""
73 | return grad_output.clamp(min=0)
74 |
75 |
76 | class DeConvNetContext(ReLUContext):
77 | """DeConvNet context.
78 |
79 | This context modifies the computation of gradient to match the DeConvNet
80 | definition.
81 |
82 | See :mod:`torchray.attribution.deconvnet` for how to use it.
83 | """
84 |
85 | def __init__(self):
86 | super(DeConvNetContext, self).__init__(DeConvNetReLU)
87 |
88 |
89 | def deconvnet(*args, context_builder=DeConvNetContext, **kwargs):
90 | """DeConvNet method.
91 |
92 | The function takes the same arguments as :func:`.common.saliency`, with
93 | the defaults required to apply the DeConvNet method, and supports the
94 | same arguments and return values.
95 | """
96 | return saliency(*args, context_builder=context_builder, **kwargs)
97 |
--------------------------------------------------------------------------------
/TorchRay/torchray/attribution/grad_cam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 |
3 | """
4 | This module provides an implementation of the *Grad-CAM* method of [GRADCAM]_
5 | for saliency visualization. The simplest interface is given by the
6 | :func:`grad_cam` function:
7 |
8 | .. literalinclude:: ../examples/grad_cam.py
9 | :language: python
10 | :linenos:
11 |
12 | Alternatively, it is possible to run the method "manually". Grad-CAM backprop
13 | is a variant of the gradient method, applied at an intermediate layer:
14 |
15 | .. literalinclude:: ../examples/grad_cam_manual.py
16 | :language: python
17 | :linenos:
18 |
19 | Note that the function :func:`gradient_to_grad_cam_saliency` is used to convert
20 | activations and gradients to a saliency map.
21 |
22 | See also :ref:`backprop` for further examples and discussion.
23 |
24 | Theory
25 | ~~~~~~
26 |
27 | Grad-CAM can be seen as a variant of the *gradient* method
28 | (:mod:`torchray.attribution.gradient`) with two differences:
29 |
30 | 1. The saliency is measured at an intermediate layer of the network, usually at
31 | the output of the last convolutional layer.
32 |
33 | 2. Saliency is defined as the clamped product of forward activation and
34 | backward gradient at that layer.
35 |
36 | References:
37 |
38 | .. [GRADCAM] Ramprasaath R. Selvaraju, Abhishek Das, Ramakrishna Vedantam,
39 | Michael Cogswell, Devi Parikh and Dhruv Batra,
40 | *Visual Explanations from Deep Networks via Gradient-based
41 | Localization,*
42 | ICCV 2017,
43 | ``__.
44 | """
45 |
46 | __all__ = ["grad_cam"]
47 |
48 | import torch
49 | from .common import saliency
50 |
51 |
52 | def gradient_to_grad_cam_saliency(x):
53 | r"""Convert activation and gradient to a Grad-CAM saliency map.
54 |
55 | The tensor :attr:`x` must have a valid gradient ``x.grad``.
56 | The function then computes the saliency map :math:`s`: given by:
57 |
58 | .. math::
59 |
60 | s_{n1u} = \max\{0, \sum_{c}x_{ncu}\cdot dx_{ncu}\}
61 |
62 | Args:
63 | x (:class:`torch.Tensor`): activation tensor with a valid gradient.
64 |
65 | Returns:
66 | :class:`torch.Tensor`: saliency map.
67 | """
68 | # Apply global average pooling (GAP) to gradient.
69 | grad_weight = torch.mean(x.grad, (2, 3), keepdim=True)
70 |
71 | # Linearly combine activations and GAP gradient weights.
72 | saliency_map = torch.sum(x * grad_weight, 1, keepdim=True)
73 |
74 | # Apply ReLU to visualization.
75 | saliency_map = torch.clamp(saliency_map, min=0)
76 |
77 | return saliency_map
78 |
79 |
80 | def grad_cam(*args,
81 | saliency_layer,
82 | gradient_to_saliency=gradient_to_grad_cam_saliency,
83 | **kwargs):
84 | r"""Grad-CAM method.
85 |
86 | The function takes the same arguments as :func:`.common.saliency`, with
87 | the defaults required to apply the Grad-CAM method, and supports the
88 | same arguments and return values.
89 | """
90 | return saliency(*args,
91 | saliency_layer=saliency_layer,
92 | gradient_to_saliency=gradient_to_saliency,
93 | **kwargs,)
94 |
--------------------------------------------------------------------------------
/TorchRay/torchray/attribution/gradient.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 |
3 | r"""
4 | This module implements the *gradient* method of [GRAD]_ for visualizing a deep
5 | network. It is a backpropagation method, and in fact the simplest of them all
6 | as it coincides with standard backpropagation. The simplest way to use this
7 | method is via the :func:`gradient` function:
8 |
9 | .. literalinclude:: ../examples/gradient.py
10 | :language: python
11 | :linenos:
12 |
13 | Alternatively, one can do so manually, as follows
14 |
15 | .. literalinclude:: ../examples/gradient_manual.py
16 | :language: python
17 | :linenos:
18 |
19 | Note that in this example, for visualization, the gradient is
20 | convernted into an image by postprocessing by using the function
21 | :func:`torchray.attribution.common.saliency`.
22 |
23 | See also :ref:`backprop` for further examples.
24 |
25 | References:
26 |
27 | .. [GRAD] Karen Simonyan, Andrea Vedaldi and Andrew Zisserman,
28 | *Deep Inside Convolutional Networks:
29 | Visualising Image Classification Models and Saliency Maps,*
30 | ICLR workshop, 2014,
31 | ``__.
32 | """
33 |
34 | __all__ = ["gradient"]
35 |
36 | from .common import saliency
37 |
38 |
39 | def gradient(*args, context_builder=None, **kwargs):
40 | r"""Gradient method
41 |
42 | The function takes the same arguments as :func:`.common.saliency`, with
43 | the defaults required to apply the gradient method, and supports the
44 | same arguments and return values.
45 | """
46 | assert context_builder is None
47 | return saliency(*args, **kwargs)
48 |
--------------------------------------------------------------------------------
/TorchRay/torchray/attribution/guided_backprop.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 |
3 | r"""
4 | This module implements *guided backpropagation* method of [GUIDED]_ or
5 | visualizing deep networks. The simplest interface is given by the
6 | :func:`guided_backprop` function:
7 |
8 | .. literalinclude:: ../examples/guided_backprop.py
9 | :language: python
10 | :linenos:
11 |
12 | Alternatively, it is possible to run the method "manually". Guided backprop is
13 | a backpropagation method, and thus works by changing the definition of the
14 | backward functions of some layers. This can be done using the
15 | :class:`GuidedBackpropContext` context:
16 |
17 | .. literalinclude:: ../examples/guided_backprop_manual.py
18 | :language: python
19 | :linenos:
20 |
21 | See also :ref:`backprop` for further examples.
22 |
23 | Theory
24 | ~~~~~~
25 |
26 | Guided backprop is a backpropagation method, and thus it works by changing the
27 | definition of the backward functions of some layers. The only change is a
28 | modified definition of the backward ReLU function:
29 |
30 | .. math::
31 | \operatorname{ReLU}^*(x,p) =
32 | \begin{cases}
33 | p, & \mathrm{if}~p > 0 ~\mathrm{and}~ x > 0,\\
34 | 0, & \mathrm{otherwise} \\
35 | \end{cases}
36 |
37 | The modified ReLU is implemented by class :class:`GuidedBackpropReLU`.
38 |
39 | References:
40 |
41 | .. [GUIDED] Springenberg et al.,
42 | *Striving for simplicity: The all convolutional net*,
43 | ICLR Workshop 2015,
44 | ``__.
45 | """
46 |
47 | __all__ = ['GuidedBackpropContext', 'guided_backprop']
48 |
49 | import torch
50 |
51 | from .common import ReLUContext, saliency
52 |
53 |
54 | class GuidedBackpropReLU(torch.autograd.Function):
55 | """This class implements a ReLU function with the guided backprop rules."""
56 | @staticmethod
57 | def forward(ctx, input):
58 | """Guided backprop ReLU forward function."""
59 | ctx.save_for_backward(input)
60 | return input.clamp(min=0)
61 |
62 | @staticmethod
63 | def backward(ctx, grad_output):
64 | """Guided backprop ReLU backward function."""
65 | input, = ctx.saved_tensors
66 | grad_input = grad_output.clone()
67 | grad_input[input < 0] = 0
68 | grad_input = grad_input.clamp(min=0)
69 | return grad_input
70 |
71 |
72 | class GuidedBackpropContext(ReLUContext):
73 | r"""GuidedBackprop context.
74 |
75 | This context modifies the computation of gradients
76 | to match the guided backpropagaton definition.
77 |
78 | See :mod:`torchray.attribution.guided_backprop` for how to use it.
79 | """
80 |
81 | def __init__(self):
82 | super(GuidedBackpropContext, self).__init__(GuidedBackpropReLU)
83 |
84 |
85 | def guided_backprop(*args, context_builder=GuidedBackpropContext, **kwargs):
86 | r"""Guided backprop.
87 |
88 | The function takes the same arguments as :func:`.common.saliency`, with
89 | the defaults required to apply the guided backprop method, and supports the
90 | same arguments and return values.
91 | """
92 | return saliency(*args,
93 | context_builder=context_builder,
94 | **kwargs)
95 |
--------------------------------------------------------------------------------
/TorchRay/torchray/attribution/linear_approx.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 |
3 | r"""
4 | This module provides an implementation of the *linear approximation* method
5 | for saliency visualization. The simplest interface is given by the
6 | :func:`linear_approx` function:
7 |
8 | .. literalinclude:: ../examples/linear_approx.py
9 | :language: python
10 | :linenos:
11 |
12 | Alternatively, it is possible to run the method "manually". Linear
13 | approximation is a variant of the gradient method, applied at an intermediate
14 | layer:
15 |
16 | .. literalinclude:: ../examples/linear_approx_manual.py
17 | :language: python
18 | :linenos:
19 |
20 | Note that the function :func:`gradient_to_linear_approx_saliency` is used to
21 | convert activations and gradients to a saliency map.
22 | """
23 |
24 | __all__ = ['gradient_to_linear_approx_saliency', 'linear_approx']
25 |
26 |
27 | import torch
28 | from .common import saliency
29 |
30 |
31 | def gradient_to_linear_approx_saliency(x):
32 | """Returns the linear approximation of a tensor.
33 |
34 | The tensor :attr:`x` must have a valid gradient ``x.grad``.
35 | The function then computes the saliency map :math:`s`: given by:
36 |
37 | .. math::
38 |
39 | s_{n1u} = \sum_{c} x_{ncu} \cdot dx_{ncu}
40 |
41 | Args:
42 | x (:class:`torch.Tensor`): activation tensor with a valid gradient.
43 |
44 | Returns:
45 | :class:`torch.Tensor`: Saliency map.
46 | """
47 | viz = torch.sum(x * x.grad, 1, keepdim=True)
48 | return viz
49 |
50 |
51 | def linear_approx(*args,
52 | gradient_to_saliency=gradient_to_linear_approx_saliency,
53 | **kwargs):
54 | """Linear approximation.
55 |
56 | The function takes the same arguments as :func:`.common.saliency`, with
57 | the defaults required to apply the linear approximation method, and
58 | supports the same arguments and return values.
59 | """
60 | return saliency(*args,
61 | gradient_to_saliency=gradient_to_saliency,
62 | **kwargs)
63 |
--------------------------------------------------------------------------------
/TorchRay/torchray/attribution/resnet_maxpool.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | try:
4 | from torch.hub import load_state_dict_from_url
5 | except ImportError:
6 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
7 |
8 | '''
9 | This version of resnet replaces AdaptiveAvgPool2d with AdaptiveMaxPool2d
10 | '''
11 |
12 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
13 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
14 | 'wide_resnet50_2', 'wide_resnet101_2']
15 |
16 |
17 | model_urls = {
18 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
19 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
20 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
21 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
22 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
23 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
24 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
25 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
26 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
27 | }
28 |
29 |
30 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
31 | """3x3 convolution with padding"""
32 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
33 | padding=dilation, groups=groups, bias=False, dilation=dilation)
34 |
35 |
36 | def conv1x1(in_planes, out_planes, stride=1):
37 | """1x1 convolution"""
38 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
39 |
40 |
41 | class BasicBlock(nn.Module):
42 | expansion = 1
43 | __constants__ = ['downsample']
44 |
45 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
46 | base_width=64, dilation=1, norm_layer=None):
47 | super(BasicBlock, self).__init__()
48 | if norm_layer is None:
49 | norm_layer = nn.BatchNorm2d
50 | if groups != 1 or base_width != 64:
51 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
52 | if dilation > 1:
53 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
54 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
55 | self.conv1 = conv3x3(inplanes, planes, stride)
56 | self.bn1 = norm_layer(planes)
57 | self.relu = nn.ReLU(inplace=True)
58 | self.conv2 = conv3x3(planes, planes)
59 | self.bn2 = norm_layer(planes)
60 | self.downsample = downsample
61 | self.stride = stride
62 |
63 | def forward(self, x):
64 | identity = x
65 |
66 | out = self.conv1(x)
67 | out = self.bn1(out)
68 | out = self.relu(out)
69 |
70 | out = self.conv2(out)
71 | out = self.bn2(out)
72 |
73 | if self.downsample is not None:
74 | identity = self.downsample(x)
75 |
76 | out += identity
77 | out = self.relu(out)
78 |
79 | return out
80 |
81 |
82 | class Bottleneck(nn.Module):
83 | expansion = 4
84 | __constants__ = ['downsample']
85 |
86 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
87 | base_width=64, dilation=1, norm_layer=None):
88 | super(Bottleneck, self).__init__()
89 | if norm_layer is None:
90 | norm_layer = nn.BatchNorm2d
91 | width = int(planes * (base_width / 64.)) * groups
92 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
93 | self.conv1 = conv1x1(inplanes, width)
94 | self.bn1 = norm_layer(width)
95 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
96 | self.bn2 = norm_layer(width)
97 | self.conv3 = conv1x1(width, planes * self.expansion)
98 | self.bn3 = norm_layer(planes * self.expansion)
99 | self.relu = nn.ReLU(inplace=True)
100 | self.downsample = downsample
101 | self.stride = stride
102 |
103 | def forward(self, x):
104 | identity = x
105 |
106 | out = self.conv1(x)
107 | out = self.bn1(out)
108 | out = self.relu(out)
109 |
110 | out = self.conv2(out)
111 | out = self.bn2(out)
112 | out = self.relu(out)
113 |
114 | out = self.conv3(out)
115 | out = self.bn3(out)
116 |
117 | if self.downsample is not None:
118 | identity = self.downsample(x)
119 |
120 | out += identity
121 | out = self.relu(out)
122 |
123 | return out
124 |
125 |
126 | class ResNet(nn.Module):
127 |
128 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
129 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
130 | norm_layer=None):
131 | super(ResNet, self).__init__()
132 | if norm_layer is None:
133 | norm_layer = nn.BatchNorm2d
134 | self._norm_layer = norm_layer
135 |
136 | self.inplanes = 64
137 | self.dilation = 1
138 | if replace_stride_with_dilation is None:
139 | # each element in the tuple indicates if we should replace
140 | # the 2x2 stride with a dilated convolution instead
141 | replace_stride_with_dilation = [False, False, False]
142 | if len(replace_stride_with_dilation) != 3:
143 | raise ValueError("replace_stride_with_dilation should be None "
144 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
145 | self.groups = groups
146 | self.base_width = width_per_group
147 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
148 | bias=False)
149 | self.bn1 = norm_layer(self.inplanes)
150 | self.relu = nn.ReLU(inplace=True)
151 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
152 | self.layer1 = self._make_layer(block, 64, layers[0])
153 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
154 | dilate=replace_stride_with_dilation[0])
155 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
156 | dilate=replace_stride_with_dilation[1])
157 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
158 | dilate=replace_stride_with_dilation[2])
159 | # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
160 | self.maxpool_last = nn.AdaptiveMaxPool2d((1, 1))
161 | self.fc = nn.Linear(512 * block.expansion, num_classes)
162 |
163 | for m in self.modules():
164 | if isinstance(m, nn.Conv2d):
165 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
166 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
167 | nn.init.constant_(m.weight, 1)
168 | nn.init.constant_(m.bias, 0)
169 |
170 | # Zero-initialize the last BN in each residual branch,
171 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
172 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
173 | if zero_init_residual:
174 | for m in self.modules():
175 | if isinstance(m, Bottleneck):
176 | nn.init.constant_(m.bn3.weight, 0)
177 | elif isinstance(m, BasicBlock):
178 | nn.init.constant_(m.bn2.weight, 0)
179 |
180 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
181 | norm_layer = self._norm_layer
182 | downsample = None
183 | previous_dilation = self.dilation
184 | if dilate:
185 | self.dilation *= stride
186 | stride = 1
187 | if stride != 1 or self.inplanes != planes * block.expansion:
188 | downsample = nn.Sequential(
189 | conv1x1(self.inplanes, planes * block.expansion, stride),
190 | norm_layer(planes * block.expansion),
191 | )
192 |
193 | layers = []
194 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
195 | self.base_width, previous_dilation, norm_layer))
196 | self.inplanes = planes * block.expansion
197 | for _ in range(1, blocks):
198 | layers.append(block(self.inplanes, planes, groups=self.groups,
199 | base_width=self.base_width, dilation=self.dilation,
200 | norm_layer=norm_layer))
201 |
202 | return nn.Sequential(*layers)
203 |
204 | def _forward_impl(self, x):
205 | # See note [TorchScript super()]
206 | x = self.conv1(x)
207 | x = self.bn1(x)
208 | x = self.relu(x)
209 | x = self.maxpool(x)
210 |
211 | x = self.layer1(x)
212 | x = self.layer2(x)
213 | x = self.layer3(x)
214 | x = self.layer4(x)
215 |
216 | # x = self.avgpool(x)
217 | x = self.maxpool_last(x)
218 | x = torch.flatten(x, 1)
219 | x = self.fc(x)
220 |
221 | return x
222 |
223 | def forward(self, x):
224 | return self._forward_impl(x)
225 |
226 |
227 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
228 | model = ResNet(block, layers, **kwargs)
229 | if pretrained:
230 | state_dict = load_state_dict_from_url(model_urls[arch],
231 | progress=progress)
232 | model.load_state_dict(state_dict)
233 | return model
234 |
235 |
236 | def resnet18(pretrained=False, progress=True, **kwargs):
237 | r"""ResNet-18 model from
238 | `"Deep Residual Learning for Image Recognition" `_
239 |
240 | Args:
241 | pretrained (bool): If True, returns a model pre-trained on ImageNet
242 | progress (bool): If True, displays a progress bar of the download to stderr
243 | """
244 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
245 | **kwargs)
246 |
247 |
248 | def resnet34(pretrained=False, progress=True, **kwargs):
249 | r"""ResNet-34 model from
250 | `"Deep Residual Learning for Image Recognition" `_
251 |
252 | Args:
253 | pretrained (bool): If True, returns a model pre-trained on ImageNet
254 | progress (bool): If True, displays a progress bar of the download to stderr
255 | """
256 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
257 | **kwargs)
258 |
259 |
260 | def resnet50(pretrained=False, progress=True, **kwargs):
261 | r"""ResNet-50 model from
262 | `"Deep Residual Learning for Image Recognition" `_
263 |
264 | Args:
265 | pretrained (bool): If True, returns a model pre-trained on ImageNet
266 | progress (bool): If True, displays a progress bar of the download to stderr
267 | """
268 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
269 | **kwargs)
270 |
271 |
272 | def resnet101(pretrained=False, progress=True, **kwargs):
273 | r"""ResNet-101 model from
274 | `"Deep Residual Learning for Image Recognition" `_
275 |
276 | Args:
277 | pretrained (bool): If True, returns a model pre-trained on ImageNet
278 | progress (bool): If True, displays a progress bar of the download to stderr
279 | """
280 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
281 | **kwargs)
282 |
283 |
284 | def resnet152(pretrained=False, progress=True, **kwargs):
285 | r"""ResNet-152 model from
286 | `"Deep Residual Learning for Image Recognition" `_
287 |
288 | Args:
289 | pretrained (bool): If True, returns a model pre-trained on ImageNet
290 | progress (bool): If True, displays a progress bar of the download to stderr
291 | """
292 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
293 | **kwargs)
294 |
295 |
296 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
297 | r"""ResNeXt-50 32x4d model from
298 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
299 |
300 | Args:
301 | pretrained (bool): If True, returns a model pre-trained on ImageNet
302 | progress (bool): If True, displays a progress bar of the download to stderr
303 | """
304 | kwargs['groups'] = 32
305 | kwargs['width_per_group'] = 4
306 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
307 | pretrained, progress, **kwargs)
308 |
309 |
310 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
311 | r"""ResNeXt-101 32x8d model from
312 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
313 |
314 | Args:
315 | pretrained (bool): If True, returns a model pre-trained on ImageNet
316 | progress (bool): If True, displays a progress bar of the download to stderr
317 | """
318 | kwargs['groups'] = 32
319 | kwargs['width_per_group'] = 8
320 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
321 | pretrained, progress, **kwargs)
322 |
323 |
324 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
325 | r"""Wide ResNet-50-2 model from
326 | `"Wide Residual Networks" `_
327 |
328 | The model is the same as ResNet except for the bottleneck number of channels
329 | which is twice larger in every block. The number of channels in outer 1x1
330 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
331 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
332 |
333 | Args:
334 | pretrained (bool): If True, returns a model pre-trained on ImageNet
335 | progress (bool): If True, displays a progress bar of the download to stderr
336 | """
337 | kwargs['width_per_group'] = 64 * 2
338 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
339 | pretrained, progress, **kwargs)
340 |
341 |
342 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
343 | r"""Wide ResNet-101-2 model from
344 | `"Wide Residual Networks" `_
345 |
346 | The model is the same as ResNet except for the bottleneck number of channels
347 | which is twice larger in every block. The number of channels in outer 1x1
348 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
349 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
350 |
351 | Args:
352 | pretrained (bool): If True, returns a model pre-trained on ImageNet
353 | progress (bool): If True, displays a progress bar of the download to stderr
354 | """
355 | kwargs['width_per_group'] = 64 * 2
356 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
357 | pretrained, progress, **kwargs)
--------------------------------------------------------------------------------
/TorchRay/torchray/attribution/rise.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 |
3 | r"""
4 | This module provides an implementation of the *RISE* method of [RISE]_ for
5 | saliency visualization. This is given by the :func:`rise` function, which
6 | can be used as follows:
7 |
8 | .. literalinclude:: ../examples/rise.py
9 | :language: python
10 | :linenos:
11 |
12 | References:
13 |
14 | .. [RISE] V. Petsiuk, A. Das and K. Saenko
15 | *RISE: Randomized Input Sampling for Explanation of Black-box
16 | Models,*
17 | BMVC 2018,
18 | ``__.
19 | """
20 |
21 | __all__ = ['rise', 'rise_class']
22 |
23 | import numpy as np
24 |
25 | import torch
26 | import torch.nn.functional as F
27 | from .common import resize_saliency
28 |
29 |
30 | def _upsample_reflect(x, size, interpolate_mode="bilinear"):
31 | r"""Upsample 4D :class:`torch.Tensor` with reflection padding.
32 |
33 | Args:
34 | x (:class:`torch.Tensor`): 4D tensor to interpolate.
35 | size (int or list or tuple of ints): target size
36 | interpolate_mode (str): mode to pass to
37 | :function:`torch.nn.functional.interpolate` function call
38 | (default: "bilinear").
39 |
40 | Returns:
41 | :class:`torch.Tensor`: upsampled tensor.
42 | """
43 | # Check and get input size.
44 | assert len(x.shape) == 4
45 | orig_size = x.shape[2:]
46 |
47 | # Check target size.
48 | if not isinstance(size, tuple) and not isinstance(size, list):
49 | assert isinstance(size, int)
50 | size = (size, size)
51 | assert len(size) == 2
52 |
53 | # Ensure upsampling.
54 | for i, o_s in enumerate(orig_size):
55 | assert o_s <= size[i]
56 |
57 | # Get size of input cell when interpolated.
58 | cell_size = [int(np.ceil(s / orig_size[i])) for i, s in enumerate(size)]
59 |
60 | # Get size of interpolated input with padding.
61 | pad_size = [int(cell_size[i] * (orig_size[i] + 2))
62 | for i in range(len(orig_size))]
63 |
64 | # Pad input with reflection padding.
65 | x_padded = F.pad(x, (1, 1, 1, 1), mode="reflect")
66 |
67 | # Interpolated padded input.
68 | x_up = F.interpolate(x_padded,
69 | pad_size,
70 | mode=interpolate_mode,
71 | align_corners=False)
72 |
73 | # Slice interpolated input to size.
74 | x_new = x_up[:,
75 | :,
76 | cell_size[0]:cell_size[0] + size[0],
77 | cell_size[1]:cell_size[1] + size[1]]
78 |
79 | return x_new
80 |
81 |
82 | def rise_class(*args, target, **kwargs):
83 | r"""Class-specific RISE.
84 |
85 | This function has the all the arguments of :func:`rise` with the following
86 | additional argument and returns a class-specific saliency map for the
87 | given :attr:`target` class(es).
88 |
89 | Args:
90 | target (int, :class:`torch.Tensor`, list, or :class:`np.ndarray`):
91 | target label(s) that can be cast to :class:`torch.long`.
92 | """
93 | saliency = rise(*args, **kwargs)
94 | assert len(saliency.shape) == 4
95 | if not isinstance(target, torch.Tensor):
96 | target = torch.tensor(target, dtype=torch.long, device=saliency.device)
97 | assert isinstance(target, torch.Tensor)
98 | assert target.dtype == torch.long
99 | assert len(target) == len(saliency)
100 |
101 | class_saliency = torch.cat([saliency[i, t].unsqueeze(0).unsqueeze(1)
102 | for i, t in enumerate(target)], dim=0)
103 | output_shape = list(saliency.shape)
104 | output_shape[1] = 1
105 | assert list(class_saliency.shape) == output_shape
106 |
107 | return class_saliency
108 |
109 |
110 | def rise(model,
111 | input,
112 | target=None,
113 | seed=0,
114 | num_masks=8000,
115 | num_cells=7,
116 | filter_masks=None,
117 | batch_size=32,
118 | p=0.5,
119 | resize=False,
120 | resize_mode='bilinear'):
121 | r"""RISE.
122 |
123 | Args:
124 | model (:class:`torch.nn.Module`): a model.
125 | input (:class:`torch.Tensor`): input tensor.
126 | seed (int, optional): manual seed used to generate random numbers.
127 | Default: ``0``.
128 | num_masks (int, optional): number of RISE random masks to use.
129 | Default: ``8000``.
130 | num_cells (int, optional): number of cells for one spatial dimension
131 | in low-res RISE random mask. Default: ``7``.
132 | filter_masks (:class:`torch.Tensor`, optional): If given, use the
133 | provided pre-computed filter masks. Default: ``None``.
134 | batch_size (int, optional): batch size to use. Default: ``128``.
135 | p (float, optional): with prob p, a low-res cell is set to 0;
136 | otherwise, it's 1. Default: ``0.5``.
137 | resize (bool or tuple of ints, optional): If True, resize saliency map
138 | to size of :attr:`input`. If False, don't resize. If (width,
139 | height) tuple, resize to (width, height). Default: ``False``.
140 | resize_mode (str, optional): If resize is not None, use this mode for
141 | the resize function. Default: ``'bilinear'``.
142 |
143 | Returns:
144 | :class:`torch.Tensor`: RISE saliency map.
145 | """
146 | with torch.no_grad():
147 | # Get device of input (i.e., GPU).
148 | dev = input.device
149 |
150 | # Initialize saliency mask and mask normalization term.
151 | input_shape = input.shape
152 | saliency_shape = list(input_shape)
153 |
154 | height = input_shape[2]
155 | width = input_shape[3]
156 |
157 | out = model(input)
158 | num_classes = out.shape[1]
159 |
160 | saliency_shape[1] = num_classes
161 | saliency = torch.zeros(saliency_shape, device=dev)
162 |
163 | # Number of spatial dimensions.
164 | nsd = len(input.shape) - 2
165 | assert nsd == 2
166 |
167 | # Spatial size of low-res grid cell.
168 | cell_size = tuple([int(np.ceil(s / num_cells))
169 | for s in input_shape[2:]])
170 |
171 | # Spatial size of upsampled mask with buffer (input size + cell size).
172 | up_size = tuple([input_shape[2 + i] + cell_size[i]
173 | for i in range(nsd)])
174 |
175 | # Save current random number generator state.
176 | state = torch.get_rng_state()
177 |
178 | # Set seed.
179 | torch.manual_seed(seed)
180 |
181 | if filter_masks is not None:
182 | assert len(filter_masks) == num_masks
183 |
184 | num_chunks = (num_masks + batch_size - 1) // batch_size
185 | for chunk in range(num_chunks):
186 | # Generate RISE random masks on the fly.
187 | mask_bs = min(num_masks - batch_size * chunk, batch_size)
188 |
189 | if filter_masks is None:
190 | # Generate low-res, random binary masks.
191 | grid = (torch.rand(mask_bs, 1, *((num_cells,) * nsd),
192 | device=dev) < p).float()
193 |
194 | # Upsample low-res masks to input shape + buffer.
195 | masks_up = _upsample_reflect(grid, up_size)
196 |
197 | # Save final RISE masks with random shift.
198 | masks = torch.empty(mask_bs, 1, *input_shape[2:], device=dev)
199 | shift_x = torch.randint(0,
200 | cell_size[0],
201 | (mask_bs,),
202 | device='cpu')
203 | shift_y = torch.randint(0,
204 | cell_size[1],
205 | (mask_bs,),
206 | device='cpu')
207 | for i in range(mask_bs):
208 | masks[i] = masks_up[i,
209 | :,
210 | shift_x[i]:shift_x[i] + height,
211 | shift_y[i]:shift_y[i] + width]
212 | else:
213 | masks = filter_masks[
214 | chunk * batch_size:chunk * batch_size + mask_bs]
215 |
216 | # Accumulate saliency mask.
217 | for i, inp in enumerate(input):
218 | out = torch.sigmoid(model(inp.unsqueeze(0) * masks))
219 | if len(out.shape) == 4:
220 | # TODO: Consider handling FC outputs more flexibly.
221 | assert out.shape[2] == 1
222 | assert out.shape[3] == 1
223 | out = out[:, :, 0, 0]
224 | sal = torch.matmul(out.data.transpose(0, 1),
225 | masks.view(mask_bs, height * width))
226 | sal = sal.view((num_classes, height, width))
227 | saliency[i] = saliency[i] + sal
228 |
229 | # Normalize saliency mask.
230 | saliency /= num_masks
231 |
232 | # Restore original random number generator state.
233 | torch.set_rng_state(state)
234 |
235 | # Resize saliency mask if needed.
236 | saliency = resize_saliency(input,
237 | saliency,
238 | resize,
239 | mode=resize_mode)
240 | return saliency
241 |
--------------------------------------------------------------------------------
/TorchRay/torchray/benchmark/__init__.py:
--------------------------------------------------------------------------------
1 | r"""This script provides a few functions for getting and plotting example data.
2 | """
3 | import os
4 | import torchvision
5 | from matplotlib import pyplot as plt
6 |
7 | from .datasets import * # noqa
8 | from .models import * # noqa
9 |
10 |
11 | def get_example_data(arch='vgg16', shape=224):
12 | """Get example data to demonstrate visualization techniques.
13 |
14 | Args:
15 | arch (str, optional): name of torchvision.models architecture.
16 | Default: ``'vgg16'``.
17 | shape (int or tuple of int, optional): shape to resize input image to.
18 | Default: ``224``.
19 |
20 | Returns:
21 | (:class:`torch.nn.Module`, :class:`torch.Tensor`, int, int): a tuple
22 | containing
23 |
24 | - a convolutional neural network model in evaluation mode.
25 | - a sample input tensor image.
26 | - the ImageNet category id of an object in the image.
27 | - the ImageNet category id of another object in the image.
28 |
29 | """
30 |
31 | # Get a network pre-trained on ImageNet.
32 | model = torchvision.models.__dict__[arch](pretrained=True)
33 |
34 | # Switch to eval mode to make the visualization deterministic.
35 | model.eval()
36 |
37 | # We do not need grads for the parameters.
38 | for param in model.parameters():
39 | param.requires_grad_(False)
40 |
41 | # Download an example image from wikimedia.
42 | import requests
43 | from io import BytesIO
44 | from PIL import Image
45 |
46 | url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/7/7f/Arthur_Heyer_-_Dog_and_Cats.jpg/592px-Arthur_Heyer_-_Dog_and_Cats.jpg'
47 | response = requests.get(url)
48 | img = Image.open(BytesIO(response.content))
49 |
50 | # Pre-process the image and convert into a tensor
51 | transform = torchvision.transforms.Compose([
52 | torchvision.transforms.Resize(shape),
53 | torchvision.transforms.CenterCrop(shape),
54 | torchvision.transforms.ToTensor(),
55 | torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
56 | std=[0.229, 0.224, 0.225]),
57 | ])
58 |
59 | x = transform(img).unsqueeze(0)
60 |
61 | # bulldog category id.
62 | category_id_1 = 245
63 |
64 | # persian cat category id.
65 | category_id_2 = 285
66 |
67 | # Move model and input to device.
68 | from torchray.utils import get_device
69 | dev = get_device()
70 | model = model.to(dev)
71 | x = x.to(dev)
72 |
73 | return model, x, category_id_1, category_id_2
74 |
75 |
76 | def plot_example(input,
77 | saliency,
78 | method,
79 | category_id,
80 | show_plot=False,
81 | save_path=None):
82 | """Plot an example.
83 |
84 | Args:
85 | input (:class:`torch.Tensor`): 4D tensor containing input images.
86 | saliency (:class:`torch.Tensor`): 4D tensor containing saliency maps.
87 | method (str): name of saliency method.
88 | category_id (int): ID of ImageNet category.
89 | show_plot (bool, optional): If True, show plot. Default: ``False``.
90 | save_path (str, optional): Path to save figure to. Default: ``None``.
91 | """
92 | from torchray.utils import imsc
93 | from torchray.benchmark.datasets import IMAGENET_CLASSES
94 |
95 | if isinstance(category_id, int):
96 | category_id = [category_id]
97 |
98 | batch_size = len(input)
99 |
100 | plt.clf()
101 | for i in range(batch_size):
102 | class_i = category_id[i % len(category_id)]
103 |
104 | plt.subplot(batch_size, 2, 1 + 2 * i)
105 | imsc(input[i])
106 | plt.title('input image', fontsize=8)
107 |
108 | plt.subplot(batch_size, 2, 2 + 2 * i)
109 | imsc(saliency[i], interpolation='none')
110 | plt.title('{} for category {} ({})'.format(
111 | method, IMAGENET_CLASSES[class_i], class_i), fontsize=8)
112 |
113 | # Save figure if path is specified.
114 | if save_path:
115 | save_dir = os.path.dirname(os.path.abspath(save_path))
116 | # Create directory if necessary.
117 | if not os.path.exists(save_dir):
118 | os.makedirs(save_dir)
119 | ext = os.path.splitext(save_path)[1].strip('.')
120 | plt.savefig(save_path, format=ext, bbox_inches='tight')
121 |
122 | # Show plot if desired.
123 | if show_plot:
124 | plt.show()
125 |
--------------------------------------------------------------------------------
/TorchRay/torchray/benchmark/evaluate_finegrained_gradcam_energy_inside_bbox.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.parallel
7 | import torch.backends.cudnn as cudnn
8 | import torch.optim
9 | import torch.utils.data.distributed
10 | import torchvision.transforms as transforms
11 | import resnet_multigpu_cgc as resnet
12 | import cv2
13 | import datasets as pointing_datasets
14 |
15 | """
16 | Here, we evaluate the content heatmap (Grad-CAM heatmap within object bounding box) on the fine-grained datasets.
17 | """
18 |
19 | model_names = ['resnet18', 'resnet50']
20 |
21 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
22 | parser.add_argument('data', metavar='DIR', help='path to dataset')
23 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
24 | choices=model_names,
25 | help='model architecture: ' +
26 | ' | '.join(model_names) +
27 | ' (default: resnet18)')
28 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
29 | help='number of data loading workers (default: 16)')
30 | parser.add_argument('-b', '--batch-size', default=256, type=int,
31 | metavar='N', help='mini-batch size (default: 96)')
32 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
33 | help='use pre-trained model')
34 | parser.add_argument('-g', '--num-gpus', default=1, type=int,
35 | metavar='N', help='number of GPUs to match (default: 4)')
36 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
37 | help='path to latest checkpoint (default: none)')
38 | parser.add_argument('--input_resize', default=224, type=int,
39 | metavar='N', help='Resize for smallest side of input (default: 224)')
40 | parser.add_argument('--maxpool', dest='maxpool', action='store_true',
41 | help='use maxpool version of the model')
42 | parser.add_argument('--dataset', type=str, default='imagenet',
43 | help='dataset to use: [imagenet, cub, aircraft, flowers, cars]')
44 |
45 |
46 | def main():
47 | global args
48 | args = parser.parse_args()
49 |
50 | if args.dataset == 'cub':
51 | num_classes = 200
52 | elif args.dataset == 'aircraft':
53 | num_classes = 90
54 | elif args.dataset == 'flowers':
55 | num_classes = 102
56 | elif args.dataset == 'cars':
57 | num_classes = 196
58 |
59 | print("=> creating model '{}' for '{}'".format(args.arch, args.dataset))
60 | if args.arch.startswith('resnet'):
61 | model = resnet.__dict__[args.arch](num_classes=num_classes)
62 | else:
63 | print('Other archs not supported')
64 | exit()
65 | model = torch.nn.DataParallel(model).cuda()
66 |
67 | if args.resume:
68 | print("=> loading checkpoint '{}'".format(args.resume))
69 | checkpoint = torch.load(args.resume)
70 | if 'state_dict' in checkpoint:
71 | model.load_state_dict(checkpoint['state_dict'])
72 | elif 'model' in checkpoint:
73 | model.load_state_dict(checkpoint['model'])
74 | else:
75 | print('Checkpoint format not supported')
76 | exit()
77 |
78 | if (not args.resume) and (not args.pretrained):
79 | assert False, "Please specify either the pre-trained model or checkpoint for evaluation"
80 |
81 | cudnn.benchmark = True
82 |
83 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
84 | std=[0.229, 0.224, 0.225])
85 |
86 | # In the version, we will not resize the images. We feed the full image and use AdaptivePooling before FC.
87 | # We will resize Gradcam heatmap to image size and compare the actual bbox co-ordinates
88 | val_dataset = pointing_datasets.ImageNetDetection(args.data,
89 | transform=transforms.Compose([
90 | transforms.Resize(args.input_resize),
91 | transforms.ToTensor(),
92 | normalize,
93 | ]))
94 |
95 | # we set batch size=1 since we are loading full resolution images.
96 | val_loader = torch.utils.data.DataLoader(
97 | val_dataset, batch_size=1, shuffle=False,
98 | num_workers=args.workers, pin_memory=True)
99 |
100 | validate_multi(val_loader, val_dataset, model)
101 |
102 |
103 | def validate_multi(val_loader, val_dataset, model):
104 | batch_time = AverageMeter()
105 | heatmap_inside_bbox = AverageMeter()
106 |
107 | # switch to evaluate mode
108 | model.eval()
109 |
110 | zero_count = 0
111 | total_count = 0
112 | end = time.time()
113 | for i, (images, annotation, targets) in enumerate(val_loader):
114 | total_count += 1
115 | images = images.cuda(non_blocking=True)
116 | targets = targets.cuda(non_blocking=True)
117 |
118 | # we assume batch size == 1 and unwrap the first elem of every list in annotation object
119 | annotation = unwrap_dict(annotation)
120 | image_size = val_dataset.as_image_size(annotation)
121 |
122 | output, feats = model(images, return_feats=True)
123 | output_gradcam = compute_gradcam(output, feats, targets)
124 | output_gradcam_np = output_gradcam.data.cpu().numpy()[0] # since we have batch size==1
125 | resized_output_gradcam = cv2.resize(output_gradcam_np, image_size)
126 | spatial_sum = resized_output_gradcam.sum()
127 | if spatial_sum <= 0:
128 | zero_count += 1
129 | continue
130 |
131 | # resized_output_gradcam is now normalized and can be considered as probabilities
132 | resized_output_gradcam = resized_output_gradcam / spatial_sum
133 |
134 | mask = pointing_datasets.imagenet_as_mask(annotation, targets[0].item())
135 |
136 | mask = mask.type(torch.ByteTensor)
137 | mask = mask.cpu().data.numpy()
138 |
139 | gcam_inside_gt_mask = mask * resized_output_gradcam
140 | # Now we sum the heatmap inside the object bounding box
141 | total_gcam_inside_gt_mask = gcam_inside_gt_mask.sum()
142 | heatmap_inside_bbox.update(total_gcam_inside_gt_mask)
143 |
144 | if i % 1000 == 0:
145 | print('\nResults after {} examples: '.format(i+1))
146 | print('Curr % of heatmap inside bbox: {:.4f} ({:.4f})'.format(heatmap_inside_bbox.val * 100,
147 | heatmap_inside_bbox.avg * 100))
148 |
149 | # measure elapsed time
150 | batch_time.update(time.time() - end)
151 | end = time.time()
152 |
153 | print('\nFinal Results - ')
154 | print('\n\n% of heatmap inside bbox: {:.4f}'.format(heatmap_inside_bbox.avg * 100))
155 | print('Zero GC found for {}/{} samples'.format(zero_count, total_count))
156 |
157 | return
158 |
159 |
160 | def compute_gradcam(output, feats, target):
161 | """
162 | Compute the gradcam for the top predicted category
163 | :param output:
164 | :param feats:
165 | :return:
166 | """
167 | eps = 1e-8
168 | relu = nn.ReLU(inplace=True)
169 |
170 | target = target.cpu().numpy()
171 | # target = np.argmax(output.cpu().data.numpy(), axis=-1)
172 | one_hot = np.zeros((output.shape[0], output.shape[-1]), dtype=np.float32)
173 | indices_range = np.arange(output.shape[0])
174 | one_hot[indices_range, target[indices_range]] = 1
175 | one_hot = torch.from_numpy(one_hot)
176 | one_hot.requires_grad = True
177 |
178 | # Compute the Grad-CAM for the original image
179 | one_hot_cuda = torch.sum(one_hot.cuda() * output)
180 | dy_dz1, = torch.autograd.grad(one_hot_cuda, feats, grad_outputs=torch.ones(one_hot_cuda.size()).cuda(),
181 | retain_graph=True, create_graph=True)
182 | dy_dz_sum1 = dy_dz1.sum(dim=2).sum(dim=2)
183 | gcam512_1 = dy_dz_sum1.unsqueeze(-1).unsqueeze(-1) * feats
184 | gradcam = gcam512_1.sum(dim=1)
185 | gradcam = relu(gradcam)
186 | spatial_sum1 = gradcam.sum(dim=[1, 2]).unsqueeze(-1).unsqueeze(-1)
187 | gradcam = (gradcam / (spatial_sum1 + eps)) + eps
188 |
189 | return gradcam
190 |
191 |
192 | def unwrap_dict(dict_object):
193 | new_dict = {}
194 | for k, v in dict_object.items():
195 | if k == 'object':
196 | new_v_list = []
197 | for elem in v:
198 | new_v_list.append(unwrap_dict(elem))
199 | new_dict[k] = new_v_list
200 | continue
201 | if isinstance(v, dict):
202 | new_v = unwrap_dict(v)
203 | elif isinstance(v, list) and len(v) == 1:
204 | new_v = v[0]
205 | else:
206 | new_v = v
207 | new_dict[k] = new_v
208 | return new_dict
209 |
210 |
211 | class AverageMeter(object):
212 | """Computes and stores the average and current value"""
213 | def __init__(self):
214 | self.reset()
215 |
216 | def reset(self):
217 | self.val = 0
218 | self.avg = 0
219 | self.sum = 0
220 | self.count = 0
221 |
222 | def update(self, val, n=1):
223 | self.val = val
224 | self.sum += val * n
225 | self.count += n
226 | self.avg = self.sum / self.count
227 |
228 |
229 | if __name__ == '__main__':
230 | main()
231 |
--------------------------------------------------------------------------------
/TorchRay/torchray/benchmark/evaluate_imagenet_excitation_backprop_energy_inside_bbox.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.parallel
7 | import torch.backends.cudnn as cudnn
8 | import torch.optim
9 | import torch.utils.data.distributed
10 | import torchvision.transforms as transforms
11 | import torchvision.models.resnet as resnet
12 | import datasets as pointing_datasets
13 | from torchray.attribution.excitation_backprop import contrastive_excitation_backprop
14 | from torchray.attribution.excitation_backprop import update_resnet
15 | from models import resnet_to_fc
16 | from torchray.attribution.common import get_pointing_gradient
17 |
18 |
19 | """
20 | Here, we evaluate the content heatmap (Excitation Backprop heatmap within object bounding box) on imagenet dataset.
21 | """
22 |
23 | model_names = ['resnet18', 'resnet50']
24 |
25 | parser = argparse.ArgumentParser(description='Pointing game evaluation for ImageNet using Contrastive Excitation Backprop')
26 | parser.add_argument('data', metavar='DIR', help='path to dataset')
27 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
28 | choices=model_names,
29 | help='model architecture: ' +
30 | ' | '.join(model_names) +
31 | ' (default: resnet18)')
32 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
33 | help='number of data loading workers (default: 16)')
34 | parser.add_argument('-b', '--batch-size', default=256, type=int,
35 | metavar='N', help='mini-batch size (default: 96)')
36 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
37 | help='use pre-trained model')
38 | parser.add_argument('-g', '--num-gpus', default=1, type=int,
39 | metavar='N', help='number of GPUs to match (default: 4)')
40 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
41 | help='path to latest checkpoint (default: none)')
42 | parser.add_argument('--input_resize', default=224, type=int,
43 | metavar='N', help='Resize for smallest side of input (default: 224)')
44 |
45 |
46 | def main():
47 | global args
48 | args = parser.parse_args()
49 |
50 | if args.pretrained:
51 | print("=> using pre-trained model '{}'".format(args.arch))
52 | model = resnet.__dict__[args.arch](pretrained=True)
53 | else:
54 | print("=> creating model '{}'".format(args.arch))
55 | model = resnet.__dict__[args.arch]()
56 | model = torch.nn.DataParallel(model)
57 |
58 | if args.resume:
59 | print("=> loading checkpoint '{}'".format(args.resume))
60 | checkpoint = torch.load(args.resume)
61 | model.load_state_dict(checkpoint['state_dict'])
62 |
63 | if (not args.resume) and (not args.pretrained):
64 | assert False, "Please specify either the pre-trained model or checkpoint for evaluation"
65 |
66 | model = model._modules['module']
67 |
68 | model = resnet_to_fc(model)
69 | model.avgpool = torch.nn.AvgPool2d((7, 7), stride=1)
70 | model = update_resnet(model, debug=True)
71 | model = model.cuda()
72 | cudnn.benchmark = False
73 |
74 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
75 | std=[0.229, 0.224, 0.225])
76 |
77 | # In the first version, we will not resize the images. We feed the full image and use AdaptivePooling before FC.
78 | # We will resize Gradcam heatmap to image size and compare the actual bbox co-ordinates
79 | val_dataset = pointing_datasets.ImageNetDetection(args.data,
80 | transform=transforms.Compose([
81 | # transforms.Resize((224, 224)),
82 | transforms.Resize(args.input_resize),
83 | transforms.ToTensor(),
84 | normalize,
85 | ]))
86 |
87 | # we set batch size=1 since we are loading full resolution images.
88 | val_loader = torch.utils.data.DataLoader(
89 | val_dataset, batch_size=1, shuffle=False,
90 | num_workers=args.workers, pin_memory=True)
91 |
92 | validate_multi(val_loader, val_dataset, model)
93 |
94 |
95 | def validate_multi(val_loader, val_dataset, model):
96 | batch_time = AverageMeter()
97 | heatmap_inside_gt_mask = AverageMeter()
98 |
99 | # switch to evaluate mode
100 | model.eval()
101 | # We do not need grads for the parameters.
102 | for param in model.parameters():
103 | param.requires_grad_(False)
104 |
105 | # prepare vis layer, contrast layer and probes
106 | contrast_layer = 'avgpool'
107 |
108 | zero_saliency_count = 0
109 | total_count = 0
110 | end = time.time()
111 | for i, (images, annotation, targets) in enumerate(val_loader):
112 | total_count += 1
113 | images = images.cuda(non_blocking=True)
114 | targets = targets.cuda(non_blocking=True)
115 | # we assume batch size == 1 and unwrap the first elem of every list in annotation object
116 | annotation = unwrap_dict(annotation)
117 | image_size = val_dataset.as_image_size(annotation)
118 |
119 | class_id = targets[0].item()
120 |
121 | saliency = contrastive_excitation_backprop(model, images, class_id,
122 | saliency_layer='layer3',
123 | contrast_layer=contrast_layer,
124 | resize=image_size,
125 | get_backward_gradient=get_pointing_gradient
126 | )
127 | saliency = saliency.squeeze() # since we have batch size==1
128 |
129 | resized_saliency = saliency.data.cpu().numpy()
130 |
131 | if np.isnan(resized_saliency).any():
132 | zero_saliency_count += 1
133 | continue
134 | spatial_sum = resized_saliency.sum()
135 | if spatial_sum <= 0:
136 | zero_saliency_count += 1
137 | continue
138 | resized_saliency = resized_saliency / spatial_sum
139 |
140 | # Now, we obtain the mask corresponding to the ground truth bounding boxes
141 | # Skip if all boxes for class_id are marked difficult.
142 | objs = annotation['annotation']['object']
143 | if not isinstance(objs, list):
144 | objs = [objs]
145 | objs = [obj for obj in objs if pointing_datasets._IMAGENET_CLASS_TO_INDEX[obj['name']] == class_id]
146 | if all([bool(int(obj['difficult'])) for obj in objs]):
147 | continue
148 | gt_mask = pointing_datasets.imagenet_as_mask(annotation, class_id)
149 | gt_mask = gt_mask.type(torch.ByteTensor)
150 | gt_mask = gt_mask.cpu().data.numpy()
151 | gcam_inside_gt_mask = gt_mask * resized_saliency
152 | total_gcam_inside_gt_mask = gcam_inside_gt_mask.sum()
153 | heatmap_inside_gt_mask.update(total_gcam_inside_gt_mask)
154 |
155 | if i % 1000 == 0:
156 | print('\nCurr % of heatmap inside GT mask: {:.4f} ({:.4f})'.format(heatmap_inside_gt_mask.val * 100,
157 | heatmap_inside_gt_mask.avg * 100))
158 |
159 | # measure elapsed time
160 | batch_time.update(time.time() - end)
161 | end = time.time()
162 |
163 | print('\n\n% of heatmap inside GT mask: {:.4f}'.format(heatmap_inside_gt_mask.avg * 100))
164 | print('\n Zero Saliency found for {} / {} images.'.format(zero_saliency_count, total_count))
165 |
166 | return
167 |
168 |
169 | def compute_gradcam(output, feats, target):
170 | """
171 | Compute the gradcam for the top predicted category
172 | :param output:
173 | :param feats:
174 | :return:
175 | """
176 | eps = 1e-8
177 | relu = nn.ReLU(inplace=True)
178 |
179 | target = target.cpu().numpy()
180 | # target = np.argmax(output.cpu().data.numpy(), axis=-1)
181 | one_hot = np.zeros((output.shape[0], output.shape[-1]), dtype=np.float32)
182 | indices_range = np.arange(output.shape[0])
183 | one_hot[indices_range, target[indices_range]] = 1
184 | one_hot = torch.from_numpy(one_hot)
185 | one_hot.requires_grad = True
186 |
187 | # Compute the Grad-CAM for the original image
188 | one_hot_cuda = torch.sum(one_hot.cuda() * output)
189 | dy_dz1, = torch.autograd.grad(one_hot_cuda, feats, grad_outputs=torch.ones(one_hot_cuda.size()).cuda(),
190 | retain_graph=True, create_graph=True)
191 | gcam512_1 = dy_dz1 * feats
192 | gradcam = gcam512_1.sum(dim=1)
193 | gradcam = relu(gradcam)
194 | spatial_sum1 = gradcam.sum(dim=[1, 2]).unsqueeze(-1).unsqueeze(-1)
195 | gradcam = (gradcam / (spatial_sum1 + eps)) + eps
196 |
197 | return gradcam
198 |
199 |
200 | def unwrap_dict(dict_object):
201 | new_dict = {}
202 | for k, v in dict_object.items():
203 | if k == 'object':
204 | new_v_list = []
205 | for elem in v:
206 | new_v_list.append(unwrap_dict(elem))
207 | new_dict[k] = new_v_list
208 | continue
209 | if isinstance(v, dict):
210 | new_v = unwrap_dict(v)
211 | elif isinstance(v, list) and len(v) == 1:
212 | new_v = v[0]
213 | # if isinstance(new_v, dict):
214 | # new_v = unwrap_dict(new_v)
215 | else:
216 | new_v = v
217 | new_dict[k] = new_v
218 | return new_dict
219 |
220 |
221 | class AverageMeter(object):
222 | """Computes and stores the average and current value"""
223 | def __init__(self):
224 | self.reset()
225 |
226 | def reset(self):
227 | self.val = 0
228 | self.avg = 0
229 | self.sum = 0
230 | self.count = 0
231 |
232 | def update(self, val, n=1):
233 | self.val = val
234 | self.sum += val * n
235 | self.count += n
236 | self.avg = self.sum / self.count
237 |
238 |
239 | if __name__ == '__main__':
240 | main()
241 |
--------------------------------------------------------------------------------
/TorchRay/torchray/benchmark/evaluate_imagenet_gradcam_energy_inside_bbox.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.parallel
7 | import torch.backends.cudnn as cudnn
8 | import torch.optim
9 | import torch.utils.data.distributed
10 | import torchvision.transforms as transforms
11 | import resnet_multigpu_cgc as resnet
12 | import cv2
13 | import datasets as pointing_datasets
14 |
15 | """
16 | Here, we evaluate the content heatmap (Grad-CAM heatmap within object bounding box) on the imagenet dataset.
17 | """
18 |
19 | model_names = ['resnet18', 'resnet50']
20 |
21 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
22 | parser.add_argument('data', metavar='DIR', help='path to dataset')
23 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
24 | choices=model_names,
25 | help='model architecture: ' +
26 | ' | '.join(model_names) +
27 | ' (default: resnet18)')
28 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
29 | help='number of data loading workers (default: 16)')
30 | parser.add_argument('-b', '--batch-size', default=256, type=int,
31 | metavar='N', help='mini-batch size (default: 96)')
32 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
33 | help='use pre-trained model')
34 | parser.add_argument('-g', '--num-gpus', default=1, type=int,
35 | metavar='N', help='number of GPUs to match (default: 4)')
36 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
37 | help='path to latest checkpoint (default: none)')
38 | parser.add_argument('--input_resize', default=224, type=int,
39 | metavar='N', help='Resize for smallest side of input (default: 224)')
40 |
41 |
42 | def main():
43 | global args
44 | args = parser.parse_args()
45 |
46 | if args.pretrained:
47 | print("=> using pre-trained model '{}'".format(args.arch))
48 | if args.arch.startswith('resnet'):
49 | model = resnet.__dict__[args.arch](pretrained=True)
50 | else:
51 | assert False, 'Unsupported architecture: {}'.format(args.arch)
52 | else:
53 | print("=> creating model '{}'".format(args.arch))
54 | if args.arch.startswith('resnet'):
55 | model = resnet.__dict__[args.arch]()
56 |
57 | model = torch.nn.DataParallel(model).cuda()
58 |
59 | if args.resume:
60 | print("=> loading checkpoint '{}'".format(args.resume))
61 | checkpoint = torch.load(args.resume)
62 | model.load_state_dict(checkpoint['state_dict'])
63 |
64 | if (not args.resume) and (not args.pretrained):
65 | assert False, "Please specify either the pre-trained model or checkpoint for evaluation"
66 |
67 | cudnn.benchmark = True
68 |
69 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
70 |
71 | # Here, we don't resize the images. We feed the full image and use AdaptivePooling before FC.
72 | # We will resize Gradcam heatmap to image size and compare the actual bbox co-ordinates
73 | val_dataset = pointing_datasets.ImageNetDetection(args.data,
74 | transform=transforms.Compose([
75 | transforms.Resize(args.input_resize),
76 | transforms.ToTensor(),
77 | normalize,
78 | ]))
79 |
80 | # we set batch size=1 since we are loading full resolution images.
81 | val_loader = torch.utils.data.DataLoader(
82 | val_dataset, batch_size=1, shuffle=False,
83 | num_workers=args.workers, pin_memory=True)
84 |
85 | validate_multi(val_loader, val_dataset, model)
86 |
87 |
88 | def validate_multi(val_loader, val_dataset, model):
89 | batch_time = AverageMeter()
90 | heatmap_inside_bbox = AverageMeter()
91 |
92 | # switch to evaluate mode
93 | model.eval()
94 |
95 | end = time.time()
96 | for i, (images, annotation, targets) in enumerate(val_loader):
97 | images = images.cuda(non_blocking=True)
98 | targets = targets.cuda(non_blocking=True)
99 |
100 | # we assume batch size == 1 and unwrap the first elem of every list in annotation object
101 | annotation = unwrap_dict(annotation)
102 | image_size = val_dataset.as_image_size(annotation)
103 |
104 | output, feats = model(images, vanilla_with_feats=True)
105 | output_gradcam = compute_gradcam(output, feats, targets)
106 | output_gradcam_np = output_gradcam.data.cpu().numpy()[0] # since we have batch size==1
107 | resized_output_gradcam = cv2.resize(output_gradcam_np, image_size)
108 | spatial_sum = resized_output_gradcam.sum()
109 | if spatial_sum <= 0:
110 | # We ignore images with zero Grad-CAM
111 | continue
112 |
113 | # resized_output_gradcam is now normalized and can be considered as probabilities
114 | resized_output_gradcam = resized_output_gradcam / spatial_sum
115 |
116 | mask = pointing_datasets.imagenet_as_mask(annotation, targets[0].item())
117 |
118 | mask = mask.type(torch.ByteTensor)
119 | mask = mask.cpu().data.numpy()
120 |
121 | gcam_inside_gt_mask = mask * resized_output_gradcam
122 |
123 | # Now we sum the heatmap inside the object bounding box
124 | total_gcam_inside_gt_mask = gcam_inside_gt_mask.sum()
125 | heatmap_inside_bbox.update(total_gcam_inside_gt_mask)
126 |
127 | if i % 1000 == 0:
128 | print('\nResults after {} examples: '.format(i+1))
129 | print('Curr % of heatmap inside bbox: {:.4f} ({:.4f})'.format(heatmap_inside_bbox.val * 100,
130 | heatmap_inside_bbox.avg * 100))
131 |
132 | # measure elapsed time
133 | batch_time.update(time.time() - end)
134 | end = time.time()
135 |
136 | print('\nFinal Results - ')
137 | print('\n\n% of heatmap inside bbox: {:.4f}'.format(heatmap_inside_bbox.avg * 100))
138 |
139 | return
140 |
141 |
142 | def compute_gradcam(output, feats, target):
143 | """
144 | Compute the gradcam for the top predicted category
145 | :param output:
146 | :param feats:
147 | :param target:
148 | :return:
149 | """
150 | eps = 1e-8
151 | relu = nn.ReLU(inplace=True)
152 |
153 | target = target.cpu().numpy()
154 | one_hot = np.zeros((output.shape[0], output.shape[-1]), dtype=np.float32)
155 | indices_range = np.arange(output.shape[0])
156 | one_hot[indices_range, target[indices_range]] = 1
157 | one_hot = torch.from_numpy(one_hot)
158 | one_hot.requires_grad = True
159 |
160 | # Compute the Grad-CAM for the original image
161 | one_hot_cuda = torch.sum(one_hot.cuda() * output)
162 | dy_dz1, = torch.autograd.grad(one_hot_cuda, feats, grad_outputs=torch.ones(one_hot_cuda.size()).cuda(),
163 | retain_graph=True, create_graph=True)
164 | # Changing to dot product of grad and features to preserve grad spatial locations
165 | gcam512_1 = dy_dz1 * feats
166 | gradcam = gcam512_1.sum(dim=1)
167 | gradcam = relu(gradcam)
168 | spatial_sum1 = gradcam.sum(dim=[1, 2]).unsqueeze(-1).unsqueeze(-1)
169 | gradcam = (gradcam / (spatial_sum1 + eps)) + eps
170 |
171 | return gradcam
172 |
173 |
174 | def unwrap_dict(dict_object):
175 | new_dict = {}
176 | for k, v in dict_object.items():
177 | if k == 'object':
178 | new_v_list = []
179 | for elem in v:
180 | new_v_list.append(unwrap_dict(elem))
181 | new_dict[k] = new_v_list
182 | continue
183 | if isinstance(v, dict):
184 | new_v = unwrap_dict(v)
185 | elif isinstance(v, list) and len(v) == 1:
186 | new_v = v[0]
187 | else:
188 | new_v = v
189 | new_dict[k] = new_v
190 | return new_dict
191 |
192 |
193 | class AverageMeter(object):
194 | """Computes and stores the average and current value"""
195 | def __init__(self):
196 | self.reset()
197 |
198 | def reset(self):
199 | self.val = 0
200 | self.avg = 0
201 | self.sum = 0
202 | self.count = 0
203 |
204 | def update(self, val, n=1):
205 | self.val = val
206 | self.sum += val * n
207 | self.count += n
208 | self.avg = self.sum / self.count
209 |
210 |
211 | if __name__ == '__main__':
212 | main()
213 |
--------------------------------------------------------------------------------
/TorchRay/torchray/benchmark/evaluate_swav_imagenet_gradcam_energy_inside_bbox.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.parallel
7 | import torch.backends.cudnn as cudnn
8 | import torch.optim
9 | import torch.utils.data.distributed
10 | import torchvision.transforms as transforms
11 | import swav_resnet_cgc as resnet
12 | import os
13 | import cv2
14 | import datasets as pointing_datasets
15 |
16 |
17 | """
18 | Here, we evaluate the content heatmap (Grad-CAM heatmap within object bounding box) on the imagenet dataset.
19 | """
20 |
21 | model_names = ['resnet18', 'resnet50']
22 |
23 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
24 | parser.add_argument('data', metavar='DIR', help='path to dataset')
25 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
26 | choices=model_names,
27 | help='model architecture: ' +
28 | ' | '.join(model_names) +
29 | ' (default: resnet18)')
30 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
31 | help='number of data loading workers (default: 16)')
32 | parser.add_argument('-b', '--batch-size', default=256, type=int,
33 | metavar='N', help='mini-batch size (default: 96)')
34 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
35 | help='use pre-trained model')
36 | parser.add_argument('-g', '--num-gpus', default=1, type=int,
37 | metavar='N', help='number of GPUs to match (default: 4)')
38 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
39 | help='path to latest checkpoint (default: none)')
40 | parser.add_argument('--input_resize', default=224, type=int,
41 | metavar='N', help='Resize for smallest side of input (default: 224)')
42 | parser.add_argument('--maxpool', dest='maxpool', action='store_true',
43 | help='use maxpool version of the model')
44 |
45 |
46 | def main():
47 | global args
48 | args = parser.parse_args()
49 |
50 | if args.pretrained:
51 | print("=> using pre-trained model '{}'".format(args.arch))
52 | if args.arch.startswith('resnet'):
53 | model = resnet.__dict__[args.arch](pretrained=True)
54 | else:
55 | assert False, 'Unsupported architecture: {}'.format(args.arch)
56 | else:
57 | print("=> creating model '{}'".format(args.arch))
58 | if args.arch.startswith('resnet'):
59 | model = resnet.__dict__[args.arch]()
60 |
61 | if args.resume:
62 | if os.path.isfile(args.resume):
63 | print("=> loading checkpoint '{}'".format(args.resume))
64 | state_dict = torch.load(args.resume)
65 | if 'state_dict' in state_dict:
66 | state_dict = state_dict['state_dict']
67 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
68 | # pdb.set_trace()
69 | for k, v in model.state_dict().items():
70 | if k not in list(state_dict):
71 | print('key "{}" could not be found in provided state dict'.format(k))
72 | elif state_dict[k].shape != v.shape:
73 | print('key "{}" is of different shape in model and provided state dict'.format(k))
74 | state_dict[k] = v
75 | model.load_state_dict(state_dict, strict=False)
76 | else:
77 | print("=> no checkpoint found at '{}'".format(args.resume))
78 |
79 | model = torch.nn.DataParallel(model).cuda()
80 |
81 | if (not args.resume) and (not args.pretrained):
82 | assert False, "Please specify either the pre-trained model or checkpoint for evaluation"
83 |
84 | cudnn.benchmark = True
85 |
86 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
87 | std=[0.229, 0.224, 0.225])
88 |
89 | # In the version, we will not resize the images. We feed the full image and use AdaptivePooling before FC.
90 | # We will resize Gradcam heatmap to image size and compare the actual bbox co-ordinates
91 | val_dataset = pointing_datasets.ImageNetDetection(args.data,
92 | transform=transforms.Compose([
93 | transforms.Resize(args.input_resize),
94 | transforms.ToTensor(),
95 | normalize,
96 | ]))
97 |
98 | # we set batch size=1 since we are loading full resolution images.
99 | val_loader = torch.utils.data.DataLoader(
100 | val_dataset, batch_size=1, shuffle=False,
101 | num_workers=args.workers, pin_memory=True)
102 |
103 | validate_multi(val_loader, val_dataset, model)
104 |
105 |
106 | def validate_multi(val_loader, val_dataset, model):
107 | batch_time = AverageMeter()
108 | heatmap_inside_bbox = AverageMeter()
109 |
110 | # switch to evaluate mode
111 | model.eval()
112 |
113 | zero_count = 0
114 | total_count = 0
115 | end = time.time()
116 | for i, (images, annotation, targets) in enumerate(val_loader):
117 | total_count += 1
118 | images = images.cuda(non_blocking=True)
119 | targets = targets.cuda(non_blocking=True)
120 |
121 | # we assume batch size == 1 and unwrap the first elem of every list in annotation object
122 | annotation = unwrap_dict(annotation)
123 | image_size = val_dataset.as_image_size(annotation)
124 |
125 | output, feats = model(images, vanilla_with_feats=True)
126 | output_gradcam = compute_gradcam(output, feats, targets)
127 | output_gradcam_np = output_gradcam.data.cpu().numpy()[0] # since we have batch size==1
128 | resized_output_gradcam = cv2.resize(output_gradcam_np, image_size)
129 | spatial_sum = resized_output_gradcam.sum()
130 | if spatial_sum <= 0:
131 | zero_count += 1
132 | continue
133 |
134 | # resized_output_gradcam is now normalized and can be considered as probabilities
135 | resized_output_gradcam = resized_output_gradcam / spatial_sum
136 |
137 | mask = pointing_datasets.imagenet_as_mask(annotation, targets[0].item())
138 |
139 | mask = mask.type(torch.ByteTensor)
140 | mask = mask.cpu().data.numpy()
141 |
142 | gcam_inside_gt_mask = mask * resized_output_gradcam
143 | # Now we sum the heatmap inside the object bounding box
144 | total_gcam_inside_gt_mask = gcam_inside_gt_mask.sum()
145 | heatmap_inside_bbox.update(total_gcam_inside_gt_mask)
146 |
147 | if i % 1000 == 0:
148 | print('\nResults after {} examples: '.format(i+1))
149 | print('Curr % of heatmap inside bbox: {:.4f} ({:.4f})'.format(heatmap_inside_bbox.val * 100,
150 | heatmap_inside_bbox.avg * 100))
151 |
152 | # measure elapsed time
153 | batch_time.update(time.time() - end)
154 | end = time.time()
155 |
156 | print('\nFinal Results - ')
157 | print('\n\n% of heatmap inside bbox: {:.4f}'.format(heatmap_inside_bbox.avg * 100))
158 | print('Zero GC found for {}/{} samples'.format(zero_count, total_count))
159 |
160 | return
161 |
162 |
163 | def compute_gradcam(output, feats, target):
164 | """
165 | Compute the gradcam for the top predicted category
166 | :param output:
167 | :param feats:
168 | :return:
169 | """
170 | eps = 1e-8
171 | relu = nn.ReLU(inplace=True)
172 |
173 | target = target.cpu().numpy()
174 | # target = np.argmax(output.cpu().data.numpy(), axis=-1)
175 | one_hot = np.zeros((output.shape[0], output.shape[-1]), dtype=np.float32)
176 | indices_range = np.arange(output.shape[0])
177 | one_hot[indices_range, target[indices_range]] = 1
178 | one_hot = torch.from_numpy(one_hot)
179 | one_hot.requires_grad = True
180 |
181 | # Compute the Grad-CAM for the original image
182 | one_hot_cuda = torch.sum(one_hot.cuda() * output)
183 | dy_dz1, = torch.autograd.grad(one_hot_cuda, feats, grad_outputs=torch.ones(one_hot_cuda.size()).cuda(),
184 | retain_graph=True, create_graph=True)
185 | dy_dz_sum1 = dy_dz1.sum(dim=2).sum(dim=2)
186 | gcam512_1 = dy_dz_sum1.unsqueeze(-1).unsqueeze(-1) * feats
187 | # Comment the above 2 lines and uncomment the below one to change to dot product of grad and features to preserve grad spatial locations
188 | # gcam512_1 = dy_dz1 * feats
189 | gradcam = gcam512_1.sum(dim=1)
190 | gradcam = relu(gradcam)
191 | spatial_sum1 = gradcam.sum(dim=[1, 2]).unsqueeze(-1).unsqueeze(-1)
192 | gradcam = (gradcam / (spatial_sum1 + eps)) + eps
193 |
194 | return gradcam
195 |
196 |
197 | def unwrap_dict(dict_object):
198 | new_dict = {}
199 | for k, v in dict_object.items():
200 | if k == 'object':
201 | new_v_list = []
202 | for elem in v:
203 | new_v_list.append(unwrap_dict(elem))
204 | new_dict[k] = new_v_list
205 | continue
206 | if isinstance(v, dict):
207 | new_v = unwrap_dict(v)
208 | elif isinstance(v, list) and len(v) == 1:
209 | new_v = v[0]
210 | # if isinstance(new_v, dict):
211 | # new_v = unwrap_dict(new_v)
212 | else:
213 | new_v = v
214 | new_dict[k] = new_v
215 | return new_dict
216 |
217 |
218 | class AverageMeter(object):
219 | """Computes and stores the average and current value"""
220 | def __init__(self):
221 | self.reset()
222 |
223 | def reset(self):
224 | self.val = 0
225 | self.avg = 0
226 | self.sum = 0
227 | self.count = 0
228 |
229 | def update(self, val, n=1):
230 | self.val = val
231 | self.sum += val * n
232 | self.count += n
233 | self.avg = self.sum / self.count
234 |
235 |
236 | if __name__ == '__main__':
237 | main()
238 |
--------------------------------------------------------------------------------
/TorchRay/torchray/benchmark/logging_mongo.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 |
3 | r"""
4 | This module provides function that to be log information (e.g., benchmark
5 | results) to a MongoDB database.
6 |
7 | See :mod:`examples.standard_suite` for an example of how to use MongoDB for
8 | logging benchmark results.
9 |
10 | To start a MongoDB server, use
11 |
12 | .. code:: shell
13 |
14 | $ python -m torchray.benchmark.server
15 |
16 | """
17 | import io
18 | import pickle
19 |
20 | import bson
21 | import numpy as np
22 | import pymongo
23 | import torch
24 |
25 | from torchray.utils import get_config
26 |
27 | __all__ = [
28 | 'mongo_connect',
29 | 'mongo_save',
30 | 'mongo_load',
31 | 'data_to_mongo',
32 | 'data_from_mongo',
33 | 'last_lines'
34 | ]
35 |
36 | _MONGO_MAX_TRIES = 10
37 |
38 |
39 | def mongo_connect(database):
40 | """
41 | Connect to MongoDB server and and return a
42 | :class:`pymongo.database.Database` object.
43 |
44 | Args:
45 | database (str): name of database.
46 |
47 | Returns:
48 | :class:`pymongo.database.Database`: database.
49 | """
50 | try:
51 | config = get_config()
52 | hostname = f"{config['mongo']['hostname']}:{config['mongo']['port']}"
53 | client = pymongo.MongoClient(hostname)
54 | client.server_info()
55 | database = client[database]
56 | return database
57 | except pymongo.errors.ServerSelectionTimeoutError as error:
58 | raise Exception(
59 | f"Cannot connect MonogDB at {hostname}") from error
60 |
61 |
62 | def mongo_save(database, collection_key, id_key, data):
63 | """Save results to MongoDB database.
64 |
65 | Args:
66 | database (:class:`pymongo.database.Database`): MongoDB database to save
67 | results to.
68 | collection_key (str): name of collection.
69 | id_key (str): id key with which to store :attr:`data`.
70 | data (:class:`bson.binary.Binary` or dict): data to store in
71 | :attr:`db`.
72 | """
73 | collection = database[collection_key].with_options(
74 | write_concern=pymongo.WriteConcern(w=1))
75 | tries_left = _MONGO_MAX_TRIES
76 | while tries_left > 0:
77 | tries_left -= 1
78 | try:
79 | collection.replace_one(
80 | {'_id': id_key},
81 | data,
82 | upsert=True
83 | )
84 | return
85 | except (pymongo.errors.WriteConcernError, pymongo.errors.WriteError):
86 | if tries_left == 0:
87 | print(
88 | f"Warning: could not write entry to mongodb after"
89 | f" {_MONGO_MAX_TRIES} attempts."
90 | )
91 | raise
92 |
93 |
94 | def mongo_load(database, collection_key, id_key):
95 | """Load data from MongoDB database.
96 |
97 | Args:
98 | database (:class:`pymongo.database.Database`): MongoDB database to save
99 | results to.
100 | collection_key (str): name of collection.
101 | id_key (str): id key to look up data.
102 |
103 | Returns:
104 | retrieved data (returns None if no data with :attr:`id_key` is found).
105 | """
106 | return database[collection_key].find_one({'_id': id_key})
107 |
108 |
109 | def data_to_mongo(data):
110 | """Prepare data to be stored in a MongoDB database.
111 |
112 | Args:
113 | data (dict, :class:`torch.Tensor`, or :class:`np.ndarray`): data to
114 | prepare for storage in a MongoDB dataset (if dict, items are
115 | recursively prepared for storage). If the underlying data is
116 | not :class:`torch.Tensor` or :class:`np.ndarray`, then :attr:`data`
117 | is returned as is.
118 |
119 | Returns:
120 | :class:`bson.binary.Binary` or dict of :class:`bson.binary.Binary`:
121 | correctly formatted data to store in a MongoDB database.
122 | """
123 | if isinstance(data, dict):
124 | return {k: data_to_mongo(v) for k, v in data.items()}
125 | if isinstance(data, torch.Tensor):
126 | bytes_data = io.BytesIO()
127 | torch.save(data, bytes_data)
128 | bytes_data.seek(0)
129 | binary = bson.binary.Binary(bytes_data.read())
130 | return binary
131 | if isinstance(data, np.ndarray):
132 | return bson.binary.Binary(pickle.dumps(data, protocol=2),
133 | subtype=128)
134 | return data
135 |
136 |
137 | def data_from_mongo(mongo_data, map_location=None):
138 | """Decode data stored in a MongoDB database.
139 |
140 | Args:
141 | mongo_data (:class:`bson.binary.Binary` or dict):
142 | data to decode (if dict, items are recursively decoded). If
143 | the underlying data type is not `:class:torch.Tensor` or
144 | something stored using :mod:`pickle`, then :attr:`mongo_data`
145 | is returned as is.
146 | map_location (function, :class:`torch.device`, str or dict): where to
147 | remap storage locations (see :func:`torch.load` for more details).
148 | Default: ``None``.
149 |
150 | Returns:
151 | decoded data.
152 | """
153 |
154 | if isinstance(mongo_data, dict):
155 | return {k: data_from_mongo(v) for k, v in mongo_data.items()}
156 | if isinstance(mongo_data, bson.binary.Binary):
157 | try:
158 | bytes_data = io.BytesIO(mongo_data)
159 | return torch.load(bytes_data, map_location=map_location)
160 | # If the underlying data is a numpy array, it throws a ValueError here.
161 | except Exception:
162 | pass
163 | try:
164 | return pickle.loads(mongo_data)
165 | except Exception:
166 | pass
167 | return mongo_data
168 |
169 |
170 | def last_lines(string, num_lines):
171 | """Extract the last few lines from a string.
172 |
173 | The function extracts the last attr:`n` lines from the string attr:`str`.
174 | If attr:`n` is a negative number, then it extracts the first lines
175 | instead. It also skips lines beginning with ``'Figure('``.
176 |
177 | Args:
178 | string (str): string.
179 | num_lines (int): number of lines to extract.
180 |
181 | Returns:
182 | str: substring.
183 | """
184 | if string is None:
185 | return ''
186 | lines = string.strip().split('\n')
187 | lines = [l for l in lines if not l.startswith('Figure(')]
188 | if not lines:
189 | return ''
190 | if num_lines > 0:
191 | min_lines = min(num_lines, len(lines))
192 | lines_ = lines[-min_lines:]
193 | if num_lines < len(lines):
194 | lines_ = ['[...]'] + lines_
195 | if num_lines < 0:
196 | num_lines = -num_lines
197 | min_lines = min(num_lines, len(lines))
198 | lines_ = lines[:min_lines]
199 | if num_lines < len(lines):
200 | lines_ = lines_ + ['[...]']
201 |
202 | return '\n'.join(lines_)
203 |
--------------------------------------------------------------------------------
/TorchRay/torchray/benchmark/pointing_game.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 |
3 | r"""
4 | The :mod:`pointing_game` modules implements the pointing game benchmark.
5 | The basic benchmark is implemented by the :class:`PointingGame` class. However,
6 | for benchmarking purposes it is recommended to use the wrapper class
7 | :class:`PointingGameBenchmark` instead. This class supports *PASCAL VOC 2007
8 | test* and *COCO 2014 val* with the modifications used in [EBP]_, including the
9 | ability to run on their "difficult" subsets as defined in the original paper.
10 |
11 | The class can be used as follows:
12 |
13 | 1. Obtain a dataset (usually COCO, PASCAL VOC or ImageNet detection) and choose a subset.
14 | 2. Initialize an instance of :class:`PointingGameBenchmark`.
15 | 3. For each image in the dataset:
16 |
17 | 1. For each class in the image:
18 |
19 | 1. Run the attribution method, usually resulting in a saliency map for
20 | class :math:`c`.
21 | 2. Convert the result to a point, usually by finding the maximizer of the
22 | saliency map.
23 | 3. Use the :func:`PointingGameBenchmark.evaluate` function to run the
24 | test and accumulate the statistics.
25 | 4. Extract the :attr:`PointingGame.hits` and :attr:`PointingGame.misses` or
26 | ``print`` the instance to display the results.
27 | """
28 |
29 | import torch
30 | from torchvision import datasets as ds
31 | import pdb
32 | import datasets as ads
33 | import numpy as np
34 | from PIL import Image
35 |
36 |
37 | class PointingGame:
38 | r"""Pointing game.
39 |
40 | Args:
41 | num_classes (int): number of classes in the dataset.
42 | tolerance (int, optional): tolerance (in pixels) of margin around
43 | ground truth annotation. Default: 15.
44 |
45 | Attributes:
46 | hits (:class:`torch.Tensor`): :attr:`num_classes`-dimensional vector of
47 | hits counts.
48 | misses (:class:`torch.Tensor`): :attr:`num_classes`-dimensional vector
49 | of misses counts.
50 | """
51 |
52 | def __init__(self, num_classes, tolerance=15):
53 | assert isinstance(num_classes, int)
54 | assert isinstance(tolerance, int)
55 | self.num_classes = num_classes
56 | self.tolerance = tolerance
57 | self.hits = torch.zeros((num_classes,), dtype=torch.float64)
58 | self.misses = torch.zeros((num_classes,), dtype=torch.float64)
59 |
60 | def evaluate(self, mask, point):
61 | r"""Evaluate a point prediction.
62 |
63 | The function tests whether the prediction :attr:`point` is within a
64 | certain tolerance of the object ground-truth region :attr:`mask`
65 | expressed as a boolean occupancy map.
66 |
67 | Use the :func:`reset` method to clear all counters.
68 |
69 | Args:
70 | mask (:class:`torch.Tensor`): :math:`\{0,1\}^{H\times W}`.
71 | point (tuple of ints): predicted point :math:`(u,v)`.
72 |
73 | Returns:
74 | int: +1 if the point hits the object; otherwise -1.
75 | """
76 | # Get an acceptance region around the point. There is a hit whenever
77 | # the acceptance region collides with the class mask.
78 | # pdb.set_trace()
79 | # v, u = torch.meshgrid((
80 | # (torch.arange(mask.shape[0],
81 | # dtype=torch.float32) - point[1])**2,
82 | # (torch.arange(mask.shape[1],
83 | # dtype=torch.float32) - point[0])**2,
84 | # ))
85 |
86 | v, u = torch.meshgrid((
87 | (torch.arange(mask.shape[0],
88 | dtype=torch.float32) - point[0]) ** 2,
89 | (torch.arange(mask.shape[1],
90 | dtype=torch.float32) - point[1]) ** 2,
91 | ))
92 | accept = (v + u) < self.tolerance**2
93 |
94 | # Test for a hit with the corresponding class.
95 | hit = (mask & accept).view(-1).any()
96 |
97 | # code to debug GT mask and accept mask
98 | # mask_np = mask.numpy().astype(np.uint8)*255
99 | # mask_im = Image.fromarray(mask_np)
100 | # mask_im.save('mask_coco_val_im_03.png')
101 | # accept_np = accept.numpy().astype(np.uint8) * 255
102 | # accept_im = Image.fromarray(accept_np)
103 | # accept_im.save('accept_mask_coco_val_03.png')
104 |
105 | return +1 if hit else -1
106 |
107 | def aggregate(self, hit, class_id):
108 | """Add pointing result from one example."""
109 | if hit == 0:
110 | return
111 | if hit == 1:
112 | self.hits[class_id] += 1
113 | elif hit == -1:
114 | self.misses[class_id] += 1
115 | else:
116 | assert False
117 |
118 | def reset(self):
119 | """Reset hits and misses."""
120 | self.hits = torch.zeros_like(self.hits)
121 | self.misses = torch.zeros_like(self.misses)
122 |
123 | @property
124 | def class_accuracies(self):
125 | """
126 | (:class:`torch.Tensor`): :attr:`num_classes`-dimensional vector
127 | containing per-class accuracy.
128 | """
129 | return self.hits / (self.hits + self.misses).clamp(min=1)
130 |
131 | @property
132 | def accuracy(self):
133 | """
134 | (:class:`torch.Tensor`): mean accuracy, computed by averaging
135 | :attr:`class_accuracies`.
136 | """
137 | return self.class_accuracies.mean()
138 |
139 | # def __str__(self):
140 | # class_accuracies = self.class_accuracies
141 | # return '{:4.1f}% ['.format(100 * class_accuracies.mean()) + " ".join([
142 | # '{}:{:4.1f}%'.format(c, 100 * a)
143 | # for c, a in enumerate(class_accuracies)
144 | # ]) + ']'
145 |
146 | def __str__(self):
147 | class_accuracies = self.class_accuracies
148 | return 'Pointing game mean accuracy: {:4.1f}% '.format(100 * class_accuracies.mean())
149 |
150 |
151 | class PointingGameBenchmark(PointingGame):
152 | """Pointing game benchmark on standard datasets.
153 |
154 | The pointing game should be initialized with a dataset, set to either:
155 |
156 | * (:class:`torchvision.VOCDetection`) VOC 2007 *test* subset.
157 | * (:class:`torchvision.CocoDetection`) COCO *val2014* subset.
158 |
159 | Args:
160 | dataset (:class:`torchvision.VisionDataset`): The dataset.
161 | tolerance (int): the tolerance for the pointing game. Default: ``15``.
162 | difficult (bool): whether to use the difficult subset.
163 | Default: ``False``.
164 | """
165 |
166 | def __init__(self, dataset, tolerance=15, difficult=False):
167 | if isinstance(dataset, ds.VOCDetection):
168 | num_classes = 20
169 | elif isinstance(dataset, ads.ImageNetDetection):
170 | num_classes = 1000
171 | elif isinstance(dataset, ds.CocoDetection):
172 | num_classes = 80
173 | else:
174 | assert False, 'Only VOCDetection, ImageNetDetection and CocoDetection are supported.'
175 |
176 | super(PointingGameBenchmark, self).__init__(
177 | num_classes=num_classes, tolerance=tolerance)
178 | self.dataset = dataset
179 | self.difficult = difficult
180 |
181 | if difficult:
182 | def load_flags(name):
183 | try:
184 | import importlib.resources as res
185 | except ImportError:
186 | import importlib_resources as res
187 | with res.open_text('torchray.benchmark', name) as file:
188 | rows = [[x for x in row.split('\t')] for row in file]
189 | return {
190 | row[0]: [bool(int(x)) for x in row[1:]]
191 | for row in rows
192 | }
193 | if isinstance(self.dataset, ds.VOCDetection):
194 | self.difficult_flags = load_flags(
195 | 'pointing_game_ebp_voc07_difficult.txt')
196 | elif isinstance(self.dataset, ds.CocoDetection):
197 | self.difficult_flags = load_flags(
198 | 'pointing_game_ebp_coco_difficult.txt')
199 | else:
200 | assert False, 'Difficult set is supported only for VOC and COCO datasets respectively.'
201 |
202 | def evaluate(self, label, class_id, point):
203 | """Evaluate an label-class-point triplet.
204 |
205 | Args:
206 | label (dict): a label in VOC or Coco detection format.
207 | class_id (int): a class id.
208 | point (iterable): a point specified as a pair of u, v coordinates.
209 |
210 | Returns:
211 | int: +1 if the point hits the object, -1 if the point misses the
212 | object, and 0 if the point is skipped during evaluation.
213 | """
214 |
215 | # Skip if testing on the EBP difficult subset and the image/class pair
216 | # is an easy one.
217 | if self.difficult:
218 | if isinstance(self.dataset, ds.VOCDetection):
219 | image_name = label['annotation']['filename'].split('.')[0]
220 | elif isinstance(self.dataset, ds.CocoDetection):
221 | image_id = label[0]['image_id']
222 | image = self.dataset.coco.loadImgs(image_id)[0]
223 | image_name = image['file_name'].split('.')[0]
224 | else:
225 | assert False, 'Only VOC and COCO datasets are supported for the difficult subset'
226 |
227 | if image_name in self.difficult_flags:
228 | if not self.difficult_flags[image_name][class_id]:
229 | return 0
230 |
231 | # Get the mask for all occurrences of class_id.
232 | if isinstance(self.dataset, ds.VOCDetection):
233 | # Skip if all boxes for class_id are PASCAL difficult.
234 | objs = label['annotation']['object']
235 | if not isinstance(objs, list):
236 | objs = [objs]
237 | objs = [obj for obj in objs if
238 | ads.VOC_CLASSES.index(obj['name']) == class_id
239 | ]
240 | if all([bool(int(obj['difficult'])) for obj in objs]):
241 | return 0
242 | mask = ads.voc_as_mask(label, class_id)
243 |
244 | elif isinstance(self.dataset, ads.ImageNetDetection):
245 | # Skip if all boxes for class_id are marked difficult.
246 | objs = label['annotation']['object']
247 | if not isinstance(objs, list):
248 | objs = [objs]
249 | objs = [obj for obj in objs if
250 | ads._IMAGENET_CLASS_TO_INDEX[obj['name']] == class_id # we assume single name elem per object
251 | ]
252 | if all([bool(int(obj['difficult'])) for obj in objs]):
253 | return 0
254 | mask = ads.imagenet_as_mask(label, class_id)
255 |
256 | elif isinstance(self.dataset, ds.CocoDetection):
257 | mask = ads.coco_as_mask(self.dataset, label, class_id)
258 |
259 | assert mask is not None
260 | return super(PointingGameBenchmark, self).evaluate(mask, point)
261 |
--------------------------------------------------------------------------------
/TorchRay/torchray/benchmark/server.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 |
3 | r"""
4 | This module is used to start and run a MongoDB server.
5 |
6 | To start a MongoDB server, use
7 |
8 | .. code:: shell
9 |
10 | $ python -m torchray.benchmark.server
11 |
12 | """
13 | import subprocess
14 | from torchray.utils import get_config
15 |
16 |
17 | def run_server():
18 | """Runs an instance of MongoDB as a logging server."""
19 | config = get_config()
20 | command = [
21 | config['mongo']['server'],
22 | '--dbpath', config['mongo']['database'],
23 | '--bind_ip', config['mongo']['hostname'],
24 | '--port', str(config['mongo']['port'])
25 | ]
26 | print(f"Command: {' '.join(command)}.")
27 | code = subprocess.call(command, cwd=".")
28 | print(f"Return code {code}")
29 |
30 |
31 | if __name__ == '__main__':
32 | run_server()
33 |
--------------------------------------------------------------------------------
/TorchRay/torchray/benchmark/swav_resnet_cgc.py:
--------------------------------------------------------------------------------
1 | # This source code is licensed under the license found in the
2 | # LICENSE file in the root directory of this source tree.
3 | #
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 |
9 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
10 | """3x3 convolution with padding"""
11 | return nn.Conv2d(
12 | in_planes,
13 | out_planes,
14 | kernel_size=3,
15 | stride=stride,
16 | padding=dilation,
17 | groups=groups,
18 | bias=False,
19 | dilation=dilation,
20 | )
21 |
22 |
23 | def conv1x1(in_planes, out_planes, stride=1):
24 | """1x1 convolution"""
25 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
26 |
27 |
28 | class BasicBlock(nn.Module):
29 | expansion = 1
30 | __constants__ = ["downsample"]
31 |
32 | def __init__(
33 | self,
34 | inplanes,
35 | planes,
36 | stride=1,
37 | downsample=None,
38 | groups=1,
39 | base_width=64,
40 | dilation=1,
41 | norm_layer=None,
42 | ):
43 | super(BasicBlock, self).__init__()
44 | if norm_layer is None:
45 | norm_layer = nn.BatchNorm2d
46 | if groups != 1 or base_width != 64:
47 | raise ValueError("BasicBlock only supports groups=1 and base_width=64")
48 | if dilation > 1:
49 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
50 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
51 | self.conv1 = conv3x3(inplanes, planes, stride)
52 | self.bn1 = norm_layer(planes)
53 | self.relu = nn.ReLU(inplace=True)
54 | self.conv2 = conv3x3(planes, planes)
55 | self.bn2 = norm_layer(planes)
56 | self.downsample = downsample
57 | self.stride = stride
58 |
59 | def forward(self, x):
60 | identity = x
61 |
62 | out = self.conv1(x)
63 | out = self.bn1(out)
64 | out = self.relu(out)
65 |
66 | out = self.conv2(out)
67 | out = self.bn2(out)
68 |
69 | if self.downsample is not None:
70 | identity = self.downsample(x)
71 |
72 | out += identity
73 | out = self.relu(out)
74 |
75 | return out
76 |
77 |
78 | class Bottleneck(nn.Module):
79 | expansion = 4
80 | __constants__ = ["downsample"]
81 |
82 | def __init__(
83 | self,
84 | inplanes,
85 | planes,
86 | stride=1,
87 | downsample=None,
88 | groups=1,
89 | base_width=64,
90 | dilation=1,
91 | norm_layer=None,
92 | ):
93 | super(Bottleneck, self).__init__()
94 | if norm_layer is None:
95 | norm_layer = nn.BatchNorm2d
96 | width = int(planes * (base_width / 64.0)) * groups
97 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
98 | self.conv1 = conv1x1(inplanes, width)
99 | self.bn1 = norm_layer(width)
100 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
101 | self.bn2 = norm_layer(width)
102 | self.conv3 = conv1x1(width, planes * self.expansion)
103 | self.bn3 = norm_layer(planes * self.expansion)
104 | self.relu = nn.ReLU(inplace=True)
105 | self.downsample = downsample
106 | self.stride = stride
107 |
108 | def forward(self, x):
109 | identity = x
110 |
111 | out = self.conv1(x)
112 | out = self.bn1(out)
113 | out = self.relu(out)
114 |
115 | out = self.conv2(out)
116 | out = self.bn2(out)
117 | out = self.relu(out)
118 |
119 | out = self.conv3(out)
120 | out = self.bn3(out)
121 |
122 | if self.downsample is not None:
123 | identity = self.downsample(x)
124 |
125 | out += identity
126 | out = self.relu(out)
127 |
128 | return out
129 |
130 |
131 | def normalize(x):
132 | return x / x.norm(2, dim=1, keepdim=True)
133 |
134 |
135 | class ResNet(nn.Module):
136 | def __init__(
137 | self,
138 | block,
139 | layers,
140 | zero_init_residual=False,
141 | groups=1,
142 | widen=1,
143 | width_per_group=64,
144 | replace_stride_with_dilation=None,
145 | norm_layer=None,
146 | normalize=False,
147 | output_dim=0,
148 | hidden_mlp=0,
149 | nmb_prototypes=0,
150 | eval_mode=False,
151 | ):
152 | super(ResNet, self).__init__()
153 | if norm_layer is None:
154 | norm_layer = nn.BatchNorm2d
155 | self._norm_layer = norm_layer
156 |
157 | self.eval_mode = eval_mode
158 | self.padding = nn.ConstantPad2d(1, 0.0)
159 |
160 | self.inplanes = width_per_group * widen
161 | self.dilation = 1
162 | if replace_stride_with_dilation is None:
163 | # each element in the tuple indicates if we should replace
164 | # the 2x2 stride with a dilated convolution instead
165 | replace_stride_with_dilation = [False, False, False]
166 | if len(replace_stride_with_dilation) != 3:
167 | raise ValueError(
168 | "replace_stride_with_dilation should be None "
169 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
170 | )
171 | self.groups = groups
172 | self.base_width = width_per_group
173 |
174 | # change padding 3 -> 2 compared to original torchvision code because added a padding layer
175 | num_out_filters = width_per_group * widen
176 | self.conv1 = nn.Conv2d(
177 | 3, num_out_filters, kernel_size=7, stride=2, padding=2, bias=False
178 | )
179 | self.bn1 = norm_layer(num_out_filters)
180 | self.relu = nn.ReLU(inplace=True)
181 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
182 | self.layer1 = self._make_layer(block, num_out_filters, layers[0])
183 | num_out_filters *= 2
184 | self.layer2 = self._make_layer(
185 | block, num_out_filters, layers[1], stride=2, dilate=replace_stride_with_dilation[0]
186 | )
187 | num_out_filters *= 2
188 | self.layer3 = self._make_layer(
189 | block, num_out_filters, layers[2], stride=2, dilate=replace_stride_with_dilation[1]
190 | )
191 | num_out_filters *= 2
192 | self.layer4 = self._make_layer(
193 | block, num_out_filters, layers[3], stride=2, dilate=replace_stride_with_dilation[2]
194 | )
195 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
196 |
197 | # normalize output features
198 | self.l2norm = normalize
199 | self.upsample = nn.Upsample((7, 7), mode='bilinear')
200 | self.eps = 1e-8
201 |
202 | # projection head
203 | if output_dim == 0:
204 | self.projection_head = None
205 | elif hidden_mlp == 0:
206 | self.projection_head = nn.Linear(num_out_filters * block.expansion, output_dim)
207 | else:
208 | self.projection_head = nn.Sequential(
209 | nn.Linear(num_out_filters * block.expansion, hidden_mlp),
210 | nn.BatchNorm1d(hidden_mlp),
211 | nn.ReLU(inplace=True),
212 | nn.Linear(hidden_mlp, output_dim),
213 | )
214 |
215 | # prototype layer
216 | self.prototypes = None
217 | if isinstance(nmb_prototypes, list):
218 | self.prototypes = MultiPrototypes(output_dim, nmb_prototypes)
219 | elif nmb_prototypes > 0:
220 | self.prototypes = nn.Linear(output_dim, nmb_prototypes, bias=False)
221 |
222 | for m in self.modules():
223 | if isinstance(m, nn.Conv2d):
224 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
225 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
226 | nn.init.constant_(m.weight, 1)
227 | nn.init.constant_(m.bias, 0)
228 |
229 | # Zero-initialize the last BN in each residual branch,
230 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
231 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
232 | if zero_init_residual:
233 | for m in self.modules():
234 | if isinstance(m, Bottleneck):
235 | nn.init.constant_(m.bn3.weight, 0)
236 | elif isinstance(m, BasicBlock):
237 | nn.init.constant_(m.bn2.weight, 0)
238 |
239 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
240 | norm_layer = self._norm_layer
241 | downsample = None
242 | previous_dilation = self.dilation
243 | if dilate:
244 | self.dilation *= stride
245 | stride = 1
246 | if stride != 1 or self.inplanes != planes * block.expansion:
247 | downsample = nn.Sequential(
248 | conv1x1(self.inplanes, planes * block.expansion, stride),
249 | norm_layer(planes * block.expansion),
250 | )
251 |
252 | layers = []
253 | layers.append(
254 | block(
255 | self.inplanes,
256 | planes,
257 | stride,
258 | downsample,
259 | self.groups,
260 | self.base_width,
261 | previous_dilation,
262 | norm_layer,
263 | )
264 | )
265 | self.inplanes = planes * block.expansion
266 | for _ in range(1, blocks):
267 | layers.append(
268 | block(
269 | self.inplanes,
270 | planes,
271 | groups=self.groups,
272 | base_width=self.base_width,
273 | dilation=self.dilation,
274 | norm_layer=norm_layer,
275 | )
276 | )
277 |
278 | return nn.Sequential(*layers)
279 |
280 | def forward_backbone(self, x):
281 | x = self.padding(x)
282 |
283 | x = self.conv1(x)
284 | x = self.bn1(x)
285 | x = self.relu(x)
286 | x = self.maxpool(x)
287 | x = self.layer1(x)
288 | x = self.layer2(x)
289 | x = self.layer3(x)
290 | feats = self.layer4(x)
291 |
292 | if self.eval_mode:
293 | return feats
294 |
295 | x = self.avgpool(feats)
296 | x = torch.flatten(x, 1)
297 |
298 | return x, feats
299 |
300 | def forward_head(self, x):
301 | if self.projection_head is not None:
302 | x = self.projection_head(x)
303 |
304 | if self.l2norm:
305 | x = nn.functional.normalize(x, dim=1, p=2)
306 |
307 | if self.prototypes is not None:
308 | return x, self.prototypes(x)
309 | return x
310 |
311 | def _forward_vanilla(self, inputs, return_feats=False):
312 | x, feats = self.forward_backbone(inputs)
313 | x = self.forward_head(x)
314 | if return_feats:
315 | return x, feats
316 | else:
317 | return x
318 |
319 | def _forward(self, lbl_images, vanilla=False, vanilla_with_feats=False):
320 | """
321 | :param lbl_images: Labeled images to be used for computing logits/feats
322 | :param vanilla: If True, return the outputs from a regular forward pass through the model
323 | :param vanilla_with_feats: If True, return the feats as well as outputs
324 | :return:
325 | """
326 | if vanilla:
327 | return self.forward_vanilla(lbl_images)
328 | if vanilla_with_feats:
329 | return self.forward_vanilla(lbl_images, return_feats=True)
330 |
331 | # Allow for accessing forward method in a inherited class
332 | forward = _forward
333 | forward_vanilla = _forward_vanilla
334 |
335 |
336 | class MultiPrototypes(nn.Module):
337 | def __init__(self, output_dim, nmb_prototypes):
338 | super(MultiPrototypes, self).__init__()
339 | self.nmb_heads = len(nmb_prototypes)
340 | for i, k in enumerate(nmb_prototypes):
341 | self.add_module("prototypes" + str(i), nn.Linear(output_dim, k, bias=False))
342 |
343 | def forward(self, x):
344 | out = []
345 | for i in range(self.nmb_heads):
346 | out.append(getattr(self, "prototypes" + str(i))(x))
347 | return out
348 |
349 |
350 | def resnet50(**kwargs):
351 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
352 |
353 |
354 | def resnet50w2(**kwargs):
355 | return ResNet(Bottleneck, [3, 4, 6, 3], widen=2, **kwargs)
356 |
357 |
358 | def resnet50w4(**kwargs):
359 | return ResNet(Bottleneck, [3, 4, 6, 3], widen=4, **kwargs)
360 |
361 |
362 | def resnet50w5(**kwargs):
363 | return ResNet(Bottleneck, [3, 4, 6, 3], widen=5, **kwargs)
--------------------------------------------------------------------------------
/TorchRay/torchray/benchmark/vision.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.utils.data as data
4 |
5 |
6 | class VisionDataset(data.Dataset):
7 | _repr_indent = 4
8 |
9 | def __init__(self, root, transforms=None, transform=None, target_transform=None):
10 | if isinstance(root, torch._six.string_classes):
11 | root = os.path.expanduser(root)
12 | self.root = root
13 |
14 | has_transforms = transforms is not None
15 | has_separate_transform = transform is not None or target_transform is not None
16 | if has_transforms and has_separate_transform:
17 | raise ValueError("Only transforms or transform/target_transform can "
18 | "be passed as argument")
19 |
20 | # for backwards-compatibility
21 | self.transform = transform
22 | self.target_transform = target_transform
23 |
24 | if has_separate_transform:
25 | transforms = StandardTransform(transform, target_transform)
26 | self.transforms = transforms
27 |
28 | def __getitem__(self, index):
29 | raise NotImplementedError
30 |
31 | def __len__(self):
32 | raise NotImplementedError
33 |
34 | def __repr__(self):
35 | head = "Dataset " + self.__class__.__name__
36 | body = ["Number of datapoints: {}".format(self.__len__())]
37 | if self.root is not None:
38 | body.append("Root location: {}".format(self.root))
39 | body += self.extra_repr().splitlines()
40 | if hasattr(self, "transforms") and self.transforms is not None:
41 | body += [repr(self.transforms)]
42 | lines = [head] + [" " * self._repr_indent + line for line in body]
43 | return '\n'.join(lines)
44 |
45 | def _format_transform_repr(self, transform, head):
46 | lines = transform.__repr__().splitlines()
47 | return (["{}{}".format(head, lines[0])] +
48 | ["{}{}".format(" " * len(head), line) for line in lines[1:]])
49 |
50 | def extra_repr(self):
51 | return ""
52 |
53 |
54 | class StandardTransform(object):
55 | def __init__(self, transform=None, target_transform=None):
56 | self.transform = transform
57 | self.target_transform = target_transform
58 |
59 | def __call__(self, input, target):
60 | if self.transform is not None:
61 | input = self.transform(input)
62 | if self.target_transform is not None:
63 | target = self.target_transform(target)
64 | return input, target
65 |
66 | def _format_transform_repr(self, transform, head):
67 | lines = transform.__repr__().splitlines()
68 | return (["{}{}".format(head, lines[0])] +
69 | ["{}{}".format(" " * len(head), line) for line in lines[1:]])
70 |
71 | def __repr__(self):
72 | body = [self.__class__.__name__]
73 | if self.transform is not None:
74 | body += self._format_transform_repr(self.transform,
75 | "Transform: ")
76 | if self.target_transform is not None:
77 | body += self._format_transform_repr(self.target_transform,
78 | "Target transform: ")
79 |
80 | return '\n'.join(body)
81 |
--------------------------------------------------------------------------------
/datasets/imagefolder_cgc_ssl.py:
--------------------------------------------------------------------------------
1 | from .vision import VisionDataset
2 |
3 | from PIL import Image
4 | import random
5 | import os
6 | import os.path
7 | import torchvision
8 | from torchvision import transforms
9 | from torchvision.transforms import functional as tvf
10 |
11 |
12 | def has_file_allowed_extension(filename, extensions):
13 | """Checks if a file is an allowed extension.
14 |
15 | Args:
16 | filename (string): path to a file
17 | extensions (tuple of strings): extensions to consider (lowercase)
18 |
19 | Returns:
20 | bool: True if the filename ends with one of given extensions
21 | """
22 | return filename.lower().endswith(extensions)
23 |
24 |
25 | def is_image_file(filename):
26 | """Checks if a file is an allowed image extension.
27 |
28 | Args:
29 | filename (string): path to a file
30 |
31 | Returns:
32 | bool: True if the filename ends with a known image extension
33 | """
34 | return has_file_allowed_extension(filename, IMG_EXTENSIONS)
35 |
36 |
37 | def make_dataset(directory, class_to_idx, extensions=None, is_valid_file=None):
38 | instances = []
39 | directory = os.path.expanduser(directory)
40 | both_none = extensions is None and is_valid_file is None
41 | both_something = extensions is not None and is_valid_file is not None
42 | if both_none or both_something:
43 | raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
44 | if extensions is not None:
45 | def is_valid_file(x):
46 | return has_file_allowed_extension(x, extensions)
47 | for target_class in sorted(class_to_idx.keys()):
48 | class_index = class_to_idx[target_class]
49 | target_dir = os.path.join(directory, target_class)
50 | if not os.path.isdir(target_dir):
51 | continue
52 | for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
53 | for fname in sorted(fnames):
54 | path = os.path.join(root, fname)
55 | if is_valid_file(path):
56 | item = path, class_index
57 | instances.append(item)
58 | return instances
59 |
60 |
61 | class DatasetFolder(VisionDataset):
62 | """A generic data loader where the samples are arranged in this way: ::
63 |
64 | root/class_x/xxx.ext
65 | root/class_x/xxy.ext
66 | root/class_x/xxz.ext
67 |
68 | root/class_y/123.ext
69 | root/class_y/nsdf3.ext
70 | root/class_y/asd932_.ext
71 |
72 | Args:
73 | root (string): Root directory path.
74 | loader (callable): A function to load a sample given its path.
75 | extensions (tuple[string]): A list of allowed extensions.
76 | both extensions and is_valid_file should not be passed.
77 | transform (callable, optional): A function/transform that takes in
78 | a sample and returns a transformed version.
79 | E.g, ``transforms.RandomCrop`` for images.
80 | target_transform (callable, optional): A function/transform that takes
81 | in the target and transforms it.
82 | is_valid_file (callable, optional): A function that takes path of a file
83 | and check if the file is a valid file (used to check of corrupt files)
84 | both extensions and is_valid_file should not be passed.
85 |
86 | Attributes:
87 | classes (list): List of the class names sorted alphabetically.
88 | class_to_idx (dict): Dict with items (class_name, class_index).
89 | samples (list): List of (sample path, class_index) tuples
90 | targets (list): The class_index value for each image in the dataset
91 | """
92 |
93 | def __init__(self, root, loader, extensions=None, transform=None,
94 | target_transform=None, is_valid_file=None):
95 | super(DatasetFolder, self).__init__(root, transform=transform,
96 | target_transform=target_transform)
97 | classes, class_to_idx = self._find_classes(self.root)
98 | samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
99 | if len(samples) == 0:
100 | msg = "Found 0 files in subfolders of: {}\n".format(self.root)
101 | if extensions is not None:
102 | msg += "Supported extensions are: {}".format(",".join(extensions))
103 | raise RuntimeError(msg)
104 |
105 | self.loader = loader
106 | self.extensions = extensions
107 |
108 | self.classes = classes
109 | self.class_to_idx = class_to_idx
110 | self.samples = samples
111 | self.targets = [s[1] for s in samples]
112 | self.hor_flip = transforms.RandomHorizontalFlip()
113 | self.to_tensor = transforms.ToTensor()
114 | self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
115 |
116 |
117 | def _find_classes(self, dir):
118 | """
119 | Finds the class folders in a dataset.
120 |
121 | Args:
122 | dir (string): Root directory path.
123 |
124 | Returns:
125 | tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
126 |
127 | Ensures:
128 | No class is a subdirectory of another.
129 | """
130 | classes = [d.name for d in os.scandir(dir) if d.is_dir()]
131 | classes.sort()
132 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
133 | return classes, class_to_idx
134 |
135 | def __getitem__(self, index):
136 | """
137 | Args:
138 | index (int): Index
139 |
140 | Returns:
141 | tuple: (sample, target) where target is class_index of the target class.
142 | """
143 | path, target = self.samples[index]
144 | sample = self.loader(path)
145 |
146 | # we first obtain an image with the regular augmentations used for cross-entropy loss
147 | xe_sample = transforms.RandomResizedCrop(224)(sample)
148 | xe_sample = transforms.RandomHorizontalFlip()(xe_sample)
149 |
150 | # next, we obtain the pair of image and augmented image to be used for CGC loss
151 | # We resize the image to have 256 on the smallest side and take center crop
152 | sample = transforms.Resize(256)(sample)
153 | sample = transforms.CenterCrop(224)(sample)
154 |
155 | # We manually apply the transformations.
156 | # Namely, RandomResizedCrop()->horizontal_flip(0.5)->Totensor()->Normalize()
157 | i, j, h, w = torchvision.transforms.RandomResizedCrop.get_params(sample, scale=(0.08, 1.0),
158 | ratio=(0.75, 1.3333333333333333))
159 | aug_sample = tvf.resized_crop(sample, i, j, h, w, size=(224, 224))
160 |
161 | hor_flip = False
162 | if random.random() > 0.5:
163 | aug_sample = self.hor_flip(aug_sample)
164 | hor_flip = True
165 | aug_sample = self.to_tensor(aug_sample)
166 | aug_sample = self.normalize(aug_sample)
167 |
168 | sample = self.to_tensor(sample)
169 | sample = self.normalize(sample)
170 |
171 | xe_sample = self.to_tensor(xe_sample)
172 | xe_sample = self.normalize(xe_sample)
173 |
174 | return xe_sample , sample, aug_sample, i, j, h, w, hor_flip, target
175 |
176 | def __len__(self):
177 | return len(self.samples)
178 |
179 |
180 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
181 |
182 |
183 | def pil_loader(path):
184 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
185 | with open(path, 'rb') as f:
186 | img = Image.open(f)
187 | return img.convert('RGB')
188 |
189 |
190 | def accimage_loader(path):
191 | import accimage
192 | try:
193 | return accimage.Image(path)
194 | except IOError:
195 | # Potentially a decoding problem, fall back to PIL.Image
196 | return pil_loader(path)
197 |
198 |
199 | def default_loader(path):
200 | from torchvision import get_image_backend
201 | if get_image_backend() == 'accimage':
202 | return accimage_loader(path)
203 | else:
204 | return pil_loader(path)
205 |
206 |
207 | class ImageFolder(DatasetFolder):
208 | """A generic data loader where the images are arranged in this way: ::
209 |
210 | root/dog/xxx.png
211 | root/dog/xxy.png
212 | root/dog/xxz.png
213 |
214 | root/cat/123.png
215 | root/cat/nsdf3.png
216 | root/cat/asd932_.png
217 |
218 | Args:
219 | root (string): Root directory path.
220 | transform (callable, optional): A function/transform that takes in an PIL image
221 | and returns a transformed version. E.g, ``transforms.RandomCrop``
222 | target_transform (callable, optional): A function/transform that takes in the
223 | target and transforms it.
224 | loader (callable, optional): A function to load an image given its path.
225 | is_valid_file (callable, optional): A function that takes path of an Image file
226 | and check if the file is a valid file (used to check of corrupt files)
227 |
228 | Attributes:
229 | classes (list): List of the class names sorted alphabetically.
230 | class_to_idx (dict): Dict with items (class_name, class_index).
231 | imgs (list): List of (image path, class_index) tuples
232 | """
233 |
234 | def __init__(self, root, transform=None, target_transform=None,
235 | loader=default_loader, is_valid_file=None):
236 | super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
237 | transform=transform,
238 | target_transform=target_transform,
239 | is_valid_file=is_valid_file)
240 | self.imgs = self.samples
--------------------------------------------------------------------------------
/datasets/vision.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.utils.data as data
4 |
5 |
6 | class VisionDataset(data.Dataset):
7 | _repr_indent = 4
8 |
9 | def __init__(self, root, transforms=None, transform=None, target_transform=None):
10 | if isinstance(root, torch._six.string_classes):
11 | root = os.path.expanduser(root)
12 | self.root = root
13 |
14 | has_transforms = transforms is not None
15 | has_separate_transform = transform is not None or target_transform is not None
16 | if has_transforms and has_separate_transform:
17 | raise ValueError("Only transforms or transform/target_transform can "
18 | "be passed as argument")
19 |
20 | # for backwards-compatibility
21 | self.transform = transform
22 | self.target_transform = target_transform
23 |
24 | if has_separate_transform:
25 | transforms = StandardTransform(transform, target_transform)
26 | self.transforms = transforms
27 |
28 | def __getitem__(self, index):
29 | raise NotImplementedError
30 |
31 | def __len__(self):
32 | raise NotImplementedError
33 |
34 | def __repr__(self):
35 | head = "Dataset " + self.__class__.__name__
36 | body = ["Number of datapoints: {}".format(self.__len__())]
37 | if self.root is not None:
38 | body.append("Root location: {}".format(self.root))
39 | body += self.extra_repr().splitlines()
40 | if hasattr(self, "transforms") and self.transforms is not None:
41 | body += [repr(self.transforms)]
42 | lines = [head] + [" " * self._repr_indent + line for line in body]
43 | return '\n'.join(lines)
44 |
45 | def _format_transform_repr(self, transform, head):
46 | lines = transform.__repr__().splitlines()
47 | return (["{}{}".format(head, lines[0])] +
48 | ["{}{}".format(" " * len(head), line) for line in lines[1:]])
49 |
50 | def extra_repr(self):
51 | return ""
52 |
53 |
54 | class StandardTransform(object):
55 | def __init__(self, transform=None, target_transform=None):
56 | self.transform = transform
57 | self.target_transform = target_transform
58 |
59 | def __call__(self, input, target):
60 | if self.transform is not None:
61 | input = self.transform(input)
62 | if self.target_transform is not None:
63 | target = self.target_transform(target)
64 | return input, target
65 |
66 | def _format_transform_repr(self, transform, head):
67 | lines = transform.__repr__().splitlines()
68 | return (["{}{}".format(head, lines[0])] +
69 | ["{}{}".format(" " * len(head), line) for line in lines[1:]])
70 |
71 | def __repr__(self):
72 | body = [self.__class__.__name__]
73 | if self.transform is not None:
74 | body += self._format_transform_repr(self.transform,
75 | "Transform: ")
76 | if self.target_transform is not None:
77 | body += self._format_transform_repr(self.target_transform,
78 | "Target transform: ")
79 |
80 | return '\n'.join(body)
81 |
--------------------------------------------------------------------------------
/misc/teaser_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCDvision/CGC/a66d87240863e19cc43d11c9e715ca447614042d/misc/teaser_image.png
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | def compute_gradcam(output, feats , target, relu):
7 | """
8 | Compute the gradcam for the top predicted category
9 | :param output: the model output before softmax
10 | :param feats: the feature output from the desired layer to be used for computing Grad-CAM
11 | :param target: The target category to be used for computing Grad-CAM
12 | :return:
13 | """
14 |
15 | one_hot = np.zeros((output.shape[0], output.shape[-1]), dtype=np.float32)
16 | indices_range = np.arange(output.shape[0])
17 | one_hot[indices_range, target[indices_range]] = 1
18 | one_hot = torch.from_numpy(one_hot)
19 | one_hot.requires_grad = True
20 |
21 | # Compute the Grad-CAM for the original image
22 | one_hot_cuda = torch.sum(one_hot.cuda() * output)
23 | dy_dz1, = torch.autograd.grad(one_hot_cuda, feats, grad_outputs=torch.ones(one_hot_cuda.size()).cuda(),
24 | retain_graph=True, create_graph=True)
25 | # Changing to dot product of grad and features to preserve grad spatial locations
26 | gcam512_1 = dy_dz1 * feats
27 | gradcam = gcam512_1.sum(dim=1)
28 | gradcam = relu(gradcam)
29 |
30 | return gradcam
31 |
32 |
33 | def compute_gradcam_mask(images_outputs, images_feats , target, relu):
34 | """
35 | This function computes the grad-cam, upsamples it to the image size and normalizes the Grad-CAM mask.
36 | """
37 | eps = 1e-8
38 | gradcam_mask = compute_gradcam(images_outputs, images_feats , target, relu)
39 | gradcam_mask = gradcam_mask.unsqueeze(1)
40 | gradcam_mask = F.interpolate(gradcam_mask, size=224, mode='bilinear')
41 | gradcam_mask = gradcam_mask.squeeze()
42 | # normalize the gradcam mask to sum to 1
43 | gradcam_mask_sum = gradcam_mask.sum(dim=[1, 2]).unsqueeze(-1).unsqueeze(-1)
44 | gradcam_mask = (gradcam_mask / (gradcam_mask_sum + eps)) + eps
45 |
46 | return gradcam_mask
47 |
48 |
49 | def perform_gradcam_aug(orig_gradcam_mask, aug_params_dict):
50 | """
51 | This function uses the augmentation params per batch element and manually applies to the
52 | grad-cam mask to obtain the corresponding augmented grad-cam mask.
53 | """
54 | transforms_i = aug_params_dict['transforms_i']
55 | transforms_j = aug_params_dict['transforms_j']
56 | transforms_h = aug_params_dict['transforms_h']
57 | transforms_w = aug_params_dict['transforms_w']
58 | hor_flip = aug_params_dict['hor_flip']
59 | gpu_batch_len = transforms_i.shape[0]
60 | augmented_orig_gradcam_mask = torch.zeros_like(orig_gradcam_mask).cuda()
61 | for b in range(gpu_batch_len):
62 | # convert orig_gradcam_mask to image
63 | orig_gcam = orig_gradcam_mask[b]
64 | orig_gcam = orig_gcam[transforms_i[b]: transforms_i[b] + transforms_h[b],
65 | transforms_j[b]: transforms_j[b] + transforms_w[b]]
66 | # We use torch functional to resize without breaking the graph
67 | orig_gcam = orig_gcam.unsqueeze(0).unsqueeze(0)
68 | orig_gcam = F.interpolate(orig_gcam, size=224, mode='bilinear')
69 | orig_gcam = orig_gcam.squeeze()
70 | if hor_flip[b]:
71 | orig_gcam = orig_gcam.flip(-1)
72 | augmented_orig_gradcam_mask[b, :, :] = orig_gcam[:, :]
73 | return augmented_orig_gradcam_mask
--------------------------------------------------------------------------------