├── .gitignore ├── LICENSE ├── README.md ├── dev-docs ├── docs │ └── index.md └── mkdocs.yml ├── examples ├── GuidedBackProReLUModel.py ├── VanillaBackProModel.py ├── cam.jpg ├── grads.jpg └── imgs │ ├── apple.JPEG │ ├── bird.JPEG │ ├── cat_dog.png │ ├── dd_tree.jpg │ ├── dog.jpg │ ├── eel.JPEG │ ├── snake.jpg │ └── spider.png ├── tests ├── activations │ ├── __init__.py │ ├── test_deconvrelu.py │ └── test_guidedReLU.py └── test.py └── vis ├── activations.py ├── algs.py ├── extractor.py ├── losses.py ├── test.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # dev-env 2 | .idea/ 3 | *.iml 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Albert Sun 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Visualization Toolkit for Pytorch 2 | Pytorch-vis is a a neural network visualization toolkit for pytorch, which aims to provide easy and effective ways to visualize the trained models in pytorch. Pytorch-vis can be used seamlessly with pytorch, so you can visualize and have a deep insight into the trained model without pain. Please go to [documentation](https://sunalbert.github.io/pytorch-vis/) for details. 3 | 4 | 5 | # TODOs 6 | Now pytorch-vis is still in development and will support following visualization techniques: 7 | - Silency maps 8 | - Class activation maps 9 | - SmoothGrad 10 | - Activation maximization 11 | - ... 12 | 13 | # Related Projects 14 | pytorch-vis is inspired by following greatful projects: 15 | - keras-vis https://raghakot.github.io/keras-vis 16 | - quiver https://keplr-io.github.io/quiver/ 17 | - pytorch-cnn-visualization https://github.com/utkuozbulak/pytorch-cnn-visualizations 18 | Thanks to the authors of these projects. 19 | 20 | # License 21 | This project is under MIT LICENSE 22 | -------------------------------------------------------------------------------- /dev-docs/docs/index.md: -------------------------------------------------------------------------------- 1 | # Welcome to MkDocs 2 | 3 | For full documentation visit [mkdocs.org](http://mkdocs.org). 4 | 5 | ## Commands 6 | 7 | * `mkdocs new [dir-name]` - Create a new project. 8 | * `mkdocs serve` - Start the live-reloading docs server. 9 | * `mkdocs build` - Build the documentation site. 10 | * `mkdocs help` - Print this help message. 11 | 12 | ## Project layout 13 | 14 | mkdocs.yml # The configuration file. 15 | docs/ 16 | index.md # The documentation homepage. 17 | ... # Other markdown pages, images and other files. 18 | -------------------------------------------------------------------------------- /dev-docs/mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: My Docs 2 | -------------------------------------------------------------------------------- /examples/GuidedBackProReLUModel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | import torchvision.models as models 6 | from torch.autograd import Variable 7 | 8 | from vis.utils import save_cam_img, save_grad_img 9 | from vis.algs import GuidedBackProReLUModel 10 | 11 | 12 | os.environ['CUDA_VISIBLE_DEVICES'] = '7' 13 | 14 | img_path = './imgs/cat_dog.png' 15 | 16 | def preprocess(img): 17 | """ 18 | Apply preprocess to input 19 | img 20 | :param processed_img: input image 21 | :return: torch tensor wrapping a img 22 | """ 23 | means = np.asarray([0.485, 0.456, 0.406]) 24 | stds = np.asarray([0.229, 0.224, 0.225]) 25 | 26 | processed_img = img.copy()[:,:,::-1] 27 | processed_img -= means 28 | processed_img /= stds 29 | 30 | processed_img = np.transpose(processed_img, (2, 0, 1)) 31 | processed_img = np.ascontiguousarray(processed_img) 32 | img_tensor = torch.from_numpy(processed_img) 33 | img_tensor.unsqueeze_(0) 34 | return img_tensor 35 | 36 | 37 | def visualization(): 38 | vis_model = GuidedBackProReLUModel(models.vgg19(pretrained=True), use_cuda=True) 39 | img = cv2.imread(img_path) 40 | img = np.float32(cv2.resize(img, (224, 224))) / 255 41 | input_img = preprocess(img) 42 | input_img = Variable(input_img, requires_grad=True) 43 | mask = vis_model(input_img) 44 | save_grad_img(mask, 'grads.jpg') 45 | save_cam_img(img, mask, 'cam.jpg') 46 | 47 | 48 | if __name__ == '__main__': 49 | visualization() -------------------------------------------------------------------------------- /examples/VanillaBackProModel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | from torch.autograd import Variable 6 | import torchvision.models as models 7 | from vis.algs import VanillaBackProModel 8 | from vis.utils import save_cam_img, save_grad_img 9 | 10 | os.environ['CUDA_VISIBLE_DEVICES']='7' 11 | 12 | img_path = './imgs/cat_dog.png' 13 | 14 | 15 | def preprocess(img): 16 | """ 17 | Apply preprocess to input 18 | img 19 | :param processed_img: input image 20 | :return: torch tensor wrapping a img 21 | """ 22 | means = np.asarray([0.485, 0.456, 0.406]) 23 | stds = np.asarray([0.229, 0.224, 0.225]) 24 | 25 | processed_img = img.copy()[:,:,::-1] 26 | processed_img -= means 27 | processed_img /= stds 28 | 29 | processed_img = np.transpose(processed_img, (2, 0, 1)) 30 | processed_img = np.ascontiguousarray(processed_img) 31 | img_tensor = torch.from_numpy(processed_img) 32 | img_tensor.unsqueeze_(0) 33 | return img_tensor 34 | 35 | 36 | def visualization(): 37 | vis_model = VanillaBackProModel(models.vgg19(pretrained=True), use_cuda=True) 38 | img = cv2.imread(img_path) 39 | img = np.float32(cv2.resize(img, (224, 224))) / 255 40 | input_img = preprocess(img) 41 | input_img = Variable(input_img, requires_grad=True) 42 | mask = vis_model(input_img) 43 | save_grad_img(mask, 'grads.jpg') 44 | save_cam_img(img, mask, 'cam.jpg') 45 | 46 | 47 | if __name__ == '__main__': 48 | visualization() 49 | -------------------------------------------------------------------------------- /examples/cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunalbert/lucid.pytorch/1bcc87a41c99bef1d64d37116c8a2440d11b0def/examples/cam.jpg -------------------------------------------------------------------------------- /examples/grads.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunalbert/lucid.pytorch/1bcc87a41c99bef1d64d37116c8a2440d11b0def/examples/grads.jpg -------------------------------------------------------------------------------- /examples/imgs/apple.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunalbert/lucid.pytorch/1bcc87a41c99bef1d64d37116c8a2440d11b0def/examples/imgs/apple.JPEG -------------------------------------------------------------------------------- /examples/imgs/bird.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunalbert/lucid.pytorch/1bcc87a41c99bef1d64d37116c8a2440d11b0def/examples/imgs/bird.JPEG -------------------------------------------------------------------------------- /examples/imgs/cat_dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunalbert/lucid.pytorch/1bcc87a41c99bef1d64d37116c8a2440d11b0def/examples/imgs/cat_dog.png -------------------------------------------------------------------------------- /examples/imgs/dd_tree.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunalbert/lucid.pytorch/1bcc87a41c99bef1d64d37116c8a2440d11b0def/examples/imgs/dd_tree.jpg -------------------------------------------------------------------------------- /examples/imgs/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunalbert/lucid.pytorch/1bcc87a41c99bef1d64d37116c8a2440d11b0def/examples/imgs/dog.jpg -------------------------------------------------------------------------------- /examples/imgs/eel.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunalbert/lucid.pytorch/1bcc87a41c99bef1d64d37116c8a2440d11b0def/examples/imgs/eel.JPEG -------------------------------------------------------------------------------- /examples/imgs/snake.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunalbert/lucid.pytorch/1bcc87a41c99bef1d64d37116c8a2440d11b0def/examples/imgs/snake.jpg -------------------------------------------------------------------------------- /examples/imgs/spider.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunalbert/lucid.pytorch/1bcc87a41c99bef1d64d37116c8a2440d11b0def/examples/imgs/spider.png -------------------------------------------------------------------------------- /tests/activations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunalbert/lucid.pytorch/1bcc87a41c99bef1d64d37116c8a2440d11b0def/tests/activations/__init__.py -------------------------------------------------------------------------------- /tests/activations/test_deconvrelu.py: -------------------------------------------------------------------------------- 1 | import nose 2 | from nose.tools import with_setup 3 | from nose.tools import assert_equal 4 | 5 | import torch 6 | import numpy as np 7 | from torch.autograd import Variable 8 | from vis.activations import DeconvnetRelu 9 | 10 | 11 | def set_up(): 12 | pass 13 | 14 | 15 | def teardown(): 16 | pass 17 | 18 | 19 | @with_setup(set_up, teardown) 20 | def test_deconvrelu_foward(): 21 | pass 22 | 23 | 24 | @with_setup(set_up, teardown) 25 | def test_deconvrelu_backward(): 26 | pass 27 | 28 | 29 | 30 | if __name__ == '__main__': 31 | nose.run() 32 | -------------------------------------------------------------------------------- /tests/activations/test_guidedReLU.py: -------------------------------------------------------------------------------- 1 | import nose 2 | from nose.tools import assert_equal 3 | from nose.tools import with_setup 4 | 5 | import torch 6 | import numpy as np 7 | from torch.autograd import Variable 8 | from vis.activations import GuidedBackProRelu 9 | 10 | 11 | def set_up(): 12 | print("Test start") 13 | 14 | 15 | def tear_down(): 16 | print("Test done") 17 | 18 | 19 | @with_setup(set_up, tear_down) 20 | def test_forward(): 21 | x = Variable(torch.randn(2,3)) 22 | 23 | grelu = GuidedBackProRelu() 24 | out = grelu(x) 25 | 26 | x_mask = torch.clamp(x, min=0) 27 | res = torch.sum(x_mask - out) 28 | res = res.data.cpu().numpy() 29 | assert_equal(res, 0) 30 | 31 | 32 | @with_setup(set_up, tear_down) 33 | def test_backward(): 34 | 35 | x = Variable(torch.randn(2,3), requires_grad=True) 36 | 37 | grelu = GuidedBackProRelu() 38 | out = grelu(x) 39 | out = torch.sum(out) 40 | out.backward() 41 | grad = x.grad 42 | 43 | x_mask = torch.gt(x, 0).float() 44 | result = torch.sum(x_mask - grad.float()) 45 | result = result.data.numpy() 46 | 47 | assert_equal(result, 0) 48 | 49 | 50 | if __name__=='__main__': 51 | nose.run() 52 | 53 | -------------------------------------------------------------------------------- /tests/test.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunalbert/lucid.pytorch/1bcc87a41c99bef1d64d37116c8a2440d11b0def/tests/test.py -------------------------------------------------------------------------------- /vis/activations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | 5 | class GuidedBackProRelu(Function): 6 | def forward(self, input): 7 | self.save_for_backward(input) 8 | return input.clamp(min=0) 9 | 10 | def backward(self, grad_outputs): 11 | input, = self.saved_tensors 12 | grad_input = grad_outputs.clone() 13 | grad_input[grad_outputs < 0] = 0 14 | grad_input[input < 0] = 0 15 | return grad_input 16 | 17 | def named_parameters(self, memo, submodule_prefix): 18 | return [] 19 | 20 | 21 | class DeconvnetRelu(Function): 22 | def forward(self, input): 23 | return input.clamp(min=0) 24 | 25 | def backward(self, grad_outputs): 26 | grad_input = grad_outputs.clone() 27 | grad_input[grad_outputs < 0] = 0 28 | return grad_input -------------------------------------------------------------------------------- /vis/algs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.autograd import Variable 4 | from vis.activations import GuidedBackProRelu 5 | from vis.extractor import Extractor 6 | 7 | 8 | def replace_relu(module): 9 | """ 10 | Replace all the ReLU activation function 11 | with GuidedBackProRelu 12 | :param module: 13 | :return: 14 | """ 15 | for idx, m in module._modules.items(): 16 | if m.__class__.__name__ == 'ReLU': 17 | module._modules[idx] = GuidedBackProRelu() 18 | else: 19 | replace_relu(m) 20 | 21 | 22 | class VanillaBackProModel(object): 23 | """ 24 | Vanilla Backpropagation Model for 25 | visualiazation 26 | """ 27 | def __init__(self, model, use_cuda): 28 | self.model = model.eval() 29 | if use_cuda: 30 | self.model = self.model.cuda() 31 | self.cuda = use_cuda 32 | 33 | def __call__(self, x, index=None): 34 | """ 35 | Perform vanilla backpropagation visualization model 36 | :param x: input Variable 37 | :param index: 38 | :return: 39 | """ 40 | if self.cuda: 41 | out = self.model(x.cuda()) 42 | else: 43 | out = self.model(x) 44 | out = out.view(-1) 45 | 46 | if index is None: 47 | if self.cuda: 48 | index = int(torch.max(out).data.cpu().numpy()) 49 | else: 50 | index = int(torch.max(out).data.numpy()) 51 | 52 | one_hot_mask = torch.zeros(out.size()) 53 | one_hot_mask[index] = 1 54 | 55 | one_hot_mask = Variable(one_hot_mask, requires_grad=True) 56 | 57 | if self.cuda: 58 | one_hot_mask = torch.sum(one_hot_mask.cuda() * out) 59 | else: 60 | one_hot_mask = torch.sum(one_hot_mask * out) 61 | 62 | # backpropagation 63 | self.model.zero_grad() 64 | one_hot_mask.backward() 65 | result = x.grad.data.cpu().numpy() 66 | return result[0] 67 | 68 | 69 | class GuidedBackProReLUModel(object): 70 | def __init__(self, model, use_cuda): 71 | self.model = model 72 | self.model.eval() 73 | if use_cuda: 74 | self.model.cuda() 75 | self.cuda = use_cuda 76 | replace_relu(self.model) 77 | 78 | def __call__(self, x, index=None): 79 | assert x.size()[0] == 1 80 | 81 | if self.cuda: 82 | # x = x.cuda() // This is not the right way 83 | out = self.model(x.cuda()) 84 | else: 85 | out = self.model(x) 86 | out = out.view(-1) 87 | 88 | if index is None: 89 | if self.cuda: 90 | index = int(torch.max(out).data.cpu().numpy()) 91 | else: 92 | index = int(torch.max(out).data.numpy()) 93 | 94 | one_hot_mask = torch.zeros(out.size()) 95 | one_hot_mask[index] = 1 96 | 97 | one_hot_mask = Variable(one_hot_mask, requires_grad=True) 98 | 99 | if self.cuda: 100 | one_hot_mask = torch.sum(one_hot_mask.cuda() * out) 101 | else: 102 | one_hot_mask = torch.sum(one_hot_mask * out) 103 | 104 | # backpropagation 105 | self.model.zero_grad() 106 | one_hot_mask.backward() 107 | result = x.grad.data.cpu().numpy() 108 | return result[0] 109 | 110 | 111 | class GradCam(object): 112 | """GradCam visualization technique 113 | """ 114 | def __init__(self, model, target_layers, use_cuda): 115 | assert len(target_layers) == 1 116 | 117 | self.cuda = use_cuda 118 | if self.cuda: 119 | self.model = model.cuda() 120 | else: 121 | self.model = model 122 | self.model.eval() 123 | 124 | self.extractor = Extractor(model, target_layers) 125 | 126 | def __call__(self, x, index=None): 127 | """ The implementation of GradCam 128 | 129 | :param x: 130 | :param index: 131 | :return: 132 | """ 133 | assert x.size(0) == 1 134 | 135 | if self.cuda: 136 | inter_outs, final_out = self.extractor(x.cuda()) 137 | else: 138 | inter_outs, final_out = self.extractor(x) 139 | 140 | final_out = final_out.view(-1) 141 | if index == None: 142 | index = int(torch.max(final_out)) 143 | 144 | one_hot_mask = torch.zeros(final_out.size()) 145 | one_hot_mask[index] = 1 146 | one_hot_mask = Variable(one_hot_mask) 147 | 148 | if self.cuda: 149 | one_hot_mask = torch.sum(one_hot_mask.cuda() * final_out) 150 | else: 151 | one_hot_mask = torch.sum(one_hot_mask * final_out) 152 | 153 | self.model.zero_grad() 154 | one_hot_mask.backward(retain_variables=True) 155 | 156 | target_grads = self.extractor.get_grads() 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | -------------------------------------------------------------------------------- /vis/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import numpy as np 4 | 5 | 6 | class Extractor(object): 7 | def __init__(self, base_model, target_layers): 8 | super(Extractor, self).__init__() 9 | self.base_model = base_model 10 | self.target_layers = target_layers 11 | self.grads = [] 12 | 13 | def save_grads(self, grad): 14 | self.grads.append(grad) 15 | 16 | def get_grads(self): 17 | return self.grads 18 | 19 | def __call__(self, x): 20 | inter_outs = [] 21 | self.grads = [] 22 | for name, module in self.base_model._modules.items(): 23 | x = module(x) 24 | if name in self.target_layers: 25 | x.register_hook(self.save_grads) 26 | inter_outs.append(x) 27 | return inter_outs, x 28 | 29 | 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /vis/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | """ 6 | This file for loss function for 2014 paper from 7 | """ -------------------------------------------------------------------------------- /vis/test.py: -------------------------------------------------------------------------------- 1 | from vis.algs import GuidedBackProRelu 2 | import torchvision.models as models 3 | 4 | 5 | def replace_relu(module): 6 | """ 7 | Replace all the ReLU activation function 8 | with GuidedBackProRelu 9 | :param module: 10 | :return: 11 | """ 12 | for idx, m in module._modules.items(): 13 | if m.__class__.__name__ == 'ReLU': 14 | module._modules[idx] = GuidedBackProRelu() 15 | else: 16 | replace_relu(m) 17 | 18 | 19 | def inference(): 20 | model = models.vgg19() 21 | replace_relu(model) 22 | for name, module in model._modules.items(): 23 | print(module) 24 | 25 | 26 | if __name__=='__main__': 27 | inference() -------------------------------------------------------------------------------- /vis/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | # TODO: please add gray image support 6 | 7 | def save_grad_img(grads, out_path): 8 | """ 9 | Save gradients as img 10 | :param grads: gradients obtained from visulaziation model 11 | :param out_path: the path to save gradients image 12 | :return: 13 | """ 14 | grads = grads - grads.min() 15 | grads /= grads.max() 16 | grads = np.transpose(grads, [1, 2, 0]) 17 | grads = grads[:, :, ::-1] 18 | grads = np.uint8(grads * 255)[..., ::-1] 19 | grads = np.squeeze(grads) 20 | cv2.imwrite(out_path, grads) 21 | 22 | 23 | def save_cam_img(img, grads, out_path): 24 | """ 25 | save the activation map on the original img 26 | :param img: original image with three chanels (RGB) in range(0,1) 27 | :param grads: grads w.r.t input image 28 | :param out_path: the path to save the image 29 | :return: 30 | """ 31 | grads = np.transpose(grads, [1, 2, 0]) 32 | grads = grads[:, :, ::-1] 33 | 34 | heat_map = cv2.applyColorMap(np.uint8(grads * 255), cv2.COLORMAP_JET) 35 | heat_map = heat_map / 255 36 | 37 | cam = heat_map + img 38 | cam = cam / np.max(cam) 39 | 40 | cv2.imwrite(out_path, np.uint8(255 * cam)) 41 | --------------------------------------------------------------------------------