├── utils ├── __init__.py └── vis_utils.py ├── models ├── __init__.py └── vgg16_deconv.py ├── imgs └── bird.jpg ├── README.md ├── vis_layers.py └── demo.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .vis_utils import * 2 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg16_deconv import * 2 | -------------------------------------------------------------------------------- /imgs/bird.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csgwon/pytorch-deconvnet/HEAD/imgs/bird.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-deconvnet 2 | 3 | Example of Deconvnet in PyTorch for VGG16. vis_utils.py adapted from the assignments of [CS231n](http://cs231n.github.io/). 4 | 5 | Work in progress. Test codes to come. 6 | 7 | 8 | -------------------------------------------------------------------------------- /vis_layers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from models import * 4 | from utils import * 5 | import matplotlib.pyplot as plt 6 | from PIL import Image 7 | import numpy as np 8 | import sys 9 | 10 | def vis_layer(activ_map): 11 | plt.ion() 12 | plt.imshow(activ_map[:,:,0], cmap='gray') 13 | 14 | if __name__ == '__main__': 15 | if len(sys.argv) < 2: 16 | print('Usage: '+sys.argv[0]+' img_file') 17 | sys.exit(0) 18 | 19 | img_filename = sys.argv[1] 20 | 21 | n_classes = 1000 # using ImageNet pretrained weights 22 | 23 | vgg16_c = VGG16_conv(n_classes) 24 | 25 | img = np.asarray(Image.open(img_filename).resize((224,224))) 26 | img_var = torch.autograd.Variable(torch.FloatTensor(img.transpose(2,0,1)[np.newaxis,:,:,:].astype(float))) 27 | 28 | conv_out = vgg16_c(img_var) 29 | print('VGG16 model:') 30 | print(vgg16_c) 31 | 32 | done = False 33 | while not done: 34 | layer = input('Layer to view: ') 35 | try: 36 | layer = int(layer) 37 | except ValueError: 38 | continue 39 | 40 | if layer < 0: 41 | sys.exit(0) 42 | activ_map = vgg16_c.feature_outputs[layer].data.numpy() 43 | vis_layer(vis_grid(activ_map.transpose(1,2,3,0))) 44 | -------------------------------------------------------------------------------- /utils/vis_utils.py: -------------------------------------------------------------------------------- 1 | from math import sqrt, ceil 2 | import numpy as np 3 | 4 | def visualize_grid(Xs, ubound=255.0, padding=1): 5 | """ 6 | Reshape a 4D tensor of image data to a grid for easy visualization. 7 | 8 | Inputs: 9 | - Xs: Data of shape (N, H, W, C) 10 | - ubound: Output grid will have values scaled to the range [0, ubound] 11 | - padding: The number of blank pixels between elements of the grid 12 | """ 13 | (N, H, W, C) = Xs.shape 14 | grid_size = int(ceil(sqrt(N))) 15 | grid_height = H * grid_size + padding * (grid_size - 1) 16 | grid_width = W * grid_size + padding * (grid_size - 1) 17 | grid = np.zeros((grid_height, grid_width, C)) 18 | next_idx = 0 19 | y0, y1 = 0, H 20 | for y in range(grid_size): 21 | x0, x1 = 0, W 22 | for x in range(grid_size): 23 | if next_idx < N: 24 | img = Xs[next_idx] 25 | low, high = np.min(img), np.max(img) 26 | grid[y0:y1, x0:x1] = ubound * (img - low) / (high - low) 27 | # grid[y0:y1, x0:x1] = Xs[next_idx] 28 | next_idx += 1 29 | x0 += W + padding 30 | x1 += W + padding 31 | y0 += H + padding 32 | y1 += H + padding 33 | # grid_max = np.max(grid) 34 | # grid_min = np.min(grid) 35 | # grid = ubound * (grid - grid_min) / (grid_max - grid_min) 36 | return grid 37 | 38 | def vis_grid(Xs): 39 | """ visualize a grid of images """ 40 | (N, H, W, C) = Xs.shape 41 | A = int(ceil(sqrt(N))) 42 | G = np.ones((A*H+A, A*W+A, C), Xs.dtype) 43 | G *= np.min(Xs) 44 | n = 0 45 | for y in range(A): 46 | for x in range(A): 47 | if n < N: 48 | G[y*H+y:(y+1)*H+y, x*W+x:(x+1)*W+x, :] = Xs[n,:,:,:] 49 | n += 1 50 | # normalize to [0,1] 51 | maxg = G.max() 52 | ming = G.min() 53 | G = (G - ming)/(maxg-ming) 54 | return G 55 | 56 | def vis_nn(rows): 57 | """ visualize array of arrays of images """ 58 | N = len(rows) 59 | D = len(rows[0]) 60 | H,W,C = rows[0][0].shape 61 | Xs = rows[0][0] 62 | G = np.ones((N*H+N, D*W+D, C), Xs.dtype) 63 | for y in range(N): 64 | for x in range(D): 65 | G[y*H+y:(y+1)*H+y, x*W+x:(x+1)*W+x, :] = rows[y][x] 66 | # normalize to [0,1] 67 | maxg = G.max() 68 | ming = G.min() 69 | G = (G - ming)/(maxg-ming) 70 | return G 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from models import * 4 | from utils import * 5 | import matplotlib.pyplot as plt 6 | from PIL import Image 7 | import numpy as np 8 | import sys 9 | 10 | def vis_layer(activ_map): 11 | plt.clf() 12 | plt.subplot(121) 13 | plt.imshow(activ_map[:,:,0], cmap='gray') 14 | 15 | def decon_img(layer_output): 16 | raw_img = layer_output.data.numpy()[0].transpose(1,2,0) 17 | img = (raw_img-raw_img.min())/(raw_img.max()-raw_img.min())*255 18 | img = img.astype(np.uint8) 19 | return img 20 | 21 | if __name__ == '__main__': 22 | if len(sys.argv) < 2: 23 | print('Usage: '+sys.argv[0]+' img_file') 24 | sys.exit(0) 25 | 26 | img_filename = sys.argv[1] 27 | 28 | n_classes = 1000 # using ImageNet pretrained weights 29 | 30 | vgg16_c = VGG16_conv(n_classes) 31 | conv_layer_indices = vgg16_c.get_conv_layer_indices() 32 | 33 | img = np.asarray(Image.open(img_filename).resize((224,224))) 34 | img_var = torch.autograd.Variable(torch.FloatTensor(img.transpose(2,0,1)[np.newaxis,:,:,:].astype(float))) 35 | 36 | conv_out = vgg16_c(img_var) 37 | print('VGG16 model:') 38 | print(vgg16_c) 39 | 40 | plt.ion() # remove blocking 41 | plt.figure(figsize=(10,5)) 42 | vgg16_d = VGG16_deconv() 43 | done = False 44 | while not done: 45 | layer = input('Layer to view (0-30, -1 to exit): ') 46 | try: 47 | layer = int(layer) 48 | except ValueError: 49 | continue 50 | 51 | if layer < 0: 52 | sys.exit(0) 53 | activ_map = vgg16_c.feature_outputs[layer].data.numpy() 54 | activ_map = activ_map.transpose(1,2,3,0) 55 | activ_map_grid = vis_grid(activ_map) 56 | vis_layer(activ_map_grid) 57 | 58 | # only transpose convolve from Conv2d or ReLU layers 59 | conv_layer = layer 60 | if conv_layer not in conv_layer_indices: 61 | conv_layer -= 1 62 | if conv_layer not in conv_layer_indices: 63 | continue 64 | 65 | n_maps = activ_map.shape[0] 66 | 67 | marker = None 68 | while True: 69 | choose_map = input('Select map? (y/[n]): ') == 'y' 70 | if marker != None: 71 | marker.pop(0).remove() 72 | 73 | if not choose_map: 74 | break 75 | 76 | _, map_x_dim, map_y_dim, _ = activ_map.shape 77 | map_img_x_dim, map_img_y_dim, _ = activ_map_grid.shape 78 | x_step = map_img_x_dim//(map_x_dim+1) 79 | 80 | print('Click on an activation map to continue') 81 | x_pos, y_pos = plt.ginput(1)[0] 82 | x_index = x_pos // (map_x_dim+1) 83 | y_index = y_pos // (map_y_dim+1) 84 | map_idx = int(x_step*y_index + x_index) 85 | 86 | if map_idx >= n_maps: 87 | print('Invalid map selected') 88 | continue 89 | 90 | decon = vgg16_d(vgg16_c.feature_outputs[layer][0][map_idx][None,None,:,:], conv_layer, map_idx, vgg16_c.pool_indices) 91 | img = decon_img(decon) 92 | plt.subplot(121) 93 | marker = plt.plot(x_pos, y_pos, marker='+', color='red') 94 | plt.subplot(122) 95 | plt.imshow(img) 96 | -------------------------------------------------------------------------------- /models/vgg16_deconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.models as models 3 | 4 | import numpy as np 5 | 6 | vgg16_pretrained = models.vgg16(pretrained=True) 7 | 8 | class VGG16_conv(torch.nn.Module): 9 | def __init__(self, n_classes): 10 | super(VGG16_conv, self).__init__() 11 | # VGG16 (using return_indices=True on the MaxPool2d layers) 12 | self.features = torch.nn.Sequential( 13 | # conv1 14 | torch.nn.Conv2d(3, 64, 3, padding=1), 15 | torch.nn.ReLU(), 16 | torch.nn.Conv2d(64, 64, 3, padding=1), 17 | torch.nn.ReLU(), 18 | torch.nn.MaxPool2d(2, stride=2, return_indices=True), 19 | # conv2 20 | torch.nn.Conv2d(64, 128, 3, padding=1), 21 | torch.nn.ReLU(), 22 | torch.nn.Conv2d(128, 128, 3, padding=1), 23 | torch.nn.ReLU(), 24 | torch.nn.MaxPool2d(2, stride=2, return_indices=True), 25 | # conv3 26 | torch.nn.Conv2d(128, 256, 3, padding=1), 27 | torch.nn.ReLU(), 28 | torch.nn.Conv2d(256, 256, 3, padding=1), 29 | torch.nn.ReLU(), 30 | torch.nn.Conv2d(256, 256, 3, padding=1), 31 | torch.nn.ReLU(), 32 | torch.nn.MaxPool2d(2, stride=2, return_indices=True), 33 | # conv4 34 | torch.nn.Conv2d(256, 512, 3, padding=1), 35 | torch.nn.ReLU(), 36 | torch.nn.Conv2d(512, 512, 3, padding=1), 37 | torch.nn.ReLU(), 38 | torch.nn.Conv2d(512, 512, 3, padding=1), 39 | torch.nn.ReLU(), 40 | torch.nn.MaxPool2d(2, stride=2, return_indices=True), 41 | # conv5 42 | torch.nn.Conv2d(512, 512, 3, padding=1), 43 | torch.nn.ReLU(), 44 | torch.nn.Conv2d(512, 512, 3, padding=1), 45 | torch.nn.ReLU(), 46 | torch.nn.Conv2d(512, 512, 3, padding=1), 47 | torch.nn.ReLU(), 48 | torch.nn.MaxPool2d(2, stride=2, return_indices=True)) 49 | self.feature_outputs = [0]*len(self.features) 50 | self.pool_indices = dict() 51 | 52 | self.classifier = torch.nn.Sequential( 53 | torch.nn.Linear(512*7*7, 4096), # 224x244 image pooled down to 7x7 from features 54 | torch.nn.ReLU(), 55 | torch.nn.Dropout(), 56 | torch.nn.Linear(4096, 4096), 57 | torch.nn.ReLU(), 58 | torch.nn.Dropout(), 59 | torch.nn.Linear(4096, n_classes)) 60 | 61 | self._initialize_weights() 62 | 63 | 64 | def _initialize_weights(self): 65 | # initializing weights using ImageNet-trained model from PyTorch 66 | for i, layer in enumerate(vgg16_pretrained.features): 67 | if isinstance(layer, torch.nn.Conv2d): 68 | self.features[i].weight.data = layer.weight.data 69 | self.features[i].bias.data = layer.bias.data 70 | 71 | def get_conv_layer_indices(self): 72 | return [0, 2, 5, 7, 10, 12, 14, 17, 19, 21, 24, 26, 28] 73 | 74 | def forward_features(self, x): 75 | output = x 76 | for i, layer in enumerate(self.features): 77 | if isinstance(layer, torch.nn.MaxPool2d): 78 | output, indices = layer(output) 79 | self.feature_outputs[i] = output 80 | self.pool_indices[i] = indices 81 | else: 82 | output = layer(output) 83 | self.feature_outputs[i] = output 84 | return output 85 | 86 | def forward(self, x): 87 | output = self.forward_features(x) 88 | output = output.view(output.size()[0], -1) 89 | output = self.classifier(output) 90 | return output 91 | 92 | class VGG16_deconv(torch.nn.Module): 93 | def __init__(self): 94 | super(VGG16_deconv, self).__init__() 95 | self.conv2DeconvIdx = {0:17, 2:16, 5:14, 7:13, 10:11, 12:10, 14:9, 17:7, 19:6, 21:5, 24:3, 26:2, 28:1} 96 | self.conv2DeconvBiasIdx = {0:16, 2:14, 5:13, 7:11, 10:10, 12:9, 14:7, 17:6, 19:5, 21:3, 24:2, 26:1, 28:0} 97 | self.unpool2PoolIdx = {15:4, 12:9, 8:16, 4:23, 0:30} 98 | 99 | self.deconv_features = torch.nn.Sequential( 100 | torch.nn.MaxUnpool2d(2, stride=2), 101 | torch.nn.ConvTranspose2d(512, 512, 3, padding=1), 102 | torch.nn.ConvTranspose2d(512, 512, 3, padding=1), 103 | torch.nn.ConvTranspose2d(512, 512, 3, padding=1), 104 | torch.nn.MaxUnpool2d(2, stride=2), 105 | torch.nn.ConvTranspose2d(512, 512, 3, padding=1), 106 | torch.nn.ConvTranspose2d(512, 512, 3, padding=1), 107 | torch.nn.ConvTranspose2d(512, 256, 3, padding=1), 108 | torch.nn.MaxUnpool2d(2, stride=2), 109 | torch.nn.ConvTranspose2d(256, 256, 3, padding=1), 110 | torch.nn.ConvTranspose2d(256, 256, 3, padding=1), 111 | torch.nn.ConvTranspose2d(256, 128, 3, padding=1), 112 | torch.nn.MaxUnpool2d(2, stride=2), 113 | torch.nn.ConvTranspose2d(128, 128, 3, padding=1), 114 | torch.nn.ConvTranspose2d(128, 64, 3, padding=1), 115 | torch.nn.MaxUnpool2d(2, stride=2), 116 | torch.nn.ConvTranspose2d(64, 64, 3, padding=1), 117 | torch.nn.ConvTranspose2d(64, 3, 3, padding=1)) 118 | 119 | # not the most elegant, given that I don't need the MaxUnpools here 120 | self.deconv_first_layers = torch.nn.ModuleList([ 121 | torch.nn.MaxUnpool2d(2, stride=2), 122 | torch.nn.ConvTranspose2d(1, 512, 3, padding=1), 123 | torch.nn.ConvTranspose2d(1, 512, 3, padding=1), 124 | torch.nn.ConvTranspose2d(1, 512, 3, padding=1), 125 | torch.nn.MaxUnpool2d(2, stride=2), 126 | torch.nn.ConvTranspose2d(1, 512, 3, padding=1), 127 | torch.nn.ConvTranspose2d(1, 512, 3, padding=1), 128 | torch.nn.ConvTranspose2d(1, 256, 3, padding=1), 129 | torch.nn.MaxUnpool2d(2, stride=2), 130 | torch.nn.ConvTranspose2d(1, 256, 3, padding=1), 131 | torch.nn.ConvTranspose2d(1, 256, 3, padding=1), 132 | torch.nn.ConvTranspose2d(1, 128, 3, padding=1), 133 | torch.nn.MaxUnpool2d(2, stride=2), 134 | torch.nn.ConvTranspose2d(1, 128, 3, padding=1), 135 | torch.nn.ConvTranspose2d(1, 64, 3, padding=1), 136 | torch.nn.MaxUnpool2d(2, stride=2), 137 | torch.nn.ConvTranspose2d(1, 64, 3, padding=1), 138 | torch.nn.ConvTranspose2d(1, 3, 3, padding=1) ]) 139 | 140 | self._initialize_weights() 141 | 142 | def _initialize_weights(self): 143 | # initializing weights using ImageNet-trained model from PyTorch 144 | for i, layer in enumerate(vgg16_pretrained.features): 145 | if isinstance(layer, torch.nn.Conv2d): 146 | self.deconv_features[self.conv2DeconvIdx[i]].weight.data = layer.weight.data 147 | biasIdx = self.conv2DeconvBiasIdx[i] 148 | if biasIdx > 0: 149 | self.deconv_features[biasIdx].bias.data = layer.bias.data 150 | 151 | 152 | def forward(self, x, layer_number, map_number, pool_indices): 153 | start_idx = self.conv2DeconvIdx[layer_number] 154 | if not isinstance(self.deconv_first_layers[start_idx], torch.nn.ConvTranspose2d): 155 | raise ValueError('Layer '+str(layer_number)+' is not of type Conv2d') 156 | # set weight and bias 157 | self.deconv_first_layers[start_idx].weight.data = self.deconv_features[start_idx].weight[map_number].data[None, :, :, :] 158 | self.deconv_first_layers[start_idx].bias.data = self.deconv_features[start_idx].bias.data 159 | # first layer will be single channeled, since we're picking a particular filter 160 | output = self.deconv_first_layers[start_idx](x) 161 | 162 | # transpose conv through the rest of the network 163 | for i in range(start_idx+1, len(self.deconv_features)): 164 | if isinstance(self.deconv_features[i], torch.nn.MaxUnpool2d): 165 | output = self.deconv_features[i](output, pool_indices[self.unpool2PoolIdx[i]]) 166 | else: 167 | output = self.deconv_features[i](output) 168 | return output 169 | --------------------------------------------------------------------------------