├── .gitignore ├── README.md ├── VBP_results ├── out_0_cat1.jpg ├── out_0_cat2.jpg ├── out_1_cat1.jpg ├── out_1_cat2.jpg ├── out_2_cat1.jpg ├── out_2_cat2.jpg ├── out_3_cat1.jpg ├── out_3_cat2.jpg ├── out_4_cat1.jpg └── out_4_cat2.jpg ├── feat_maps ├── feat_0_cat1.jpg ├── feat_0_cat2.jpg ├── feat_1_cat1.jpg ├── feat_1_cat2.jpg ├── feat_2_cat1.jpg ├── feat_2_cat2.jpg ├── feat_3_cat1.jpg ├── feat_3_cat2.jpg ├── feat_4_cat1.jpg └── feat_4_cat2.jpg ├── image ├── COCO.jpg ├── COCO2.jpg ├── COCO3.jpg ├── cat1.jpg ├── cat2.jpg ├── resized_cat1.jpg └── resized_cat2.jpg ├── net ├── __init__.py ├── __init__.pyc ├── __pycache__ │ └── __init__.cpython-36.pyc ├── models │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ │ ├── alexnet.cpython-36.pyc │ │ └── vgg.cpython-36.pyc │ ├── alexnet.py │ ├── alexnet.pyc │ ├── vgg.py │ └── vgg.pyc └── utility │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ └── tools.cpython-36.pyc │ ├── tools.py │ └── tools.pyc ├── overlay ├── overlay_cat1.jpg └── overlay_cat2.jpg └── visualbackprop.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pth 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VisualBackProp in PyTorch 2 | VisualBackProp - visualization method for convolutional neural networks 3 | 4 | # Description 5 | Detailed description of the VisualBackProp can be found in: https://arxiv.org/abs/1611.05418 6 | 7 | # Results 8 | ## Example 1 9 | 10 | Resized Image | Backprop result | Overlay result 11 | :-------------------------:|:-------------------------:|:-------------------------: 12 | ![Orig](https://github.com/eugenelet/VisualBackProp-PyTorch/blob/master/image/resized_cat1.jpg) | ![Backprop result](https://github.com/eugenelet/VisualBackProp-PyTorch/blob/master/VBP_results/out_4_cat1.jpg) | ![Overlay](https://github.com/eugenelet/VisualBackProp-PyTorch/blob/master/overlay/overlay_cat1.jpg) 13 | 14 | ## Example 2 15 | 16 | Resized Image | Backprop result | Overlay result 17 | :-------------------------:|:-------------------------:|:-------------------------: 18 | ![Orig](https://github.com/eugenelet/VisualBackProp-PyTorch/blob/master/image/resized_cat2.jpg) | ![Backprop result](https://github.com/eugenelet/VisualBackProp-PyTorch/blob/master/VBP_results/out_4_cat2.jpg) | ![Overlay](https://github.com/eugenelet/VisualBackProp-PyTorch/blob/master/overlay/overlay_cat2.jpg) -------------------------------------------------------------------------------- /VBP_results/out_0_cat1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/VBP_results/out_0_cat1.jpg -------------------------------------------------------------------------------- /VBP_results/out_0_cat2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/VBP_results/out_0_cat2.jpg -------------------------------------------------------------------------------- /VBP_results/out_1_cat1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/VBP_results/out_1_cat1.jpg -------------------------------------------------------------------------------- /VBP_results/out_1_cat2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/VBP_results/out_1_cat2.jpg -------------------------------------------------------------------------------- /VBP_results/out_2_cat1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/VBP_results/out_2_cat1.jpg -------------------------------------------------------------------------------- /VBP_results/out_2_cat2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/VBP_results/out_2_cat2.jpg -------------------------------------------------------------------------------- /VBP_results/out_3_cat1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/VBP_results/out_3_cat1.jpg -------------------------------------------------------------------------------- /VBP_results/out_3_cat2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/VBP_results/out_3_cat2.jpg -------------------------------------------------------------------------------- /VBP_results/out_4_cat1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/VBP_results/out_4_cat1.jpg -------------------------------------------------------------------------------- /VBP_results/out_4_cat2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/VBP_results/out_4_cat2.jpg -------------------------------------------------------------------------------- /feat_maps/feat_0_cat1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/feat_maps/feat_0_cat1.jpg -------------------------------------------------------------------------------- /feat_maps/feat_0_cat2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/feat_maps/feat_0_cat2.jpg -------------------------------------------------------------------------------- /feat_maps/feat_1_cat1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/feat_maps/feat_1_cat1.jpg -------------------------------------------------------------------------------- /feat_maps/feat_1_cat2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/feat_maps/feat_1_cat2.jpg -------------------------------------------------------------------------------- /feat_maps/feat_2_cat1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/feat_maps/feat_2_cat1.jpg -------------------------------------------------------------------------------- /feat_maps/feat_2_cat2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/feat_maps/feat_2_cat2.jpg -------------------------------------------------------------------------------- /feat_maps/feat_3_cat1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/feat_maps/feat_3_cat1.jpg -------------------------------------------------------------------------------- /feat_maps/feat_3_cat2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/feat_maps/feat_3_cat2.jpg -------------------------------------------------------------------------------- /feat_maps/feat_4_cat1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/feat_maps/feat_4_cat1.jpg -------------------------------------------------------------------------------- /feat_maps/feat_4_cat2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/feat_maps/feat_4_cat2.jpg -------------------------------------------------------------------------------- /image/COCO.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/image/COCO.jpg -------------------------------------------------------------------------------- /image/COCO2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/image/COCO2.jpg -------------------------------------------------------------------------------- /image/COCO3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/image/COCO3.jpg -------------------------------------------------------------------------------- /image/cat1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/image/cat1.jpg -------------------------------------------------------------------------------- /image/cat2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/image/cat2.jpg -------------------------------------------------------------------------------- /image/resized_cat1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/image/resized_cat1.jpg -------------------------------------------------------------------------------- /image/resized_cat2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/image/resized_cat2.jpg -------------------------------------------------------------------------------- /net/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/net/__init__.py -------------------------------------------------------------------------------- /net/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/net/__init__.pyc -------------------------------------------------------------------------------- /net/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/net/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /net/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/net/models/__init__.py -------------------------------------------------------------------------------- /net/models/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/net/models/__init__.pyc -------------------------------------------------------------------------------- /net/models/__pycache__/alexnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/net/models/__pycache__/alexnet.cpython-36.pyc -------------------------------------------------------------------------------- /net/models/__pycache__/vgg.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/net/models/__pycache__/vgg.cpython-36.pyc -------------------------------------------------------------------------------- /net/models/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import torch.nn.parallel 6 | import torch.optim 7 | import torch.utils.data 8 | import torchvision.transforms as transforms 9 | import torchvision.datasets as datasets 10 | import torchvision.models as models 11 | from torchvision import transforms 12 | 13 | import numpy as np 14 | #from net.utility import tools 15 | import cv2 16 | 17 | class AlexNet(nn.Module): 18 | 19 | def __init__(self, num_classes=1000): 20 | super(AlexNet, self).__init__() 21 | self.features = nn.Sequential( 22 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 23 | nn.ReLU(inplace=True), 24 | nn.MaxPool2d(kernel_size=3, stride=2), 25 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 26 | nn.ReLU(inplace=True), 27 | nn.MaxPool2d(kernel_size=3, stride=2), 28 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 29 | nn.ReLU(inplace=True), 30 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 31 | nn.ReLU(inplace=True), 32 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 33 | nn.ReLU(inplace=True), 34 | nn.MaxPool2d(kernel_size=3, stride=2), 35 | ) 36 | self.classifier = nn.Sequential( 37 | nn.Dropout(), 38 | nn.Linear(256 * 6 * 6, 4096), 39 | nn.ReLU(inplace=True), 40 | nn.Dropout(), 41 | nn.Linear(4096, 4096), 42 | nn.ReLU(inplace=True), 43 | nn.Linear(4096, num_classes), 44 | ) 45 | 46 | def forward(self, x): 47 | x = self.features(x) 48 | x = x.view(x.size(0), 256 * 6 * 6) 49 | x = self.classifier(x) 50 | logit = x 51 | prob = F.sigmoid(logit) 52 | return logit, prob 53 | 54 | 55 | def alexnet(**kwargs): 56 | r"""AlexNet model architecture from the 57 | `"One weird trick..." `_ paper. 58 | Args: 59 | pretrained (bool): If True, returns a model pre-trained on ImageNet 60 | """ 61 | model = AlexNet(**kwargs) 62 | # if pretrained: 63 | # model.load_state_dict(model_zoo.load_url(model_urls['alexnet'])) 64 | return model 65 | 66 | 67 | if __name__=="__main__": 68 | 69 | image = cv2.imread('/Users/Eugene/Documents/Git/pytorch/image/cat1.jpg') 70 | image = cv2.resize(image,(227,227)) 71 | loader = transforms.Compose([ 72 | transforms.ToTensor(), 73 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 74 | ]) 75 | image = loader(image) 76 | image = image.contiguous().view(1,3,227,227) 77 | pretrained_file = "/Users/Eugene/Documents/Git/pytorch/VisualBackProp/snap/alexnet-owt-4df8aa71.pth" 78 | model = alexnet() 79 | 80 | tools.load_valid(model,pretrained_file,None) 81 | logit, prob = model.forward(Variable(image)) 82 | prob = prob.view(-1) 83 | value, indices = prob.max(0) 84 | print("Max Value: ", value) 85 | print("Index: ",indices) 86 | -------------------------------------------------------------------------------- /net/models/alexnet.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/net/models/alexnet.pyc -------------------------------------------------------------------------------- /net/models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import torch.nn.parallel 6 | import torch.optim 7 | import torch.utils.data 8 | import torchvision.transforms as transforms 9 | import torchvision.datasets as datasets 10 | import torchvision.models as models 11 | from torchvision import transforms 12 | 13 | import cv2 14 | import numpy as np 15 | import math 16 | 17 | from ..utility import tools 18 | 19 | 20 | __all__ = [ 21 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 22 | 'vgg19_bn', 'vgg19', 23 | ] 24 | 25 | 26 | model_urls = { 27 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 28 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 29 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 30 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 31 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 32 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 33 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 34 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 35 | } 36 | 37 | 38 | class VGG(nn.Module): 39 | 40 | def __init__(self, features, num_classes=1000): 41 | super(VGG, self).__init__() 42 | self.features = features 43 | self.classifier = nn.Sequential( 44 | nn.Linear(512 * 7 * 7, 4096), 45 | nn.ReLU(True), 46 | nn.Dropout(), 47 | nn.Linear(4096, 4096), 48 | nn.ReLU(True), 49 | nn.Dropout(), 50 | nn.Linear(4096, num_classes), 51 | ) 52 | self._initialize_weights() 53 | 54 | def forward(self, x): 55 | x = self.features(x) 56 | x = x.view(x.size(0), -1) 57 | logit = self.classifier(x) 58 | prob = F.sigmoid(logit) 59 | 60 | 61 | return logit,prob 62 | 63 | def _initialize_weights(self): 64 | for m in self.modules(): 65 | if isinstance(m, nn.Conv2d): 66 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 67 | m.weight.data.normal_(0, math.sqrt(2. / n)) 68 | if m.bias is not None: 69 | m.bias.data.zero_() 70 | elif isinstance(m, nn.BatchNorm2d): 71 | m.weight.data.fill_(1) 72 | m.bias.data.zero_() 73 | elif isinstance(m, nn.Linear): 74 | n = m.weight.size(1) 75 | m.weight.data.normal_(0, 0.01) 76 | m.bias.data.zero_() 77 | 78 | 79 | def make_layers(cfg, batch_norm=False): 80 | layers = [] 81 | in_channels = 3 82 | for v in cfg: 83 | if v == 'M': 84 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 85 | else: 86 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 87 | if batch_norm: 88 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 89 | else: 90 | layers += [conv2d, nn.ReLU(inplace=True)] 91 | in_channels = v 92 | return nn.Sequential(*layers) 93 | 94 | 95 | cfg = { 96 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 97 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 98 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 99 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 100 | } 101 | 102 | 103 | def vgg11(**kwargs): 104 | """VGG 11-layer model (configuration "A") 105 | 106 | Args: 107 | pretrained (bool): If True, returns a model pre-trained on ImageNet 108 | """ 109 | model = VGG(make_layers(cfg['A']), **kwargs) 110 | return model 111 | 112 | 113 | def vgg11_bn(**kwargs): 114 | """VGG 11-layer model (configuration "A") with batch normalization 115 | 116 | Args: 117 | pretrained (bool): If True, returns a model pre-trained on ImageNet 118 | """ 119 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) 120 | return model 121 | 122 | 123 | def vgg13(**kwargs): 124 | """VGG 13-layer model (configuration "B") 125 | 126 | Args: 127 | pretrained (bool): If True, returns a model pre-trained on ImageNet 128 | """ 129 | model = VGG(make_layers(cfg['B']), **kwargs) 130 | return model 131 | 132 | 133 | def vgg13_bn(**kwargs): 134 | """VGG 13-layer model (configuration "B") with batch normalization 135 | 136 | Args: 137 | pretrained (bool): If True, returns a model pre-trained on ImageNet 138 | """ 139 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) 140 | if pretrained: 141 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn'])) 142 | return model 143 | 144 | 145 | def vgg16(**kwargs): 146 | """VGG 16-layer model (configuration "D") 147 | 148 | Args: 149 | pretrained (bool): If True, returns a model pre-trained on ImageNet 150 | """ 151 | model = VGG(make_layers(cfg['D']), **kwargs) 152 | return model 153 | 154 | 155 | def vgg16_bn(**kwargs): 156 | """VGG 16-layer model (configuration "D") with batch normalization 157 | 158 | Args: 159 | pretrained (bool): If True, returns a model pre-trained on ImageNet 160 | """ 161 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) 162 | return model 163 | 164 | 165 | def vgg19(**kwargs): 166 | """VGG 19-layer model (configuration "E") 167 | 168 | Args: 169 | pretrained (bool): If True, returns a model pre-trained on ImageNet 170 | """ 171 | model = VGG(make_layers(cfg['E']), **kwargs) 172 | return model 173 | 174 | 175 | def vgg19_bn(**kwargs): 176 | """VGG 19-layer model (configuration 'E') with batch normalization 177 | 178 | Args: 179 | pretrained (bool): If True, returns a model pre-trained on ImageNet 180 | """ 181 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) 182 | return model 183 | 184 | 185 | if __name__ == "__main__": 186 | model = vgg19(#num_classes=17 187 | ) 188 | pretrained_file = "/Users/Eugene/Documents/Git/pytorch/VisualBackProp/snap/vgg19-dcbb9e9d.pth" 189 | tools.load_valid(model, 190 | pretrained_file, 191 | skip_list= None, #['classifier.6.weight', 'classifier.6.bias'], 192 | ) 193 | #print(model.state_dict().keys()) 194 | 195 | image = cv2.imread('/Users/Eugene/Documents/Git/pytorch/image/cat1.jpg') 196 | image = cv2.resize(image, (227, 227)) 197 | loader = transforms.Compose([ 198 | transforms.ToTensor(), 199 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 200 | ]) 201 | image = loader(image) 202 | image = image.contiguous().view(1, 3, 227, 227) 203 | 204 | logit, prob = model.forward(Variable(image)) 205 | prob = prob.view(-1) 206 | value, indices = prob.max(0) 207 | print("Max Value: ", value) 208 | print("Index: ", indices) 209 | 210 | -------------------------------------------------------------------------------- /net/models/vgg.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/net/models/vgg.pyc -------------------------------------------------------------------------------- /net/utility/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/net/utility/__init__.py -------------------------------------------------------------------------------- /net/utility/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/net/utility/__init__.pyc -------------------------------------------------------------------------------- /net/utility/__pycache__/tools.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/net/utility/__pycache__/tools.cpython-36.pyc -------------------------------------------------------------------------------- /net/utility/tools.py: -------------------------------------------------------------------------------- 1 | # stdlib 2 | import os 3 | 4 | # PyTorch 5 | import torch 6 | from torch.autograd import Variable 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | # 3rd party packages 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | from PIL import Image 14 | 15 | # modules 16 | from net.models.vgg import * 17 | from net.models.alexnet import * 18 | 19 | 20 | def vis_square(image): 21 | """Take an array of shape (n, height, width) or (n, height, width, 3) 22 | and visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n)""" 23 | image = image.numpy() 24 | 25 | # normalize data for display 26 | data = (image - image.min()) / (image.max() - image.min()) 27 | 28 | count = 0 29 | for img in data: 30 | count += 1 31 | if(count>30): 32 | print(img*255) 33 | im = Image.fromarray(img*255).convert("L") 34 | im.show() 35 | break 36 | im.save('filter'+str(count)+".jpg") 37 | 38 | 39 | def vis_single_square(image,savedir): 40 | """Take an array of shape (n, height, width) or (n, height, width, 3) 41 | and visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n)""" 42 | image = image.numpy() 43 | 44 | # normalize data for display 45 | data = (image - image.min()) / (image.max() - image.min()) 46 | data = data[0,0,:,:] 47 | 48 | im = Image.fromarray(data*255).convert("L") 49 | im.show() 50 | im.save(savedir) 51 | 52 | 53 | 54 | def load_valid(model, pretrained_file, skip_list=None): 55 | 56 | pretrained_dict = torch.load(pretrained_file) 57 | model_dict = model.state_dict() 58 | 59 | # 1. filter out unnecessary keys 60 | if skip_list is not None: 61 | pretrained_dict1 = {k: v for k, v in pretrained_dict.items() if k in model_dict and k not in skip_list } 62 | else: 63 | pretrained_dict1 = pretrained_dict 64 | # 2. overwrite entries in the existing state dict 65 | model_dict.update(pretrained_dict1) 66 | model.load_state_dict(model_dict) 67 | return model 68 | 69 | if __name__ == "__main__": 70 | pass -------------------------------------------------------------------------------- /net/utility/tools.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/net/utility/tools.pyc -------------------------------------------------------------------------------- /overlay/overlay_cat1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/overlay/overlay_cat1.jpg -------------------------------------------------------------------------------- /overlay/overlay_cat2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/VisualBackProp-PyTorch/14a85e9aadd09b780f4fb3ece167d4fe0d9799cd/overlay/overlay_cat2.jpg -------------------------------------------------------------------------------- /visualbackprop.py: -------------------------------------------------------------------------------- 1 | # stdlib 2 | import os 3 | 4 | # PyTorch 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | 10 | import torchvision.transforms as transforms 11 | 12 | # 3rd party packages 13 | import cv2 14 | import numpy as np 15 | import matplotlib.pyplot as plt 16 | from PIL import Image 17 | 18 | # modules 19 | from net.models.vgg import * 20 | from net.models.alexnet import * 21 | # from net.utility.tools import * 22 | 23 | import math 24 | import pdb 25 | 26 | maps = [] 27 | layers=[] 28 | hooks=[] 29 | 30 | 31 | FEAT_KEEP = 30 # Feature Maps to show 32 | FEAT_SIZE = 224 # Size of feature maps to show 33 | FEAT_MAPS_DIR = 'feat_maps' # dir. to save feat maps 34 | VBP_DIR = 'VBP_results' # dir. to save VBP results 35 | OVERLAY_DIR = "overlay" # dir. to save overlay results 36 | 37 | 38 | def normalize_gamma(image, gamma=1.0): 39 | # normalize data for display 40 | image = (image - image.min()) / (image.max() - image.min()) 41 | invGamma = 1.0 / gamma 42 | image = (image ** invGamma) * 255 43 | return image.astype("uint8") 44 | 45 | 46 | def visual_feature(self, input, output): 47 | # The hook function that show you the feature maps while forward propagate 48 | 49 | vis_square(output.data[0,:]) 50 | 51 | def save_feature_maps(self,input,output): 52 | # The hook function that saves feature maps while forward propagate 53 | 54 | map = output.data 55 | maps.append(map) 56 | 57 | 58 | def add_hook(net,layer_name,func): 59 | ''' 60 | Add a hook function in the layers you specified. 61 | Hook will be called during forward propagate at the layer you specified. 62 | 63 | :param net: The model you defined 64 | :param layer_name: Specify which layer you want to hook, currently you can hook 'all', 'maxpool', 'relu' 65 | :param func: Specify which hook function you want to hook while forward propagate 66 | :return: this function will return the model that hooked the function you specified in specific layer 67 | ''' 68 | 69 | if layer_name=='maxpool': 70 | for m in net.features: 71 | if isinstance(m, nn.MaxPool2d): 72 | m.register_forward_hook(func) 73 | return net 74 | 75 | if layer_name == 'relu': 76 | for index, m in enumerate(net.features): 77 | if isinstance(m, nn.ReLU): 78 | type_name = str(type(m)).replace("<'", '').replace("'>", '').split('.')[-1] 79 | name = 'features' + '-' + str(index) + '-' + type_name 80 | hook = m.register_forward_hook(func) 81 | layers.append((name, m)) 82 | hooks.append(hook) 83 | return net 84 | 85 | if layer_name == 'all': 86 | for index, m in enumerate(net.features): 87 | type_name = str(type(m)).replace("<'", '').replace("'>", '').split('.')[-1] 88 | name = 'features' + '-' + str(index) + '-' + type_name 89 | hook = m.register_forward_hook(func) 90 | layers.append((name, m)) 91 | hooks.append(hook) 92 | return net 93 | 94 | 95 | def visualbackprop(layers,maps): 96 | 97 | ''' 98 | :param layers: the saved layers 99 | :param maps: the saved maps 100 | :return: return the final mask 101 | ''' 102 | 103 | num_layers = len(maps) 104 | avgs = [] 105 | mask = None 106 | ups = [] 107 | 108 | upSample = nn.Upsample(scale_factor=2) 109 | 110 | for n in range(num_layers-1,0,-1): 111 | cur_layer=layers[n][1] 112 | if type(cur_layer) in [torch.nn.MaxPool2d]: 113 | print(layers[n][0]) 114 | ########################## 115 | # Get and set attributes # 116 | ########################## 117 | relu = maps[n-1] 118 | conv = maps[n-2] 119 | 120 | ########################################### 121 | # Average filters and multiply pixel-wise # 122 | ########################################### 123 | 124 | # Average filters 125 | avg = relu.mean(dim=1) 126 | avg = avg.unsqueeze(0) 127 | avgs.append(avg) 128 | 129 | if mask is not None: 130 | mask = upSample(Variable(mask)).data 131 | mask = mask * avg 132 | else: 133 | mask = avg 134 | 135 | # upsampling : see http://pytorch.org/docs/nn.html#convtranspose2d 136 | weight = Variable(torch.ones(1, 1, 3, 3)) 137 | up = F.conv_transpose2d(Variable(mask), weight, stride=1, padding=1) 138 | mask = up.data 139 | ups.append(mask) 140 | 141 | return ups 142 | 143 | 144 | 145 | 146 | 147 | 148 | def plotFeatMaps(layers,maps): 149 | 150 | ''' 151 | :param layers: the saved layers 152 | :param maps: the saved maps 153 | :return: top feat. maps of relu layers 154 | ''' 155 | 156 | num_layers = len(maps) 157 | feat_collection = [] 158 | # Show top FEAT_KEEP feature maps (after ReLU) starting from bottom layers 159 | for n in range(num_layers): 160 | cur_layer=layers[n][1] 161 | if type(cur_layer) in [torch.nn.MaxPool2d]: 162 | ########################## 163 | # Get and set attributes # 164 | ########################## 165 | relu = maps[n-1] 166 | 167 | ########################################### 168 | # Sort Feat Maps based on energy of F.M. # 169 | ########################################### 170 | feat_energy = [] 171 | # Get energy of each channel 172 | for channel_n in range(relu.shape[1]): 173 | feat_energy.append(np.sum(relu[0][channel_n].numpy())) 174 | feat_energy = np.array(feat_energy) 175 | # Sort energy 176 | feat_rank = np.argsort(feat_energy)[::-1] 177 | 178 | # Empty background 179 | back_len = int(math.ceil(math.sqrt(FEAT_SIZE * FEAT_SIZE * FEAT_KEEP * 2))) 180 | feat = np.zeros((back_len, back_len)) 181 | col = 0 182 | row = 0 183 | for feat_n in range(FEAT_KEEP): 184 | if col*FEAT_SIZE + FEAT_SIZE < back_len: 185 | feat[row*FEAT_SIZE:row*FEAT_SIZE + FEAT_SIZE, col*FEAT_SIZE:col*FEAT_SIZE + FEAT_SIZE] =\ 186 | cv2.resize(normalize_gamma(relu[0][feat_rank[feat_n]].numpy(), 0.1), (FEAT_SIZE,FEAT_SIZE)) 187 | col = col + 1 188 | else: 189 | row = row + 1 190 | col = 0 191 | feat[row*FEAT_SIZE:row*FEAT_SIZE + FEAT_SIZE, col*FEAT_SIZE:col*FEAT_SIZE + FEAT_SIZE] =\ 192 | cv2.resize(normalize_gamma(relu[0][feat_rank[feat_n]].numpy(), 0.1), (FEAT_SIZE,FEAT_SIZE)) 193 | col = col + 1 194 | 195 | feat_collection.append(feat) 196 | 197 | return feat_collection 198 | 199 | 200 | # Show VBP Result 201 | def show_VBP(label, image): 202 | """Take an array of shape (n, height, width) or (n, height, width, 3) 203 | and visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n)""" 204 | image = image.numpy() 205 | 206 | # normalize data for display 207 | data = (image - image.min()) / (image.max() - image.min()) 208 | data = data[0,0,:,:] 209 | data = cv2.resize(data, (224,224)) 210 | data = (data*255).astype("uint8") 211 | # cv2.imwrite(label, data) 212 | cv2.imshow(label, data) 213 | 214 | 215 | # Save VBP Result 216 | def save_VBP(label, image): 217 | image = image.numpy() 218 | 219 | # normalize data for display 220 | data = (image - image.min()) / (image.max() - image.min()) 221 | data = data[0,0,:,:] 222 | data = cv2.resize(data, (224,224)) 223 | data = (data*255).astype("uint8") 224 | # cv2.imwrite(label, data) 225 | cv2.imwrite(label, data) 226 | 227 | 228 | def overlay(image, mask): 229 | # normalize data for display 230 | mask = (mask - mask.min()) / (mask.max() - mask.min()) 231 | mask = mask[0,0,:,:] 232 | mask = cv2.resize(mask, (224,224)) 233 | mask = (mask*255).astype("uint8") 234 | # pdb.set_trace() 235 | # assert image.shape == mask.shape, "image %r and mask %r must be of same shape" % (image.shape, mask.shape) 236 | # if image[:,:,2] + mask > 255: 237 | # image[:,:,2] = image[:,:,2] + mask 238 | # else: 239 | image[:,:,2] = cv2.add(image[:,:,2], mask) 240 | 241 | return image 242 | 243 | if __name__ == "__main__": 244 | 245 | ''' 246 | Load image, resize the image to 224 x 224, then transfer the loaded image numpy array. 247 | The transfer included 1.ArrayToTensor and 2.Normalization. 248 | After image transformation, we need to define the model. 249 | The model used here is VGG19, you can choose whatever model you like. 250 | After the model is defined, the pre-trained model is loaded. 251 | 252 | Before we forward the image through VGG, we need to define our hook first. 253 | The function "add_hook" provide you an easy way to add hook to layer. 254 | You have to specify: 255 | 1. the model you want to hook 256 | 2. the layer you want to hook 257 | 3. the function you want to hook 258 | 259 | Since I want to save the feature maps to a list. 260 | I create "save_image_maps" as my hook function 261 | This function will help me to extract the layer output. 262 | After the outputs are extracted, I want to output the extracted feature maps to image. 263 | So by calling "save_image_maps", we call save all the maps as image locally. 264 | ''' 265 | 266 | 267 | BASE_DIR = os.path.dirname(os.path.abspath(os.path.dirname('__dir__'))) 268 | IMG_DIR = './image' 269 | # MODEL_DIR = BASE_DIR + '/pretrained_model' 270 | IMG_NAME = 'cat1.jpg' 271 | 272 | 273 | image = cv2.imread(IMG_DIR+'/'+IMG_NAME) 274 | image = cv2.resize(image, (224, 224)) 275 | 276 | cv2.imwrite('./image/resized_' + IMG_NAME, image) 277 | 278 | loader = transforms.Compose([ 279 | transforms.ToTensor(), 280 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 281 | ]) 282 | img = loader(image) 283 | img = img.contiguous().view(1, 3, 224, 224) 284 | model = vgg19() 285 | 286 | #load_valid(model, MODEL_DIR + "/vgg19-dcbb9e9d.pth", None) 287 | x = Variable(img) 288 | 289 | 290 | add_hook(model,'all',save_feature_maps) 291 | logits, probs = model.forward(x) 292 | 293 | feat_collection = plotFeatMaps(layers, maps) 294 | 295 | # Save Feature Maps 296 | for i in range(len(feat_collection)): 297 | cv2.imwrite(FEAT_MAPS_DIR + '/feat_' + str(i) + '_' + IMG_NAME, feat_collection[i] * 255) 298 | masks = visualbackprop(layers, maps) 299 | 300 | mask_num = len(masks) 301 | 302 | cv2.imshow('ori', image) 303 | cv2.moveWindow('ori', 50, 50) 304 | for i in range(mask_num): 305 | save_VBP(VBP_DIR + '/out_' + str(i) + '_' + IMG_NAME, masks[i]) 306 | show_VBP('vbp_' + str(i) + '.png', masks[i]) 307 | cv2.moveWindow('vbp_' + str(i) + '.png', i*30 + 100, i*30 + 100) 308 | 309 | 310 | overlay_img = overlay(image, masks[mask_num - 1].numpy()) 311 | cv2.imshow('overlay', overlay_img) 312 | cv2.moveWindow('overlay', 200, 200) 313 | cv2.imwrite(OVERLAY_DIR + '/overlay_' + IMG_NAME, overlay_img) 314 | cv2.waitKey(0) 315 | --------------------------------------------------------------------------------