├── LICENSE ├── README.md ├── RISE ├── evaluate_auc_metrics.py ├── evaluation.py ├── pytorch_grad_cam │ ├── __init__.py │ ├── ablation_cam.py │ ├── activations_and_gradients.py │ ├── base_cam.py │ ├── eigen_cam.py │ ├── eigen_grad_cam.py │ ├── grad_cam.py │ ├── grad_cam_plusplus.py │ ├── guided_backprop.py │ ├── score_cam.py │ ├── utils │ │ ├── __init__.py │ │ ├── image.py │ │ └── svd_on_activations.py │ └── xgrad_cam.py └── utils.py ├── TorchRay ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── docs │ ├── attribution.rst │ ├── benchmark.rst │ ├── conf.py │ ├── index.rst │ ├── make-docs │ ├── pointing.csv │ ├── static │ │ └── css │ │ │ └── equations.css │ └── utils.rst ├── examples │ ├── __init__.py │ ├── __main__.py │ ├── attribution_benchmark.py │ ├── attribution_benchmark_debug.py │ ├── contrastive_excitation_backprop.py │ ├── contrastive_excitation_backprop_manual.py │ ├── deconvnet.py │ ├── deconvnet_manual.py │ ├── excitation_backprop.py │ ├── excitation_backprop_manual.py │ ├── extremal_perturbation.py │ ├── grad_cam.py │ ├── grad_cam_manual.py │ ├── gradient.py │ ├── gradient_manual.py │ ├── guided_backprop.py │ ├── guided_backprop_manual.py │ ├── linear_approx.py │ ├── linear_approx_manual.py │ └── rise.py ├── packaging │ └── meta.yaml ├── scripts │ └── torchrayrc ├── setup.py └── torchray │ ├── VERSION │ ├── __init__.py │ ├── attribution │ ├── __init__.py │ ├── common.py │ ├── deconvnet.py │ ├── excitation_backprop.py │ ├── extremal_perturbation.py │ ├── grad_cam.py │ ├── gradient.py │ ├── guided_backprop.py │ ├── linear_approx.py │ ├── resnet_maxpool.py │ └── rise.py │ ├── benchmark │ ├── __init__.py │ ├── datasets.py │ ├── evaluate_finegrained_gradcam_energy_inside_bbox.py │ ├── evaluate_imagenet_excitation_backprop_energy_inside_bbox.py │ ├── evaluate_imagenet_gradcam_energy_inside_bbox.py │ ├── evaluate_swav_imagenet_gradcam_energy_inside_bbox.py │ ├── imagenet_classes.txt │ ├── logging_mongo.py │ ├── models.py │ ├── pointing_game.py │ ├── resnet_multigpu_cgc.py │ ├── server.py │ ├── swav_resnet_cgc.py │ └── vision.py │ └── utils.py ├── baseline_train_eval.py ├── datasets ├── imagefolder_cgc_ssl.py └── vision.py ├── misc └── teaser_image.png ├── models ├── resnet_multigpu_cgc.py ├── swav_resnet_cgc.py └── utils.py ├── train_eval_cgc.py ├── train_eval_gc_l2_no_contrast_baseline.py └── train_imagenet_1pc_swav_cgc_unlabeled.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 UCDvision 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Consistent-Explanations-by-Contrastive-Learning 2 | Official PyTorch code for CVPR 2022 paper - [Consistent Explanations by Contrastive Learning][1] 3 | 4 | 5 | Post-hoc explanation methods, e.g., Grad-CAM, enable humans to inspect the spatial regions responsible for a particular network decision. However, it is shown that such explanations are not always consistent with human priors, such as consistency across image transformations. Given an interpretation algorithm, e.g., Grad-CAM, we introduce a novel training method to train the model to produce more consistent explanations. Since obtaining the ground truth for a desired model interpretation is not a well-defined task, we adopt ideas from contrastive self-supervised learning, and apply them to the interpretations of the model rather than its embeddings. We show that our method, Contrastive Grad-CAM Consistency (CGC), results in Grad-CAM interpretation heatmaps that are more consistent with human annotations while still achieving comparable classification accuracy. Moreover, our method acts as a regularizer and improves the accuracy on limited-data, fine-grained classification settings. In addition, because our method does not rely on annotations, it allows for the incorporation of unlabeled data into training, which enables better generalization of the model. 6 | 7 | ![Teaser image][teaser] 8 | 9 |
10 | 11 | ## Bibtex 12 | ``` 13 | @InProceedings{Pillai_2022_CVPR, 14 | author = {Pillai, Vipin and Abbasi Koohpayegani, Soroush and Ouligian, Ashley and Fong, Dennis and Pirsiavash, Hamed}, 15 | title = {Consistent Explanations by Contrastive Learning}, 16 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 17 | month = {June}, 18 | year = {2022} 19 | } 20 | ``` 21 | 22 | ## Pre-requisites 23 | - Pytorch 1.3 - Please install [PyTorch](https://pytorch.org/get-started/locally/) and CUDA if you don't have it installed. 24 | 25 | ## Datasets 26 | - [ImageNet - 1K](https://www.image-net.org/download.php) 27 | - [CUB-200](https://vision.cornell.edu/se3/caltech-ucsd-birds-200/) 28 | - [FGVC-Aircraft](https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/) 29 | - [Stanford Cars-196](https://ai.stanford.edu/~jkrause/cars/car_dataset.html) 30 | - [VGG Flowers-102](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/) 31 | 32 | ## Training 33 | 34 | #### Train and evaluate a ResNet50 model on the ImageNet dataset using our CGC loss 35 | ``` 36 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train_eval_cgc.py /datasets/imagenet -a resnet50 -p 100 -j 8 -b 256 --lr 0.1 --lambda 0.5 -t 0.5 --save_dir --log_dir 37 | ``` 38 | 39 | #### Train and evaluate a ResNet50 model on 1pc labeled subset of ImageNet dataset and the rest as unlabeled dataset. We initialize the model from SwAV 40 | For the below command, can be downloaded from the github directory of SwAV - https://github.com/facebookresearch/swav 41 | We use the model checkpoint listed on the first row (800 epochs, 75.3% ImageNet top-1 acc.) of the Model Zoo of the above repository. 42 | 43 | ``` 44 | CUDA_VISIBLE_DEVICES=0,1 python train_imagenet_1pc_swav_cgc_unlabeled.py -a resnet50 -b 128 -j 8 --lambda 0.25 -t 0.5 --epochs 50 --lr 0.02 --lr_last_layer 5 --resume --save_dir --log_dir 2>&1 | tee 45 | ``` 46 | 47 |
48 | 49 | ## Checkpoints 50 | * ResNet50 model pre-trained on ImageNet - [link](https://drive.google.com/drive/folders/1n7lFew0CdWuYCpR1kImMt7UC7_vsO5CT?usp=sharing) 51 | 52 | ## Evaluation 53 | 54 | #### Evaluate model checkpoint using Content Heatmap (CH) evaluation metric 55 | We use the evaluation code adapted from the TorchRay framework. 56 | * Change directory to TorchRay and install the library. Please refer to the [TorchRay](https://github.com/facebookresearch/TorchRay) repository for full documentation and instructions. 57 | * cd TorchRay 58 | * python setup.py install 59 | 60 | * Change directory to TorchRay/torchray/benchmark 61 | * cd torchray/benchmark 62 | 63 | * For the ImageNet & CUB-200 datasets, this evaluation requires the following structure for validation images and bounding box xml annotations 64 | * /val/*.JPEG - Flat list of validation images 65 | * /annotation/*.xml - Flat list of annotation xml files 66 | 67 | ##### Evaluate ResNet50 models trained on the full ImageNet dataset 68 | ``` 69 | CUDA_VISIBLE_DEVICES=0 python evaluate_imagenet_gradcam_energy_inside_bbox.py -j 0 -b 1 --resume --input_resize 448 -a resnet50 70 | ``` 71 | 72 | ##### Evaluate ResNet50 models trained on the CUB-200 fine-grained dataset 73 | ``` 74 | CUDA_VISIBLE_DEVICES=0 python evaluate_finegrained_gradcam_energy_inside_bbox.py --dataset cub -j 0 -b 1 --resume --input_resize 448 -a resnet50 75 | ``` 76 | 77 | ##### Evaluate ResNet50 models trained from SwAV initialized models with 1pc labeled subset of ImageNet and rest as unlabeled 78 | ``` 79 | CUDA_VISIBLE_DEVICES=0 python evaluate_swav_imagenet_gradcam_energy_inside_bbox.py -j 0 -b 1 --resume --input_resize 448 -a resnet50 80 | ``` 81 | 82 |
83 | 84 | #### Evaluate model checkpoint using Insertion AUC (IAUC) evaluation metric 85 | Change to directory RISE/ and follow the below commands: 86 | 87 | ##### Evaluate pre-trained ResNet50 model 88 | ``` 89 | CUDA_VISIBLE_DEVICES=0 python evaluate_auc_metrics.py --pretrained 90 | ``` 91 | 92 | ##### Evaluate ResNet50 model trained using our CGC method 93 | ``` 94 | CUDA_VISIBLE_DEVICES=0 python evaluate_auc_metrics.py --ckpt-path 95 | ``` 96 | 97 |
98 | 99 | ## License 100 | This project is licensed under the MIT License. 101 | 102 | [1]: https://arxiv.org/pdf/2110.00527.pdf 103 | [teaser]: https://github.com/UMBCvision/Consistent-Explanations-by-Contrastive-Learning/blob/main/misc/teaser_image.png 104 | -------------------------------------------------------------------------------- /RISE/evaluate_auc_metrics.py: -------------------------------------------------------------------------------- 1 | ## Code to evaluate the pre-trained model and our CGC trained model with Insertion AUC score. 2 | 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.backends.cudnn as cudnn 9 | import torchvision.datasets as datasets 10 | import torchvision.models as models 11 | import torch.nn.functional as F 12 | from utils import * 13 | from evaluation import CausalMetric, auc, gkern 14 | from pytorch_grad_cam import GradCAM 15 | import argparse 16 | 17 | parser = argparse.ArgumentParser(description='PyTorch AUC Metric Evaluation') 18 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 19 | help='use pre-trained model') 20 | parser.add_argument('--ckpt-path', dest='ckpt_path', type=str, help='path to checkpoint file') 21 | 22 | def main(): 23 | args = parser.parse_args() 24 | 25 | cudnn.benchmark = True 26 | 27 | scores = {'del': [], 'ins': []} 28 | if args.pretrained: 29 | net = models.resnet50(pretrained=True) 30 | else: 31 | net = models.resnet50() 32 | state_dict = torch.load(args.ckpt_path)['state_dict'] 33 | 34 | # remove the module prefix if model was saved with DataParallel 35 | state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} 36 | # load params 37 | net.load_state_dict(state_dict) 38 | 39 | target_layer = net.layer4[-1] 40 | cam = GradCAM(model=net, target_layer=target_layer, use_cuda=True) 41 | 42 | # we process the imagenet 50k val images in 10 set of 5k each and compute mean 43 | for i in range(10): 44 | auc_score = get_auc_per_data_subset(i, net, cam) 45 | scores['ins'].append(auc_score) 46 | print('Finished evaluating the insertion metrics...') 47 | 48 | print('----------------------------------------------------------------') 49 | print('Final:\nInsertion - {:.5f}'.format(np.mean(scores['ins']))) 50 | 51 | 52 | def get_auc_per_data_subset(range_index, net, cam): 53 | batch_size = 100 54 | data_loader = torch.utils.data.DataLoader( 55 | dataset=datasets.ImageFolder('/nfs3/datasets/imagenet/val/', preprocess), 56 | batch_size=batch_size, shuffle=False, 57 | num_workers=8, pin_memory=True, sampler=RangeSampler(range(5000 * range_index, 5000 * (range_index + 1)))) 58 | 59 | net = net.train() 60 | 61 | images = [] 62 | targets = [] 63 | gcam_exp = [] 64 | 65 | for j, (img, trg) in enumerate(tqdm(data_loader, total=len(data_loader), desc='Loading images')): 66 | grayscale_gradcam = cam(input_tensor=img, target_category=trg) 67 | for k in range(batch_size): 68 | images.append(img[k]) 69 | targets.append(trg[k]) 70 | gcam_exp.append(grayscale_gradcam[k]) 71 | 72 | images = torch.stack(images).cpu().numpy() 73 | gcam_exp = np.stack(gcam_exp) 74 | images = np.asarray(images) 75 | gcam_exp = np.asarray(gcam_exp) 76 | 77 | images = images.reshape((-1, 3, 224, 224)) 78 | gcam_exp = gcam_exp.reshape((-1, 224, 224)) 79 | print('Finished obtaining CAM') 80 | 81 | model = nn.Sequential(net, nn.Softmax(dim=1)) 82 | model = model.eval() 83 | model = model.cuda() 84 | 85 | for p in model.parameters(): 86 | p.requires_grad = False 87 | 88 | # To use multiple GPUs 89 | ddp_model = nn.DataParallel(model) 90 | 91 | # we use blur as the substrate function 92 | klen = 11 93 | ksig = 5 94 | kern = gkern(klen, ksig) 95 | # Function that blurs input image 96 | blur = lambda x: F.conv2d(x, kern, padding=klen // 2) 97 | 98 | insertion = CausalMetric(ddp_model, 'ins', 224 * 8, substrate_fn=blur) 99 | 100 | # Evaluate insertion 101 | h = insertion.evaluate(torch.from_numpy(images.astype('float32')), gcam_exp, batch_size) 102 | 103 | model = model.train() 104 | for p in model.parameters(): 105 | p.requires_grad = True 106 | 107 | return auc(h.mean(1)) 108 | 109 | 110 | if __name__ == '__main__': 111 | main() 112 | -------------------------------------------------------------------------------- /RISE/evaluation.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from tqdm import tqdm 3 | from scipy.ndimage.filters import gaussian_filter 4 | 5 | from utils import * 6 | 7 | HW = 224 * 224 # image area 8 | n_classes = 1000 9 | 10 | def gkern(klen, nsig): 11 | """Returns a Gaussian kernel array. 12 | Convolution with it results in image blurring.""" 13 | # create nxn zeros 14 | inp = np.zeros((klen, klen)) 15 | # set element at the middle to one, a dirac delta 16 | inp[klen//2, klen//2] = 1 17 | # gaussian-smooth the dirac, resulting in a gaussian filter mask 18 | k = gaussian_filter(inp, nsig) 19 | kern = np.zeros((3, 3, klen, klen)) 20 | kern[0, 0] = k 21 | kern[1, 1] = k 22 | kern[2, 2] = k 23 | return torch.from_numpy(kern.astype('float32')) 24 | 25 | def auc(arr): 26 | """Returns normalized Area Under Curve of the array.""" 27 | return (arr.sum() - arr[0] / 2 - arr[-1] / 2) / (arr.shape[0] - 1) 28 | 29 | class CausalMetric(): 30 | 31 | def __init__(self, model, mode, step, substrate_fn): 32 | r"""Create deletion/insertion metric instance. 33 | 34 | Args: 35 | model (nn.Module): Black-box model being explained. 36 | mode (str): 'del' or 'ins'. 37 | step (int): number of pixels modified per one iteration. 38 | substrate_fn (func): a mapping from old pixels to new pixels. 39 | """ 40 | assert mode in ['del', 'ins'] 41 | self.model = model 42 | self.mode = mode 43 | self.step = step 44 | self.substrate_fn = substrate_fn 45 | 46 | def single_run(self, img_tensor, explanation, verbose=0, save_to=None): 47 | r"""Run metric on one image-saliency pair. 48 | 49 | Args: 50 | img_tensor (Tensor): normalized image tensor. 51 | explanation (np.ndarray): saliency map. 52 | verbose (int): in [0, 1, 2]. 53 | 0 - return list of scores. 54 | 1 - also plot final step. 55 | 2 - also plot every step and print 2 top classes. 56 | save_to (str): directory to save every step plots to. 57 | 58 | Return: 59 | scores (nd.array): Array containing scores at every step. 60 | """ 61 | pred = self.model(img_tensor.cuda()) 62 | top, c = torch.max(pred, 1) 63 | c = c.cpu().numpy()[0] 64 | n_steps = (HW + self.step - 1) // self.step 65 | 66 | if self.mode == 'del': 67 | title = 'Deletion game' 68 | ylabel = 'Pixels deleted' 69 | start = img_tensor.clone() 70 | finish = self.substrate_fn(img_tensor) 71 | elif self.mode == 'ins': 72 | title = 'Insertion game' 73 | ylabel = 'Pixels inserted' 74 | start = self.substrate_fn(img_tensor) 75 | finish = img_tensor.clone() 76 | 77 | scores = np.empty(n_steps + 1) 78 | # Coordinates of pixels in order of decreasing saliency 79 | salient_order = np.flip(np.argsort(explanation.reshape(-1, HW), axis=1), axis=-1) 80 | for i in range(n_steps+1): 81 | pred = self.model(start.cuda()) 82 | pr, cl = torch.topk(pred, 2) 83 | if verbose == 2: 84 | print('{}: {:.3f}'.format(get_class_name(cl[0][0]), float(pr[0][0]))) 85 | print('{}: {:.3f}'.format(get_class_name(cl[0][1]), float(pr[0][1]))) 86 | scores[i] = pred[0, c] 87 | # Render image if verbose, if it's the last step or if save is required. 88 | if verbose == 2 or (verbose == 1 and i == n_steps) or save_to: 89 | plt.figure(figsize=(10, 5)) 90 | plt.subplot(121) 91 | plt.title('{} {:.1f}%, P={:.4f}'.format(ylabel, 100 * i / n_steps, scores[i])) 92 | plt.axis('off') 93 | tensor_imshow(start[0]) 94 | 95 | plt.subplot(122) 96 | plt.plot(np.arange(i+1) / n_steps, scores[:i+1]) 97 | plt.xlim(-0.1, 1.1) 98 | plt.ylim(0, 1.05) 99 | plt.fill_between(np.arange(i+1) / n_steps, 0, scores[:i+1], alpha=0.4) 100 | plt.title(title) 101 | plt.xlabel(ylabel) 102 | plt.ylabel(get_class_name(c)) 103 | if save_to: 104 | plt.savefig(save_to + '/{:03d}.png'.format(i)) 105 | plt.close() 106 | else: 107 | plt.show() 108 | if i < n_steps: 109 | coords = salient_order[:, self.step * i:self.step * (i + 1)] 110 | start.cpu().numpy().reshape(1, 3, HW)[0, :, coords] = finish.cpu().numpy().reshape(1, 3, HW)[0, :, coords] 111 | return scores 112 | 113 | def evaluate(self, img_batch, exp_batch, batch_size): 114 | r"""Efficiently evaluate big batch of images. 115 | 116 | Args: 117 | img_batch (Tensor): batch of images. 118 | exp_batch (np.ndarray): batch of explanations. 119 | batch_size (int): number of images for one small batch. 120 | 121 | Returns: 122 | scores (nd.array): Array containing scores at every step for every image. 123 | """ 124 | n_samples = img_batch.shape[0] 125 | predictions = torch.FloatTensor(n_samples, n_classes) 126 | # assert n_samples % batch_size == 0 127 | for i in tqdm(range(n_samples // batch_size), desc='Predicting labels'): 128 | preds = self.model(img_batch[i*batch_size:(i+1)*batch_size].cuda()).cpu() 129 | predictions[i*batch_size:(i+1)*batch_size] = preds 130 | if n_samples % batch_size != 0: 131 | start_index = n_samples // batch_size 132 | preds = self.model(img_batch[start_index*batch_size:].cuda()).cpu() 133 | predictions[start_index*batch_size:] = preds 134 | 135 | top = np.argmax(predictions, -1) 136 | n_steps = (HW + self.step - 1) // self.step 137 | scores = np.empty((n_steps + 1, n_samples)) 138 | salient_order = np.flip(np.argsort(exp_batch.reshape(-1, HW), axis=1), axis=-1) 139 | r = np.arange(n_samples).reshape(n_samples, 1) 140 | 141 | substrate = torch.zeros_like(img_batch) 142 | for j in tqdm(range(n_samples // batch_size), desc='Substrate'): 143 | substrate[j*batch_size:(j+1)*batch_size] = self.substrate_fn(img_batch[j*batch_size:(j+1)*batch_size]) 144 | 145 | if n_samples % batch_size != 0: 146 | start_index = n_samples // batch_size 147 | substrate[start_index*batch_size:] = self.substrate_fn(img_batch[start_index*batch_size:]) 148 | 149 | if self.mode == 'del': 150 | caption = 'Deleting ' 151 | start = img_batch.clone() 152 | finish = substrate 153 | elif self.mode == 'ins': 154 | caption = 'Inserting ' 155 | start = substrate 156 | finish = img_batch.clone() 157 | 158 | # While not all pixels are changed 159 | for i in tqdm(range(n_steps+1), desc=caption + 'pixels'): 160 | # Iterate over batches 161 | for j in range(n_samples // batch_size): 162 | # Compute new scores 163 | preds = self.model(start[j*batch_size:(j+1)*batch_size].cuda()) 164 | preds = preds.cpu().numpy()[range(batch_size), top[j*batch_size:(j+1)*batch_size]] 165 | scores[i, j*batch_size:(j+1)*batch_size] = preds 166 | if n_samples % batch_size != 0: 167 | start_index = n_samples // batch_size 168 | preds = self.model(start[start_index*batch_size:].cuda()) 169 | preds = preds.cpu().numpy()[range(n_samples%batch_size), top[start_index*batch_size:]] 170 | scores[i, start_index*batch_size:] = preds 171 | 172 | # Change specified number of most salient pixels to substrate pixels 173 | coords = salient_order[:, self.step * i:self.step * (i + 1)] 174 | start.cpu().numpy().reshape(n_samples, 3, HW)[r, :, coords] = finish.cpu().numpy().reshape(n_samples, 3, HW)[r, :, coords] 175 | print('AUC: {}'.format(auc(scores.mean(1)))) 176 | return scores 177 | -------------------------------------------------------------------------------- /RISE/pytorch_grad_cam/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_grad_cam.grad_cam import GradCAM 2 | from pytorch_grad_cam.ablation_cam import AblationCAM 3 | from pytorch_grad_cam.xgrad_cam import XGradCAM 4 | from pytorch_grad_cam.grad_cam_plusplus import GradCAMPlusPlus 5 | from pytorch_grad_cam.score_cam import ScoreCAM 6 | from pytorch_grad_cam.eigen_cam import EigenCAM 7 | from pytorch_grad_cam.eigen_grad_cam import EigenGradCAM 8 | from pytorch_grad_cam.guided_backprop import GuidedBackpropReLUModel -------------------------------------------------------------------------------- /RISE/pytorch_grad_cam/ablation_cam.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import tqdm 5 | from pytorch_grad_cam.base_cam import BaseCAM 6 | 7 | class AblationLayer(torch.nn.Module): 8 | def __init__(self, layer, reshape_transform, indices): 9 | super(AblationLayer, self).__init__() 10 | 11 | self.layer = layer 12 | self.reshape_transform = reshape_transform 13 | # The channels to zero out: 14 | self.indices = indices 15 | 16 | def forward(self, x): 17 | self.__call__(x) 18 | 19 | def __call__(self, x): 20 | output = self.layer(x) 21 | 22 | # Hack to work with ViT, 23 | # Since the activation channels are last and not first like in CNNs 24 | # Probably should remove it? 25 | if self.reshape_transform is not None: 26 | output = output.transpose(1, 2) 27 | 28 | for i in range(output.size(0)): 29 | 30 | # Commonly the minimum activation will be 0, 31 | # And then it makes sense to zero it out. 32 | # However depending on the architecture, 33 | # If the values can be negative, we use very negative values 34 | # to perform the ablation, deviating from the paper. 35 | if torch.min(output) == 0: 36 | output[i, self.indices[i], :] = 0 37 | else: 38 | ABLATION_VALUE = 1e5 39 | output[i, self.indices[i], :] = torch.min(output) - ABLATION_VALUE 40 | 41 | if self.reshape_transform is not None: 42 | output = output.transpose(2, 1) 43 | 44 | return output 45 | 46 | def replace_layer_recursive(model, old_layer, new_layer): 47 | for name, layer in model._modules.items(): 48 | if layer == old_layer: 49 | model._modules[name] = new_layer 50 | return True 51 | elif replace_layer_recursive(layer, old_layer, new_layer): 52 | return True 53 | return False 54 | 55 | class AblationCAM(BaseCAM): 56 | def __init__(self, model, target_layer, use_cuda=False, 57 | reshape_transform=None): 58 | super(AblationCAM, self).__init__(model, target_layer, use_cuda, 59 | reshape_transform) 60 | 61 | def get_cam_weights(self, 62 | input_tensor, 63 | target_category, 64 | activations, 65 | grads): 66 | with torch.no_grad(): 67 | outputs = self.model(input_tensor).cpu().numpy() 68 | original_scores = [] 69 | for i in range(input_tensor.size(0)): 70 | original_scores.append(outputs[i, target_category[i]]) 71 | original_scores = np.float32(original_scores) 72 | 73 | ablation_layer = AblationLayer(self.target_layer, 74 | self.reshape_transform, indices=[]) 75 | replace_layer_recursive(self.model, self.target_layer, ablation_layer) 76 | 77 | 78 | if hasattr(self, "batch_size"): 79 | BATCH_SIZE = self.batch_size 80 | else: 81 | BATCH_SIZE = 32 82 | 83 | number_of_channels = activations.shape[1] 84 | weights = [] 85 | 86 | with torch.no_grad(): 87 | 88 | # Iterate over the input batch 89 | for tensor, category in zip(input_tensor, target_category): 90 | batch_tensor = tensor.repeat(BATCH_SIZE, 1, 1, 1) 91 | for i in tqdm.tqdm(range(0, number_of_channels, BATCH_SIZE)): 92 | ablation_layer.indices = list(range(i, i + BATCH_SIZE)) 93 | 94 | if i + BATCH_SIZE > number_of_channels: 95 | keep = number_of_channels - i 96 | batch_tensor = batch_tensor[:keep] 97 | ablation_layer.indices = ablation_layer.indices[:keep] 98 | score = self.model(batch_tensor)[:, category].cpu().numpy() 99 | weights.extend(score) 100 | 101 | weights = np.float32(weights) 102 | weights = weights.reshape(activations.shape[:2]) 103 | original_scores = original_scores[:, None] 104 | weights = (original_scores - weights) / original_scores 105 | 106 | #replace the model back to the original state 107 | replace_layer_recursive(self.model, ablation_layer, self.target_layer) 108 | return weights -------------------------------------------------------------------------------- /RISE/pytorch_grad_cam/activations_and_gradients.py: -------------------------------------------------------------------------------- 1 | class ActivationsAndGradients: 2 | """ Class for extracting activations and 3 | registering gradients from targetted intermediate layers """ 4 | 5 | def __init__(self, model, target_layer, reshape_transform): 6 | self.model = model 7 | self.gradients = [] 8 | self.activations = [] 9 | self.reshape_transform = reshape_transform 10 | 11 | target_layer.register_forward_hook(self.save_activation) 12 | target_layer.register_backward_hook(self.save_gradient) 13 | 14 | def save_activation(self, module, input, output): 15 | activation = output 16 | if self.reshape_transform is not None: 17 | activation = self.reshape_transform(activation) 18 | self.activations.append(activation.cpu().detach()) 19 | 20 | def save_gradient(self, module, grad_input, grad_output): 21 | # Gradients are computed in reverse order 22 | grad = grad_output[0] 23 | if self.reshape_transform is not None: 24 | grad = self.reshape_transform(grad) 25 | self.gradients = [grad.cpu().detach()] + self.gradients 26 | 27 | def __call__(self, x): 28 | self.gradients = [] 29 | self.activations = [] 30 | return self.model(x) -------------------------------------------------------------------------------- /RISE/pytorch_grad_cam/base_cam.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import ttach as tta 5 | from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients 6 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection 7 | 8 | 9 | class BaseCAM: 10 | def __init__(self, 11 | model, 12 | target_layer, 13 | use_cuda=False, 14 | reshape_transform=None): 15 | self.model = model.eval() 16 | self.target_layer = target_layer 17 | self.cuda = use_cuda 18 | if self.cuda: 19 | self.model = model.cuda() 20 | self.reshape_transform = reshape_transform 21 | self.activations_and_grads = ActivationsAndGradients(self.model, 22 | target_layer, reshape_transform) 23 | 24 | def forward(self, input_img): 25 | return self.model(input_img) 26 | 27 | def get_cam_weights(self, 28 | input_tensor, 29 | target_category, 30 | activations, 31 | grads): 32 | raise Exception("Not Implemented") 33 | 34 | def get_loss(self, output, target_category): 35 | loss = 0 36 | for i in range(len(target_category)): 37 | loss = loss + output[i, target_category[i]] 38 | return loss 39 | 40 | def get_cam_image(self, 41 | input_tensor, 42 | target_category, 43 | activations, 44 | grads, 45 | eigen_smooth=False): 46 | weights = self.get_cam_weights(input_tensor, target_category, activations, grads) 47 | weighted_activations = weights[:, :, None, None] * activations 48 | if eigen_smooth: 49 | cam = get_2d_projection(weighted_activations) 50 | else: 51 | # import pdb 52 | # pdb.set_trace() 53 | cam = weighted_activations.sum(axis=1) 54 | return cam 55 | 56 | def forward(self, input_tensor, target_category=None, eigen_smooth=False): 57 | 58 | if self.cuda: 59 | input_tensor = input_tensor.cuda() 60 | 61 | output = self.activations_and_grads(input_tensor) 62 | 63 | if type(target_category) is int: 64 | target_category = [target_category] * input_tensor.size(0) 65 | 66 | if target_category is None: 67 | target_category = np.argmax(output.cpu().data.numpy(), axis=-1) 68 | else: 69 | assert(len(target_category) == input_tensor.size(0)) 70 | 71 | self.model.zero_grad() 72 | loss = self.get_loss(output, target_category) 73 | loss.backward(retain_graph=True) 74 | 75 | activations = self.activations_and_grads.activations[-1].cpu().data.numpy() 76 | grads = self.activations_and_grads.gradients[-1].cpu().data.numpy() 77 | 78 | cam = self.get_cam_image(input_tensor, target_category, 79 | activations, grads, eigen_smooth) 80 | 81 | cam = np.maximum(cam, 0) 82 | 83 | result = [] 84 | for img in cam: 85 | img = cv2.resize(img, input_tensor.shape[-2:][::-1]) 86 | img = img - np.min(img) 87 | img = img / (np.max(img) + 1e-8) 88 | result.append(img) 89 | result = np.float32(result) 90 | return result 91 | 92 | def forward_augmentation_smoothing(self, 93 | input_tensor, 94 | target_category=None, 95 | eigen_smooth=False): 96 | transforms = tta.Compose( 97 | [ 98 | tta.HorizontalFlip(), 99 | tta.Multiply(factors=[0.9, 1, 1.1]), 100 | ] 101 | ) 102 | cams = [] 103 | for transform in transforms: 104 | augmented_tensor = transform.augment_image(input_tensor) 105 | cam = self.forward(augmented_tensor, 106 | target_category, eigen_smooth) 107 | 108 | # The ttach library expects a tensor of size BxCxHxW 109 | cam = cam[:, None, :, :] 110 | cam = torch.from_numpy(cam) 111 | cam = transform.deaugment_mask(cam) 112 | 113 | # Back to numpy float32, HxW 114 | cam = cam.numpy() 115 | cam = cam[:, 0, :, :] 116 | cams.append(cam) 117 | 118 | cam = np.mean(np.float32(cams), axis=0) 119 | return cam 120 | 121 | def __call__(self, 122 | input_tensor, 123 | target_category=None, 124 | aug_smooth=False, 125 | eigen_smooth=False): 126 | if aug_smooth is True: 127 | return self.forward_augmentation_smoothing(input_tensor, 128 | target_category, eigen_smooth) 129 | 130 | return self.forward(input_tensor, 131 | target_category, eigen_smooth) -------------------------------------------------------------------------------- /RISE/pytorch_grad_cam/eigen_cam.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from pytorch_grad_cam.base_cam import BaseCAM 5 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection 6 | 7 | # https://arxiv.org/abs/2008.00299 8 | class EigenCAM(BaseCAM): 9 | def __init__(self, model, target_layer, use_cuda=False, 10 | reshape_transform=None): 11 | super(EigenCAM, self).__init__(model, target_layer, use_cuda, 12 | reshape_transform) 13 | 14 | def get_cam_image(self, 15 | input_tensor, 16 | target_category, 17 | activations, 18 | grads, 19 | eigen_smooth): 20 | return get_2d_projection(activations) 21 | -------------------------------------------------------------------------------- /RISE/pytorch_grad_cam/eigen_grad_cam.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from pytorch_grad_cam.base_cam import BaseCAM 5 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection 6 | 7 | # Like Eigen CAM: https://arxiv.org/abs/2008.00299 8 | # But multiply the activations x gradients 9 | class EigenGradCAM(BaseCAM): 10 | def __init__(self, model, target_layer, use_cuda=False, 11 | reshape_transform=None): 12 | super(EigenGradCAM, self).__init__(model, target_layer, use_cuda, 13 | reshape_transform) 14 | 15 | def get_cam_image(self, 16 | input_tensor, 17 | target_category, 18 | activations, 19 | grads, 20 | eigen_smooth): 21 | return get_2d_projection(grads*activations) -------------------------------------------------------------------------------- /RISE/pytorch_grad_cam/grad_cam.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from pytorch_grad_cam.base_cam import BaseCAM 5 | 6 | 7 | class GradCAM(BaseCAM): 8 | def __init__(self, model, target_layer, use_cuda=False, reshape_transform=None): 9 | super(GradCAM, self).__init__(model, target_layer, use_cuda, reshape_transform) 10 | 11 | def get_cam_weights(self, 12 | input_tensor, 13 | target_category, 14 | activations, 15 | grads): 16 | return np.mean(grads, axis=(2, 3)) 17 | -------------------------------------------------------------------------------- /RISE/pytorch_grad_cam/grad_cam_plusplus.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from pytorch_grad_cam.base_cam import BaseCAM 5 | 6 | class GradCAMPlusPlus(BaseCAM): 7 | def __init__(self, model, target_layer, use_cuda=False, 8 | reshape_transform=None): 9 | super(GradCAMPlusPlus, self).__init__(model, target_layer, use_cuda, 10 | reshape_transform) 11 | 12 | def get_cam_weights(self, input_tensor, 13 | target_category, 14 | activations, 15 | grads): 16 | grads_power_2 = grads**2 17 | grads_power_3 = grads_power_2*grads 18 | # Equation 19 in https://arxiv.org/abs/1710.11063 19 | sum_activations = np.sum(activations, axis=(2, 3)) 20 | eps = 0.000001 21 | aij = grads_power_2 / (2*grads_power_2 + 22 | sum_activations[:, :, None, None]*grads_power_3 + eps) 23 | # Now bring back the ReLU from eq.7 in the paper, 24 | # And zero out aijs where the activations are 0 25 | aij = np.where(grads != 0, aij, 0) 26 | 27 | weights = np.maximum(grads, 0)*aij 28 | weights = np.sum(weights, axis=(2, 3)) 29 | return weights -------------------------------------------------------------------------------- /RISE/pytorch_grad_cam/guided_backprop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Function 4 | 5 | 6 | class GuidedBackpropReLU(Function): 7 | @staticmethod 8 | def forward(self, input_img): 9 | positive_mask = (input_img > 0).type_as(input_img) 10 | output = torch.addcmul(torch.zeros(input_img.size()).type_as(input_img), input_img, positive_mask) 11 | self.save_for_backward(input_img, output) 12 | return output 13 | 14 | @staticmethod 15 | def backward(self, grad_output): 16 | input_img, output = self.saved_tensors 17 | grad_input = None 18 | 19 | positive_mask_1 = (input_img > 0).type_as(grad_output) 20 | positive_mask_2 = (grad_output > 0).type_as(grad_output) 21 | grad_input = torch.addcmul(torch.zeros(input_img.size()).type_as(input_img), 22 | torch.addcmul(torch.zeros(input_img.size()).type_as(input_img), grad_output, 23 | positive_mask_1), positive_mask_2) 24 | return grad_input 25 | 26 | 27 | class GuidedBackpropReLUModel: 28 | def __init__(self, model, use_cuda): 29 | self.model = model 30 | self.model.eval() 31 | self.cuda = use_cuda 32 | if self.cuda: 33 | self.model = self.model.cuda() 34 | 35 | def forward(self, input_img): 36 | return self.model(input_img) 37 | 38 | def recursive_replace_relu_with_guidedrelu(self, module_top): 39 | for idx, module in module_top._modules.items(): 40 | self.recursive_replace_relu_with_guidedrelu(module) 41 | if module.__class__.__name__ == 'ReLU': 42 | module_top._modules[idx] = GuidedBackpropReLU.apply 43 | 44 | def recursive_replace_guidedrelu_with_relu(self, module_top): 45 | try: 46 | for idx, module in module_top._modules.items(): 47 | self.recursive_replace_guidedrelu_with_relu(module) 48 | if module == GuidedBackpropReLU.apply: 49 | module_top._modules[idx] = torch.nn.ReLU() 50 | except: 51 | pass 52 | 53 | 54 | def __call__(self, input_img, target_category=None): 55 | # replace ReLU with GuidedBackpropReLU 56 | self.recursive_replace_relu_with_guidedrelu(self.model) 57 | 58 | if self.cuda: 59 | input_img = input_img.cuda() 60 | 61 | input_img = input_img.requires_grad_(True) 62 | 63 | output = self.forward(input_img) 64 | 65 | if target_category is None: 66 | target_category = np.argmax(output.cpu().data.numpy()) 67 | 68 | loss = output[0, target_category] 69 | loss.backward(retain_graph=True) 70 | 71 | output = input_img.grad.cpu().data.numpy() 72 | output = output[0, :, :, :] 73 | output = output.transpose((1, 2, 0)) 74 | 75 | # replace GuidedBackpropReLU back with ReLU 76 | self.recursive_replace_guidedrelu_with_relu(self.model) 77 | 78 | return output 79 | -------------------------------------------------------------------------------- /RISE/pytorch_grad_cam/score_cam.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import tqdm 5 | from pytorch_grad_cam.base_cam import BaseCAM 6 | 7 | class ScoreCAM(BaseCAM): 8 | def __init__(self, model, target_layer, use_cuda=False, reshape_transform=None): 9 | super(ScoreCAM, self).__init__(model, target_layer, use_cuda, 10 | reshape_transform=reshape_transform) 11 | 12 | def get_cam_weights(self, 13 | input_tensor, 14 | target_category, 15 | activations, 16 | grads): 17 | with torch.no_grad(): 18 | upsample = torch.nn.UpsamplingBilinear2d(size=input_tensor.shape[-2 : ]) 19 | activation_tensor = torch.from_numpy(activations) 20 | if self.cuda: 21 | activation_tensor = activation_tensor.cuda() 22 | 23 | upsampled = upsample(activation_tensor) 24 | 25 | maxs = upsampled.view(upsampled.size(0), 26 | upsampled.size(1), -1).max(dim=-1)[0] 27 | mins = upsampled.view(upsampled.size(0), 28 | upsampled.size(1), -1).min(dim=-1)[0] 29 | maxs, mins = maxs[:, :, None, None], mins[:, :, None, None] 30 | upsampled = (upsampled - mins) / (maxs - mins) 31 | 32 | input_tensors = input_tensor[:, None, :, :]*upsampled[:, :, None, :, :] 33 | 34 | if hasattr(self, "batch_size"): 35 | BATCH_SIZE = self.batch_size 36 | else: 37 | BATCH_SIZE = 16 38 | 39 | scores = [] 40 | for batch_index, tensor in enumerate(input_tensors): 41 | category = target_category[batch_index] 42 | for i in tqdm.tqdm(range(0, tensor.size(0), BATCH_SIZE)): 43 | batch = tensor[i : i + BATCH_SIZE, :] 44 | outputs = self.model(batch).cpu().numpy()[:, category] 45 | scores.extend(outputs) 46 | scores = torch.Tensor(scores) 47 | scores = scores.view(activations.shape[0], activations.shape[1]) 48 | 49 | weights = torch.nn.Softmax(dim=-1)(scores).numpy() 50 | return weights 51 | -------------------------------------------------------------------------------- /RISE/pytorch_grad_cam/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_grad_cam.utils.image import deprocess_image 2 | from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection 3 | from pytorch_grad_cam.utils.image import preprocess_image 4 | 5 | -------------------------------------------------------------------------------- /RISE/pytorch_grad_cam/utils/image.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from torchvision.transforms import Compose, Normalize, ToTensor 5 | 6 | 7 | def preprocess_image(img: np.ndarray, mean=None, std=None) -> torch.Tensor: 8 | if std is None: 9 | std = [0.5, 0.5, 0.5] 10 | if mean is None: 11 | mean = [0.5, 0.5, 0.5] 12 | 13 | preprocessing = Compose([ 14 | ToTensor(), 15 | Normalize(mean=mean, std=std) 16 | ]) 17 | 18 | return preprocessing(img.copy()).unsqueeze(0) 19 | 20 | 21 | def deprocess_image(img): 22 | """ see https://github.com/jacobgil/keras-grad-cam/blob/master/grad-cam.py#L65 """ 23 | img = img - np.mean(img) 24 | img = img / (np.std(img) + 1e-5) 25 | img = img * 0.1 26 | img = img + 0.5 27 | img = np.clip(img, 0, 1) 28 | return np.uint8(img * 255) 29 | 30 | 31 | def show_cam_on_image(img: np.ndarray, 32 | mask: np.ndarray, 33 | use_rgb: bool = False, 34 | colormap: int = cv2.COLORMAP_JET) -> np.ndarray: 35 | """ This function overlays the cam mask on the image as an heatmap. 36 | By default the heatmap is in BGR format. 37 | 38 | :param img: The base image in RGB or BGR format. 39 | :param mask: The cam mask. 40 | :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. 41 | :param colormap: The OpenCV colormap to be used. 42 | :returns: The default image with the cam overlay. 43 | """ 44 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) 45 | if use_rgb: 46 | heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) 47 | heatmap = np.float32(heatmap) / 255 48 | 49 | if np.max(img) > 1: 50 | raise Exception("The input image should np.float32 in the range [0, 1]") 51 | 52 | cam = heatmap + img 53 | cam = cam / np.max(cam) 54 | return np.uint8(255 * cam) 55 | -------------------------------------------------------------------------------- /RISE/pytorch_grad_cam/utils/svd_on_activations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def get_2d_projection(activation_batch): 4 | # TBD: use pytorch batch svd implementation 5 | projections = [] 6 | for activations in activation_batch: 7 | reshaped_activations = (activations).reshape(activations.shape[0], -1).transpose() 8 | # Centering before the SVD seems to be important here, 9 | # Otherwise the image returned is negative 10 | reshaped_activations = reshaped_activations - reshaped_activations.mean(axis=0) 11 | U, S, VT = np.linalg.svd(reshaped_activations, full_matrices=True) 12 | projection = reshaped_activations @ VT[0, :] 13 | projection = projection.reshape(activations.shape[1 : ]) 14 | projections.append(projection) 15 | return np.float32(projections) -------------------------------------------------------------------------------- /RISE/pytorch_grad_cam/xgrad_cam.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from pytorch_grad_cam.base_cam import BaseCAM 5 | 6 | class XGradCAM(BaseCAM): 7 | def __init__(self, model, target_layer, use_cuda=False, reshape_transform=None): 8 | super(XGradCAM, self).__init__(model, target_layer, use_cuda, reshape_transform) 9 | 10 | def get_cam_weights(self, 11 | input_tensor, 12 | target_category, 13 | activations, 14 | grads): 15 | sum_activations = np.sum(activations, axis=(2, 3)) 16 | eps = 1e-7 17 | weights = grads * activations / (sum_activations[:, :, None, None] + eps) 18 | weights = weights.sum(axis=(2, 3)) 19 | return weights -------------------------------------------------------------------------------- /RISE/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | import torch 4 | from torch.utils.data.sampler import Sampler 5 | from torchvision import transforms, datasets 6 | from PIL import Image 7 | 8 | 9 | # Dummy class to store arguments 10 | class Dummy(): 11 | pass 12 | 13 | 14 | # Function that opens image from disk, normalizes it and converts to tensor 15 | read_tensor = transforms.Compose([ 16 | lambda x: Image.open(x), 17 | transforms.Resize((224, 224)), 18 | transforms.ToTensor(), 19 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 20 | std=[0.229, 0.224, 0.225]), 21 | lambda x: torch.unsqueeze(x, 0) 22 | ]) 23 | 24 | 25 | # Plots image from tensor 26 | def tensor_imshow(inp, title=None, **kwargs): 27 | """Imshow for Tensor.""" 28 | inp = inp.numpy().transpose((1, 2, 0)) 29 | # Mean and std for ImageNet 30 | mean = np.array([0.485, 0.456, 0.406]) 31 | std = np.array([0.229, 0.224, 0.225]) 32 | inp = std * inp + mean 33 | inp = np.clip(inp, 0, 1) 34 | plt.imshow(inp, **kwargs) 35 | if title is not None: 36 | plt.title(title) 37 | 38 | 39 | # Given label number returns class name 40 | def get_class_name(c): 41 | labels = np.loadtxt('synset_words.txt', str, delimiter='\t') 42 | return ' '.join(labels[c].split(',')[0].split()[1:]) 43 | 44 | 45 | # Image preprocessing function 46 | preprocess = transforms.Compose([ 47 | transforms.Resize((224, 224)), 48 | transforms.ToTensor(), 49 | # Normalization for ImageNet 50 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 51 | std=[0.229, 0.224, 0.225]), 52 | ]) 53 | 54 | 55 | # Sampler for pytorch loader. Given range r loader will only 56 | # return dataset[r] instead of whole dataset. 57 | class RangeSampler(Sampler): 58 | def __init__(self, r): 59 | self.r = r 60 | 61 | def __iter__(self): 62 | return iter(self.r) 63 | 64 | def __len__(self): 65 | return len(self.r) 66 | -------------------------------------------------------------------------------- /TorchRay/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # 1.0.0 (October 2019) 2 | 3 | * Initial public release. 4 | -------------------------------------------------------------------------------- /TorchRay/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /TorchRay/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to TorchRay 2 | 3 | We want to make contributing to this project as easy and transparent as possible. 4 | 5 | ## Pull Requests 6 | 7 | We actively welcome your pull requests. 8 | 9 | 1. Fork the repo and create your branch from `master`. 10 | 2. If you've added code that should be tested, add tests. 11 | 3. If you've changed APIs, update the documentation. 12 | 4. Ensure the test passes (for basic testing use `examples.run_all_examples()`) 13 | 5. Make sure your code lints. 14 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 15 | 16 | ## Contributor License Agreement ("CLA") 17 | 18 | In order to accept your pull request, we need you to submit a CLA. You only need to do this once to work on any of Facebook's open source projects. 19 | 20 | Complete your CLA here: 21 | 22 | ## Issues 23 | 24 | We use GitHub issues to track public bugs. Please ensure your description is clear and has sufficient instructions to be able to reproduce the issue. 25 | 26 | ## Coding Style 27 | 28 | We follow the PEP8 standard using flake8 for linting. For the documentation, we use Sphinx and Napoleon using the Google style for the comments, and LaTeX math for equations. 29 | 30 | ## License 31 | 32 | By contributing to TorchRay, you agree that your contributions will be licensed under the LICENSE file in the root directory of this source tree. 33 | -------------------------------------------------------------------------------- /TorchRay/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include CHANGELOG.md 2 | include CODE_OF_CONDUCT.md 3 | include CONTRIBUTING.md 4 | include LICENSE 5 | include README.md 6 | include torchray/VERSION 7 | include docs 8 | include docs/html 9 | include scripts 10 | include scripts/torchrayrc 11 | recursive-include docs/html *.html *.png *.gif *.js *.css *.eot *.ttf *.woff *.woff2 *.svg *.inv *.txt 12 | recursive-include torchray *.py *.txt 13 | recursive-include scripts *.sh 14 | -------------------------------------------------------------------------------- /TorchRay/Makefile: -------------------------------------------------------------------------------- 1 | VER=$(shell head -n 1 torchray/VERSION) 2 | 3 | define get_conda_env 4 | $(shell which python | xargs dirname | xargs dirname) 5 | endef 6 | 7 | .PHONY: all 8 | all: 9 | 10 | 11 | .PHONY: dist 12 | dist: 13 | rm -rf dist 14 | docs/make-docs 15 | python3 setup.py sdist 16 | tar -tzvf dist/torchray-*.tar.gz 17 | 18 | 19 | .PHONY: conda 20 | conda: 21 | mkdir -p dist 22 | VER=$(VER) conda build -c defaults -c pytorch -c conda-forge \ 23 | --no-anaconda-upload --python 3.7 packaging/meta.yaml 24 | 25 | 26 | .PHONY: conda-install 27 | conda-install: 28 | conda install --use-local torchray 29 | 30 | 31 | .PHONY: docs 32 | docs: 33 | rm -rf docs/html 34 | docs/make-docs 35 | 36 | 37 | .PHONY: pub 38 | pub: docs 39 | git push pub master 40 | git push -f pub gh-pages 41 | git push -f --tags pub 42 | 43 | touch docs/html/.nojekyll 44 | git -C docs/html init 45 | git -C docs/html remote add pub git@github.com:facebookresearch/TorchRay.git 46 | git -C docs/html add . 47 | git -C docs/html commit -m "add documentation" 48 | git -C docs/html push -f pub master:gh-pages 49 | 50 | 51 | .PHONY: tag 52 | tag: 53 | git tag -f v$(VER) 54 | 55 | 56 | .PHONY: distclean 57 | distclean: 58 | rm -rf dist 59 | -------------------------------------------------------------------------------- /TorchRay/README.md: -------------------------------------------------------------------------------- 1 | # TorchRay 2 | 3 | The *TorchRay* package implements several visualization methods for deep 4 | convolutional neural networks using PyTorch. In this release, TorchRay focuses 5 | on *attribution*, namely the problem of determining which part of the input, 6 | usually an image, is responsible for the value computed by a neural network. 7 | 8 | *TorchRay* is research oriented: in addition to implementing well known 9 | techniques form the literature, it provides code for reproducing results that 10 | appear in several papers, in order to support *reproducible research*. 11 | 12 | *TorchRay* was initially developed to support the paper: 13 | 14 | * *Understanding deep networks via extremal perturbations and smooth masks.* 15 | Fong, Patrick, Vedaldi. 16 | Proceedings of the International Conference on Computer Vision (ICCV), 2019. 17 | 18 | ## Examples 19 | 20 | The package contains several usage examples in the 21 | [`examples`](https://github.com/facebookresearch/TorchRay/tree/master/examples) 22 | subdirectory. 23 | 24 | Here is a complete example for using GradCAM: 25 | 26 | ```python 27 | from torchray.attribution.grad_cam import grad_cam 28 | from torchray.benchmark import get_example_data, plot_example 29 | 30 | # Obtain example data. 31 | model, x, category_id, _ = get_example_data() 32 | 33 | # Grad-CAM backprop. 34 | saliency = grad_cam(model, x, category_id, saliency_layer='features.29') 35 | 36 | # Plots. 37 | plot_example(x, saliency, 'grad-cam backprop', category_id) 38 | ``` 39 | 40 | ## Requirements 41 | 42 | TorchRay requires: 43 | 44 | * Python 3.4 or greater 45 | * pytorch 1.1.0 or greater 46 | * matplotlib 47 | 48 | For benchmarking, it also requires: 49 | 50 | * torchvision 0.3.0 or greater 51 | * pycocotools 52 | * mongodb (suggested) 53 | * pymongod (suggested) 54 | 55 | On Linux/macOS, using conda you can install 56 | 57 | ```bash 58 | while read requirement; do conda install \ 59 | -c defaults -c pytorch -c conda-forge --yes $requirement; done <=1.1.0 61 | pycocotools 62 | torchvision>=0.3.0 63 | mongodb 64 | pymongo 65 | EOF 66 | ``` 67 | 68 | ## Installing TorchRay 69 | 70 | Using `pip`: 71 | 72 | ```shell 73 | pip install torchray 74 | ``` 75 | 76 | From source: 77 | 78 | ```shell 79 | python setup.py install 80 | ``` 81 | 82 | or 83 | 84 | ```shell 85 | pip install . 86 | ``` 87 | 88 | ## Full documentation 89 | 90 | The full documentation can be found 91 | [here](https://facebookresearch.github.io/TorchRay). 92 | 93 | ## Changes 94 | 95 | See the [CHANGELOG](CHANGELOG.md). 96 | 97 | ## Join the TorchRay community 98 | 99 | * Website: https://github.com/facebookresearch/TorchRay 100 | 101 | See the [CONTRIBUTING](CONTRIBUTING.md) file for how to help out. 102 | 103 | ## The team 104 | 105 | TorchRay has been primarily developed by Ruth C. Fong and Andrea Vedaldi. 106 | 107 | ## License 108 | 109 | TorchRay is CC-BY-NC licensed, as found in the [LICENSE](LICENSE) file. 110 | -------------------------------------------------------------------------------- /TorchRay/docs/attribution.rst: -------------------------------------------------------------------------------- 1 | .. _backprop: 2 | 3 | Attribution 4 | =========== 5 | 6 | *Attribution* is the problem of determining which part of the input, 7 | e.g. an image, is responsible for the value computed by a predictor 8 | such as a neural network. 9 | 10 | Formally, let :math:`\mathbf{x}` be the input to a convolutional neural 11 | network, e.g., a :math:`N \times C \times H \times W` real tensor. The neural 12 | network is a function :math:`\Phi` mapping :math:`\mathbf{x}` to a scalar 13 | output :math:`z \in \mathbb{R}`. Thus the goal is to find which of the 14 | elements of :math:`\mathbf{x}` are "most responsible" for the outcome 15 | :math:`z`. 16 | 17 | Some attribution methods are "black box" approaches, in the sense that they 18 | ignore the nature of the function :math:`\Phi` (however, most assume that it is 19 | at least possible to compute the gradient of :math:`\Phi` efficiently). Most 20 | attribution methods, however, are "white box" approaches, in the sense that 21 | they exploit the knowledge of the structure of :math:`\Phi`. 22 | 23 | :ref:`Backpropagation methods ` are "white box" visualization 24 | approaches that build on backpropagation, thus leveraging the functionality 25 | already implemented in standard deep learning packages toolboxes such as 26 | PyTorch. 27 | 28 | :ref:`Perturbation methods ` are "black box" visualization 29 | approaches that generate attribution visualizations by perturbing the input 30 | and observing the changes in a model's output. 31 | 32 | TorchRay implements the following methods: 33 | 34 | * Backpropagation methods 35 | 36 | * Deconvolution (:mod:`.deconvnet`) 37 | * Excitation backpropagation (:mod:`.excitation_backprop`) 38 | * Gradient [1]_ (:mod:`.gradient`) 39 | * Grad-CAM (:mod:`.grad_cam`) 40 | * Guided backpropagation (:mod:`.guided_backprop`) 41 | * Linear approximation (:mod:`.linear_approx`) 42 | 43 | * Perturbation methods 44 | 45 | * Extremal perturbation [1]_ (:mod:`.extremal_perturbation`) 46 | * RISE (:mod:`.rise`) 47 | 48 | .. rubric:: Footnotes 49 | 50 | .. [1] The :mod:`.gradient` and :mod:`.extremal_perturbation` methods actually 51 | straddle the boundaries between white and black box methods, as they 52 | only require the ability to compute the gradient of the predictor, 53 | which does not necessarily require to know the predictor internals. 54 | However, in TorchRay both are implemented using backpropagation. 55 | 56 | .. _backpropagation: 57 | 58 | Backpropagation methods 59 | ----------------------- 60 | 61 | Backpropagation methods work by tweaking the backpropagation algorithm that, on 62 | its own, computes the gradient of tensor functions. Formally, a neural network 63 | :math:`\Phi` is a collection :math:`\Phi_1,\dots,\Phi_n` of :math:`n` layers. 64 | Each layer is in itself a "smaller" function inputting and outputting tensors, 65 | called *activations* (for simplicity, we call activations the network input and 66 | parameter tensors as well). Layers are interconnected in a *Directed Acyclic 67 | Graph* (DAG). The DAG is bipartite with some nodes representing the activation 68 | tensors and the other nodes representing the layers, with interconnections 69 | between layers and input/output tensors in the obvious way. The DAG sources are 70 | the network's input and parameter tensors and the DAG sinks are the network's 71 | output tensors. 72 | 73 | The main goal of a deep neural network toolbox such as PyTorch is to evaluate 74 | the function :math:`\Phi` implemented by the DAG as well as its gradients with 75 | respect to various tensors (usually the model parameters). The calculation of 76 | the gradients, which uses backpropagation, associates to the forward DAG a 77 | backward DAG, obtained as follows: 78 | 79 | * Activation tensors :math:`\mathbf{x}_j` become gradient tensors 80 | :math:`d\mathbf{x}_j` (preserving their shape). 81 | * Forward layers :math:`\Phi_i` become backward layers :math:`\Phi_i^*`. 82 | * All arrows are reversed. 83 | * Additional arrows connecting the activation tensors :math:`\mathbf{x}_i` 84 | as inputs to the corresponding backward function :math:`\Phi_i^*` are added 85 | as well. 86 | 87 | Backpropagation methods modify the backward graph in order to generate a 88 | visualization of the network forward pass. Additionally, inputs as well as 89 | intermediate activations can be inspected to obtain different visualizations. 90 | These two concepts are explained next. 91 | 92 | Changing the backward propagation rules 93 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 94 | 95 | Changing the backward propagation rules amounts to redefining the functions 96 | :math:`\Phi_i^*`. After doing so, the "gradients" computed by backpropagation 97 | change their meaning into something useful for visualization. We call these 98 | modified gradients *pseudo-gradients*. 99 | 100 | TorchRay provides a number of context managers that enable patching PyTorch 101 | functions on the fly in order to change the backward propagation rules for 102 | a segment of code. For example, let ``x`` be an input tensor and ``model`` 103 | a deep classification network. Furthermore, let ``category_id`` be the 104 | index of the class for which we want to attribute input regions. The following 105 | code uses :mod:`.guided_backprop` to compute and store the pseudo gradient in 106 | ``x.grad``. 107 | 108 | .. code-block:: python 109 | 110 | from torchray.attribution.guided_backprop import GuidedBackpropContext 111 | 112 | x.requires_grad_(True) 113 | 114 | with GuidedBackpropContext(): 115 | y = model(x) 116 | z = y[0, category_id] 117 | z.backward() 118 | 119 | At this point, ``x.grad`` contains the "guided gradient" computed by this 120 | method. This gradient is usually flattened along the channel dimension to 121 | produce a saliency map for visualization: 122 | 123 | .. code-block:: python 124 | 125 | from torchray.attribution.common import gradient_to_saliency 126 | 127 | saliency = gradient_to_saliency(x) 128 | 129 | TorchRay contains also some wrapper code, such as 130 | :func:`.guided_backprop.guided_backprop`, that combine these steps in a way 131 | that would work for common networks. 132 | 133 | Probing intermediate activations and gradients 134 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 135 | 136 | Most visualization methods are based on inspecting the activations when the 137 | network is evaluated and the pseudo-gradients during backpropagation. This is 138 | generally easy for input tensors. For intermediate tensors, when using PyTorch 139 | functional interface, this is also easy: simply use ``retain_grad_(True)`` in 140 | order to retain the gradient of an intermediate tensor: 141 | 142 | .. code-block:: python 143 | 144 | from torch.nn.functional import relu, conv2d 145 | from torchray.attribution import GuidedBackpropContext 146 | 147 | with GuidedBackpropContext(): 148 | y = conv2d(x, weight) 149 | y.requires_grad_(True) 150 | y.retain_grad_(True) 151 | z = relu(y)[0, class_index] 152 | z.backward() 153 | 154 | # Now y and y.grad contain the activation and guided gradient, 155 | # respectively. 156 | 157 | However, in PyTorch most network components are implemented as 158 | :class:`torch.nn.Module` objects. In this case, is not obvious how to access a 159 | specific layer's information. In order to simplify this process, the library 160 | provides the :class:`Probe` class: 161 | 162 | .. code-block:: python 163 | 164 | from torch.nn.functional import relu, conv2d 165 | from torchray.attribution.guided_backprop import GuidedBackpropContext 166 | import torchray.attribution.Probe 167 | 168 | # Attach a probe to the last conv layer. 169 | probe = Probe(alexnet.features[11]) 170 | 171 | with GuidedBackpropContext(): 172 | y = alexnet(x) 173 | z = y[0, class_index] 174 | z.backward() 175 | 176 | # Now probe.data[0] and probe.data[0].grad contain 177 | # the activations and guided gradients. 178 | 179 | The probe automatically applies :func:`torch.Tensor.requires_grad_` and 180 | :func:`torch.Tensor.retain_grad_` as needed. You can use ``probe.remove()`` to 181 | remove the probe from the network once you are done. 182 | 183 | Limitations 184 | ^^^^^^^^^^^ 185 | 186 | Except for the gradient method, backpropagation methods require modifying 187 | the backward function of each layer. TorchRay implements the rules 188 | necessary to do so as originally defined by each authors' method. 189 | However, as new neural network layers are introduced, it is possible 190 | that the default behavior, which is to not change backpropagation, may 191 | be inappropriate or suboptimal for them. 192 | 193 | .. _perturbation: 194 | 195 | Perturbation methods 196 | -------------------- 197 | 198 | Perturbation methods work by changing the input to the neural network in a 199 | controlled manner, observing the outcome on the output generated by the 200 | network. Attribution can be achieved by occluding (setting to zero) specific 201 | parts of the image and checking whether this has a strong effect on the output. 202 | This can be thought of as a form of sensitivity analysis which is still 203 | specific to a given input, but is not differential as for the gradient method. 204 | 205 | 206 | DeConvNet 207 | --------- 208 | 209 | .. automodule:: torchray.attribution.deconvnet 210 | :members: 211 | :show-inheritance: 212 | 213 | 214 | Excitation backprop 215 | ------------------- 216 | 217 | .. automodule:: torchray.attribution.excitation_backprop 218 | :members: 219 | :show-inheritance: 220 | 221 | Extremal perturbation 222 | --------------------- 223 | 224 | .. automodule:: torchray.attribution.extremal_perturbation 225 | :members: 226 | :show-inheritance: 227 | 228 | Gradient 229 | -------- 230 | 231 | .. automodule:: torchray.attribution.gradient 232 | :members: 233 | :show-inheritance: 234 | 235 | Grad-CAM 236 | -------- 237 | 238 | .. automodule:: torchray.attribution.grad_cam 239 | :members: 240 | :show-inheritance: 241 | 242 | Guided backprop 243 | --------------- 244 | 245 | .. automodule:: torchray.attribution.guided_backprop 246 | :members: 247 | :show-inheritance: 248 | 249 | Linear approximation 250 | -------------------- 251 | 252 | .. automodule:: torchray.attribution.linear_approx 253 | :members: 254 | :show-inheritance: 255 | 256 | RISE 257 | ---- 258 | 259 | .. automodule:: torchray.attribution.rise 260 | :members: 261 | :show-inheritance: 262 | 263 | Common code 264 | ----------- 265 | 266 | .. automodule:: torchray.attribution.common 267 | :members: 268 | :show-inheritance: -------------------------------------------------------------------------------- /TorchRay/docs/benchmark.rst: -------------------------------------------------------------------------------- 1 | .. _benchmark: 2 | 3 | Benchmarking 4 | ============ 5 | 6 | This module contains code for benchmarking attribution methods, including 7 | reproducing several published results. In addition to implementations of 8 | benchmarking protocols (:mod:`.pointing_game`), the module also provides 9 | implementations of *reference datasets* and *reference models* used in prior 10 | research work, properly converted to PyTorch. Overall, this implementations 11 | closely reproduces prior results, notably the ones in the [EBP]_ paper. 12 | 13 | A standard benchmarking suite is included in this library as 14 | :mod:`examples.standard_suite`. For slow methods, a computer cluster may be 15 | required for evaluation (we do not include explicit support for clusters, but 16 | it is easy to add on top of this example code). 17 | 18 | It is also recommended to turn on logging (see 19 | :mod:`torchray.benchmark.logging`), which allows the driver to 20 | uses MongoDB to store partial benchmarking results as it goes. 21 | Computations can then be cached and reused to resume the calculations 22 | after a crash or other issue. In order to start the logging server, use 23 | 24 | .. code:: shell 25 | 26 | $ python -m torchray.benchmark.server 27 | 28 | The server parameters (address, port, etc) can be configured by writing 29 | a ``.torchrayrc`` file in your current or home directory. The package 30 | contains an example configuration file. The server creates a regular 31 | MongoDB database (by default in ``./data/db``) which can be manually 32 | explored by means of the MongoDB shell. 33 | 34 | By default, the driver writes data in the ``./data/`` subfolder. 35 | You can change that via the configuration file, or, possibly more easily, 36 | add a symbolic link to where you want to store the data. 37 | 38 | The data include the *datasets* (PASCAL VOC, COCO, ImageNet; see 39 | :mod:`torchray.benchmark.datasets`). These must be downloaded manually and 40 | stored in ``./data/datasets/{voc,coco,imagenet}`` unless this is changed via 41 | the configuration file. Note that these datasets can be very large (many GBs). 42 | 43 | The data also include *reference models* (see 44 | :mod:`torchray.benchmark.models`). 45 | 46 | .. automodule:: torchray.benchmark 47 | :members: 48 | :show-inheritance: 49 | 50 | Pointing Game 51 | ------------- 52 | 53 | The *Pointing Game* [EBP]_ assesses the quality of an attribution method by 54 | testing how well it can extract from a predictor a response correlated with the 55 | presence of known object categories in the image. 56 | 57 | Given an input image :math:`x` containing an object of category :math:`c`, the 58 | attribution method is applied to the predictor in order to find the part of the 59 | images responsible for predicting :math:`c`. The attribution method usually 60 | returns a saliency heatmap. The latter must then be converted in a single point 61 | :math:`(u,v)` that is "most likely" to be contained by an object of that class. 62 | The specific way the point is obtained is method-dependent. 63 | 64 | The attribution method then scores a hit if the point is within a *tolerance* 65 | :math:`\tau` (set to 15 pixels by default) to the image region :math:`\Omega` 66 | containing that object: 67 | 68 | .. math:: 69 | \operatorname{hit}(u,v|\Omega) 70 | = [ \exists (u',v') \in \Omega : \|(u,v) - (u',v')\| \leq \tau]. 71 | 72 | The point coordinates :math:`(u,v)` are also indices :math:`x_{ncvu}` in the 73 | input image tensor :math:`x`. 74 | 75 | RISE [RISE]_ and Extremal Perturbation [EP]_ results are averaged over 3 runs. 76 | 77 | .. csv-table:: Pointing game results 78 | :widths: auto 79 | :header-rows: 2 80 | :stub-columns: 1 81 | :file: pointing.csv 82 | 83 | 84 | .. automodule:: torchray.benchmark.pointing_game 85 | :members: 86 | :show-inheritance: 87 | 88 | Datasets 89 | -------- 90 | 91 | .. automodule:: torchray.benchmark.datasets 92 | :members: 93 | :show-inheritance: 94 | 95 | .. autodata:: IMAGENET_CLASSES 96 | :annotation: 97 | 98 | .. autodata:: VOC_CLASSES 99 | :annotation: 100 | 101 | .. autodata:: COCO_CLASSES 102 | :annotation: 103 | 104 | Reference models 105 | ---------------- 106 | 107 | .. automodule:: torchray.benchmark.models 108 | :members: 109 | :show-inheritance: 110 | 111 | Logging with MongoDB 112 | -------------------- 113 | 114 | .. automodule:: torchray.benchmark.logging 115 | :members: 116 | :show-inheritance: 117 | 118 | -------------------------------------------------------------------------------- /TorchRay/docs/conf.py: -------------------------------------------------------------------------------- 1 | # conf.py - Sphinx configuration 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | import os 5 | import sys 6 | 7 | 8 | def setup(app): 9 | app.add_stylesheet('css/equations.css') 10 | 11 | 12 | author = 'TorchRay Contributors' 13 | copyright = 'TorchRay Contributors' 14 | project = 'TorchRay' 15 | release = 'beta' 16 | version = '1.0' 17 | 18 | extensions = [ 19 | 'sphinx.ext.autodoc', 20 | 'sphinx.ext.mathjax', 21 | 'sphinx.ext.ifconfig', 22 | 'sphinx.ext.viewcode', 23 | 'sphinx.ext.napoleon', 24 | ] 25 | 26 | exclude_patterns = ['html'] 27 | master_doc = 'index' 28 | pygments_style = None 29 | source_suffix = ['.rst', '.md'] 30 | 31 | # HTML documentation. 32 | html_theme = 'sphinx_rtd_theme' 33 | html_theme_options = { 34 | 'analytics_id': '', 35 | 'canonical_url': '', 36 | 'display_version': True, 37 | 'logo_only': False, 38 | 'prev_next_buttons_location': 'bottom', 39 | 'style_external_links': False, 40 | 41 | # Toc options 42 | 'collapse_navigation': True, 43 | 'includehidden': True, 44 | 'navigation_depth': 4, 45 | 'sticky_navigation': True, 46 | 'titles_only': False 47 | } 48 | html_static_path = ['static'] 49 | -------------------------------------------------------------------------------- /TorchRay/docs/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to TorchRay 2 | =================== 3 | 4 | The *TorchRay* package implements several visualization methods for deep 5 | convolutional neural networks using PyTorch. In this release, TorchRay focuses 6 | on *attribution*, namely the problem of determining which part of the input, 7 | e.g., an image, is responsible for the value computed by a neural network. 8 | 9 | *TorchRay* is research-oriented. In addition to implementing well known 10 | techniques form the literature, it provides code for reproducing results that 11 | appear in several papers, and can thus be a tool for *reproducible research*. 12 | 13 | For downloads, installation instructions, and access to the source code 14 | use the 15 | `GitHub repository `_. 16 | 17 | .. toctree:: 18 | attribution 19 | benchmark 20 | utils 21 | 22 | Indices 23 | ======= 24 | 25 | * :ref:`genindex` 26 | * :ref:`modindex` 27 | * :ref:`search` -------------------------------------------------------------------------------- /TorchRay/docs/make-docs: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | PYTHON=${PYTHON:-python} 3 | ${PYTHON} -m sphinx -b html docs docs/html 4 | -------------------------------------------------------------------------------- /TorchRay/docs/pointing.csv: -------------------------------------------------------------------------------- 1 | ,voc_2007,voc_2007,voc_2007,voc_2007,coco,coco,coco,coco 2 | ,vgg16,vgg16,resnet50,resnet50,vgg16,vgg16,resnet50,resnet50 3 | center,69.6,42.4,69.6,42.4,27.8,19.5,27.8,19.5 4 | gradient,76.3,56.9,72.3,56.8,37.7,31.4,35.0,29.4 5 | deconvnet,67.5,44.2,68.6,44.7,30.7,23.0,30.0,21.9 6 | guided_backprop,75.9,53.0,77.2,59.4,39.1,31.4,42.1,35.3 7 | excitation_backprop,77.1,56.6,84.5,70.8,39.8,32.8,49.6,43.9 8 | contrastive_excitation_backprop,79.9,66.5,90.7,82.1,49.7,44.3,58.5,53.6 9 | rise,86.9,75.1,86.4,78.8,50.8,45.3,54.7,50.0 10 | grad_cam,86.6,74.0,90.4,82.3,54.2,49.0,57.3,52.3 11 | extremal_perturbation,88.0,76.1,88.9,78.7,51.5,45.9,56.5,51.5 -------------------------------------------------------------------------------- /TorchRay/docs/static/css/equations.css: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. */ 2 | 3 | .math { 4 | text-align: left; 5 | } 6 | 7 | .eqno { 8 | float: right; 9 | } -------------------------------------------------------------------------------- /TorchRay/docs/utils.rst: -------------------------------------------------------------------------------- 1 | Utilities 2 | ========= 3 | 4 | .. contents:: :local: 5 | 6 | .. automodule:: torchray.utils 7 | :members: 8 | :show-inheritance: -------------------------------------------------------------------------------- /TorchRay/examples/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Define :func:`run_all_examples` to run all examples of saliency methods 3 | (excluding :mod:`examples.standard_suite`). 4 | """ 5 | 6 | __all__ = ['run_all_examples'] 7 | 8 | from matplotlib import pyplot as plt 9 | 10 | 11 | def run_all_examples(): 12 | """Run all examples.""" 13 | 14 | plt.figure() 15 | import examples.extremal_perturbation 16 | plt.draw() 17 | plt.pause(0.001) 18 | 19 | plt.figure() 20 | import examples.deconvnet_manual 21 | plt.draw() 22 | plt.pause(0.001) 23 | 24 | plt.figure() 25 | import examples.deconvnet 26 | plt.draw() 27 | plt.pause(0.001) 28 | 29 | plt.figure() 30 | import examples.grad_cam_manual 31 | plt.draw() 32 | plt.pause(0.001) 33 | 34 | plt.figure() 35 | import examples.grad_cam 36 | plt.draw() 37 | plt.pause(0.001) 38 | 39 | plt.figure() 40 | import examples.contrastive_excitation_backprop_manual 41 | plt.draw() 42 | plt.pause(0.001) 43 | 44 | plt.figure() 45 | import examples.contrastive_excitation_backprop 46 | plt.draw() 47 | plt.pause(0.001) 48 | 49 | plt.figure() 50 | import examples.excitation_backprop_manual 51 | plt.draw() 52 | plt.pause(0.001) 53 | 54 | plt.figure() 55 | import examples.excitation_backprop 56 | plt.draw() 57 | plt.pause(0.001) 58 | 59 | plt.figure() 60 | import examples.guided_backprop_manual 61 | plt.draw() 62 | plt.pause(0.001) 63 | 64 | plt.figure() 65 | import examples.guided_backprop 66 | plt.draw() 67 | plt.pause(0.001) 68 | 69 | plt.figure() 70 | import examples.gradient_manual 71 | plt.draw() 72 | plt.pause(0.001) 73 | 74 | plt.figure() 75 | import examples.gradient 76 | plt.draw() 77 | plt.pause(0.001) 78 | 79 | plt.figure() 80 | import examples.rise 81 | plt.draw() 82 | plt.pause(0.001) 83 | -------------------------------------------------------------------------------- /TorchRay/examples/__main__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run all examples of saliency methods (excluding 3 | :mod:`examples.standard_suite`). 4 | """ 5 | from . import run_all_examples 6 | 7 | if __name__ == "__main__": 8 | run_all_examples() 9 | -------------------------------------------------------------------------------- /TorchRay/examples/contrastive_excitation_backprop.py: -------------------------------------------------------------------------------- 1 | from torchray.attribution.excitation_backprop import contrastive_excitation_backprop 2 | from torchray.benchmark import get_example_data, plot_example 3 | 4 | # Obtain example data. 5 | model, x, category_id, _ = get_example_data() 6 | 7 | # Contrastive excitation backprop. 8 | saliency = contrastive_excitation_backprop( 9 | model, 10 | x, 11 | category_id, 12 | saliency_layer='features.9', 13 | contrast_layer='features.30', 14 | classifier_layer='classifier.6', 15 | ) 16 | 17 | # Plots. 18 | plot_example(x, saliency, 'contrastive excitation backprop', category_id) 19 | -------------------------------------------------------------------------------- /TorchRay/examples/contrastive_excitation_backprop_manual.py: -------------------------------------------------------------------------------- 1 | from torchray.attribution.common import Probe, get_module 2 | from torchray.attribution.excitation_backprop import ExcitationBackpropContext 3 | from torchray.attribution.excitation_backprop import gradient_to_contrastive_excitation_backprop_saliency 4 | from torchray.benchmark import get_example_data, plot_example 5 | 6 | # Obtain example data. 7 | model, x, category_id, _ = get_example_data() 8 | 9 | # Contrastive excitation backprop. 10 | input_layer = get_module(model, 'features.9') 11 | contrast_layer = get_module(model, 'features.30') 12 | classifier_layer = get_module(model, 'classifier.6') 13 | 14 | input_probe = Probe(input_layer, target='output') 15 | contrast_probe = Probe(contrast_layer, target='output') 16 | 17 | with ExcitationBackpropContext(): 18 | y = model(x) 19 | z = y[0, category_id] 20 | classifier_layer.weight.data.neg_() 21 | z.backward() 22 | 23 | classifier_layer.weight.data.neg_() 24 | 25 | contrast_probe.contrast = [contrast_probe.data[0].grad] 26 | 27 | y = model(x) 28 | z = y[0, category_id] 29 | z.backward() 30 | 31 | saliency = gradient_to_contrastive_excitation_backprop_saliency(input_probe.data[0]) 32 | 33 | input_probe.remove() 34 | contrast_probe.remove() 35 | 36 | # Plots. 37 | plot_example(x, saliency, 'contrastive excitation backprop', category_id) 38 | -------------------------------------------------------------------------------- /TorchRay/examples/deconvnet.py: -------------------------------------------------------------------------------- 1 | from torchray.attribution.deconvnet import deconvnet 2 | from torchray.benchmark import get_example_data, plot_example 3 | 4 | # Obtain example data. 5 | model, x, category_id, _ = get_example_data() 6 | 7 | # DeConvNet method. 8 | saliency = deconvnet(model, x, category_id) 9 | 10 | # Plots. 11 | plot_example(x, saliency, 'deconvnet', category_id) 12 | -------------------------------------------------------------------------------- /TorchRay/examples/deconvnet_manual.py: -------------------------------------------------------------------------------- 1 | from torchray.attribution.common import gradient_to_saliency 2 | from torchray.attribution.deconvnet import DeConvNetContext 3 | from torchray.benchmark import get_example_data, plot_example 4 | 5 | # Obtain example data. 6 | model, x, category_id, _ = get_example_data() 7 | 8 | # DeConvNet method. 9 | x.requires_grad_(True) 10 | 11 | with DeConvNetContext(): 12 | y = model(x) 13 | z = y[0, category_id] 14 | z.backward() 15 | 16 | saliency = gradient_to_saliency(x) 17 | 18 | # Plots. 19 | plot_example(x, saliency, 'deconvnet', category_id) 20 | -------------------------------------------------------------------------------- /TorchRay/examples/excitation_backprop.py: -------------------------------------------------------------------------------- 1 | from torchray.attribution.excitation_backprop import excitation_backprop 2 | from torchray.benchmark import get_example_data, plot_example 3 | 4 | # Obtain example data. 5 | model, x, category_id, _ = get_example_data() 6 | 7 | # Contrastive excitation backprop. 8 | saliency = excitation_backprop( 9 | model, 10 | x, 11 | category_id, 12 | saliency_layer='features.9', 13 | ) 14 | 15 | # Plots. 16 | plot_example(x, saliency, 'excitation backprop', category_id) 17 | -------------------------------------------------------------------------------- /TorchRay/examples/excitation_backprop_manual.py: -------------------------------------------------------------------------------- 1 | from torchray.attribution.common import Probe, get_module 2 | from torchray.attribution.excitation_backprop import ExcitationBackpropContext 3 | from torchray.attribution.excitation_backprop import gradient_to_excitation_backprop_saliency 4 | from torchray.benchmark import get_example_data, plot_example 5 | 6 | # Obtain example data. 7 | model, x, category_id, _ = get_example_data() 8 | 9 | # Contrastive excitation backprop. 10 | saliency_layer = get_module(model, 'features.9') 11 | saliency_probe = Probe(saliency_layer, target='output') 12 | 13 | with ExcitationBackpropContext(): 14 | y = model(x) 15 | z = y[0, category_id] 16 | z.backward() 17 | 18 | saliency = gradient_to_excitation_backprop_saliency(saliency_probe.data[0]) 19 | 20 | saliency_probe.remove() 21 | 22 | # Plots. 23 | plot_example(x, saliency, 'excitation backprop', category_id) 24 | -------------------------------------------------------------------------------- /TorchRay/examples/extremal_perturbation.py: -------------------------------------------------------------------------------- 1 | from torchray.attribution.extremal_perturbation import extremal_perturbation, contrastive_reward 2 | from torchray.benchmark import get_example_data, plot_example 3 | from torchray.utils import get_device 4 | 5 | # Obtain example data. 6 | model, x, category_id_1, category_id_2 = get_example_data() 7 | 8 | # Run on GPU if available. 9 | device = get_device() 10 | model.to(device) 11 | x = x.to(device) 12 | 13 | # Extremal perturbation backprop. 14 | masks_1, _ = extremal_perturbation( 15 | model, x, category_id_1, 16 | reward_func=contrastive_reward, 17 | debug=True, 18 | areas=[0.12], 19 | ) 20 | 21 | masks_2, _ = extremal_perturbation( 22 | model, x, category_id_2, 23 | reward_func=contrastive_reward, 24 | debug=True, 25 | areas=[0.05], 26 | ) 27 | 28 | # Plots. 29 | plot_example(x, masks_1, 'extremal perturbation', category_id_1) 30 | plot_example(x, masks_2, 'extremal perturbation', category_id_2) 31 | -------------------------------------------------------------------------------- /TorchRay/examples/grad_cam.py: -------------------------------------------------------------------------------- 1 | from torchray.attribution.grad_cam import grad_cam 2 | from torchray.benchmark import get_example_data, plot_example 3 | 4 | # Obtain example data. 5 | model, x, category_id, _ = get_example_data() 6 | 7 | # Grad-CAM backprop. 8 | saliency = grad_cam(model, x, category_id, saliency_layer='features.29') 9 | 10 | # Plots. 11 | plot_example(x, saliency, 'grad-cam backprop', category_id) 12 | -------------------------------------------------------------------------------- /TorchRay/examples/grad_cam_manual.py: -------------------------------------------------------------------------------- 1 | from torchray.attribution.common import Probe, get_module 2 | from torchray.attribution.grad_cam import gradient_to_grad_cam_saliency 3 | from torchray.benchmark import get_example_data, plot_example 4 | 5 | # Obtain example data. 6 | model, x, category_id, _ = get_example_data() 7 | 8 | # Grad-CAM backprop. 9 | saliency_layer = get_module(model, 'features.29') 10 | 11 | probe = Probe(saliency_layer, target='output') 12 | 13 | y = model(x) 14 | z = y[0, category_id] 15 | z.backward() 16 | 17 | saliency = gradient_to_grad_cam_saliency(probe.data[0]) 18 | 19 | # Plots. 20 | plot_example(x, saliency, 'grad-cam backprop', category_id) 21 | -------------------------------------------------------------------------------- /TorchRay/examples/gradient.py: -------------------------------------------------------------------------------- 1 | from torchray.attribution.gradient import gradient 2 | from torchray.benchmark import get_example_data, plot_example 3 | 4 | # Obtain example data. 5 | model, x, category_id, _ = get_example_data() 6 | 7 | # Gradient method. 8 | saliency = gradient(model, x, category_id) 9 | 10 | # Plots. 11 | plot_example(x, saliency, 'gradient', category_id) 12 | -------------------------------------------------------------------------------- /TorchRay/examples/gradient_manual.py: -------------------------------------------------------------------------------- 1 | from torchray.attribution.common import gradient_to_saliency 2 | from torchray.benchmark import get_example_data, plot_example 3 | 4 | # Obtain example data. 5 | model, x, category_id, _ = get_example_data() 6 | 7 | # Gradient method. 8 | 9 | x.requires_grad_(True) 10 | y = model(x) 11 | z = y[0, category_id] 12 | z.backward() 13 | 14 | saliency = gradient_to_saliency(x) 15 | 16 | # Plots. 17 | plot_example(x, saliency, 'gradient', category_id) 18 | -------------------------------------------------------------------------------- /TorchRay/examples/guided_backprop.py: -------------------------------------------------------------------------------- 1 | from torchray.attribution.guided_backprop import guided_backprop 2 | from torchray.benchmark import get_example_data, plot_example 3 | 4 | # Obtain example data. 5 | model, x, category_id, _ = get_example_data() 6 | 7 | # Guided backprop. 8 | saliency = guided_backprop(model, x, category_id) 9 | 10 | # Plots. 11 | plot_example(x, saliency, 'guided backprop', category_id) 12 | -------------------------------------------------------------------------------- /TorchRay/examples/guided_backprop_manual.py: -------------------------------------------------------------------------------- 1 | from torchray.attribution.common import gradient_to_saliency 2 | from torchray.attribution.guided_backprop import GuidedBackpropContext 3 | from torchray.benchmark import get_example_data, plot_example 4 | 5 | # Obtain example data. 6 | model, x, category_id, _ = get_example_data() 7 | 8 | # Guided backprop. 9 | x.requires_grad_(True) 10 | 11 | with GuidedBackpropContext(): 12 | y = model(x) 13 | z = y[0, category_id] 14 | z.backward() 15 | 16 | saliency = gradient_to_saliency(x) 17 | 18 | # Plots. 19 | plot_example(x, saliency, 'guided backprop', category_id) 20 | -------------------------------------------------------------------------------- /TorchRay/examples/linear_approx.py: -------------------------------------------------------------------------------- 1 | from torchray.attribution.linear_approx import linear_approx 2 | from torchray.benchmark import get_example_data, plot_example 3 | 4 | # Obtain example data. 5 | model, x, category_id, _ = get_example_data() 6 | 7 | # Linear approximation backprop. 8 | saliency = linear_approx(model, x, category_id, saliency_layer='features.29') 9 | 10 | # Plots. 11 | plot_example(x, saliency, 'linear approx', category_id) 12 | -------------------------------------------------------------------------------- /TorchRay/examples/linear_approx_manual.py: -------------------------------------------------------------------------------- 1 | from torchray.attribution.common import Probe, get_module 2 | from torchray.attribution.linear_approx import gradient_to_linear_approx_saliency 3 | from torchray.benchmark import get_example_data, plot_example 4 | 5 | # Obtain example data. 6 | model, x, category_id, _ = get_example_data() 7 | 8 | # Linear approximation. 9 | saliency_layer = get_module(model, 'features.29') 10 | 11 | probe = Probe(saliency_layer, target='output') 12 | 13 | y = model(x) 14 | z = y[0, category_id] 15 | z.backward() 16 | 17 | saliency = gradient_to_linear_approx_saliency(probe.data[0]) 18 | 19 | # Plots. 20 | plot_example(x, saliency, 'linear approx', category_id) 21 | -------------------------------------------------------------------------------- /TorchRay/examples/rise.py: -------------------------------------------------------------------------------- 1 | from torchray.attribution.rise import rise 2 | from torchray.benchmark import get_example_data, plot_example 3 | from torchray.utils import get_device 4 | 5 | # Obtain example data. 6 | model, x, category_id, _ = get_example_data() 7 | 8 | # Run on GPU if available. 9 | device = get_device() 10 | model.to(device) 11 | x = x.to(device) 12 | 13 | # RISE method. 14 | saliency = rise(model, x) 15 | saliency = saliency[:, category_id].unsqueeze(0) 16 | 17 | # Plots. 18 | plot_example(x, saliency, 'RISE', category_id) 19 | -------------------------------------------------------------------------------- /TorchRay/packaging/meta.yaml: -------------------------------------------------------------------------------- 1 | package: 2 | name: torchray 3 | version: {{ VER }} 4 | 5 | source: 6 | git_rev: v{{ VER }} 7 | git_url: git@github.com:facebookresearch/TorchRay.git 8 | 9 | requirements: 10 | build: 11 | - python >=3.4 12 | 13 | run: 14 | - python >=3.4 15 | - pytorch >=1.1.0 16 | - torchvision >=0.3.0 17 | - pycocotools >=2.0.0 18 | - matplotlib 19 | 20 | 21 | build: 22 | noarch: python 23 | number: 0 24 | script: python setup.py install --single-version-externally-managed --record=record.txt 25 | 26 | about: 27 | home: https://github.com/facebookresearch/TorchRay 28 | license: Attribution-NonCommercial 4.0 International 29 | summary: 'A PyTorch library for visualzing convnets.' 30 | description: | 31 | The *TorchRay* package implements several visualization methods for deep 32 | convolutional neural networks using PyTorch. In this release, TorchRay focuses 33 | on *attribution*, namely the problem of determining which part of the input, 34 | usually an image, is responsible for the value computed by a neural network. 35 | 36 | *TorchRay* is research oriented: in addition to implementing well known 37 | techniques form the literature, it provides code for reproducing results that 38 | appear in several papers, and can thus be a tool for *reproducible research*. 39 | dev_url: https://github.com/facebookresearch/TorchRay 40 | doc_url: https://facebookresearch.github.io/TorchRay 41 | doc_source_url: https://github.com/facebookresearch/TorchRay -------------------------------------------------------------------------------- /TorchRay/scripts/torchrayrc: -------------------------------------------------------------------------------- 1 | { 2 | "mongo": { 3 | "server": "mongod", 4 | "hostname": "localhost", 5 | "port": 27017, 6 | "database": "./data/db" 7 | }, 8 | "benchmark": { 9 | "imagenet_dir": "./data/datasets/imagenet", 10 | "models_dir": "./data/models", 11 | "experiments_dir": "./data" 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /TorchRay/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | from setuptools import setup 4 | 5 | setup( 6 | name='torchray', 7 | version=open('torchray/VERSION').readline(), 8 | packages=[ 9 | 'torchray', 10 | 'torchray.attribution', 11 | 'torchray.benchmark' 12 | ], 13 | package_data={ 14 | 'torchray': ['VERSION'], 15 | 'torchray.benchmark': ['*.txt'] 16 | }, 17 | url='http://pypi.python.org/pypi/torchray/', 18 | author='Andrea Vedaldi', 19 | author_email='vedaldi@fb.com', 20 | license='Creative Commons Attribution-Noncommercial 4.0 International', 21 | description='TorchRay is a PyTorch library of visualization methods for convnets.', 22 | long_description=open('README.md').read(), 23 | long_description_content_type="text/markdown", 24 | classifiers=[ 25 | 'Development Status :: 5 - Production/Stable', 26 | 'Intended Audience :: Developers', 27 | 'Topic :: Machine Learning :: Neural Networks', 28 | 'License :: OSI Approved :: Creative Commons Attribution-Noncommercial 4.0 International', 29 | 'Programming Language :: Python :: 3', 30 | ], 31 | install_requires=[ 32 | 'importlib_resources', 33 | 'matplotlib', 34 | 'packaging', 35 | 'pycocotools >= 2.0.0', 36 | 'pymongo', 37 | 'requests', 38 | 'torch >= 1.1', 39 | 'torchvision >= 0.3.0', 40 | ], 41 | setup_requires=[ 42 | 'cython', 43 | 'numpy', 44 | ] 45 | ) 46 | -------------------------------------------------------------------------------- /TorchRay/torchray/VERSION: -------------------------------------------------------------------------------- 1 | 1.0.0 2 | -------------------------------------------------------------------------------- /TorchRay/torchray/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | import importlib.resources as resources 3 | except ImportError: 4 | import importlib_resources as resources 5 | 6 | with resources.open_text('torchray', 'VERSION') as f: 7 | __version__ = f.readlines()[0].rstrip() 8 | -------------------------------------------------------------------------------- /TorchRay/torchray/attribution/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCDvision/CGC/a66d87240863e19cc43d11c9e715ca447614042d/TorchRay/torchray/attribution/__init__.py -------------------------------------------------------------------------------- /TorchRay/torchray/attribution/deconvnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | r""" 4 | This module implements the *deconvolution* method of [DECONV]_ for visualizing 5 | deep networks. The simplest interface is given by the :func:`deconvnet` 6 | function: 7 | 8 | .. literalinclude:: ../examples/deconvnet.py 9 | :language: python 10 | :linenos: 11 | 12 | Alternatively, it is possible to run the method "manually". DeConvNet is a 13 | backpropagation method, and thus works by changing the definition of the 14 | backward functions of some layers. The modified ReLU is implemented by class 15 | :class:`DeConvNetReLU`; however, this is rarely used directly; instead, one 16 | uses the :class:`DeConvNetContext` context instead, as follows: 17 | 18 | .. literalinclude:: ../examples/deconvnet_manual.py 19 | :language: python 20 | :linenos: 21 | 22 | See also :ref:`Backprogation methods ` for further examples 23 | and discussion. 24 | 25 | Theory 26 | ~~~~~~ 27 | 28 | The only change is a modified definition of the backward ReLU function: 29 | 30 | .. math:: 31 | \operatorname{ReLU}^*(x,p) = 32 | \begin{cases} 33 | p, & \mathrm{if}~ p > 0,\\ 34 | 0, & \mathrm{otherwise} \\ 35 | \end{cases} 36 | 37 | Warning: 38 | 39 | DeConvNets are defined for "standard" networks that use ReLU operations. 40 | Further modifications may be required for more complex or new networks 41 | that use other type of non-linearities. 42 | 43 | References: 44 | 45 | .. [DECONV] Zeiler and Fergus, 46 | *Visualizing and Understanding Convolutional Networks*, 47 | ECCV 2014, 48 | ``__. 49 | """ 50 | 51 | __all__ = ["DeConvNetContext", "deconvnet"] 52 | 53 | import torch 54 | 55 | from .common import ReLUContext, saliency 56 | 57 | 58 | class DeConvNetReLU(torch.autograd.Function): 59 | """DeConvNet ReLU autograd function. 60 | 61 | This is an autograd function that redefines the ``relu`` function 62 | to match the DeConvNet ReLU definition. 63 | """ 64 | 65 | @staticmethod 66 | def forward(ctx, input): 67 | """DeConvNet ReLU forward function.""" 68 | return input.clamp(min=0) 69 | 70 | @staticmethod 71 | def backward(ctx, grad_output): 72 | """DeConvNet ReLU backward function.""" 73 | return grad_output.clamp(min=0) 74 | 75 | 76 | class DeConvNetContext(ReLUContext): 77 | """DeConvNet context. 78 | 79 | This context modifies the computation of gradient to match the DeConvNet 80 | definition. 81 | 82 | See :mod:`torchray.attribution.deconvnet` for how to use it. 83 | """ 84 | 85 | def __init__(self): 86 | super(DeConvNetContext, self).__init__(DeConvNetReLU) 87 | 88 | 89 | def deconvnet(*args, context_builder=DeConvNetContext, **kwargs): 90 | """DeConvNet method. 91 | 92 | The function takes the same arguments as :func:`.common.saliency`, with 93 | the defaults required to apply the DeConvNet method, and supports the 94 | same arguments and return values. 95 | """ 96 | return saliency(*args, context_builder=context_builder, **kwargs) 97 | -------------------------------------------------------------------------------- /TorchRay/torchray/attribution/grad_cam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """ 4 | This module provides an implementation of the *Grad-CAM* method of [GRADCAM]_ 5 | for saliency visualization. The simplest interface is given by the 6 | :func:`grad_cam` function: 7 | 8 | .. literalinclude:: ../examples/grad_cam.py 9 | :language: python 10 | :linenos: 11 | 12 | Alternatively, it is possible to run the method "manually". Grad-CAM backprop 13 | is a variant of the gradient method, applied at an intermediate layer: 14 | 15 | .. literalinclude:: ../examples/grad_cam_manual.py 16 | :language: python 17 | :linenos: 18 | 19 | Note that the function :func:`gradient_to_grad_cam_saliency` is used to convert 20 | activations and gradients to a saliency map. 21 | 22 | See also :ref:`backprop` for further examples and discussion. 23 | 24 | Theory 25 | ~~~~~~ 26 | 27 | Grad-CAM can be seen as a variant of the *gradient* method 28 | (:mod:`torchray.attribution.gradient`) with two differences: 29 | 30 | 1. The saliency is measured at an intermediate layer of the network, usually at 31 | the output of the last convolutional layer. 32 | 33 | 2. Saliency is defined as the clamped product of forward activation and 34 | backward gradient at that layer. 35 | 36 | References: 37 | 38 | .. [GRADCAM] Ramprasaath R. Selvaraju, Abhishek Das, Ramakrishna Vedantam, 39 | Michael Cogswell, Devi Parikh and Dhruv Batra, 40 | *Visual Explanations from Deep Networks via Gradient-based 41 | Localization,* 42 | ICCV 2017, 43 | ``__. 44 | """ 45 | 46 | __all__ = ["grad_cam"] 47 | 48 | import torch 49 | from .common import saliency 50 | 51 | 52 | def gradient_to_grad_cam_saliency(x): 53 | r"""Convert activation and gradient to a Grad-CAM saliency map. 54 | 55 | The tensor :attr:`x` must have a valid gradient ``x.grad``. 56 | The function then computes the saliency map :math:`s`: given by: 57 | 58 | .. math:: 59 | 60 | s_{n1u} = \max\{0, \sum_{c}x_{ncu}\cdot dx_{ncu}\} 61 | 62 | Args: 63 | x (:class:`torch.Tensor`): activation tensor with a valid gradient. 64 | 65 | Returns: 66 | :class:`torch.Tensor`: saliency map. 67 | """ 68 | # Apply global average pooling (GAP) to gradient. 69 | grad_weight = torch.mean(x.grad, (2, 3), keepdim=True) 70 | 71 | # Linearly combine activations and GAP gradient weights. 72 | saliency_map = torch.sum(x * grad_weight, 1, keepdim=True) 73 | 74 | # Apply ReLU to visualization. 75 | saliency_map = torch.clamp(saliency_map, min=0) 76 | 77 | return saliency_map 78 | 79 | 80 | def grad_cam(*args, 81 | saliency_layer, 82 | gradient_to_saliency=gradient_to_grad_cam_saliency, 83 | **kwargs): 84 | r"""Grad-CAM method. 85 | 86 | The function takes the same arguments as :func:`.common.saliency`, with 87 | the defaults required to apply the Grad-CAM method, and supports the 88 | same arguments and return values. 89 | """ 90 | return saliency(*args, 91 | saliency_layer=saliency_layer, 92 | gradient_to_saliency=gradient_to_saliency, 93 | **kwargs,) 94 | -------------------------------------------------------------------------------- /TorchRay/torchray/attribution/gradient.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | r""" 4 | This module implements the *gradient* method of [GRAD]_ for visualizing a deep 5 | network. It is a backpropagation method, and in fact the simplest of them all 6 | as it coincides with standard backpropagation. The simplest way to use this 7 | method is via the :func:`gradient` function: 8 | 9 | .. literalinclude:: ../examples/gradient.py 10 | :language: python 11 | :linenos: 12 | 13 | Alternatively, one can do so manually, as follows 14 | 15 | .. literalinclude:: ../examples/gradient_manual.py 16 | :language: python 17 | :linenos: 18 | 19 | Note that in this example, for visualization, the gradient is 20 | convernted into an image by postprocessing by using the function 21 | :func:`torchray.attribution.common.saliency`. 22 | 23 | See also :ref:`backprop` for further examples. 24 | 25 | References: 26 | 27 | .. [GRAD] Karen Simonyan, Andrea Vedaldi and Andrew Zisserman, 28 | *Deep Inside Convolutional Networks: 29 | Visualising Image Classification Models and Saliency Maps,* 30 | ICLR workshop, 2014, 31 | ``__. 32 | """ 33 | 34 | __all__ = ["gradient"] 35 | 36 | from .common import saliency 37 | 38 | 39 | def gradient(*args, context_builder=None, **kwargs): 40 | r"""Gradient method 41 | 42 | The function takes the same arguments as :func:`.common.saliency`, with 43 | the defaults required to apply the gradient method, and supports the 44 | same arguments and return values. 45 | """ 46 | assert context_builder is None 47 | return saliency(*args, **kwargs) 48 | -------------------------------------------------------------------------------- /TorchRay/torchray/attribution/guided_backprop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | r""" 4 | This module implements *guided backpropagation* method of [GUIDED]_ or 5 | visualizing deep networks. The simplest interface is given by the 6 | :func:`guided_backprop` function: 7 | 8 | .. literalinclude:: ../examples/guided_backprop.py 9 | :language: python 10 | :linenos: 11 | 12 | Alternatively, it is possible to run the method "manually". Guided backprop is 13 | a backpropagation method, and thus works by changing the definition of the 14 | backward functions of some layers. This can be done using the 15 | :class:`GuidedBackpropContext` context: 16 | 17 | .. literalinclude:: ../examples/guided_backprop_manual.py 18 | :language: python 19 | :linenos: 20 | 21 | See also :ref:`backprop` for further examples. 22 | 23 | Theory 24 | ~~~~~~ 25 | 26 | Guided backprop is a backpropagation method, and thus it works by changing the 27 | definition of the backward functions of some layers. The only change is a 28 | modified definition of the backward ReLU function: 29 | 30 | .. math:: 31 | \operatorname{ReLU}^*(x,p) = 32 | \begin{cases} 33 | p, & \mathrm{if}~p > 0 ~\mathrm{and}~ x > 0,\\ 34 | 0, & \mathrm{otherwise} \\ 35 | \end{cases} 36 | 37 | The modified ReLU is implemented by class :class:`GuidedBackpropReLU`. 38 | 39 | References: 40 | 41 | .. [GUIDED] Springenberg et al., 42 | *Striving for simplicity: The all convolutional net*, 43 | ICLR Workshop 2015, 44 | ``__. 45 | """ 46 | 47 | __all__ = ['GuidedBackpropContext', 'guided_backprop'] 48 | 49 | import torch 50 | 51 | from .common import ReLUContext, saliency 52 | 53 | 54 | class GuidedBackpropReLU(torch.autograd.Function): 55 | """This class implements a ReLU function with the guided backprop rules.""" 56 | @staticmethod 57 | def forward(ctx, input): 58 | """Guided backprop ReLU forward function.""" 59 | ctx.save_for_backward(input) 60 | return input.clamp(min=0) 61 | 62 | @staticmethod 63 | def backward(ctx, grad_output): 64 | """Guided backprop ReLU backward function.""" 65 | input, = ctx.saved_tensors 66 | grad_input = grad_output.clone() 67 | grad_input[input < 0] = 0 68 | grad_input = grad_input.clamp(min=0) 69 | return grad_input 70 | 71 | 72 | class GuidedBackpropContext(ReLUContext): 73 | r"""GuidedBackprop context. 74 | 75 | This context modifies the computation of gradients 76 | to match the guided backpropagaton definition. 77 | 78 | See :mod:`torchray.attribution.guided_backprop` for how to use it. 79 | """ 80 | 81 | def __init__(self): 82 | super(GuidedBackpropContext, self).__init__(GuidedBackpropReLU) 83 | 84 | 85 | def guided_backprop(*args, context_builder=GuidedBackpropContext, **kwargs): 86 | r"""Guided backprop. 87 | 88 | The function takes the same arguments as :func:`.common.saliency`, with 89 | the defaults required to apply the guided backprop method, and supports the 90 | same arguments and return values. 91 | """ 92 | return saliency(*args, 93 | context_builder=context_builder, 94 | **kwargs) 95 | -------------------------------------------------------------------------------- /TorchRay/torchray/attribution/linear_approx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | r""" 4 | This module provides an implementation of the *linear approximation* method 5 | for saliency visualization. The simplest interface is given by the 6 | :func:`linear_approx` function: 7 | 8 | .. literalinclude:: ../examples/linear_approx.py 9 | :language: python 10 | :linenos: 11 | 12 | Alternatively, it is possible to run the method "manually". Linear 13 | approximation is a variant of the gradient method, applied at an intermediate 14 | layer: 15 | 16 | .. literalinclude:: ../examples/linear_approx_manual.py 17 | :language: python 18 | :linenos: 19 | 20 | Note that the function :func:`gradient_to_linear_approx_saliency` is used to 21 | convert activations and gradients to a saliency map. 22 | """ 23 | 24 | __all__ = ['gradient_to_linear_approx_saliency', 'linear_approx'] 25 | 26 | 27 | import torch 28 | from .common import saliency 29 | 30 | 31 | def gradient_to_linear_approx_saliency(x): 32 | """Returns the linear approximation of a tensor. 33 | 34 | The tensor :attr:`x` must have a valid gradient ``x.grad``. 35 | The function then computes the saliency map :math:`s`: given by: 36 | 37 | .. math:: 38 | 39 | s_{n1u} = \sum_{c} x_{ncu} \cdot dx_{ncu} 40 | 41 | Args: 42 | x (:class:`torch.Tensor`): activation tensor with a valid gradient. 43 | 44 | Returns: 45 | :class:`torch.Tensor`: Saliency map. 46 | """ 47 | viz = torch.sum(x * x.grad, 1, keepdim=True) 48 | return viz 49 | 50 | 51 | def linear_approx(*args, 52 | gradient_to_saliency=gradient_to_linear_approx_saliency, 53 | **kwargs): 54 | """Linear approximation. 55 | 56 | The function takes the same arguments as :func:`.common.saliency`, with 57 | the defaults required to apply the linear approximation method, and 58 | supports the same arguments and return values. 59 | """ 60 | return saliency(*args, 61 | gradient_to_saliency=gradient_to_saliency, 62 | **kwargs) 63 | -------------------------------------------------------------------------------- /TorchRay/torchray/attribution/resnet_maxpool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | try: 4 | from torch.hub import load_state_dict_from_url 5 | except ImportError: 6 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 7 | 8 | ''' 9 | This version of resnet replaces AdaptiveAvgPool2d with AdaptiveMaxPool2d 10 | ''' 11 | 12 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 13 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 14 | 'wide_resnet50_2', 'wide_resnet101_2'] 15 | 16 | 17 | model_urls = { 18 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 19 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 20 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 21 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 22 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 23 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 24 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 25 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 26 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 27 | } 28 | 29 | 30 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 31 | """3x3 convolution with padding""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 33 | padding=dilation, groups=groups, bias=False, dilation=dilation) 34 | 35 | 36 | def conv1x1(in_planes, out_planes, stride=1): 37 | """1x1 convolution""" 38 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 39 | 40 | 41 | class BasicBlock(nn.Module): 42 | expansion = 1 43 | __constants__ = ['downsample'] 44 | 45 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 46 | base_width=64, dilation=1, norm_layer=None): 47 | super(BasicBlock, self).__init__() 48 | if norm_layer is None: 49 | norm_layer = nn.BatchNorm2d 50 | if groups != 1 or base_width != 64: 51 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 52 | if dilation > 1: 53 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 54 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 55 | self.conv1 = conv3x3(inplanes, planes, stride) 56 | self.bn1 = norm_layer(planes) 57 | self.relu = nn.ReLU(inplace=True) 58 | self.conv2 = conv3x3(planes, planes) 59 | self.bn2 = norm_layer(planes) 60 | self.downsample = downsample 61 | self.stride = stride 62 | 63 | def forward(self, x): 64 | identity = x 65 | 66 | out = self.conv1(x) 67 | out = self.bn1(out) 68 | out = self.relu(out) 69 | 70 | out = self.conv2(out) 71 | out = self.bn2(out) 72 | 73 | if self.downsample is not None: 74 | identity = self.downsample(x) 75 | 76 | out += identity 77 | out = self.relu(out) 78 | 79 | return out 80 | 81 | 82 | class Bottleneck(nn.Module): 83 | expansion = 4 84 | __constants__ = ['downsample'] 85 | 86 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 87 | base_width=64, dilation=1, norm_layer=None): 88 | super(Bottleneck, self).__init__() 89 | if norm_layer is None: 90 | norm_layer = nn.BatchNorm2d 91 | width = int(planes * (base_width / 64.)) * groups 92 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 93 | self.conv1 = conv1x1(inplanes, width) 94 | self.bn1 = norm_layer(width) 95 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 96 | self.bn2 = norm_layer(width) 97 | self.conv3 = conv1x1(width, planes * self.expansion) 98 | self.bn3 = norm_layer(planes * self.expansion) 99 | self.relu = nn.ReLU(inplace=True) 100 | self.downsample = downsample 101 | self.stride = stride 102 | 103 | def forward(self, x): 104 | identity = x 105 | 106 | out = self.conv1(x) 107 | out = self.bn1(out) 108 | out = self.relu(out) 109 | 110 | out = self.conv2(out) 111 | out = self.bn2(out) 112 | out = self.relu(out) 113 | 114 | out = self.conv3(out) 115 | out = self.bn3(out) 116 | 117 | if self.downsample is not None: 118 | identity = self.downsample(x) 119 | 120 | out += identity 121 | out = self.relu(out) 122 | 123 | return out 124 | 125 | 126 | class ResNet(nn.Module): 127 | 128 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 129 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 130 | norm_layer=None): 131 | super(ResNet, self).__init__() 132 | if norm_layer is None: 133 | norm_layer = nn.BatchNorm2d 134 | self._norm_layer = norm_layer 135 | 136 | self.inplanes = 64 137 | self.dilation = 1 138 | if replace_stride_with_dilation is None: 139 | # each element in the tuple indicates if we should replace 140 | # the 2x2 stride with a dilated convolution instead 141 | replace_stride_with_dilation = [False, False, False] 142 | if len(replace_stride_with_dilation) != 3: 143 | raise ValueError("replace_stride_with_dilation should be None " 144 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 145 | self.groups = groups 146 | self.base_width = width_per_group 147 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 148 | bias=False) 149 | self.bn1 = norm_layer(self.inplanes) 150 | self.relu = nn.ReLU(inplace=True) 151 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 152 | self.layer1 = self._make_layer(block, 64, layers[0]) 153 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 154 | dilate=replace_stride_with_dilation[0]) 155 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 156 | dilate=replace_stride_with_dilation[1]) 157 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 158 | dilate=replace_stride_with_dilation[2]) 159 | # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 160 | self.maxpool_last = nn.AdaptiveMaxPool2d((1, 1)) 161 | self.fc = nn.Linear(512 * block.expansion, num_classes) 162 | 163 | for m in self.modules(): 164 | if isinstance(m, nn.Conv2d): 165 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 166 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 167 | nn.init.constant_(m.weight, 1) 168 | nn.init.constant_(m.bias, 0) 169 | 170 | # Zero-initialize the last BN in each residual branch, 171 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 172 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 173 | if zero_init_residual: 174 | for m in self.modules(): 175 | if isinstance(m, Bottleneck): 176 | nn.init.constant_(m.bn3.weight, 0) 177 | elif isinstance(m, BasicBlock): 178 | nn.init.constant_(m.bn2.weight, 0) 179 | 180 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 181 | norm_layer = self._norm_layer 182 | downsample = None 183 | previous_dilation = self.dilation 184 | if dilate: 185 | self.dilation *= stride 186 | stride = 1 187 | if stride != 1 or self.inplanes != planes * block.expansion: 188 | downsample = nn.Sequential( 189 | conv1x1(self.inplanes, planes * block.expansion, stride), 190 | norm_layer(planes * block.expansion), 191 | ) 192 | 193 | layers = [] 194 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 195 | self.base_width, previous_dilation, norm_layer)) 196 | self.inplanes = planes * block.expansion 197 | for _ in range(1, blocks): 198 | layers.append(block(self.inplanes, planes, groups=self.groups, 199 | base_width=self.base_width, dilation=self.dilation, 200 | norm_layer=norm_layer)) 201 | 202 | return nn.Sequential(*layers) 203 | 204 | def _forward_impl(self, x): 205 | # See note [TorchScript super()] 206 | x = self.conv1(x) 207 | x = self.bn1(x) 208 | x = self.relu(x) 209 | x = self.maxpool(x) 210 | 211 | x = self.layer1(x) 212 | x = self.layer2(x) 213 | x = self.layer3(x) 214 | x = self.layer4(x) 215 | 216 | # x = self.avgpool(x) 217 | x = self.maxpool_last(x) 218 | x = torch.flatten(x, 1) 219 | x = self.fc(x) 220 | 221 | return x 222 | 223 | def forward(self, x): 224 | return self._forward_impl(x) 225 | 226 | 227 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 228 | model = ResNet(block, layers, **kwargs) 229 | if pretrained: 230 | state_dict = load_state_dict_from_url(model_urls[arch], 231 | progress=progress) 232 | model.load_state_dict(state_dict) 233 | return model 234 | 235 | 236 | def resnet18(pretrained=False, progress=True, **kwargs): 237 | r"""ResNet-18 model from 238 | `"Deep Residual Learning for Image Recognition" `_ 239 | 240 | Args: 241 | pretrained (bool): If True, returns a model pre-trained on ImageNet 242 | progress (bool): If True, displays a progress bar of the download to stderr 243 | """ 244 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 245 | **kwargs) 246 | 247 | 248 | def resnet34(pretrained=False, progress=True, **kwargs): 249 | r"""ResNet-34 model from 250 | `"Deep Residual Learning for Image Recognition" `_ 251 | 252 | Args: 253 | pretrained (bool): If True, returns a model pre-trained on ImageNet 254 | progress (bool): If True, displays a progress bar of the download to stderr 255 | """ 256 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 257 | **kwargs) 258 | 259 | 260 | def resnet50(pretrained=False, progress=True, **kwargs): 261 | r"""ResNet-50 model from 262 | `"Deep Residual Learning for Image Recognition" `_ 263 | 264 | Args: 265 | pretrained (bool): If True, returns a model pre-trained on ImageNet 266 | progress (bool): If True, displays a progress bar of the download to stderr 267 | """ 268 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 269 | **kwargs) 270 | 271 | 272 | def resnet101(pretrained=False, progress=True, **kwargs): 273 | r"""ResNet-101 model from 274 | `"Deep Residual Learning for Image Recognition" `_ 275 | 276 | Args: 277 | pretrained (bool): If True, returns a model pre-trained on ImageNet 278 | progress (bool): If True, displays a progress bar of the download to stderr 279 | """ 280 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 281 | **kwargs) 282 | 283 | 284 | def resnet152(pretrained=False, progress=True, **kwargs): 285 | r"""ResNet-152 model from 286 | `"Deep Residual Learning for Image Recognition" `_ 287 | 288 | Args: 289 | pretrained (bool): If True, returns a model pre-trained on ImageNet 290 | progress (bool): If True, displays a progress bar of the download to stderr 291 | """ 292 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 293 | **kwargs) 294 | 295 | 296 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 297 | r"""ResNeXt-50 32x4d model from 298 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 299 | 300 | Args: 301 | pretrained (bool): If True, returns a model pre-trained on ImageNet 302 | progress (bool): If True, displays a progress bar of the download to stderr 303 | """ 304 | kwargs['groups'] = 32 305 | kwargs['width_per_group'] = 4 306 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 307 | pretrained, progress, **kwargs) 308 | 309 | 310 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 311 | r"""ResNeXt-101 32x8d model from 312 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 313 | 314 | Args: 315 | pretrained (bool): If True, returns a model pre-trained on ImageNet 316 | progress (bool): If True, displays a progress bar of the download to stderr 317 | """ 318 | kwargs['groups'] = 32 319 | kwargs['width_per_group'] = 8 320 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 321 | pretrained, progress, **kwargs) 322 | 323 | 324 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 325 | r"""Wide ResNet-50-2 model from 326 | `"Wide Residual Networks" `_ 327 | 328 | The model is the same as ResNet except for the bottleneck number of channels 329 | which is twice larger in every block. The number of channels in outer 1x1 330 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 331 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 332 | 333 | Args: 334 | pretrained (bool): If True, returns a model pre-trained on ImageNet 335 | progress (bool): If True, displays a progress bar of the download to stderr 336 | """ 337 | kwargs['width_per_group'] = 64 * 2 338 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 339 | pretrained, progress, **kwargs) 340 | 341 | 342 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 343 | r"""Wide ResNet-101-2 model from 344 | `"Wide Residual Networks" `_ 345 | 346 | The model is the same as ResNet except for the bottleneck number of channels 347 | which is twice larger in every block. The number of channels in outer 1x1 348 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 349 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 350 | 351 | Args: 352 | pretrained (bool): If True, returns a model pre-trained on ImageNet 353 | progress (bool): If True, displays a progress bar of the download to stderr 354 | """ 355 | kwargs['width_per_group'] = 64 * 2 356 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 357 | pretrained, progress, **kwargs) -------------------------------------------------------------------------------- /TorchRay/torchray/attribution/rise.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | r""" 4 | This module provides an implementation of the *RISE* method of [RISE]_ for 5 | saliency visualization. This is given by the :func:`rise` function, which 6 | can be used as follows: 7 | 8 | .. literalinclude:: ../examples/rise.py 9 | :language: python 10 | :linenos: 11 | 12 | References: 13 | 14 | .. [RISE] V. Petsiuk, A. Das and K. Saenko 15 | *RISE: Randomized Input Sampling for Explanation of Black-box 16 | Models,* 17 | BMVC 2018, 18 | ``__. 19 | """ 20 | 21 | __all__ = ['rise', 'rise_class'] 22 | 23 | import numpy as np 24 | 25 | import torch 26 | import torch.nn.functional as F 27 | from .common import resize_saliency 28 | 29 | 30 | def _upsample_reflect(x, size, interpolate_mode="bilinear"): 31 | r"""Upsample 4D :class:`torch.Tensor` with reflection padding. 32 | 33 | Args: 34 | x (:class:`torch.Tensor`): 4D tensor to interpolate. 35 | size (int or list or tuple of ints): target size 36 | interpolate_mode (str): mode to pass to 37 | :function:`torch.nn.functional.interpolate` function call 38 | (default: "bilinear"). 39 | 40 | Returns: 41 | :class:`torch.Tensor`: upsampled tensor. 42 | """ 43 | # Check and get input size. 44 | assert len(x.shape) == 4 45 | orig_size = x.shape[2:] 46 | 47 | # Check target size. 48 | if not isinstance(size, tuple) and not isinstance(size, list): 49 | assert isinstance(size, int) 50 | size = (size, size) 51 | assert len(size) == 2 52 | 53 | # Ensure upsampling. 54 | for i, o_s in enumerate(orig_size): 55 | assert o_s <= size[i] 56 | 57 | # Get size of input cell when interpolated. 58 | cell_size = [int(np.ceil(s / orig_size[i])) for i, s in enumerate(size)] 59 | 60 | # Get size of interpolated input with padding. 61 | pad_size = [int(cell_size[i] * (orig_size[i] + 2)) 62 | for i in range(len(orig_size))] 63 | 64 | # Pad input with reflection padding. 65 | x_padded = F.pad(x, (1, 1, 1, 1), mode="reflect") 66 | 67 | # Interpolated padded input. 68 | x_up = F.interpolate(x_padded, 69 | pad_size, 70 | mode=interpolate_mode, 71 | align_corners=False) 72 | 73 | # Slice interpolated input to size. 74 | x_new = x_up[:, 75 | :, 76 | cell_size[0]:cell_size[0] + size[0], 77 | cell_size[1]:cell_size[1] + size[1]] 78 | 79 | return x_new 80 | 81 | 82 | def rise_class(*args, target, **kwargs): 83 | r"""Class-specific RISE. 84 | 85 | This function has the all the arguments of :func:`rise` with the following 86 | additional argument and returns a class-specific saliency map for the 87 | given :attr:`target` class(es). 88 | 89 | Args: 90 | target (int, :class:`torch.Tensor`, list, or :class:`np.ndarray`): 91 | target label(s) that can be cast to :class:`torch.long`. 92 | """ 93 | saliency = rise(*args, **kwargs) 94 | assert len(saliency.shape) == 4 95 | if not isinstance(target, torch.Tensor): 96 | target = torch.tensor(target, dtype=torch.long, device=saliency.device) 97 | assert isinstance(target, torch.Tensor) 98 | assert target.dtype == torch.long 99 | assert len(target) == len(saliency) 100 | 101 | class_saliency = torch.cat([saliency[i, t].unsqueeze(0).unsqueeze(1) 102 | for i, t in enumerate(target)], dim=0) 103 | output_shape = list(saliency.shape) 104 | output_shape[1] = 1 105 | assert list(class_saliency.shape) == output_shape 106 | 107 | return class_saliency 108 | 109 | 110 | def rise(model, 111 | input, 112 | target=None, 113 | seed=0, 114 | num_masks=8000, 115 | num_cells=7, 116 | filter_masks=None, 117 | batch_size=32, 118 | p=0.5, 119 | resize=False, 120 | resize_mode='bilinear'): 121 | r"""RISE. 122 | 123 | Args: 124 | model (:class:`torch.nn.Module`): a model. 125 | input (:class:`torch.Tensor`): input tensor. 126 | seed (int, optional): manual seed used to generate random numbers. 127 | Default: ``0``. 128 | num_masks (int, optional): number of RISE random masks to use. 129 | Default: ``8000``. 130 | num_cells (int, optional): number of cells for one spatial dimension 131 | in low-res RISE random mask. Default: ``7``. 132 | filter_masks (:class:`torch.Tensor`, optional): If given, use the 133 | provided pre-computed filter masks. Default: ``None``. 134 | batch_size (int, optional): batch size to use. Default: ``128``. 135 | p (float, optional): with prob p, a low-res cell is set to 0; 136 | otherwise, it's 1. Default: ``0.5``. 137 | resize (bool or tuple of ints, optional): If True, resize saliency map 138 | to size of :attr:`input`. If False, don't resize. If (width, 139 | height) tuple, resize to (width, height). Default: ``False``. 140 | resize_mode (str, optional): If resize is not None, use this mode for 141 | the resize function. Default: ``'bilinear'``. 142 | 143 | Returns: 144 | :class:`torch.Tensor`: RISE saliency map. 145 | """ 146 | with torch.no_grad(): 147 | # Get device of input (i.e., GPU). 148 | dev = input.device 149 | 150 | # Initialize saliency mask and mask normalization term. 151 | input_shape = input.shape 152 | saliency_shape = list(input_shape) 153 | 154 | height = input_shape[2] 155 | width = input_shape[3] 156 | 157 | out = model(input) 158 | num_classes = out.shape[1] 159 | 160 | saliency_shape[1] = num_classes 161 | saliency = torch.zeros(saliency_shape, device=dev) 162 | 163 | # Number of spatial dimensions. 164 | nsd = len(input.shape) - 2 165 | assert nsd == 2 166 | 167 | # Spatial size of low-res grid cell. 168 | cell_size = tuple([int(np.ceil(s / num_cells)) 169 | for s in input_shape[2:]]) 170 | 171 | # Spatial size of upsampled mask with buffer (input size + cell size). 172 | up_size = tuple([input_shape[2 + i] + cell_size[i] 173 | for i in range(nsd)]) 174 | 175 | # Save current random number generator state. 176 | state = torch.get_rng_state() 177 | 178 | # Set seed. 179 | torch.manual_seed(seed) 180 | 181 | if filter_masks is not None: 182 | assert len(filter_masks) == num_masks 183 | 184 | num_chunks = (num_masks + batch_size - 1) // batch_size 185 | for chunk in range(num_chunks): 186 | # Generate RISE random masks on the fly. 187 | mask_bs = min(num_masks - batch_size * chunk, batch_size) 188 | 189 | if filter_masks is None: 190 | # Generate low-res, random binary masks. 191 | grid = (torch.rand(mask_bs, 1, *((num_cells,) * nsd), 192 | device=dev) < p).float() 193 | 194 | # Upsample low-res masks to input shape + buffer. 195 | masks_up = _upsample_reflect(grid, up_size) 196 | 197 | # Save final RISE masks with random shift. 198 | masks = torch.empty(mask_bs, 1, *input_shape[2:], device=dev) 199 | shift_x = torch.randint(0, 200 | cell_size[0], 201 | (mask_bs,), 202 | device='cpu') 203 | shift_y = torch.randint(0, 204 | cell_size[1], 205 | (mask_bs,), 206 | device='cpu') 207 | for i in range(mask_bs): 208 | masks[i] = masks_up[i, 209 | :, 210 | shift_x[i]:shift_x[i] + height, 211 | shift_y[i]:shift_y[i] + width] 212 | else: 213 | masks = filter_masks[ 214 | chunk * batch_size:chunk * batch_size + mask_bs] 215 | 216 | # Accumulate saliency mask. 217 | for i, inp in enumerate(input): 218 | out = torch.sigmoid(model(inp.unsqueeze(0) * masks)) 219 | if len(out.shape) == 4: 220 | # TODO: Consider handling FC outputs more flexibly. 221 | assert out.shape[2] == 1 222 | assert out.shape[3] == 1 223 | out = out[:, :, 0, 0] 224 | sal = torch.matmul(out.data.transpose(0, 1), 225 | masks.view(mask_bs, height * width)) 226 | sal = sal.view((num_classes, height, width)) 227 | saliency[i] = saliency[i] + sal 228 | 229 | # Normalize saliency mask. 230 | saliency /= num_masks 231 | 232 | # Restore original random number generator state. 233 | torch.set_rng_state(state) 234 | 235 | # Resize saliency mask if needed. 236 | saliency = resize_saliency(input, 237 | saliency, 238 | resize, 239 | mode=resize_mode) 240 | return saliency 241 | -------------------------------------------------------------------------------- /TorchRay/torchray/benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | r"""This script provides a few functions for getting and plotting example data. 2 | """ 3 | import os 4 | import torchvision 5 | from matplotlib import pyplot as plt 6 | 7 | from .datasets import * # noqa 8 | from .models import * # noqa 9 | 10 | 11 | def get_example_data(arch='vgg16', shape=224): 12 | """Get example data to demonstrate visualization techniques. 13 | 14 | Args: 15 | arch (str, optional): name of torchvision.models architecture. 16 | Default: ``'vgg16'``. 17 | shape (int or tuple of int, optional): shape to resize input image to. 18 | Default: ``224``. 19 | 20 | Returns: 21 | (:class:`torch.nn.Module`, :class:`torch.Tensor`, int, int): a tuple 22 | containing 23 | 24 | - a convolutional neural network model in evaluation mode. 25 | - a sample input tensor image. 26 | - the ImageNet category id of an object in the image. 27 | - the ImageNet category id of another object in the image. 28 | 29 | """ 30 | 31 | # Get a network pre-trained on ImageNet. 32 | model = torchvision.models.__dict__[arch](pretrained=True) 33 | 34 | # Switch to eval mode to make the visualization deterministic. 35 | model.eval() 36 | 37 | # We do not need grads for the parameters. 38 | for param in model.parameters(): 39 | param.requires_grad_(False) 40 | 41 | # Download an example image from wikimedia. 42 | import requests 43 | from io import BytesIO 44 | from PIL import Image 45 | 46 | url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/7/7f/Arthur_Heyer_-_Dog_and_Cats.jpg/592px-Arthur_Heyer_-_Dog_and_Cats.jpg' 47 | response = requests.get(url) 48 | img = Image.open(BytesIO(response.content)) 49 | 50 | # Pre-process the image and convert into a tensor 51 | transform = torchvision.transforms.Compose([ 52 | torchvision.transforms.Resize(shape), 53 | torchvision.transforms.CenterCrop(shape), 54 | torchvision.transforms.ToTensor(), 55 | torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], 56 | std=[0.229, 0.224, 0.225]), 57 | ]) 58 | 59 | x = transform(img).unsqueeze(0) 60 | 61 | # bulldog category id. 62 | category_id_1 = 245 63 | 64 | # persian cat category id. 65 | category_id_2 = 285 66 | 67 | # Move model and input to device. 68 | from torchray.utils import get_device 69 | dev = get_device() 70 | model = model.to(dev) 71 | x = x.to(dev) 72 | 73 | return model, x, category_id_1, category_id_2 74 | 75 | 76 | def plot_example(input, 77 | saliency, 78 | method, 79 | category_id, 80 | show_plot=False, 81 | save_path=None): 82 | """Plot an example. 83 | 84 | Args: 85 | input (:class:`torch.Tensor`): 4D tensor containing input images. 86 | saliency (:class:`torch.Tensor`): 4D tensor containing saliency maps. 87 | method (str): name of saliency method. 88 | category_id (int): ID of ImageNet category. 89 | show_plot (bool, optional): If True, show plot. Default: ``False``. 90 | save_path (str, optional): Path to save figure to. Default: ``None``. 91 | """ 92 | from torchray.utils import imsc 93 | from torchray.benchmark.datasets import IMAGENET_CLASSES 94 | 95 | if isinstance(category_id, int): 96 | category_id = [category_id] 97 | 98 | batch_size = len(input) 99 | 100 | plt.clf() 101 | for i in range(batch_size): 102 | class_i = category_id[i % len(category_id)] 103 | 104 | plt.subplot(batch_size, 2, 1 + 2 * i) 105 | imsc(input[i]) 106 | plt.title('input image', fontsize=8) 107 | 108 | plt.subplot(batch_size, 2, 2 + 2 * i) 109 | imsc(saliency[i], interpolation='none') 110 | plt.title('{} for category {} ({})'.format( 111 | method, IMAGENET_CLASSES[class_i], class_i), fontsize=8) 112 | 113 | # Save figure if path is specified. 114 | if save_path: 115 | save_dir = os.path.dirname(os.path.abspath(save_path)) 116 | # Create directory if necessary. 117 | if not os.path.exists(save_dir): 118 | os.makedirs(save_dir) 119 | ext = os.path.splitext(save_path)[1].strip('.') 120 | plt.savefig(save_path, format=ext, bbox_inches='tight') 121 | 122 | # Show plot if desired. 123 | if show_plot: 124 | plt.show() 125 | -------------------------------------------------------------------------------- /TorchRay/torchray/benchmark/evaluate_finegrained_gradcam_energy_inside_bbox.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel 7 | import torch.backends.cudnn as cudnn 8 | import torch.optim 9 | import torch.utils.data.distributed 10 | import torchvision.transforms as transforms 11 | import resnet_multigpu_cgc as resnet 12 | import cv2 13 | import datasets as pointing_datasets 14 | 15 | """ 16 | Here, we evaluate the content heatmap (Grad-CAM heatmap within object bounding box) on the fine-grained datasets. 17 | """ 18 | 19 | model_names = ['resnet18', 'resnet50'] 20 | 21 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 22 | parser.add_argument('data', metavar='DIR', help='path to dataset') 23 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 24 | choices=model_names, 25 | help='model architecture: ' + 26 | ' | '.join(model_names) + 27 | ' (default: resnet18)') 28 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', 29 | help='number of data loading workers (default: 16)') 30 | parser.add_argument('-b', '--batch-size', default=256, type=int, 31 | metavar='N', help='mini-batch size (default: 96)') 32 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 33 | help='use pre-trained model') 34 | parser.add_argument('-g', '--num-gpus', default=1, type=int, 35 | metavar='N', help='number of GPUs to match (default: 4)') 36 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 37 | help='path to latest checkpoint (default: none)') 38 | parser.add_argument('--input_resize', default=224, type=int, 39 | metavar='N', help='Resize for smallest side of input (default: 224)') 40 | parser.add_argument('--maxpool', dest='maxpool', action='store_true', 41 | help='use maxpool version of the model') 42 | parser.add_argument('--dataset', type=str, default='imagenet', 43 | help='dataset to use: [imagenet, cub, aircraft, flowers, cars]') 44 | 45 | 46 | def main(): 47 | global args 48 | args = parser.parse_args() 49 | 50 | if args.dataset == 'cub': 51 | num_classes = 200 52 | elif args.dataset == 'aircraft': 53 | num_classes = 90 54 | elif args.dataset == 'flowers': 55 | num_classes = 102 56 | elif args.dataset == 'cars': 57 | num_classes = 196 58 | 59 | print("=> creating model '{}' for '{}'".format(args.arch, args.dataset)) 60 | if args.arch.startswith('resnet'): 61 | model = resnet.__dict__[args.arch](num_classes=num_classes) 62 | else: 63 | print('Other archs not supported') 64 | exit() 65 | model = torch.nn.DataParallel(model).cuda() 66 | 67 | if args.resume: 68 | print("=> loading checkpoint '{}'".format(args.resume)) 69 | checkpoint = torch.load(args.resume) 70 | if 'state_dict' in checkpoint: 71 | model.load_state_dict(checkpoint['state_dict']) 72 | elif 'model' in checkpoint: 73 | model.load_state_dict(checkpoint['model']) 74 | else: 75 | print('Checkpoint format not supported') 76 | exit() 77 | 78 | if (not args.resume) and (not args.pretrained): 79 | assert False, "Please specify either the pre-trained model or checkpoint for evaluation" 80 | 81 | cudnn.benchmark = True 82 | 83 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 84 | std=[0.229, 0.224, 0.225]) 85 | 86 | # In the version, we will not resize the images. We feed the full image and use AdaptivePooling before FC. 87 | # We will resize Gradcam heatmap to image size and compare the actual bbox co-ordinates 88 | val_dataset = pointing_datasets.ImageNetDetection(args.data, 89 | transform=transforms.Compose([ 90 | transforms.Resize(args.input_resize), 91 | transforms.ToTensor(), 92 | normalize, 93 | ])) 94 | 95 | # we set batch size=1 since we are loading full resolution images. 96 | val_loader = torch.utils.data.DataLoader( 97 | val_dataset, batch_size=1, shuffle=False, 98 | num_workers=args.workers, pin_memory=True) 99 | 100 | validate_multi(val_loader, val_dataset, model) 101 | 102 | 103 | def validate_multi(val_loader, val_dataset, model): 104 | batch_time = AverageMeter() 105 | heatmap_inside_bbox = AverageMeter() 106 | 107 | # switch to evaluate mode 108 | model.eval() 109 | 110 | zero_count = 0 111 | total_count = 0 112 | end = time.time() 113 | for i, (images, annotation, targets) in enumerate(val_loader): 114 | total_count += 1 115 | images = images.cuda(non_blocking=True) 116 | targets = targets.cuda(non_blocking=True) 117 | 118 | # we assume batch size == 1 and unwrap the first elem of every list in annotation object 119 | annotation = unwrap_dict(annotation) 120 | image_size = val_dataset.as_image_size(annotation) 121 | 122 | output, feats = model(images, return_feats=True) 123 | output_gradcam = compute_gradcam(output, feats, targets) 124 | output_gradcam_np = output_gradcam.data.cpu().numpy()[0] # since we have batch size==1 125 | resized_output_gradcam = cv2.resize(output_gradcam_np, image_size) 126 | spatial_sum = resized_output_gradcam.sum() 127 | if spatial_sum <= 0: 128 | zero_count += 1 129 | continue 130 | 131 | # resized_output_gradcam is now normalized and can be considered as probabilities 132 | resized_output_gradcam = resized_output_gradcam / spatial_sum 133 | 134 | mask = pointing_datasets.imagenet_as_mask(annotation, targets[0].item()) 135 | 136 | mask = mask.type(torch.ByteTensor) 137 | mask = mask.cpu().data.numpy() 138 | 139 | gcam_inside_gt_mask = mask * resized_output_gradcam 140 | # Now we sum the heatmap inside the object bounding box 141 | total_gcam_inside_gt_mask = gcam_inside_gt_mask.sum() 142 | heatmap_inside_bbox.update(total_gcam_inside_gt_mask) 143 | 144 | if i % 1000 == 0: 145 | print('\nResults after {} examples: '.format(i+1)) 146 | print('Curr % of heatmap inside bbox: {:.4f} ({:.4f})'.format(heatmap_inside_bbox.val * 100, 147 | heatmap_inside_bbox.avg * 100)) 148 | 149 | # measure elapsed time 150 | batch_time.update(time.time() - end) 151 | end = time.time() 152 | 153 | print('\nFinal Results - ') 154 | print('\n\n% of heatmap inside bbox: {:.4f}'.format(heatmap_inside_bbox.avg * 100)) 155 | print('Zero GC found for {}/{} samples'.format(zero_count, total_count)) 156 | 157 | return 158 | 159 | 160 | def compute_gradcam(output, feats, target): 161 | """ 162 | Compute the gradcam for the top predicted category 163 | :param output: 164 | :param feats: 165 | :return: 166 | """ 167 | eps = 1e-8 168 | relu = nn.ReLU(inplace=True) 169 | 170 | target = target.cpu().numpy() 171 | # target = np.argmax(output.cpu().data.numpy(), axis=-1) 172 | one_hot = np.zeros((output.shape[0], output.shape[-1]), dtype=np.float32) 173 | indices_range = np.arange(output.shape[0]) 174 | one_hot[indices_range, target[indices_range]] = 1 175 | one_hot = torch.from_numpy(one_hot) 176 | one_hot.requires_grad = True 177 | 178 | # Compute the Grad-CAM for the original image 179 | one_hot_cuda = torch.sum(one_hot.cuda() * output) 180 | dy_dz1, = torch.autograd.grad(one_hot_cuda, feats, grad_outputs=torch.ones(one_hot_cuda.size()).cuda(), 181 | retain_graph=True, create_graph=True) 182 | dy_dz_sum1 = dy_dz1.sum(dim=2).sum(dim=2) 183 | gcam512_1 = dy_dz_sum1.unsqueeze(-1).unsqueeze(-1) * feats 184 | gradcam = gcam512_1.sum(dim=1) 185 | gradcam = relu(gradcam) 186 | spatial_sum1 = gradcam.sum(dim=[1, 2]).unsqueeze(-1).unsqueeze(-1) 187 | gradcam = (gradcam / (spatial_sum1 + eps)) + eps 188 | 189 | return gradcam 190 | 191 | 192 | def unwrap_dict(dict_object): 193 | new_dict = {} 194 | for k, v in dict_object.items(): 195 | if k == 'object': 196 | new_v_list = [] 197 | for elem in v: 198 | new_v_list.append(unwrap_dict(elem)) 199 | new_dict[k] = new_v_list 200 | continue 201 | if isinstance(v, dict): 202 | new_v = unwrap_dict(v) 203 | elif isinstance(v, list) and len(v) == 1: 204 | new_v = v[0] 205 | else: 206 | new_v = v 207 | new_dict[k] = new_v 208 | return new_dict 209 | 210 | 211 | class AverageMeter(object): 212 | """Computes and stores the average and current value""" 213 | def __init__(self): 214 | self.reset() 215 | 216 | def reset(self): 217 | self.val = 0 218 | self.avg = 0 219 | self.sum = 0 220 | self.count = 0 221 | 222 | def update(self, val, n=1): 223 | self.val = val 224 | self.sum += val * n 225 | self.count += n 226 | self.avg = self.sum / self.count 227 | 228 | 229 | if __name__ == '__main__': 230 | main() 231 | -------------------------------------------------------------------------------- /TorchRay/torchray/benchmark/evaluate_imagenet_excitation_backprop_energy_inside_bbox.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel 7 | import torch.backends.cudnn as cudnn 8 | import torch.optim 9 | import torch.utils.data.distributed 10 | import torchvision.transforms as transforms 11 | import torchvision.models.resnet as resnet 12 | import datasets as pointing_datasets 13 | from torchray.attribution.excitation_backprop import contrastive_excitation_backprop 14 | from torchray.attribution.excitation_backprop import update_resnet 15 | from models import resnet_to_fc 16 | from torchray.attribution.common import get_pointing_gradient 17 | 18 | 19 | """ 20 | Here, we evaluate the content heatmap (Excitation Backprop heatmap within object bounding box) on imagenet dataset. 21 | """ 22 | 23 | model_names = ['resnet18', 'resnet50'] 24 | 25 | parser = argparse.ArgumentParser(description='Pointing game evaluation for ImageNet using Contrastive Excitation Backprop') 26 | parser.add_argument('data', metavar='DIR', help='path to dataset') 27 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 28 | choices=model_names, 29 | help='model architecture: ' + 30 | ' | '.join(model_names) + 31 | ' (default: resnet18)') 32 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', 33 | help='number of data loading workers (default: 16)') 34 | parser.add_argument('-b', '--batch-size', default=256, type=int, 35 | metavar='N', help='mini-batch size (default: 96)') 36 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 37 | help='use pre-trained model') 38 | parser.add_argument('-g', '--num-gpus', default=1, type=int, 39 | metavar='N', help='number of GPUs to match (default: 4)') 40 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 41 | help='path to latest checkpoint (default: none)') 42 | parser.add_argument('--input_resize', default=224, type=int, 43 | metavar='N', help='Resize for smallest side of input (default: 224)') 44 | 45 | 46 | def main(): 47 | global args 48 | args = parser.parse_args() 49 | 50 | if args.pretrained: 51 | print("=> using pre-trained model '{}'".format(args.arch)) 52 | model = resnet.__dict__[args.arch](pretrained=True) 53 | else: 54 | print("=> creating model '{}'".format(args.arch)) 55 | model = resnet.__dict__[args.arch]() 56 | model = torch.nn.DataParallel(model) 57 | 58 | if args.resume: 59 | print("=> loading checkpoint '{}'".format(args.resume)) 60 | checkpoint = torch.load(args.resume) 61 | model.load_state_dict(checkpoint['state_dict']) 62 | 63 | if (not args.resume) and (not args.pretrained): 64 | assert False, "Please specify either the pre-trained model or checkpoint for evaluation" 65 | 66 | model = model._modules['module'] 67 | 68 | model = resnet_to_fc(model) 69 | model.avgpool = torch.nn.AvgPool2d((7, 7), stride=1) 70 | model = update_resnet(model, debug=True) 71 | model = model.cuda() 72 | cudnn.benchmark = False 73 | 74 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 75 | std=[0.229, 0.224, 0.225]) 76 | 77 | # In the first version, we will not resize the images. We feed the full image and use AdaptivePooling before FC. 78 | # We will resize Gradcam heatmap to image size and compare the actual bbox co-ordinates 79 | val_dataset = pointing_datasets.ImageNetDetection(args.data, 80 | transform=transforms.Compose([ 81 | # transforms.Resize((224, 224)), 82 | transforms.Resize(args.input_resize), 83 | transforms.ToTensor(), 84 | normalize, 85 | ])) 86 | 87 | # we set batch size=1 since we are loading full resolution images. 88 | val_loader = torch.utils.data.DataLoader( 89 | val_dataset, batch_size=1, shuffle=False, 90 | num_workers=args.workers, pin_memory=True) 91 | 92 | validate_multi(val_loader, val_dataset, model) 93 | 94 | 95 | def validate_multi(val_loader, val_dataset, model): 96 | batch_time = AverageMeter() 97 | heatmap_inside_gt_mask = AverageMeter() 98 | 99 | # switch to evaluate mode 100 | model.eval() 101 | # We do not need grads for the parameters. 102 | for param in model.parameters(): 103 | param.requires_grad_(False) 104 | 105 | # prepare vis layer, contrast layer and probes 106 | contrast_layer = 'avgpool' 107 | 108 | zero_saliency_count = 0 109 | total_count = 0 110 | end = time.time() 111 | for i, (images, annotation, targets) in enumerate(val_loader): 112 | total_count += 1 113 | images = images.cuda(non_blocking=True) 114 | targets = targets.cuda(non_blocking=True) 115 | # we assume batch size == 1 and unwrap the first elem of every list in annotation object 116 | annotation = unwrap_dict(annotation) 117 | image_size = val_dataset.as_image_size(annotation) 118 | 119 | class_id = targets[0].item() 120 | 121 | saliency = contrastive_excitation_backprop(model, images, class_id, 122 | saliency_layer='layer3', 123 | contrast_layer=contrast_layer, 124 | resize=image_size, 125 | get_backward_gradient=get_pointing_gradient 126 | ) 127 | saliency = saliency.squeeze() # since we have batch size==1 128 | 129 | resized_saliency = saliency.data.cpu().numpy() 130 | 131 | if np.isnan(resized_saliency).any(): 132 | zero_saliency_count += 1 133 | continue 134 | spatial_sum = resized_saliency.sum() 135 | if spatial_sum <= 0: 136 | zero_saliency_count += 1 137 | continue 138 | resized_saliency = resized_saliency / spatial_sum 139 | 140 | # Now, we obtain the mask corresponding to the ground truth bounding boxes 141 | # Skip if all boxes for class_id are marked difficult. 142 | objs = annotation['annotation']['object'] 143 | if not isinstance(objs, list): 144 | objs = [objs] 145 | objs = [obj for obj in objs if pointing_datasets._IMAGENET_CLASS_TO_INDEX[obj['name']] == class_id] 146 | if all([bool(int(obj['difficult'])) for obj in objs]): 147 | continue 148 | gt_mask = pointing_datasets.imagenet_as_mask(annotation, class_id) 149 | gt_mask = gt_mask.type(torch.ByteTensor) 150 | gt_mask = gt_mask.cpu().data.numpy() 151 | gcam_inside_gt_mask = gt_mask * resized_saliency 152 | total_gcam_inside_gt_mask = gcam_inside_gt_mask.sum() 153 | heatmap_inside_gt_mask.update(total_gcam_inside_gt_mask) 154 | 155 | if i % 1000 == 0: 156 | print('\nCurr % of heatmap inside GT mask: {:.4f} ({:.4f})'.format(heatmap_inside_gt_mask.val * 100, 157 | heatmap_inside_gt_mask.avg * 100)) 158 | 159 | # measure elapsed time 160 | batch_time.update(time.time() - end) 161 | end = time.time() 162 | 163 | print('\n\n% of heatmap inside GT mask: {:.4f}'.format(heatmap_inside_gt_mask.avg * 100)) 164 | print('\n Zero Saliency found for {} / {} images.'.format(zero_saliency_count, total_count)) 165 | 166 | return 167 | 168 | 169 | def compute_gradcam(output, feats, target): 170 | """ 171 | Compute the gradcam for the top predicted category 172 | :param output: 173 | :param feats: 174 | :return: 175 | """ 176 | eps = 1e-8 177 | relu = nn.ReLU(inplace=True) 178 | 179 | target = target.cpu().numpy() 180 | # target = np.argmax(output.cpu().data.numpy(), axis=-1) 181 | one_hot = np.zeros((output.shape[0], output.shape[-1]), dtype=np.float32) 182 | indices_range = np.arange(output.shape[0]) 183 | one_hot[indices_range, target[indices_range]] = 1 184 | one_hot = torch.from_numpy(one_hot) 185 | one_hot.requires_grad = True 186 | 187 | # Compute the Grad-CAM for the original image 188 | one_hot_cuda = torch.sum(one_hot.cuda() * output) 189 | dy_dz1, = torch.autograd.grad(one_hot_cuda, feats, grad_outputs=torch.ones(one_hot_cuda.size()).cuda(), 190 | retain_graph=True, create_graph=True) 191 | gcam512_1 = dy_dz1 * feats 192 | gradcam = gcam512_1.sum(dim=1) 193 | gradcam = relu(gradcam) 194 | spatial_sum1 = gradcam.sum(dim=[1, 2]).unsqueeze(-1).unsqueeze(-1) 195 | gradcam = (gradcam / (spatial_sum1 + eps)) + eps 196 | 197 | return gradcam 198 | 199 | 200 | def unwrap_dict(dict_object): 201 | new_dict = {} 202 | for k, v in dict_object.items(): 203 | if k == 'object': 204 | new_v_list = [] 205 | for elem in v: 206 | new_v_list.append(unwrap_dict(elem)) 207 | new_dict[k] = new_v_list 208 | continue 209 | if isinstance(v, dict): 210 | new_v = unwrap_dict(v) 211 | elif isinstance(v, list) and len(v) == 1: 212 | new_v = v[0] 213 | # if isinstance(new_v, dict): 214 | # new_v = unwrap_dict(new_v) 215 | else: 216 | new_v = v 217 | new_dict[k] = new_v 218 | return new_dict 219 | 220 | 221 | class AverageMeter(object): 222 | """Computes and stores the average and current value""" 223 | def __init__(self): 224 | self.reset() 225 | 226 | def reset(self): 227 | self.val = 0 228 | self.avg = 0 229 | self.sum = 0 230 | self.count = 0 231 | 232 | def update(self, val, n=1): 233 | self.val = val 234 | self.sum += val * n 235 | self.count += n 236 | self.avg = self.sum / self.count 237 | 238 | 239 | if __name__ == '__main__': 240 | main() 241 | -------------------------------------------------------------------------------- /TorchRay/torchray/benchmark/evaluate_imagenet_gradcam_energy_inside_bbox.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel 7 | import torch.backends.cudnn as cudnn 8 | import torch.optim 9 | import torch.utils.data.distributed 10 | import torchvision.transforms as transforms 11 | import resnet_multigpu_cgc as resnet 12 | import cv2 13 | import datasets as pointing_datasets 14 | 15 | """ 16 | Here, we evaluate the content heatmap (Grad-CAM heatmap within object bounding box) on the imagenet dataset. 17 | """ 18 | 19 | model_names = ['resnet18', 'resnet50'] 20 | 21 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 22 | parser.add_argument('data', metavar='DIR', help='path to dataset') 23 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 24 | choices=model_names, 25 | help='model architecture: ' + 26 | ' | '.join(model_names) + 27 | ' (default: resnet18)') 28 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', 29 | help='number of data loading workers (default: 16)') 30 | parser.add_argument('-b', '--batch-size', default=256, type=int, 31 | metavar='N', help='mini-batch size (default: 96)') 32 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 33 | help='use pre-trained model') 34 | parser.add_argument('-g', '--num-gpus', default=1, type=int, 35 | metavar='N', help='number of GPUs to match (default: 4)') 36 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 37 | help='path to latest checkpoint (default: none)') 38 | parser.add_argument('--input_resize', default=224, type=int, 39 | metavar='N', help='Resize for smallest side of input (default: 224)') 40 | 41 | 42 | def main(): 43 | global args 44 | args = parser.parse_args() 45 | 46 | if args.pretrained: 47 | print("=> using pre-trained model '{}'".format(args.arch)) 48 | if args.arch.startswith('resnet'): 49 | model = resnet.__dict__[args.arch](pretrained=True) 50 | else: 51 | assert False, 'Unsupported architecture: {}'.format(args.arch) 52 | else: 53 | print("=> creating model '{}'".format(args.arch)) 54 | if args.arch.startswith('resnet'): 55 | model = resnet.__dict__[args.arch]() 56 | 57 | model = torch.nn.DataParallel(model).cuda() 58 | 59 | if args.resume: 60 | print("=> loading checkpoint '{}'".format(args.resume)) 61 | checkpoint = torch.load(args.resume) 62 | model.load_state_dict(checkpoint['state_dict']) 63 | 64 | if (not args.resume) and (not args.pretrained): 65 | assert False, "Please specify either the pre-trained model or checkpoint for evaluation" 66 | 67 | cudnn.benchmark = True 68 | 69 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 70 | 71 | # Here, we don't resize the images. We feed the full image and use AdaptivePooling before FC. 72 | # We will resize Gradcam heatmap to image size and compare the actual bbox co-ordinates 73 | val_dataset = pointing_datasets.ImageNetDetection(args.data, 74 | transform=transforms.Compose([ 75 | transforms.Resize(args.input_resize), 76 | transforms.ToTensor(), 77 | normalize, 78 | ])) 79 | 80 | # we set batch size=1 since we are loading full resolution images. 81 | val_loader = torch.utils.data.DataLoader( 82 | val_dataset, batch_size=1, shuffle=False, 83 | num_workers=args.workers, pin_memory=True) 84 | 85 | validate_multi(val_loader, val_dataset, model) 86 | 87 | 88 | def validate_multi(val_loader, val_dataset, model): 89 | batch_time = AverageMeter() 90 | heatmap_inside_bbox = AverageMeter() 91 | 92 | # switch to evaluate mode 93 | model.eval() 94 | 95 | end = time.time() 96 | for i, (images, annotation, targets) in enumerate(val_loader): 97 | images = images.cuda(non_blocking=True) 98 | targets = targets.cuda(non_blocking=True) 99 | 100 | # we assume batch size == 1 and unwrap the first elem of every list in annotation object 101 | annotation = unwrap_dict(annotation) 102 | image_size = val_dataset.as_image_size(annotation) 103 | 104 | output, feats = model(images, vanilla_with_feats=True) 105 | output_gradcam = compute_gradcam(output, feats, targets) 106 | output_gradcam_np = output_gradcam.data.cpu().numpy()[0] # since we have batch size==1 107 | resized_output_gradcam = cv2.resize(output_gradcam_np, image_size) 108 | spatial_sum = resized_output_gradcam.sum() 109 | if spatial_sum <= 0: 110 | # We ignore images with zero Grad-CAM 111 | continue 112 | 113 | # resized_output_gradcam is now normalized and can be considered as probabilities 114 | resized_output_gradcam = resized_output_gradcam / spatial_sum 115 | 116 | mask = pointing_datasets.imagenet_as_mask(annotation, targets[0].item()) 117 | 118 | mask = mask.type(torch.ByteTensor) 119 | mask = mask.cpu().data.numpy() 120 | 121 | gcam_inside_gt_mask = mask * resized_output_gradcam 122 | 123 | # Now we sum the heatmap inside the object bounding box 124 | total_gcam_inside_gt_mask = gcam_inside_gt_mask.sum() 125 | heatmap_inside_bbox.update(total_gcam_inside_gt_mask) 126 | 127 | if i % 1000 == 0: 128 | print('\nResults after {} examples: '.format(i+1)) 129 | print('Curr % of heatmap inside bbox: {:.4f} ({:.4f})'.format(heatmap_inside_bbox.val * 100, 130 | heatmap_inside_bbox.avg * 100)) 131 | 132 | # measure elapsed time 133 | batch_time.update(time.time() - end) 134 | end = time.time() 135 | 136 | print('\nFinal Results - ') 137 | print('\n\n% of heatmap inside bbox: {:.4f}'.format(heatmap_inside_bbox.avg * 100)) 138 | 139 | return 140 | 141 | 142 | def compute_gradcam(output, feats, target): 143 | """ 144 | Compute the gradcam for the top predicted category 145 | :param output: 146 | :param feats: 147 | :param target: 148 | :return: 149 | """ 150 | eps = 1e-8 151 | relu = nn.ReLU(inplace=True) 152 | 153 | target = target.cpu().numpy() 154 | one_hot = np.zeros((output.shape[0], output.shape[-1]), dtype=np.float32) 155 | indices_range = np.arange(output.shape[0]) 156 | one_hot[indices_range, target[indices_range]] = 1 157 | one_hot = torch.from_numpy(one_hot) 158 | one_hot.requires_grad = True 159 | 160 | # Compute the Grad-CAM for the original image 161 | one_hot_cuda = torch.sum(one_hot.cuda() * output) 162 | dy_dz1, = torch.autograd.grad(one_hot_cuda, feats, grad_outputs=torch.ones(one_hot_cuda.size()).cuda(), 163 | retain_graph=True, create_graph=True) 164 | # Changing to dot product of grad and features to preserve grad spatial locations 165 | gcam512_1 = dy_dz1 * feats 166 | gradcam = gcam512_1.sum(dim=1) 167 | gradcam = relu(gradcam) 168 | spatial_sum1 = gradcam.sum(dim=[1, 2]).unsqueeze(-1).unsqueeze(-1) 169 | gradcam = (gradcam / (spatial_sum1 + eps)) + eps 170 | 171 | return gradcam 172 | 173 | 174 | def unwrap_dict(dict_object): 175 | new_dict = {} 176 | for k, v in dict_object.items(): 177 | if k == 'object': 178 | new_v_list = [] 179 | for elem in v: 180 | new_v_list.append(unwrap_dict(elem)) 181 | new_dict[k] = new_v_list 182 | continue 183 | if isinstance(v, dict): 184 | new_v = unwrap_dict(v) 185 | elif isinstance(v, list) and len(v) == 1: 186 | new_v = v[0] 187 | else: 188 | new_v = v 189 | new_dict[k] = new_v 190 | return new_dict 191 | 192 | 193 | class AverageMeter(object): 194 | """Computes and stores the average and current value""" 195 | def __init__(self): 196 | self.reset() 197 | 198 | def reset(self): 199 | self.val = 0 200 | self.avg = 0 201 | self.sum = 0 202 | self.count = 0 203 | 204 | def update(self, val, n=1): 205 | self.val = val 206 | self.sum += val * n 207 | self.count += n 208 | self.avg = self.sum / self.count 209 | 210 | 211 | if __name__ == '__main__': 212 | main() 213 | -------------------------------------------------------------------------------- /TorchRay/torchray/benchmark/evaluate_swav_imagenet_gradcam_energy_inside_bbox.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel 7 | import torch.backends.cudnn as cudnn 8 | import torch.optim 9 | import torch.utils.data.distributed 10 | import torchvision.transforms as transforms 11 | import swav_resnet_cgc as resnet 12 | import os 13 | import cv2 14 | import datasets as pointing_datasets 15 | 16 | 17 | """ 18 | Here, we evaluate the content heatmap (Grad-CAM heatmap within object bounding box) on the imagenet dataset. 19 | """ 20 | 21 | model_names = ['resnet18', 'resnet50'] 22 | 23 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 24 | parser.add_argument('data', metavar='DIR', help='path to dataset') 25 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 26 | choices=model_names, 27 | help='model architecture: ' + 28 | ' | '.join(model_names) + 29 | ' (default: resnet18)') 30 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', 31 | help='number of data loading workers (default: 16)') 32 | parser.add_argument('-b', '--batch-size', default=256, type=int, 33 | metavar='N', help='mini-batch size (default: 96)') 34 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 35 | help='use pre-trained model') 36 | parser.add_argument('-g', '--num-gpus', default=1, type=int, 37 | metavar='N', help='number of GPUs to match (default: 4)') 38 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 39 | help='path to latest checkpoint (default: none)') 40 | parser.add_argument('--input_resize', default=224, type=int, 41 | metavar='N', help='Resize for smallest side of input (default: 224)') 42 | parser.add_argument('--maxpool', dest='maxpool', action='store_true', 43 | help='use maxpool version of the model') 44 | 45 | 46 | def main(): 47 | global args 48 | args = parser.parse_args() 49 | 50 | if args.pretrained: 51 | print("=> using pre-trained model '{}'".format(args.arch)) 52 | if args.arch.startswith('resnet'): 53 | model = resnet.__dict__[args.arch](pretrained=True) 54 | else: 55 | assert False, 'Unsupported architecture: {}'.format(args.arch) 56 | else: 57 | print("=> creating model '{}'".format(args.arch)) 58 | if args.arch.startswith('resnet'): 59 | model = resnet.__dict__[args.arch]() 60 | 61 | if args.resume: 62 | if os.path.isfile(args.resume): 63 | print("=> loading checkpoint '{}'".format(args.resume)) 64 | state_dict = torch.load(args.resume) 65 | if 'state_dict' in state_dict: 66 | state_dict = state_dict['state_dict'] 67 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 68 | # pdb.set_trace() 69 | for k, v in model.state_dict().items(): 70 | if k not in list(state_dict): 71 | print('key "{}" could not be found in provided state dict'.format(k)) 72 | elif state_dict[k].shape != v.shape: 73 | print('key "{}" is of different shape in model and provided state dict'.format(k)) 74 | state_dict[k] = v 75 | model.load_state_dict(state_dict, strict=False) 76 | else: 77 | print("=> no checkpoint found at '{}'".format(args.resume)) 78 | 79 | model = torch.nn.DataParallel(model).cuda() 80 | 81 | if (not args.resume) and (not args.pretrained): 82 | assert False, "Please specify either the pre-trained model or checkpoint for evaluation" 83 | 84 | cudnn.benchmark = True 85 | 86 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 87 | std=[0.229, 0.224, 0.225]) 88 | 89 | # In the version, we will not resize the images. We feed the full image and use AdaptivePooling before FC. 90 | # We will resize Gradcam heatmap to image size and compare the actual bbox co-ordinates 91 | val_dataset = pointing_datasets.ImageNetDetection(args.data, 92 | transform=transforms.Compose([ 93 | transforms.Resize(args.input_resize), 94 | transforms.ToTensor(), 95 | normalize, 96 | ])) 97 | 98 | # we set batch size=1 since we are loading full resolution images. 99 | val_loader = torch.utils.data.DataLoader( 100 | val_dataset, batch_size=1, shuffle=False, 101 | num_workers=args.workers, pin_memory=True) 102 | 103 | validate_multi(val_loader, val_dataset, model) 104 | 105 | 106 | def validate_multi(val_loader, val_dataset, model): 107 | batch_time = AverageMeter() 108 | heatmap_inside_bbox = AverageMeter() 109 | 110 | # switch to evaluate mode 111 | model.eval() 112 | 113 | zero_count = 0 114 | total_count = 0 115 | end = time.time() 116 | for i, (images, annotation, targets) in enumerate(val_loader): 117 | total_count += 1 118 | images = images.cuda(non_blocking=True) 119 | targets = targets.cuda(non_blocking=True) 120 | 121 | # we assume batch size == 1 and unwrap the first elem of every list in annotation object 122 | annotation = unwrap_dict(annotation) 123 | image_size = val_dataset.as_image_size(annotation) 124 | 125 | output, feats = model(images, vanilla_with_feats=True) 126 | output_gradcam = compute_gradcam(output, feats, targets) 127 | output_gradcam_np = output_gradcam.data.cpu().numpy()[0] # since we have batch size==1 128 | resized_output_gradcam = cv2.resize(output_gradcam_np, image_size) 129 | spatial_sum = resized_output_gradcam.sum() 130 | if spatial_sum <= 0: 131 | zero_count += 1 132 | continue 133 | 134 | # resized_output_gradcam is now normalized and can be considered as probabilities 135 | resized_output_gradcam = resized_output_gradcam / spatial_sum 136 | 137 | mask = pointing_datasets.imagenet_as_mask(annotation, targets[0].item()) 138 | 139 | mask = mask.type(torch.ByteTensor) 140 | mask = mask.cpu().data.numpy() 141 | 142 | gcam_inside_gt_mask = mask * resized_output_gradcam 143 | # Now we sum the heatmap inside the object bounding box 144 | total_gcam_inside_gt_mask = gcam_inside_gt_mask.sum() 145 | heatmap_inside_bbox.update(total_gcam_inside_gt_mask) 146 | 147 | if i % 1000 == 0: 148 | print('\nResults after {} examples: '.format(i+1)) 149 | print('Curr % of heatmap inside bbox: {:.4f} ({:.4f})'.format(heatmap_inside_bbox.val * 100, 150 | heatmap_inside_bbox.avg * 100)) 151 | 152 | # measure elapsed time 153 | batch_time.update(time.time() - end) 154 | end = time.time() 155 | 156 | print('\nFinal Results - ') 157 | print('\n\n% of heatmap inside bbox: {:.4f}'.format(heatmap_inside_bbox.avg * 100)) 158 | print('Zero GC found for {}/{} samples'.format(zero_count, total_count)) 159 | 160 | return 161 | 162 | 163 | def compute_gradcam(output, feats, target): 164 | """ 165 | Compute the gradcam for the top predicted category 166 | :param output: 167 | :param feats: 168 | :return: 169 | """ 170 | eps = 1e-8 171 | relu = nn.ReLU(inplace=True) 172 | 173 | target = target.cpu().numpy() 174 | # target = np.argmax(output.cpu().data.numpy(), axis=-1) 175 | one_hot = np.zeros((output.shape[0], output.shape[-1]), dtype=np.float32) 176 | indices_range = np.arange(output.shape[0]) 177 | one_hot[indices_range, target[indices_range]] = 1 178 | one_hot = torch.from_numpy(one_hot) 179 | one_hot.requires_grad = True 180 | 181 | # Compute the Grad-CAM for the original image 182 | one_hot_cuda = torch.sum(one_hot.cuda() * output) 183 | dy_dz1, = torch.autograd.grad(one_hot_cuda, feats, grad_outputs=torch.ones(one_hot_cuda.size()).cuda(), 184 | retain_graph=True, create_graph=True) 185 | dy_dz_sum1 = dy_dz1.sum(dim=2).sum(dim=2) 186 | gcam512_1 = dy_dz_sum1.unsqueeze(-1).unsqueeze(-1) * feats 187 | # Comment the above 2 lines and uncomment the below one to change to dot product of grad and features to preserve grad spatial locations 188 | # gcam512_1 = dy_dz1 * feats 189 | gradcam = gcam512_1.sum(dim=1) 190 | gradcam = relu(gradcam) 191 | spatial_sum1 = gradcam.sum(dim=[1, 2]).unsqueeze(-1).unsqueeze(-1) 192 | gradcam = (gradcam / (spatial_sum1 + eps)) + eps 193 | 194 | return gradcam 195 | 196 | 197 | def unwrap_dict(dict_object): 198 | new_dict = {} 199 | for k, v in dict_object.items(): 200 | if k == 'object': 201 | new_v_list = [] 202 | for elem in v: 203 | new_v_list.append(unwrap_dict(elem)) 204 | new_dict[k] = new_v_list 205 | continue 206 | if isinstance(v, dict): 207 | new_v = unwrap_dict(v) 208 | elif isinstance(v, list) and len(v) == 1: 209 | new_v = v[0] 210 | # if isinstance(new_v, dict): 211 | # new_v = unwrap_dict(new_v) 212 | else: 213 | new_v = v 214 | new_dict[k] = new_v 215 | return new_dict 216 | 217 | 218 | class AverageMeter(object): 219 | """Computes and stores the average and current value""" 220 | def __init__(self): 221 | self.reset() 222 | 223 | def reset(self): 224 | self.val = 0 225 | self.avg = 0 226 | self.sum = 0 227 | self.count = 0 228 | 229 | def update(self, val, n=1): 230 | self.val = val 231 | self.sum += val * n 232 | self.count += n 233 | self.avg = self.sum / self.count 234 | 235 | 236 | if __name__ == '__main__': 237 | main() 238 | -------------------------------------------------------------------------------- /TorchRay/torchray/benchmark/logging_mongo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | r""" 4 | This module provides function that to be log information (e.g., benchmark 5 | results) to a MongoDB database. 6 | 7 | See :mod:`examples.standard_suite` for an example of how to use MongoDB for 8 | logging benchmark results. 9 | 10 | To start a MongoDB server, use 11 | 12 | .. code:: shell 13 | 14 | $ python -m torchray.benchmark.server 15 | 16 | """ 17 | import io 18 | import pickle 19 | 20 | import bson 21 | import numpy as np 22 | import pymongo 23 | import torch 24 | 25 | from torchray.utils import get_config 26 | 27 | __all__ = [ 28 | 'mongo_connect', 29 | 'mongo_save', 30 | 'mongo_load', 31 | 'data_to_mongo', 32 | 'data_from_mongo', 33 | 'last_lines' 34 | ] 35 | 36 | _MONGO_MAX_TRIES = 10 37 | 38 | 39 | def mongo_connect(database): 40 | """ 41 | Connect to MongoDB server and and return a 42 | :class:`pymongo.database.Database` object. 43 | 44 | Args: 45 | database (str): name of database. 46 | 47 | Returns: 48 | :class:`pymongo.database.Database`: database. 49 | """ 50 | try: 51 | config = get_config() 52 | hostname = f"{config['mongo']['hostname']}:{config['mongo']['port']}" 53 | client = pymongo.MongoClient(hostname) 54 | client.server_info() 55 | database = client[database] 56 | return database 57 | except pymongo.errors.ServerSelectionTimeoutError as error: 58 | raise Exception( 59 | f"Cannot connect MonogDB at {hostname}") from error 60 | 61 | 62 | def mongo_save(database, collection_key, id_key, data): 63 | """Save results to MongoDB database. 64 | 65 | Args: 66 | database (:class:`pymongo.database.Database`): MongoDB database to save 67 | results to. 68 | collection_key (str): name of collection. 69 | id_key (str): id key with which to store :attr:`data`. 70 | data (:class:`bson.binary.Binary` or dict): data to store in 71 | :attr:`db`. 72 | """ 73 | collection = database[collection_key].with_options( 74 | write_concern=pymongo.WriteConcern(w=1)) 75 | tries_left = _MONGO_MAX_TRIES 76 | while tries_left > 0: 77 | tries_left -= 1 78 | try: 79 | collection.replace_one( 80 | {'_id': id_key}, 81 | data, 82 | upsert=True 83 | ) 84 | return 85 | except (pymongo.errors.WriteConcernError, pymongo.errors.WriteError): 86 | if tries_left == 0: 87 | print( 88 | f"Warning: could not write entry to mongodb after" 89 | f" {_MONGO_MAX_TRIES} attempts." 90 | ) 91 | raise 92 | 93 | 94 | def mongo_load(database, collection_key, id_key): 95 | """Load data from MongoDB database. 96 | 97 | Args: 98 | database (:class:`pymongo.database.Database`): MongoDB database to save 99 | results to. 100 | collection_key (str): name of collection. 101 | id_key (str): id key to look up data. 102 | 103 | Returns: 104 | retrieved data (returns None if no data with :attr:`id_key` is found). 105 | """ 106 | return database[collection_key].find_one({'_id': id_key}) 107 | 108 | 109 | def data_to_mongo(data): 110 | """Prepare data to be stored in a MongoDB database. 111 | 112 | Args: 113 | data (dict, :class:`torch.Tensor`, or :class:`np.ndarray`): data to 114 | prepare for storage in a MongoDB dataset (if dict, items are 115 | recursively prepared for storage). If the underlying data is 116 | not :class:`torch.Tensor` or :class:`np.ndarray`, then :attr:`data` 117 | is returned as is. 118 | 119 | Returns: 120 | :class:`bson.binary.Binary` or dict of :class:`bson.binary.Binary`: 121 | correctly formatted data to store in a MongoDB database. 122 | """ 123 | if isinstance(data, dict): 124 | return {k: data_to_mongo(v) for k, v in data.items()} 125 | if isinstance(data, torch.Tensor): 126 | bytes_data = io.BytesIO() 127 | torch.save(data, bytes_data) 128 | bytes_data.seek(0) 129 | binary = bson.binary.Binary(bytes_data.read()) 130 | return binary 131 | if isinstance(data, np.ndarray): 132 | return bson.binary.Binary(pickle.dumps(data, protocol=2), 133 | subtype=128) 134 | return data 135 | 136 | 137 | def data_from_mongo(mongo_data, map_location=None): 138 | """Decode data stored in a MongoDB database. 139 | 140 | Args: 141 | mongo_data (:class:`bson.binary.Binary` or dict): 142 | data to decode (if dict, items are recursively decoded). If 143 | the underlying data type is not `:class:torch.Tensor` or 144 | something stored using :mod:`pickle`, then :attr:`mongo_data` 145 | is returned as is. 146 | map_location (function, :class:`torch.device`, str or dict): where to 147 | remap storage locations (see :func:`torch.load` for more details). 148 | Default: ``None``. 149 | 150 | Returns: 151 | decoded data. 152 | """ 153 | 154 | if isinstance(mongo_data, dict): 155 | return {k: data_from_mongo(v) for k, v in mongo_data.items()} 156 | if isinstance(mongo_data, bson.binary.Binary): 157 | try: 158 | bytes_data = io.BytesIO(mongo_data) 159 | return torch.load(bytes_data, map_location=map_location) 160 | # If the underlying data is a numpy array, it throws a ValueError here. 161 | except Exception: 162 | pass 163 | try: 164 | return pickle.loads(mongo_data) 165 | except Exception: 166 | pass 167 | return mongo_data 168 | 169 | 170 | def last_lines(string, num_lines): 171 | """Extract the last few lines from a string. 172 | 173 | The function extracts the last attr:`n` lines from the string attr:`str`. 174 | If attr:`n` is a negative number, then it extracts the first lines 175 | instead. It also skips lines beginning with ``'Figure('``. 176 | 177 | Args: 178 | string (str): string. 179 | num_lines (int): number of lines to extract. 180 | 181 | Returns: 182 | str: substring. 183 | """ 184 | if string is None: 185 | return '' 186 | lines = string.strip().split('\n') 187 | lines = [l for l in lines if not l.startswith('Figure(')] 188 | if not lines: 189 | return '' 190 | if num_lines > 0: 191 | min_lines = min(num_lines, len(lines)) 192 | lines_ = lines[-min_lines:] 193 | if num_lines < len(lines): 194 | lines_ = ['[...]'] + lines_ 195 | if num_lines < 0: 196 | num_lines = -num_lines 197 | min_lines = min(num_lines, len(lines)) 198 | lines_ = lines[:min_lines] 199 | if num_lines < len(lines): 200 | lines_ = lines_ + ['[...]'] 201 | 202 | return '\n'.join(lines_) 203 | -------------------------------------------------------------------------------- /TorchRay/torchray/benchmark/pointing_game.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | r""" 4 | The :mod:`pointing_game` modules implements the pointing game benchmark. 5 | The basic benchmark is implemented by the :class:`PointingGame` class. However, 6 | for benchmarking purposes it is recommended to use the wrapper class 7 | :class:`PointingGameBenchmark` instead. This class supports *PASCAL VOC 2007 8 | test* and *COCO 2014 val* with the modifications used in [EBP]_, including the 9 | ability to run on their "difficult" subsets as defined in the original paper. 10 | 11 | The class can be used as follows: 12 | 13 | 1. Obtain a dataset (usually COCO, PASCAL VOC or ImageNet detection) and choose a subset. 14 | 2. Initialize an instance of :class:`PointingGameBenchmark`. 15 | 3. For each image in the dataset: 16 | 17 | 1. For each class in the image: 18 | 19 | 1. Run the attribution method, usually resulting in a saliency map for 20 | class :math:`c`. 21 | 2. Convert the result to a point, usually by finding the maximizer of the 22 | saliency map. 23 | 3. Use the :func:`PointingGameBenchmark.evaluate` function to run the 24 | test and accumulate the statistics. 25 | 4. Extract the :attr:`PointingGame.hits` and :attr:`PointingGame.misses` or 26 | ``print`` the instance to display the results. 27 | """ 28 | 29 | import torch 30 | from torchvision import datasets as ds 31 | import pdb 32 | import datasets as ads 33 | import numpy as np 34 | from PIL import Image 35 | 36 | 37 | class PointingGame: 38 | r"""Pointing game. 39 | 40 | Args: 41 | num_classes (int): number of classes in the dataset. 42 | tolerance (int, optional): tolerance (in pixels) of margin around 43 | ground truth annotation. Default: 15. 44 | 45 | Attributes: 46 | hits (:class:`torch.Tensor`): :attr:`num_classes`-dimensional vector of 47 | hits counts. 48 | misses (:class:`torch.Tensor`): :attr:`num_classes`-dimensional vector 49 | of misses counts. 50 | """ 51 | 52 | def __init__(self, num_classes, tolerance=15): 53 | assert isinstance(num_classes, int) 54 | assert isinstance(tolerance, int) 55 | self.num_classes = num_classes 56 | self.tolerance = tolerance 57 | self.hits = torch.zeros((num_classes,), dtype=torch.float64) 58 | self.misses = torch.zeros((num_classes,), dtype=torch.float64) 59 | 60 | def evaluate(self, mask, point): 61 | r"""Evaluate a point prediction. 62 | 63 | The function tests whether the prediction :attr:`point` is within a 64 | certain tolerance of the object ground-truth region :attr:`mask` 65 | expressed as a boolean occupancy map. 66 | 67 | Use the :func:`reset` method to clear all counters. 68 | 69 | Args: 70 | mask (:class:`torch.Tensor`): :math:`\{0,1\}^{H\times W}`. 71 | point (tuple of ints): predicted point :math:`(u,v)`. 72 | 73 | Returns: 74 | int: +1 if the point hits the object; otherwise -1. 75 | """ 76 | # Get an acceptance region around the point. There is a hit whenever 77 | # the acceptance region collides with the class mask. 78 | # pdb.set_trace() 79 | # v, u = torch.meshgrid(( 80 | # (torch.arange(mask.shape[0], 81 | # dtype=torch.float32) - point[1])**2, 82 | # (torch.arange(mask.shape[1], 83 | # dtype=torch.float32) - point[0])**2, 84 | # )) 85 | 86 | v, u = torch.meshgrid(( 87 | (torch.arange(mask.shape[0], 88 | dtype=torch.float32) - point[0]) ** 2, 89 | (torch.arange(mask.shape[1], 90 | dtype=torch.float32) - point[1]) ** 2, 91 | )) 92 | accept = (v + u) < self.tolerance**2 93 | 94 | # Test for a hit with the corresponding class. 95 | hit = (mask & accept).view(-1).any() 96 | 97 | # code to debug GT mask and accept mask 98 | # mask_np = mask.numpy().astype(np.uint8)*255 99 | # mask_im = Image.fromarray(mask_np) 100 | # mask_im.save('mask_coco_val_im_03.png') 101 | # accept_np = accept.numpy().astype(np.uint8) * 255 102 | # accept_im = Image.fromarray(accept_np) 103 | # accept_im.save('accept_mask_coco_val_03.png') 104 | 105 | return +1 if hit else -1 106 | 107 | def aggregate(self, hit, class_id): 108 | """Add pointing result from one example.""" 109 | if hit == 0: 110 | return 111 | if hit == 1: 112 | self.hits[class_id] += 1 113 | elif hit == -1: 114 | self.misses[class_id] += 1 115 | else: 116 | assert False 117 | 118 | def reset(self): 119 | """Reset hits and misses.""" 120 | self.hits = torch.zeros_like(self.hits) 121 | self.misses = torch.zeros_like(self.misses) 122 | 123 | @property 124 | def class_accuracies(self): 125 | """ 126 | (:class:`torch.Tensor`): :attr:`num_classes`-dimensional vector 127 | containing per-class accuracy. 128 | """ 129 | return self.hits / (self.hits + self.misses).clamp(min=1) 130 | 131 | @property 132 | def accuracy(self): 133 | """ 134 | (:class:`torch.Tensor`): mean accuracy, computed by averaging 135 | :attr:`class_accuracies`. 136 | """ 137 | return self.class_accuracies.mean() 138 | 139 | # def __str__(self): 140 | # class_accuracies = self.class_accuracies 141 | # return '{:4.1f}% ['.format(100 * class_accuracies.mean()) + " ".join([ 142 | # '{}:{:4.1f}%'.format(c, 100 * a) 143 | # for c, a in enumerate(class_accuracies) 144 | # ]) + ']' 145 | 146 | def __str__(self): 147 | class_accuracies = self.class_accuracies 148 | return 'Pointing game mean accuracy: {:4.1f}% '.format(100 * class_accuracies.mean()) 149 | 150 | 151 | class PointingGameBenchmark(PointingGame): 152 | """Pointing game benchmark on standard datasets. 153 | 154 | The pointing game should be initialized with a dataset, set to either: 155 | 156 | * (:class:`torchvision.VOCDetection`) VOC 2007 *test* subset. 157 | * (:class:`torchvision.CocoDetection`) COCO *val2014* subset. 158 | 159 | Args: 160 | dataset (:class:`torchvision.VisionDataset`): The dataset. 161 | tolerance (int): the tolerance for the pointing game. Default: ``15``. 162 | difficult (bool): whether to use the difficult subset. 163 | Default: ``False``. 164 | """ 165 | 166 | def __init__(self, dataset, tolerance=15, difficult=False): 167 | if isinstance(dataset, ds.VOCDetection): 168 | num_classes = 20 169 | elif isinstance(dataset, ads.ImageNetDetection): 170 | num_classes = 1000 171 | elif isinstance(dataset, ds.CocoDetection): 172 | num_classes = 80 173 | else: 174 | assert False, 'Only VOCDetection, ImageNetDetection and CocoDetection are supported.' 175 | 176 | super(PointingGameBenchmark, self).__init__( 177 | num_classes=num_classes, tolerance=tolerance) 178 | self.dataset = dataset 179 | self.difficult = difficult 180 | 181 | if difficult: 182 | def load_flags(name): 183 | try: 184 | import importlib.resources as res 185 | except ImportError: 186 | import importlib_resources as res 187 | with res.open_text('torchray.benchmark', name) as file: 188 | rows = [[x for x in row.split('\t')] for row in file] 189 | return { 190 | row[0]: [bool(int(x)) for x in row[1:]] 191 | for row in rows 192 | } 193 | if isinstance(self.dataset, ds.VOCDetection): 194 | self.difficult_flags = load_flags( 195 | 'pointing_game_ebp_voc07_difficult.txt') 196 | elif isinstance(self.dataset, ds.CocoDetection): 197 | self.difficult_flags = load_flags( 198 | 'pointing_game_ebp_coco_difficult.txt') 199 | else: 200 | assert False, 'Difficult set is supported only for VOC and COCO datasets respectively.' 201 | 202 | def evaluate(self, label, class_id, point): 203 | """Evaluate an label-class-point triplet. 204 | 205 | Args: 206 | label (dict): a label in VOC or Coco detection format. 207 | class_id (int): a class id. 208 | point (iterable): a point specified as a pair of u, v coordinates. 209 | 210 | Returns: 211 | int: +1 if the point hits the object, -1 if the point misses the 212 | object, and 0 if the point is skipped during evaluation. 213 | """ 214 | 215 | # Skip if testing on the EBP difficult subset and the image/class pair 216 | # is an easy one. 217 | if self.difficult: 218 | if isinstance(self.dataset, ds.VOCDetection): 219 | image_name = label['annotation']['filename'].split('.')[0] 220 | elif isinstance(self.dataset, ds.CocoDetection): 221 | image_id = label[0]['image_id'] 222 | image = self.dataset.coco.loadImgs(image_id)[0] 223 | image_name = image['file_name'].split('.')[0] 224 | else: 225 | assert False, 'Only VOC and COCO datasets are supported for the difficult subset' 226 | 227 | if image_name in self.difficult_flags: 228 | if not self.difficult_flags[image_name][class_id]: 229 | return 0 230 | 231 | # Get the mask for all occurrences of class_id. 232 | if isinstance(self.dataset, ds.VOCDetection): 233 | # Skip if all boxes for class_id are PASCAL difficult. 234 | objs = label['annotation']['object'] 235 | if not isinstance(objs, list): 236 | objs = [objs] 237 | objs = [obj for obj in objs if 238 | ads.VOC_CLASSES.index(obj['name']) == class_id 239 | ] 240 | if all([bool(int(obj['difficult'])) for obj in objs]): 241 | return 0 242 | mask = ads.voc_as_mask(label, class_id) 243 | 244 | elif isinstance(self.dataset, ads.ImageNetDetection): 245 | # Skip if all boxes for class_id are marked difficult. 246 | objs = label['annotation']['object'] 247 | if not isinstance(objs, list): 248 | objs = [objs] 249 | objs = [obj for obj in objs if 250 | ads._IMAGENET_CLASS_TO_INDEX[obj['name']] == class_id # we assume single name elem per object 251 | ] 252 | if all([bool(int(obj['difficult'])) for obj in objs]): 253 | return 0 254 | mask = ads.imagenet_as_mask(label, class_id) 255 | 256 | elif isinstance(self.dataset, ds.CocoDetection): 257 | mask = ads.coco_as_mask(self.dataset, label, class_id) 258 | 259 | assert mask is not None 260 | return super(PointingGameBenchmark, self).evaluate(mask, point) 261 | -------------------------------------------------------------------------------- /TorchRay/torchray/benchmark/server.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | r""" 4 | This module is used to start and run a MongoDB server. 5 | 6 | To start a MongoDB server, use 7 | 8 | .. code:: shell 9 | 10 | $ python -m torchray.benchmark.server 11 | 12 | """ 13 | import subprocess 14 | from torchray.utils import get_config 15 | 16 | 17 | def run_server(): 18 | """Runs an instance of MongoDB as a logging server.""" 19 | config = get_config() 20 | command = [ 21 | config['mongo']['server'], 22 | '--dbpath', config['mongo']['database'], 23 | '--bind_ip', config['mongo']['hostname'], 24 | '--port', str(config['mongo']['port']) 25 | ] 26 | print(f"Command: {' '.join(command)}.") 27 | code = subprocess.call(command, cwd=".") 28 | print(f"Return code {code}") 29 | 30 | 31 | if __name__ == '__main__': 32 | run_server() 33 | -------------------------------------------------------------------------------- /TorchRay/torchray/benchmark/swav_resnet_cgc.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | # 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 10 | """3x3 convolution with padding""" 11 | return nn.Conv2d( 12 | in_planes, 13 | out_planes, 14 | kernel_size=3, 15 | stride=stride, 16 | padding=dilation, 17 | groups=groups, 18 | bias=False, 19 | dilation=dilation, 20 | ) 21 | 22 | 23 | def conv1x1(in_planes, out_planes, stride=1): 24 | """1x1 convolution""" 25 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | expansion = 1 30 | __constants__ = ["downsample"] 31 | 32 | def __init__( 33 | self, 34 | inplanes, 35 | planes, 36 | stride=1, 37 | downsample=None, 38 | groups=1, 39 | base_width=64, 40 | dilation=1, 41 | norm_layer=None, 42 | ): 43 | super(BasicBlock, self).__init__() 44 | if norm_layer is None: 45 | norm_layer = nn.BatchNorm2d 46 | if groups != 1 or base_width != 64: 47 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 48 | if dilation > 1: 49 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 50 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 51 | self.conv1 = conv3x3(inplanes, planes, stride) 52 | self.bn1 = norm_layer(planes) 53 | self.relu = nn.ReLU(inplace=True) 54 | self.conv2 = conv3x3(planes, planes) 55 | self.bn2 = norm_layer(planes) 56 | self.downsample = downsample 57 | self.stride = stride 58 | 59 | def forward(self, x): 60 | identity = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | out = self.conv2(out) 67 | out = self.bn2(out) 68 | 69 | if self.downsample is not None: 70 | identity = self.downsample(x) 71 | 72 | out += identity 73 | out = self.relu(out) 74 | 75 | return out 76 | 77 | 78 | class Bottleneck(nn.Module): 79 | expansion = 4 80 | __constants__ = ["downsample"] 81 | 82 | def __init__( 83 | self, 84 | inplanes, 85 | planes, 86 | stride=1, 87 | downsample=None, 88 | groups=1, 89 | base_width=64, 90 | dilation=1, 91 | norm_layer=None, 92 | ): 93 | super(Bottleneck, self).__init__() 94 | if norm_layer is None: 95 | norm_layer = nn.BatchNorm2d 96 | width = int(planes * (base_width / 64.0)) * groups 97 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 98 | self.conv1 = conv1x1(inplanes, width) 99 | self.bn1 = norm_layer(width) 100 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 101 | self.bn2 = norm_layer(width) 102 | self.conv3 = conv1x1(width, planes * self.expansion) 103 | self.bn3 = norm_layer(planes * self.expansion) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.downsample = downsample 106 | self.stride = stride 107 | 108 | def forward(self, x): 109 | identity = x 110 | 111 | out = self.conv1(x) 112 | out = self.bn1(out) 113 | out = self.relu(out) 114 | 115 | out = self.conv2(out) 116 | out = self.bn2(out) 117 | out = self.relu(out) 118 | 119 | out = self.conv3(out) 120 | out = self.bn3(out) 121 | 122 | if self.downsample is not None: 123 | identity = self.downsample(x) 124 | 125 | out += identity 126 | out = self.relu(out) 127 | 128 | return out 129 | 130 | 131 | def normalize(x): 132 | return x / x.norm(2, dim=1, keepdim=True) 133 | 134 | 135 | class ResNet(nn.Module): 136 | def __init__( 137 | self, 138 | block, 139 | layers, 140 | zero_init_residual=False, 141 | groups=1, 142 | widen=1, 143 | width_per_group=64, 144 | replace_stride_with_dilation=None, 145 | norm_layer=None, 146 | normalize=False, 147 | output_dim=0, 148 | hidden_mlp=0, 149 | nmb_prototypes=0, 150 | eval_mode=False, 151 | ): 152 | super(ResNet, self).__init__() 153 | if norm_layer is None: 154 | norm_layer = nn.BatchNorm2d 155 | self._norm_layer = norm_layer 156 | 157 | self.eval_mode = eval_mode 158 | self.padding = nn.ConstantPad2d(1, 0.0) 159 | 160 | self.inplanes = width_per_group * widen 161 | self.dilation = 1 162 | if replace_stride_with_dilation is None: 163 | # each element in the tuple indicates if we should replace 164 | # the 2x2 stride with a dilated convolution instead 165 | replace_stride_with_dilation = [False, False, False] 166 | if len(replace_stride_with_dilation) != 3: 167 | raise ValueError( 168 | "replace_stride_with_dilation should be None " 169 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation) 170 | ) 171 | self.groups = groups 172 | self.base_width = width_per_group 173 | 174 | # change padding 3 -> 2 compared to original torchvision code because added a padding layer 175 | num_out_filters = width_per_group * widen 176 | self.conv1 = nn.Conv2d( 177 | 3, num_out_filters, kernel_size=7, stride=2, padding=2, bias=False 178 | ) 179 | self.bn1 = norm_layer(num_out_filters) 180 | self.relu = nn.ReLU(inplace=True) 181 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 182 | self.layer1 = self._make_layer(block, num_out_filters, layers[0]) 183 | num_out_filters *= 2 184 | self.layer2 = self._make_layer( 185 | block, num_out_filters, layers[1], stride=2, dilate=replace_stride_with_dilation[0] 186 | ) 187 | num_out_filters *= 2 188 | self.layer3 = self._make_layer( 189 | block, num_out_filters, layers[2], stride=2, dilate=replace_stride_with_dilation[1] 190 | ) 191 | num_out_filters *= 2 192 | self.layer4 = self._make_layer( 193 | block, num_out_filters, layers[3], stride=2, dilate=replace_stride_with_dilation[2] 194 | ) 195 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 196 | 197 | # normalize output features 198 | self.l2norm = normalize 199 | self.upsample = nn.Upsample((7, 7), mode='bilinear') 200 | self.eps = 1e-8 201 | 202 | # projection head 203 | if output_dim == 0: 204 | self.projection_head = None 205 | elif hidden_mlp == 0: 206 | self.projection_head = nn.Linear(num_out_filters * block.expansion, output_dim) 207 | else: 208 | self.projection_head = nn.Sequential( 209 | nn.Linear(num_out_filters * block.expansion, hidden_mlp), 210 | nn.BatchNorm1d(hidden_mlp), 211 | nn.ReLU(inplace=True), 212 | nn.Linear(hidden_mlp, output_dim), 213 | ) 214 | 215 | # prototype layer 216 | self.prototypes = None 217 | if isinstance(nmb_prototypes, list): 218 | self.prototypes = MultiPrototypes(output_dim, nmb_prototypes) 219 | elif nmb_prototypes > 0: 220 | self.prototypes = nn.Linear(output_dim, nmb_prototypes, bias=False) 221 | 222 | for m in self.modules(): 223 | if isinstance(m, nn.Conv2d): 224 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 225 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 226 | nn.init.constant_(m.weight, 1) 227 | nn.init.constant_(m.bias, 0) 228 | 229 | # Zero-initialize the last BN in each residual branch, 230 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 231 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 232 | if zero_init_residual: 233 | for m in self.modules(): 234 | if isinstance(m, Bottleneck): 235 | nn.init.constant_(m.bn3.weight, 0) 236 | elif isinstance(m, BasicBlock): 237 | nn.init.constant_(m.bn2.weight, 0) 238 | 239 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 240 | norm_layer = self._norm_layer 241 | downsample = None 242 | previous_dilation = self.dilation 243 | if dilate: 244 | self.dilation *= stride 245 | stride = 1 246 | if stride != 1 or self.inplanes != planes * block.expansion: 247 | downsample = nn.Sequential( 248 | conv1x1(self.inplanes, planes * block.expansion, stride), 249 | norm_layer(planes * block.expansion), 250 | ) 251 | 252 | layers = [] 253 | layers.append( 254 | block( 255 | self.inplanes, 256 | planes, 257 | stride, 258 | downsample, 259 | self.groups, 260 | self.base_width, 261 | previous_dilation, 262 | norm_layer, 263 | ) 264 | ) 265 | self.inplanes = planes * block.expansion 266 | for _ in range(1, blocks): 267 | layers.append( 268 | block( 269 | self.inplanes, 270 | planes, 271 | groups=self.groups, 272 | base_width=self.base_width, 273 | dilation=self.dilation, 274 | norm_layer=norm_layer, 275 | ) 276 | ) 277 | 278 | return nn.Sequential(*layers) 279 | 280 | def forward_backbone(self, x): 281 | x = self.padding(x) 282 | 283 | x = self.conv1(x) 284 | x = self.bn1(x) 285 | x = self.relu(x) 286 | x = self.maxpool(x) 287 | x = self.layer1(x) 288 | x = self.layer2(x) 289 | x = self.layer3(x) 290 | feats = self.layer4(x) 291 | 292 | if self.eval_mode: 293 | return feats 294 | 295 | x = self.avgpool(feats) 296 | x = torch.flatten(x, 1) 297 | 298 | return x, feats 299 | 300 | def forward_head(self, x): 301 | if self.projection_head is not None: 302 | x = self.projection_head(x) 303 | 304 | if self.l2norm: 305 | x = nn.functional.normalize(x, dim=1, p=2) 306 | 307 | if self.prototypes is not None: 308 | return x, self.prototypes(x) 309 | return x 310 | 311 | def _forward_vanilla(self, inputs, return_feats=False): 312 | x, feats = self.forward_backbone(inputs) 313 | x = self.forward_head(x) 314 | if return_feats: 315 | return x, feats 316 | else: 317 | return x 318 | 319 | def _forward(self, lbl_images, vanilla=False, vanilla_with_feats=False): 320 | """ 321 | :param lbl_images: Labeled images to be used for computing logits/feats 322 | :param vanilla: If True, return the outputs from a regular forward pass through the model 323 | :param vanilla_with_feats: If True, return the feats as well as outputs 324 | :return: 325 | """ 326 | if vanilla: 327 | return self.forward_vanilla(lbl_images) 328 | if vanilla_with_feats: 329 | return self.forward_vanilla(lbl_images, return_feats=True) 330 | 331 | # Allow for accessing forward method in a inherited class 332 | forward = _forward 333 | forward_vanilla = _forward_vanilla 334 | 335 | 336 | class MultiPrototypes(nn.Module): 337 | def __init__(self, output_dim, nmb_prototypes): 338 | super(MultiPrototypes, self).__init__() 339 | self.nmb_heads = len(nmb_prototypes) 340 | for i, k in enumerate(nmb_prototypes): 341 | self.add_module("prototypes" + str(i), nn.Linear(output_dim, k, bias=False)) 342 | 343 | def forward(self, x): 344 | out = [] 345 | for i in range(self.nmb_heads): 346 | out.append(getattr(self, "prototypes" + str(i))(x)) 347 | return out 348 | 349 | 350 | def resnet50(**kwargs): 351 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 352 | 353 | 354 | def resnet50w2(**kwargs): 355 | return ResNet(Bottleneck, [3, 4, 6, 3], widen=2, **kwargs) 356 | 357 | 358 | def resnet50w4(**kwargs): 359 | return ResNet(Bottleneck, [3, 4, 6, 3], widen=4, **kwargs) 360 | 361 | 362 | def resnet50w5(**kwargs): 363 | return ResNet(Bottleneck, [3, 4, 6, 3], widen=5, **kwargs) -------------------------------------------------------------------------------- /TorchRay/torchray/benchmark/vision.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | 5 | 6 | class VisionDataset(data.Dataset): 7 | _repr_indent = 4 8 | 9 | def __init__(self, root, transforms=None, transform=None, target_transform=None): 10 | if isinstance(root, torch._six.string_classes): 11 | root = os.path.expanduser(root) 12 | self.root = root 13 | 14 | has_transforms = transforms is not None 15 | has_separate_transform = transform is not None or target_transform is not None 16 | if has_transforms and has_separate_transform: 17 | raise ValueError("Only transforms or transform/target_transform can " 18 | "be passed as argument") 19 | 20 | # for backwards-compatibility 21 | self.transform = transform 22 | self.target_transform = target_transform 23 | 24 | if has_separate_transform: 25 | transforms = StandardTransform(transform, target_transform) 26 | self.transforms = transforms 27 | 28 | def __getitem__(self, index): 29 | raise NotImplementedError 30 | 31 | def __len__(self): 32 | raise NotImplementedError 33 | 34 | def __repr__(self): 35 | head = "Dataset " + self.__class__.__name__ 36 | body = ["Number of datapoints: {}".format(self.__len__())] 37 | if self.root is not None: 38 | body.append("Root location: {}".format(self.root)) 39 | body += self.extra_repr().splitlines() 40 | if hasattr(self, "transforms") and self.transforms is not None: 41 | body += [repr(self.transforms)] 42 | lines = [head] + [" " * self._repr_indent + line for line in body] 43 | return '\n'.join(lines) 44 | 45 | def _format_transform_repr(self, transform, head): 46 | lines = transform.__repr__().splitlines() 47 | return (["{}{}".format(head, lines[0])] + 48 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 49 | 50 | def extra_repr(self): 51 | return "" 52 | 53 | 54 | class StandardTransform(object): 55 | def __init__(self, transform=None, target_transform=None): 56 | self.transform = transform 57 | self.target_transform = target_transform 58 | 59 | def __call__(self, input, target): 60 | if self.transform is not None: 61 | input = self.transform(input) 62 | if self.target_transform is not None: 63 | target = self.target_transform(target) 64 | return input, target 65 | 66 | def _format_transform_repr(self, transform, head): 67 | lines = transform.__repr__().splitlines() 68 | return (["{}{}".format(head, lines[0])] + 69 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 70 | 71 | def __repr__(self): 72 | body = [self.__class__.__name__] 73 | if self.transform is not None: 74 | body += self._format_transform_repr(self.transform, 75 | "Transform: ") 76 | if self.target_transform is not None: 77 | body += self._format_transform_repr(self.target_transform, 78 | "Target transform: ") 79 | 80 | return '\n'.join(body) 81 | -------------------------------------------------------------------------------- /datasets/imagefolder_cgc_ssl.py: -------------------------------------------------------------------------------- 1 | from .vision import VisionDataset 2 | 3 | from PIL import Image 4 | import random 5 | import os 6 | import os.path 7 | import torchvision 8 | from torchvision import transforms 9 | from torchvision.transforms import functional as tvf 10 | 11 | 12 | def has_file_allowed_extension(filename, extensions): 13 | """Checks if a file is an allowed extension. 14 | 15 | Args: 16 | filename (string): path to a file 17 | extensions (tuple of strings): extensions to consider (lowercase) 18 | 19 | Returns: 20 | bool: True if the filename ends with one of given extensions 21 | """ 22 | return filename.lower().endswith(extensions) 23 | 24 | 25 | def is_image_file(filename): 26 | """Checks if a file is an allowed image extension. 27 | 28 | Args: 29 | filename (string): path to a file 30 | 31 | Returns: 32 | bool: True if the filename ends with a known image extension 33 | """ 34 | return has_file_allowed_extension(filename, IMG_EXTENSIONS) 35 | 36 | 37 | def make_dataset(directory, class_to_idx, extensions=None, is_valid_file=None): 38 | instances = [] 39 | directory = os.path.expanduser(directory) 40 | both_none = extensions is None and is_valid_file is None 41 | both_something = extensions is not None and is_valid_file is not None 42 | if both_none or both_something: 43 | raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") 44 | if extensions is not None: 45 | def is_valid_file(x): 46 | return has_file_allowed_extension(x, extensions) 47 | for target_class in sorted(class_to_idx.keys()): 48 | class_index = class_to_idx[target_class] 49 | target_dir = os.path.join(directory, target_class) 50 | if not os.path.isdir(target_dir): 51 | continue 52 | for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): 53 | for fname in sorted(fnames): 54 | path = os.path.join(root, fname) 55 | if is_valid_file(path): 56 | item = path, class_index 57 | instances.append(item) 58 | return instances 59 | 60 | 61 | class DatasetFolder(VisionDataset): 62 | """A generic data loader where the samples are arranged in this way: :: 63 | 64 | root/class_x/xxx.ext 65 | root/class_x/xxy.ext 66 | root/class_x/xxz.ext 67 | 68 | root/class_y/123.ext 69 | root/class_y/nsdf3.ext 70 | root/class_y/asd932_.ext 71 | 72 | Args: 73 | root (string): Root directory path. 74 | loader (callable): A function to load a sample given its path. 75 | extensions (tuple[string]): A list of allowed extensions. 76 | both extensions and is_valid_file should not be passed. 77 | transform (callable, optional): A function/transform that takes in 78 | a sample and returns a transformed version. 79 | E.g, ``transforms.RandomCrop`` for images. 80 | target_transform (callable, optional): A function/transform that takes 81 | in the target and transforms it. 82 | is_valid_file (callable, optional): A function that takes path of a file 83 | and check if the file is a valid file (used to check of corrupt files) 84 | both extensions and is_valid_file should not be passed. 85 | 86 | Attributes: 87 | classes (list): List of the class names sorted alphabetically. 88 | class_to_idx (dict): Dict with items (class_name, class_index). 89 | samples (list): List of (sample path, class_index) tuples 90 | targets (list): The class_index value for each image in the dataset 91 | """ 92 | 93 | def __init__(self, root, loader, extensions=None, transform=None, 94 | target_transform=None, is_valid_file=None): 95 | super(DatasetFolder, self).__init__(root, transform=transform, 96 | target_transform=target_transform) 97 | classes, class_to_idx = self._find_classes(self.root) 98 | samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file) 99 | if len(samples) == 0: 100 | msg = "Found 0 files in subfolders of: {}\n".format(self.root) 101 | if extensions is not None: 102 | msg += "Supported extensions are: {}".format(",".join(extensions)) 103 | raise RuntimeError(msg) 104 | 105 | self.loader = loader 106 | self.extensions = extensions 107 | 108 | self.classes = classes 109 | self.class_to_idx = class_to_idx 110 | self.samples = samples 111 | self.targets = [s[1] for s in samples] 112 | self.hor_flip = transforms.RandomHorizontalFlip() 113 | self.to_tensor = transforms.ToTensor() 114 | self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 115 | 116 | 117 | def _find_classes(self, dir): 118 | """ 119 | Finds the class folders in a dataset. 120 | 121 | Args: 122 | dir (string): Root directory path. 123 | 124 | Returns: 125 | tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. 126 | 127 | Ensures: 128 | No class is a subdirectory of another. 129 | """ 130 | classes = [d.name for d in os.scandir(dir) if d.is_dir()] 131 | classes.sort() 132 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} 133 | return classes, class_to_idx 134 | 135 | def __getitem__(self, index): 136 | """ 137 | Args: 138 | index (int): Index 139 | 140 | Returns: 141 | tuple: (sample, target) where target is class_index of the target class. 142 | """ 143 | path, target = self.samples[index] 144 | sample = self.loader(path) 145 | 146 | # we first obtain an image with the regular augmentations used for cross-entropy loss 147 | xe_sample = transforms.RandomResizedCrop(224)(sample) 148 | xe_sample = transforms.RandomHorizontalFlip()(xe_sample) 149 | 150 | # next, we obtain the pair of image and augmented image to be used for CGC loss 151 | # We resize the image to have 256 on the smallest side and take center crop 152 | sample = transforms.Resize(256)(sample) 153 | sample = transforms.CenterCrop(224)(sample) 154 | 155 | # We manually apply the transformations. 156 | # Namely, RandomResizedCrop()->horizontal_flip(0.5)->Totensor()->Normalize() 157 | i, j, h, w = torchvision.transforms.RandomResizedCrop.get_params(sample, scale=(0.08, 1.0), 158 | ratio=(0.75, 1.3333333333333333)) 159 | aug_sample = tvf.resized_crop(sample, i, j, h, w, size=(224, 224)) 160 | 161 | hor_flip = False 162 | if random.random() > 0.5: 163 | aug_sample = self.hor_flip(aug_sample) 164 | hor_flip = True 165 | aug_sample = self.to_tensor(aug_sample) 166 | aug_sample = self.normalize(aug_sample) 167 | 168 | sample = self.to_tensor(sample) 169 | sample = self.normalize(sample) 170 | 171 | xe_sample = self.to_tensor(xe_sample) 172 | xe_sample = self.normalize(xe_sample) 173 | 174 | return xe_sample , sample, aug_sample, i, j, h, w, hor_flip, target 175 | 176 | def __len__(self): 177 | return len(self.samples) 178 | 179 | 180 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') 181 | 182 | 183 | def pil_loader(path): 184 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 185 | with open(path, 'rb') as f: 186 | img = Image.open(f) 187 | return img.convert('RGB') 188 | 189 | 190 | def accimage_loader(path): 191 | import accimage 192 | try: 193 | return accimage.Image(path) 194 | except IOError: 195 | # Potentially a decoding problem, fall back to PIL.Image 196 | return pil_loader(path) 197 | 198 | 199 | def default_loader(path): 200 | from torchvision import get_image_backend 201 | if get_image_backend() == 'accimage': 202 | return accimage_loader(path) 203 | else: 204 | return pil_loader(path) 205 | 206 | 207 | class ImageFolder(DatasetFolder): 208 | """A generic data loader where the images are arranged in this way: :: 209 | 210 | root/dog/xxx.png 211 | root/dog/xxy.png 212 | root/dog/xxz.png 213 | 214 | root/cat/123.png 215 | root/cat/nsdf3.png 216 | root/cat/asd932_.png 217 | 218 | Args: 219 | root (string): Root directory path. 220 | transform (callable, optional): A function/transform that takes in an PIL image 221 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 222 | target_transform (callable, optional): A function/transform that takes in the 223 | target and transforms it. 224 | loader (callable, optional): A function to load an image given its path. 225 | is_valid_file (callable, optional): A function that takes path of an Image file 226 | and check if the file is a valid file (used to check of corrupt files) 227 | 228 | Attributes: 229 | classes (list): List of the class names sorted alphabetically. 230 | class_to_idx (dict): Dict with items (class_name, class_index). 231 | imgs (list): List of (image path, class_index) tuples 232 | """ 233 | 234 | def __init__(self, root, transform=None, target_transform=None, 235 | loader=default_loader, is_valid_file=None): 236 | super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None, 237 | transform=transform, 238 | target_transform=target_transform, 239 | is_valid_file=is_valid_file) 240 | self.imgs = self.samples -------------------------------------------------------------------------------- /datasets/vision.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | 5 | 6 | class VisionDataset(data.Dataset): 7 | _repr_indent = 4 8 | 9 | def __init__(self, root, transforms=None, transform=None, target_transform=None): 10 | if isinstance(root, torch._six.string_classes): 11 | root = os.path.expanduser(root) 12 | self.root = root 13 | 14 | has_transforms = transforms is not None 15 | has_separate_transform = transform is not None or target_transform is not None 16 | if has_transforms and has_separate_transform: 17 | raise ValueError("Only transforms or transform/target_transform can " 18 | "be passed as argument") 19 | 20 | # for backwards-compatibility 21 | self.transform = transform 22 | self.target_transform = target_transform 23 | 24 | if has_separate_transform: 25 | transforms = StandardTransform(transform, target_transform) 26 | self.transforms = transforms 27 | 28 | def __getitem__(self, index): 29 | raise NotImplementedError 30 | 31 | def __len__(self): 32 | raise NotImplementedError 33 | 34 | def __repr__(self): 35 | head = "Dataset " + self.__class__.__name__ 36 | body = ["Number of datapoints: {}".format(self.__len__())] 37 | if self.root is not None: 38 | body.append("Root location: {}".format(self.root)) 39 | body += self.extra_repr().splitlines() 40 | if hasattr(self, "transforms") and self.transforms is not None: 41 | body += [repr(self.transforms)] 42 | lines = [head] + [" " * self._repr_indent + line for line in body] 43 | return '\n'.join(lines) 44 | 45 | def _format_transform_repr(self, transform, head): 46 | lines = transform.__repr__().splitlines() 47 | return (["{}{}".format(head, lines[0])] + 48 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 49 | 50 | def extra_repr(self): 51 | return "" 52 | 53 | 54 | class StandardTransform(object): 55 | def __init__(self, transform=None, target_transform=None): 56 | self.transform = transform 57 | self.target_transform = target_transform 58 | 59 | def __call__(self, input, target): 60 | if self.transform is not None: 61 | input = self.transform(input) 62 | if self.target_transform is not None: 63 | target = self.target_transform(target) 64 | return input, target 65 | 66 | def _format_transform_repr(self, transform, head): 67 | lines = transform.__repr__().splitlines() 68 | return (["{}{}".format(head, lines[0])] + 69 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 70 | 71 | def __repr__(self): 72 | body = [self.__class__.__name__] 73 | if self.transform is not None: 74 | body += self._format_transform_repr(self.transform, 75 | "Transform: ") 76 | if self.target_transform is not None: 77 | body += self._format_transform_repr(self.target_transform, 78 | "Target transform: ") 79 | 80 | return '\n'.join(body) 81 | -------------------------------------------------------------------------------- /misc/teaser_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCDvision/CGC/a66d87240863e19cc43d11c9e715ca447614042d/misc/teaser_image.png -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def compute_gradcam(output, feats , target, relu): 7 | """ 8 | Compute the gradcam for the top predicted category 9 | :param output: the model output before softmax 10 | :param feats: the feature output from the desired layer to be used for computing Grad-CAM 11 | :param target: The target category to be used for computing Grad-CAM 12 | :return: 13 | """ 14 | 15 | one_hot = np.zeros((output.shape[0], output.shape[-1]), dtype=np.float32) 16 | indices_range = np.arange(output.shape[0]) 17 | one_hot[indices_range, target[indices_range]] = 1 18 | one_hot = torch.from_numpy(one_hot) 19 | one_hot.requires_grad = True 20 | 21 | # Compute the Grad-CAM for the original image 22 | one_hot_cuda = torch.sum(one_hot.cuda() * output) 23 | dy_dz1, = torch.autograd.grad(one_hot_cuda, feats, grad_outputs=torch.ones(one_hot_cuda.size()).cuda(), 24 | retain_graph=True, create_graph=True) 25 | # Changing to dot product of grad and features to preserve grad spatial locations 26 | gcam512_1 = dy_dz1 * feats 27 | gradcam = gcam512_1.sum(dim=1) 28 | gradcam = relu(gradcam) 29 | 30 | return gradcam 31 | 32 | 33 | def compute_gradcam_mask(images_outputs, images_feats , target, relu): 34 | """ 35 | This function computes the grad-cam, upsamples it to the image size and normalizes the Grad-CAM mask. 36 | """ 37 | eps = 1e-8 38 | gradcam_mask = compute_gradcam(images_outputs, images_feats , target, relu) 39 | gradcam_mask = gradcam_mask.unsqueeze(1) 40 | gradcam_mask = F.interpolate(gradcam_mask, size=224, mode='bilinear') 41 | gradcam_mask = gradcam_mask.squeeze() 42 | # normalize the gradcam mask to sum to 1 43 | gradcam_mask_sum = gradcam_mask.sum(dim=[1, 2]).unsqueeze(-1).unsqueeze(-1) 44 | gradcam_mask = (gradcam_mask / (gradcam_mask_sum + eps)) + eps 45 | 46 | return gradcam_mask 47 | 48 | 49 | def perform_gradcam_aug(orig_gradcam_mask, aug_params_dict): 50 | """ 51 | This function uses the augmentation params per batch element and manually applies to the 52 | grad-cam mask to obtain the corresponding augmented grad-cam mask. 53 | """ 54 | transforms_i = aug_params_dict['transforms_i'] 55 | transforms_j = aug_params_dict['transforms_j'] 56 | transforms_h = aug_params_dict['transforms_h'] 57 | transforms_w = aug_params_dict['transforms_w'] 58 | hor_flip = aug_params_dict['hor_flip'] 59 | gpu_batch_len = transforms_i.shape[0] 60 | augmented_orig_gradcam_mask = torch.zeros_like(orig_gradcam_mask).cuda() 61 | for b in range(gpu_batch_len): 62 | # convert orig_gradcam_mask to image 63 | orig_gcam = orig_gradcam_mask[b] 64 | orig_gcam = orig_gcam[transforms_i[b]: transforms_i[b] + transforms_h[b], 65 | transforms_j[b]: transforms_j[b] + transforms_w[b]] 66 | # We use torch functional to resize without breaking the graph 67 | orig_gcam = orig_gcam.unsqueeze(0).unsqueeze(0) 68 | orig_gcam = F.interpolate(orig_gcam, size=224, mode='bilinear') 69 | orig_gcam = orig_gcam.squeeze() 70 | if hor_flip[b]: 71 | orig_gcam = orig_gcam.flip(-1) 72 | augmented_orig_gradcam_mask[b, :, :] = orig_gcam[:, :] 73 | return augmented_orig_gradcam_mask --------------------------------------------------------------------------------