├── README.md ├── criterion.py ├── data.py ├── imm_model.py ├── main.py ├── tps_sampler.py ├── utils.py └── vgg.py /README.md: -------------------------------------------------------------------------------- 1 | # imm-pytorch 2 | PyTorch implementation of ["Unsupervised Learning of Object Landmarks through Conditional Image Generation"](http://www.robots.ox.ac.uk/~vgg/research/unsupervised_landmarks/), Tomas Jakab*, Ankush Gupta*, Hakan Bilen, Andrea Vedaldi, Advances in Neural Information Processing Systems (NeurIPS) 2018. 3 | 4 | ## Requirements: 5 | * pytorch >= 1.2, torchvision >= 0.4 6 | * visdom, torchnet 7 | -------------------------------------------------------------------------------- /criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from vgg import Vgg16 6 | 7 | class LossFunc(nn.Module): 8 | """ 9 | Loss function for landmark prediction 10 | """ 11 | def __init__(self, loss_type='perceptual'): 12 | super(LossFunc, self).__init__() 13 | self.loss_type = loss_type 14 | self.ema = EMA() 15 | self.vggnet = Vgg16() if loss_type == 'perceptual' else None 16 | self._init_ema() 17 | 18 | def forward(self, future_im_pred, future_im, mask=None): 19 | loss = self._loss(future_im_pred, future_im, mask=mask) 20 | return loss 21 | 22 | def _loss(self, future_im_pred, future_im, mask=None): 23 | "loss function" 24 | vgg_losses = [] 25 | w_reconstruct = 1. / 255. 26 | if self.loss_type == 'perceptual': 27 | w_reconstruct = 1. 28 | reconstruction_loss, vgg_losses = self._colorization_reconstruction_loss( 29 | future_im, future_im_pred, mask=mask) 30 | elif self.loss_type == 'l2': 31 | if mask is not None: 32 | l = F.mse_loss(future_im_pred, future_im, reduction='none') 33 | reconstruction_loss = torch.mean(self._loss_mask(l, mask)) 34 | else: 35 | reconstruction_loss = F.mse_loss(future_im_pred, future_im) 36 | else: 37 | raise ValueError('Incorrect loss-type') 38 | 39 | loss = w_reconstruct * reconstruction_loss 40 | 41 | return loss, vgg_losses 42 | 43 | def _loss_mask(self, imap, mask): 44 | mask = F.interpolate(mask, imap.shape[-2:]) 45 | return imap * mask 46 | 47 | def _colorization_reconstruction_loss( 48 | self, gt_image, pred_image, mask=None): 49 | "perceptual loss" 50 | names = list(self.ema.avgs) 51 | 52 | #get features map from vgg 53 | feats_gt = self.vggnet(gt_image) 54 | feats_pred = self.vggnet(pred_image) 55 | 56 | feat_gt, feat_pred = [gt_image], [pred_image] 57 | for k in names[1:]: #no need input 58 | feat_gt.append(getattr(feats_gt, k)) 59 | feat_pred.append(getattr(feats_pred, k)) 60 | 61 | losses = [] 62 | for k, v in enumerate(names): 63 | l = F.mse_loss(feat_pred[k], feat_gt[k], reduction='none') 64 | if mask is not None: 65 | l = self._loss_mask(l, mask) 66 | #update EMA 67 | # wl = self.exp_moving_avg( 68 | # torch.mean(l).item(), name=v, init_val=self.ema[v]) 69 | l /= self.ema[v] 70 | l = torch.mean(l) 71 | losses.append(l) 72 | vgg_losses = [x.item() for x in losses] #for display 73 | loss = torch.stack(losses).sum() 74 | return loss, vgg_losses 75 | 76 | # def exp_moving_avg(self, x, name='x', init_val=0.): 77 | # "exponential moving average" 78 | # with torch.no_grad(): 79 | # if not self.training: 80 | # return init_val 81 | # x_new = self.ema.update(name, x, init_val) 82 | # return x_new 83 | 84 | def _init_ema(self, ws=[50., 40., 6., 3., 3., 1.], 85 | names=['input', 'conv1_2', 'conv2_2', 'conv3_2', 'conv4_2', 'conv5_2']): 86 | "init weight for perceptual loss/EMA" 87 | for k, v in range(names): 88 | self.ema.update(v, ws[k], 0.) 89 | 90 | 91 | class EMA(object): 92 | """Exponential running average 93 | """ 94 | def __init__(self, decay=0.99): 95 | self.rho = decay 96 | self.avgs = {} 97 | 98 | def register(self, name, val): 99 | "add val to shadow by key=name" 100 | self.avgs.update({name: val}) 101 | 102 | def get(self, name): 103 | "get value with key=name" 104 | return self.avgs[name] 105 | 106 | def update(self, name, x, init_val=0.): 107 | "update new value for variable x" 108 | if name not in self.avgs.keys(): 109 | self.register(name, init_val) 110 | return init_val 111 | 112 | x_avg = self.get(name) 113 | w_update = 1. - self.rho 114 | x_new = x_avg + w_update * (x - x_avg) 115 | self.register(name, x_new) 116 | return x_new 117 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Load CelebA dataset. 3 | Perform proc_img_pair (crop,resize) and tps_warping 4 | """ 5 | import os 6 | from os import path 7 | import numpy as np 8 | from PIL import Image 9 | 10 | import torch 11 | from torch.nn import functional as F 12 | from torch.utils import data 13 | 14 | from torchvision import transforms as T 15 | from tps_sampler import TPSRandomSampler 16 | 17 | 18 | #------------------------------------------------------------------------------ 19 | #Initial Dataset 20 | #------------------------------------------------------------------------------ 21 | 22 | def load_dataset(data_root, dataset, subset): 23 | image_dir = os.path.join(data_root, 'celeba', 'img_align_celeba') 24 | 25 | with open(os.path.join(data_root, 'celeba', 'list_landmarks_align_celeba.txt'), 'r') as f: 26 | lines = f.read().splitlines() 27 | # skip header 28 | lines = lines[2:] 29 | image_files = [] 30 | keypoints = [] 31 | for line in lines: 32 | image_files.append(line.split()[0]) 33 | keypoints.append([int(x) for x in line.split()[1:]]) 34 | keypoints = np.array(keypoints, dtype=np.float32) 35 | assert image_files[0] == '000001.jpg' 36 | 37 | images_set = np.zeros(len(image_files), dtype=np.int32) 38 | 39 | if dataset == 'celeba': 40 | with open(os.path.join(data_root, 'celeba', 'list_eval_partition.txt'), 'r') as f: 41 | celeba_set = [int(line.split()[1]) for line in f.readlines()] 42 | images_set[:] = celeba_set 43 | images_set += 1 44 | 45 | if dataset == 'celeba': 46 | if subset == 'train': 47 | label = 1 48 | elif subset == 'val': 49 | label = 2 50 | else: 51 | raise ValueError( 52 | 'subset = %s for celeba dataset not recognized.' % subset) 53 | 54 | image_files = np.array(image_files) 55 | images = image_files[images_set == label] 56 | keypoints = keypoints[images_set == label] 57 | 58 | # convert keypoints to 59 | # [[lefteye_x, lefteye_y], [righteye_x, righteye_y], [nose_x, nose_y], 60 | # [leftmouth_x, leftmouth_y], [rightmouth_x, rightmouth_y]] 61 | keypoints = np.reshape(keypoints, [-1, 5, 2]) 62 | 63 | return image_dir, images, keypoints 64 | 65 | 66 | class DatasetFromFolder(data.Dataset): 67 | """Manipulate data from folder 68 | """ 69 | def __init__(self, data_root, dataset, subset, transform): 70 | super(DatasetFromFolder, self).__init__() 71 | self.transform = transform 72 | self.image_dir, self.image_name, self.keypoints = load_dataset(data_root, dataset, subset) 73 | # len = image_name.shape[0] // 5 #use 20% of data 74 | # self.image_name = image_name[:len] 75 | # self.keypoints = keypoints[:len] 76 | 77 | def __getitem__(self, idx): 78 | img = Image.open(path.join(self.image_dir, self.image_name[idx])) 79 | img = self.transform(img) 80 | keypts = torch.from_numpy(self.keypoints[idx]) 81 | return img, keypts 82 | 83 | def __len__(self): 84 | return self.image_name.shape[0] 85 | 86 | 87 | def transforms(size=[128, 128]): 88 | return T.Compose([ 89 | # T.Resize(size), 90 | T.ToTensor(), 91 | ]) 92 | 93 | 94 | def get_dataset(data_root, dataset, subset): 95 | return DatasetFromFolder(data_root, dataset, subset, transform=transforms()) 96 | 97 | #------------------------------------------------------------------------------ 98 | #Get method (used for DataLoader) 99 | #------------------------------------------------------------------------------ 100 | 101 | class BatchTransform(object): 102 | """ Preprocessing batch of pytorch tensors 103 | """ 104 | def __init__(self, image_size=[128, 128], \ 105 | rotsd=[0.0, 5.0], scalesd=[0.0, 0.1], \ 106 | transsd=[0.1, 0.1], warpsd=[0.001, 0.005, 0.001, 0.01]): 107 | self.image_size = image_size 108 | self.target_sampler, self.source_sampler = \ 109 | self._create_tps(image_size, rotsd, scalesd, transsd, warpsd) 110 | 111 | def exe(self, image, landmarks=None): 112 | #call _proc_im_pair 113 | batch = self._proc_im_pair(image, landmarks=landmarks) 114 | 115 | #call _apply_tps 116 | image, future_image, future_mask = self._apply_tps(batch['image'], batch['mask']) 117 | 118 | batch.update({'image': image, 'future_image': future_image, 'mask': future_mask}) 119 | 120 | return batch 121 | 122 | #TPS 123 | def _create_tps(self, image_size, rotsd, scalesd, transsd, warpsd): 124 | """create tps sampler for target and source images""" 125 | target_sampler = TPSRandomSampler( 126 | image_size[1], image_size[0], rotsd=rotsd[0], 127 | scalesd=scalesd[0], transsd=transsd[0], warpsd=warpsd[:2], pad=False) 128 | source_sampler = TPSRandomSampler( 129 | image_size[1], image_size[0], rotsd=rotsd[1], 130 | scalesd=scalesd[1], transsd=transsd[1], warpsd=warpsd[2:], pad=False) 131 | return target_sampler, source_sampler 132 | 133 | def _apply_tps(self, image, mask): 134 | #expand mask to match batch size and n_dim 135 | mask = mask[None, None].expand(image.shape[0], -1, -1, -1) 136 | image = torch.cat([mask, image], dim=1) 137 | # shape = image.shape 138 | 139 | future_image = self.target_sampler.forward(image) 140 | image = self.source_sampler.forward(future_image) 141 | 142 | #reshape -- no need 143 | # image = image.reshape(shape) 144 | # future_image = future_image.reshape(shape) 145 | 146 | future_mask = future_image[:, 0:1, ...] 147 | future_image = future_image[:, 1:, ...] 148 | 149 | mask = image[:, 0:1, ...] 150 | image = image[:, 1:, ...] 151 | 152 | return image, future_image, future_mask 153 | 154 | #Process image pair 155 | def _proc_im_pair(self, image, landmarks=None): 156 | m, M = image.min(), image.max() 157 | 158 | height, width = self.image_size[:2] 159 | 160 | #crop image 161 | crop_percent = 0.8 162 | final_sz = self.image_size[0] 163 | resize_sz = np.round(final_sz / crop_percent).astype(np.int32) 164 | margin = np.round((resize_sz - final_sz) / 2.0).astype(np.int32) 165 | 166 | if landmarks is not None: 167 | original_sz = image.shape[-2:] 168 | landmarks = self._resize_points( 169 | landmarks, original_sz, [resize_sz, resize_sz]) 170 | landmarks -= margin 171 | 172 | image = F.interpolate(image, \ 173 | size=[resize_sz, resize_sz], mode='bilinear', align_corners=True) 174 | 175 | #take center crop 176 | image = image[..., margin:margin + final_sz, margin:margin + final_sz] 177 | image = torch.clamp(image, m, M) 178 | 179 | mask = self._get_smooth_mask(height, width, 10, 20) #shape HxW 180 | mask = mask.to(image.device) 181 | 182 | future_landmarks = landmarks 183 | # future_image = image.clone() 184 | 185 | batch = {} 186 | batch.update({'image': image, 'mask': mask, \ 187 | 'landmarks': landmarks, 'future_landmarks': future_landmarks}) 188 | 189 | return batch 190 | 191 | def _resize_points(self, points, size, new_size): 192 | dtype = points.dtype 193 | device = points.device 194 | 195 | size = torch.tensor(size).to(device).float() 196 | new_size = torch.tensor(new_size).to(device).float() 197 | 198 | ratio = new_size / size 199 | points = (points.float() * ratio[None]).type(dtype) 200 | return points 201 | 202 | def _get_smooth_step(self, n, b): 203 | x = torch.linspace(-1, 1, n) 204 | y = 0.5 + 0.5 * torch.tanh(x / b) 205 | return y 206 | 207 | def _get_smooth_mask(self, h, w, margin, step): 208 | b = 0.4 209 | step_up = self._get_smooth_step(step, b) 210 | step_down = self._get_smooth_step(step, -b) 211 | 212 | def _create_strip(size): 213 | return torch.cat( 214 | [torch.zeros(margin), 215 | step_up, 216 | torch.ones(size - 2 * margin - 2 * step), 217 | step_down, 218 | torch.zeros(margin)], dim=0) 219 | 220 | mask_x = _create_strip(w) 221 | mask_y = _create_strip(h) 222 | mask2d = mask_y[:, None] * mask_x[None] 223 | return mask2d 224 | -------------------------------------------------------------------------------- /imm_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of 3 | Unsupervised Learning of Object Landmarks through Conditional Image Generation 4 | http://www.robots.ox.ac.uk/~vgg/research/unsupervised_landmarks/unsupervised_landmarks.pdf 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import init 10 | from torch.nn import functional as F 11 | 12 | 13 | def _init_weight(modules): 14 | for m in modules: 15 | if isinstance(m, (nn.Conv2d, nn.Conv3d)): 16 | init.kaiming_normal_(m.weight) 17 | if m.bias is not None: 18 | init.constant_(m.bias, 0) 19 | 20 | 21 | def get_coord(x, other_axis, axis_size): 22 | "get x-y coordinates" 23 | g_c_prob = torch.mean(x, dim=other_axis) # B,NMAP,W 24 | g_c_prob = F.softmax(g_c_prob, dim=2) # B,NMAP,W 25 | coord_pt = torch.linspace(-1.0, 1.0, axis_size).to(x.device) # W 26 | coord_pt = coord_pt.view(1, 1, axis_size) # 1,1,W 27 | g_c = torch.sum(g_c_prob * coord_pt, dim=2) # B,NMAP 28 | return g_c, g_c_prob 29 | 30 | 31 | def get_gaussian_maps(mu, shape_hw, inv_std, mode='rot'): 32 | """ 33 | Generates [B,NMAPS,SHAPE_H,SHAPE_W] tensor of 2D gaussians, 34 | given the gaussian centers: MU [B, NMAPS, 2] tensor. 35 | 36 | STD: is the fixed standard dev. 37 | """ 38 | mu_y, mu_x = mu[:, :, 0:1], mu[:, :, 1:2] 39 | 40 | y = torch.linspace(-1.0, 1.0, shape_hw[0]).to(mu.device) 41 | 42 | x = torch.linspace(-1.0, 1.0, shape_hw[1]).to(mu.device) 43 | 44 | if mode in ['rot', 'flat']: 45 | mu_y, mu_x = torch.unsqueeze(mu_y, dim=-1), torch.unsqueeze(mu_x, dim=-1) 46 | 47 | y = y.view(1, 1, shape_hw[0], 1) 48 | x = x.view(1, 1, 1, shape_hw[1]) 49 | 50 | g_y = (y - mu_y)**2 51 | g_x = (x - mu_x)**2 52 | dist = (g_y + g_x) * inv_std**2 53 | 54 | if mode == 'rot': 55 | g_yx = torch.exp(-dist) 56 | else: 57 | g_yx = torch.exp(-torch.pow(dist + 1e-5, 0.25)) 58 | 59 | elif mode == 'ankush': 60 | y = y.view(1, 1, shape_hw[0]) 61 | x = x.view(1, 1, shape_hw[1]) 62 | 63 | g_y = torch.exp(-torch.sqrt(1e-4 + torch.abs((mu_y - y) * inv_std))) 64 | g_x = torch.exp(-torch.sqrt(1e-4 + torch.abs((mu_x - x) * inv_std))) 65 | 66 | g_y = torch.unsqueeze(g_y, dim=3) 67 | g_x = torch.unsqueeze(g_x, dim=2) 68 | g_yx = torch.matmul(g_y, g_x) # [B, NMAPS, H, W] 69 | 70 | else: 71 | raise ValueError('Unknown mode: ' + str(mode)) 72 | 73 | return g_yx 74 | 75 | 76 | def conv_block(in_channels, out_channels, kernel_size, stride, dilation=1, bias=True, batch_norm=True, layer_norm=False, activation='ReLU'): 77 | padding = (dilation*(kernel_size-1)+2-stride)//2 78 | seq_modules = nn.Sequential() 79 | seq_modules.add_module('conv', \ 80 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)) 81 | if batch_norm: 82 | seq_modules.add_module('norm', nn.BatchNorm2d(out_channels)) 83 | elif layer_norm: 84 | seq_modules.add_module('norm', LayerNorm()) 85 | if activation is not None: 86 | seq_modules.add_module('relu', getattr(nn, activation)(inplace=True)) 87 | return seq_modules 88 | 89 | 90 | class LayerNorm(nn.Module): 91 | "Cast layernorm function to class" 92 | def __init__(self): 93 | super(LayerNorm, self).__init__() 94 | 95 | def forward(self, x): 96 | return F.layer_norm(x, x.shape[1:]) 97 | 98 | 99 | class Encoder(nn.Module): 100 | """Phi Net: 101 | input: target image -- distorted image 102 | output: confidence maps""" 103 | def __init__(self, in_channels, n_filters, batch_norm=True, layer_norm=False): 104 | super(Encoder, self).__init__() 105 | self.block_layers = nn.ModuleList() 106 | conv1 = conv_block(in_channels, n_filters, kernel_size=7, stride=1, batch_norm=batch_norm, layer_norm=layer_norm) 107 | conv2 = conv_block(n_filters, n_filters, kernel_size=3, stride=1, batch_norm=batch_norm, layer_norm=layer_norm) 108 | self.block_layers.append(conv1) 109 | self.block_layers.append(conv2) 110 | 111 | for _ in range(3): 112 | filters = n_filters*2 113 | conv_i0 = conv_block(n_filters, filters, kernel_size=3, stride=2, batch_norm=batch_norm, layer_norm=layer_norm) 114 | conv_i1 = conv_block(filters, filters, kernel_size=3, stride=1, batch_norm=batch_norm, layer_norm=layer_norm) 115 | self.block_layers.append(conv_i0) 116 | self.block_layers.append(conv_i1) 117 | n_filters = filters 118 | 119 | def forward(self, x): 120 | block_features = [] 121 | for block in self.block_layers: 122 | x = block(x) 123 | block_features.append(x) 124 | return block_features 125 | 126 | 127 | class ImageEncoder(nn.Module): 128 | """Image_Encoder: 129 | input: source image 130 | ouput: features 131 | """ 132 | def __init__(self, in_channels, n_filters): 133 | super(ImageEncoder, self).__init__() 134 | self.image_encoder = Encoder(in_channels, n_filters) 135 | 136 | def forward(self, x): 137 | block_features = self.image_encoder(x) 138 | block_features = [x] + block_features 139 | return block_features 140 | 141 | class PoseEncoder(nn.Module): 142 | """Pose_Encoder: 143 | input: target image (transformed image) 144 | ouput: gaussian maps of landmarks 145 | """ 146 | def __init__(self, in_channels, n_filters, n_maps, map_sizes, gauss_std=0.1, gauss_mode='ankush'): 147 | super(PoseEncoder, self).__init__() 148 | self.map_sizes = map_sizes 149 | self.gauss_std = gauss_std 150 | self.gauss_mode = gauss_mode 151 | 152 | self.image_encoder = Encoder(in_channels, n_filters) 153 | self.conv = conv_block(n_filters*8, n_maps, kernel_size=1, stride=1, batch_norm=False, activation=None) 154 | 155 | def forward(self, x): 156 | block_features = self.image_encoder(x) 157 | x = block_features[-1] 158 | 159 | xshape = x.shape 160 | x = self.conv(x) 161 | 162 | gauss_y, gauss_y_prob = get_coord(x, 3, xshape[2]) # B,NMAP 163 | gauss_x, gauss_x_prob = get_coord(x, 2, xshape[3]) # B,NMAP 164 | gauss_mu = torch.stack([gauss_y, gauss_x], dim=2) 165 | 166 | gauss_xy = [] 167 | for shape_hw in self.map_sizes: 168 | gauss_xy_hw = \ 169 | get_gaussian_maps(gauss_mu, [shape_hw, shape_hw], 1.0 / self.gauss_std, mode=self.gauss_mode) 170 | gauss_xy.append(gauss_xy_hw) 171 | 172 | return gauss_mu, gauss_xy 173 | 174 | class Renderer(nn.Module): 175 | """Renderer: 176 | input: image encoded features + gauss maps 177 | output: reconstructed image 178 | """ 179 | def __init__(self, map_size, map_filters, n_filters, n_final_out, n_final_res, batch_norm=True): 180 | super(Renderer, self).__init__() 181 | self.seq_renderers = nn.Sequential() 182 | i = 1 183 | while map_size[0] <= n_final_res: 184 | self.seq_renderers.add_module('conv_render{}'.format(i), \ 185 | conv_block(map_filters, n_filters, kernel_size=3, stride=1, batch_norm=batch_norm)) 186 | 187 | if map_size[0] == n_final_res: 188 | self.seq_renderers.add_module('conv_render_final', \ 189 | conv_block(n_filters, n_final_out, kernel_size=3, stride=1, batch_norm=False, activation=None)) 190 | break 191 | else: 192 | self.seq_renderers.add_module('conv_render{}'.format(i+1), \ 193 | conv_block(n_filters, n_filters, kernel_size=3, stride=1, batch_norm=batch_norm)) 194 | #upsample 195 | map_size = [2 * s for s in map_size] 196 | self.seq_renderers.add_module('upsampler_render{}'.format(i+1), nn.Upsample(size=map_size)) 197 | 198 | map_filters = n_filters 199 | if n_filters >= 8: 200 | n_filters //= 2 201 | i += 2 202 | 203 | def forward(self, x): 204 | x = self.seq_renderers(x) 205 | x = torch.sigmoid(x) 206 | return x 207 | 208 | class AssembleNet(nn.Module): 209 | """ 210 | Assembling PhiNet and PsiNet 211 | """ 212 | def __init__(self, in_channels=3, n_filters=32, n_maps=10, gauss_std=0.1, \ 213 | renderer_stride=2, n_render_filters=32, n_final_out=3, \ 214 | max_size=[128, 128], min_size=[16, 16]): 215 | super(AssembleNet, self).__init__() 216 | self.gauss_std = gauss_std 217 | self.render_sizes = self._create_render_sizes(max_size[0], min_size[0], renderer_stride) 218 | self.map_filters = n_filters*8 + n_maps 219 | self.image_encoder = ImageEncoder(in_channels, n_filters) 220 | self.pose_encoder = PoseEncoder(in_channels, n_filters, n_maps, map_sizes=self.render_sizes) 221 | self.renderer = Renderer(min_size, self.map_filters, n_render_filters, n_final_out, n_final_res=max_size[0]) 222 | 223 | _init_weight(self.modules()) 224 | 225 | def _create_render_sizes(self, max_size, min_size, renderer_stride): 226 | render_sizes = [] 227 | size = max_size 228 | while size >= min_size: 229 | render_sizes.append(size) 230 | size = max_size // renderer_stride 231 | max_size = size 232 | return render_sizes 233 | 234 | def forward(self, im, future_im): 235 | embeddings = self.image_encoder(im) #features: [x, out1, ...] with sizes decrease by 2 236 | gauss_pt, pose_embeddings = self.pose_encoder(future_im) #gauss_mu, gauss_xy -- gauss_xy = [map1, map2, ..] 237 | 238 | #cat last embeddings: 239 | joint_embedding = torch.cat((embeddings[-1], pose_embeddings[-1]), dim=1) 240 | future_im_pred = self.renderer(joint_embedding) 241 | 242 | return future_im_pred, gauss_pt, pose_embeddings[-1] 243 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train Evaluate and Test Model 3 | """ 4 | 5 | import os, argparse, gc, glob, time, pickle 6 | from os import path 7 | 8 | import torch 9 | from torch import nn, optim, cuda 10 | from torch.utils.data import DataLoader 11 | from visdom import Visdom 12 | from torchnet import meter 13 | from torchnet.logger import VisdomPlotLogger, VisdomSaver 14 | 15 | import data 16 | import utils 17 | from imm_model import AssembleNet 18 | from criterion import LossFunc 19 | 20 | 21 | PARSER = argparse.ArgumentParser(description='Option for Conditional Image Generating') 22 | #------------------------------------------------------------------- data-option 23 | PARSER.add_argument('--data_root', type=str, 24 | default='../data/', 25 | help='location of root dir') 26 | PARSER.add_argument('--dataset', type=str, 27 | default='celeba', 28 | help='location of dataset') 29 | PARSER.add_argument('--testset', type=str, 30 | default='../data/', 31 | help='location of test data') 32 | PARSER.add_argument('--nthreads', type=int, default=8, 33 | help='number of threads for data loader') 34 | PARSER.add_argument('--batch_size', type=int, default=64, metavar='N', 35 | help='train batch size') 36 | PARSER.add_argument('--val_batch_size', type=int, default=64, metavar='N', 37 | help='val batch size') 38 | #------------------------------------------------------------------ model-option 39 | PARSER.add_argument('--pretrained_model', type=str, default='', 40 | help='pretrain model location') 41 | PARSER.add_argument('--loss_type', type=str, default='perceptual', 42 | help='loss type for criterion: perceptual | l2') 43 | #--------------------------------------------------------------- training-option 44 | PARSER.add_argument('--seed', type=int, default=1234, 45 | help='random seed') 46 | PARSER.add_argument('--gpus', type=list, default=[3], 47 | help='list of GPUs in use') 48 | #optimizer-option 49 | PARSER.add_argument('--optim_algor', type=str, default='Adam', 50 | help='optimization algorithm') 51 | PARSER.add_argument('--lr', type=float, default=1e-3, 52 | help='learning rate') 53 | PARSER.add_argument('--weight_decay', type=float, default=1e-8, 54 | help='weight_decay rate') 55 | #saving-option 56 | PARSER.add_argument('--epochs', type=int, default=5000, 57 | help='number of epochs') 58 | PARSER.add_argument('--checkpoint_interval', type=int, default=1, 59 | help='epoch interval of saving checkpoint') 60 | PARSER.add_argument('--save_path', type=str, default='checkpoint', 61 | help='directory for saving checkpoint') 62 | PARSER.add_argument('--resume_checkpoint', type=str, default='', 63 | help='location of saved checkpoint') 64 | #only prediction-option 65 | PARSER.add_argument('--trained_model', type=str, default='', 66 | help='location of trained checkpoint') 67 | 68 | ARGS = PARSER.parse_args() 69 | 70 | DEVICE = torch.device('cuda:{}'.format(ARGS.gpus[0]) if len(ARGS.gpus) > 0 else 'cpu') 71 | # Set the random seed manually for reproducibility. 72 | torch.manual_seed(ARGS.seed) 73 | if DEVICE.type == 'cuda': 74 | cuda.set_device(ARGS.gpus[0]) 75 | cuda.manual_seed(ARGS.seed) 76 | 77 | 78 | def _make_model(opt): 79 | "create model, criterion" 80 | #if use pretrained model (load pretrained weight) 81 | neuralnet = AssembleNet() 82 | 83 | if opt.pretrained_model: 84 | print("Loading pretrained model {} \n".format(opt.pretrained_model)) 85 | pretrained_state = torch.load(opt.pretrained_model, \ 86 | map_location=lambda storage, loc: storage, \ 87 | pickle_module=pickle)['modelstate'] 88 | neuralnet.load_state_dict(pretrained_state) 89 | 90 | model_parameters = filter(lambda p: p.requires_grad, neuralnet.parameters()) 91 | n_params = sum([p.numel() for p in model_parameters]) 92 | print('number of params', n_params) 93 | 94 | return neuralnet 95 | 96 | 97 | def _make_optimizer(opt, neuralnet, param_groups=None): 98 | parameters = filter(lambda p: p.requires_grad, neuralnet.parameters()) 99 | 100 | if param_groups is not None: 101 | lr = param_groups[0]['lr'] 102 | weight_decay = param_groups[0]['weight_decay'] 103 | else: 104 | lr = opt.lr 105 | weight_decay = opt.weight_decay 106 | 107 | optimizer = getattr(optim, opt.optim_algor)( 108 | parameters, lr=lr, weight_decay=weight_decay) 109 | 110 | return optimizer 111 | 112 | 113 | def _make_data(opt, subset='train', shuffle=True): 114 | #get data 115 | split = data.get_dataset(opt.data_root, opt.dataset, subset=subset) 116 | 117 | #dataloader 118 | loader = DataLoader(dataset=split, \ 119 | num_workers=opt.nthreads, batch_size=opt.batch_size, shuffle=shuffle) 120 | 121 | return loader 122 | 123 | 124 | class Main: 125 | """Wrap training and evaluating processes 126 | """ 127 | def __init__(self, opt): 128 | self.opt = opt 129 | os.makedirs(self.opt.save_path, exist_ok=True) 130 | 131 | self.neuralnet = _make_model(opt) 132 | self.optimizer = _make_optimizer(opt, self.neuralnet) 133 | self.train_loader = _make_data(opt) 134 | self.val_loader = _make_data(opt, subset='val', shuffle=False) 135 | 136 | #loss function 137 | self.criterion = LossFunc(opt.loss_type) 138 | 139 | #batch data transform 140 | self.batch_transform = data.BatchTransform() 141 | 142 | #meter 143 | self.loss_meter = meter.AverageValueMeter() 144 | 145 | 146 | #=========================================================================== 147 | # Training and Evaluating 148 | #=========================================================================== 149 | 150 | def _resetmeter(self): 151 | self.loss_meter.reset() 152 | 153 | def _evaluate(self, dataloader): 154 | gc.collect() 155 | self._resetmeter() 156 | 157 | self.neuralnet.eval() 158 | self.criterion.eval() 159 | 160 | for _, batch in enumerate(dataloader): 161 | with torch.no_grad(): 162 | im = batch[0].requires_grad_(False).to(DEVICE) 163 | keypts = batch[1].requires_grad_(False).to(DEVICE) 164 | 165 | deformed_batch = self.batch_transform.exe(im, landmarks=keypts) 166 | im, future_im, mask = deformed_batch['image'], deformed_batch['future_image'], deformed_batch['mask'] 167 | 168 | future_im_pred, _, _ = self.neuralnet(im, future_im) 169 | 170 | #loss 171 | loss, _ = self.criterion(future_im_pred, future_im) 172 | 173 | #log meter 174 | self.loss_meter.add(loss.item()) 175 | 176 | self.neuralnet.train() 177 | self.criterion.train() 178 | 179 | return self.loss_meter.value()[0] 180 | 181 | def _train(self, dataloader, epoch): 182 | self._resetmeter() 183 | 184 | self.neuralnet.train() 185 | self.criterion.train() 186 | 187 | for iteration, batch in enumerate(dataloader, 1): 188 | start_time = time.time() 189 | 190 | im = batch[0].to(DEVICE) 191 | keypts = batch[1].to(DEVICE) 192 | 193 | deformed_batch = self.batch_transform.exe(im, landmarks=keypts) 194 | im, future_im, mask = deformed_batch['image'], deformed_batch['future_image'], deformed_batch['mask'] 195 | 196 | #zero gradient first,then forward 197 | self.optimizer.zero_grad() 198 | future_im_pred, _, _ = self.neuralnet(im, future_im) 199 | 200 | #loss 201 | loss, loss_values = self.criterion(future_im_pred, future_im) 202 | 203 | loss.backward() 204 | self.optimizer.step() 205 | 206 | #update weight of perceptual loss by EMA 207 | for k, new_val in enumerate(loss_values): 208 | tmp_name = list(self.criterion.ema.avgs)[k] 209 | tmp_init_val = self.criterion.ema.avgs[tmp_name] 210 | self.criterion.ema.update(tmp_name, new_val, init_val=tmp_init_val) 211 | 212 | #log meter 213 | self.loss_meter.add(loss.item()) 214 | 215 | #print 216 | eslapsed = time.time() - start_time 217 | print('| epoch {:3d} | {:3d}/{:3d} ith_batch | time(s) {:5.2f} | \n loss {:5.2f} | vgg losses {} \n'.format( \ 218 | epoch, iteration, len(dataloader), eslapsed, loss.item(), loss_values)) 219 | 220 | return self.loss_meter.value()[0] 221 | 222 | def exe(self): 223 | print(self.opt) 224 | print('\n') 225 | start_epoch = 1 226 | best_result = 1. 227 | best_flag = False 228 | 229 | #resume from saved checkpoint 230 | if self.opt.resume_checkpoint: 231 | print('Resuming checkpoint at {}'.format(self.opt.resume_checkpoint)) 232 | checkpoint = torch.load( 233 | self.opt.resume_checkpoint, 234 | map_location=lambda storage, loc: storage, pickle_module=pickle) 235 | 236 | model_state = checkpoint['modelstate'] 237 | self.neuralnet.load_state_dict(model_state) 238 | 239 | optim_state = checkpoint['optimstate'] 240 | self.optimizer = _make_optimizer( 241 | self.opt, self.neuralnet, param_groups=optim_state['param_groups']) 242 | 243 | start_epoch = checkpoint['epoch']+1 244 | best_result = checkpoint['best_result'] 245 | 246 | #DataParallel for multiple GPUs: 247 | if len(self.opt.gpus) > 1: 248 | #dim always is 0 because of input data always is in shape N*W 249 | self.neuralnet = nn.DataParallel(self.neuralnet, device_ids=self.opt.gpus, dim=0) 250 | 251 | self.neuralnet.to(DEVICE) 252 | self.criterion.to(DEVICE) 253 | 254 | #visualization 255 | port = 8097 256 | viz = Visdom(port=port) 257 | visdom_saver = VisdomSaver([viz.env]) 258 | 259 | loss_logger = VisdomPlotLogger('line', port=port, \ 260 | opts={'title': 'Total Loss', 'legend': ['train', 'val']}) 261 | 262 | print('Start training: optim {}, on device {}'.format( \ 263 | self.opt.optim_algor, DEVICE)) 264 | 265 | lr_scheduler = optim.lr_scheduler.CosineAnnealingLR( 266 | self.optimizer, T_max=1000, eta_min=5e-5) 267 | 268 | for epoch in range(start_epoch, self.opt.epochs+1): 269 | #let's go 270 | print('\n') 271 | print('-' * 65) 272 | print('{}'.format(time.asctime(time.localtime()))) 273 | print(' **Training epoch {}, lr {}'.format(epoch, self.optimizer.param_groups[0]['lr'])) 274 | 275 | start_time = time.time() 276 | train_loss = self._train(self.train_loader, epoch) 277 | 278 | print('| finish training on epoch {:3d} | time(s) {:5.2f} | loss {:3.4f}'.format( 279 | epoch, time.time() - start_time, train_loss)) 280 | 281 | print(' **Evaluating on validate set') 282 | 283 | start_time = time.time() 284 | val_loss = self._evaluate(self.val_loader) 285 | 286 | print('| finish validating on epoch {:3d} | time(s) {:5.2f} | loss {:3.4f}'.format( 287 | epoch, time.time() - start_time, val_loss)) 288 | 289 | #save check point 290 | if val_loss < best_result: 291 | best_result = val_loss 292 | best_flag = True 293 | print('*' * 10, 'BEST result {} at epoch {}'.format(best_result, epoch), '*' * 10) 294 | 295 | if epoch % self.opt.checkpoint_interval == 0 or epoch == self.opt.epochs or best_flag: 296 | print(' **Saving checkpoint {}'.format(epoch)) 297 | snapshot_prefix = path.join(self.opt.save_path, 'snapshot') 298 | snapshot_path = snapshot_prefix + '_{}.pt'.format(epoch) 299 | 300 | model_state = self.neuralnet.module.state_dict() \ 301 | if len(self.opt.gpus) > 1 else self.neuralnet.state_dict() 302 | 303 | optim_state = self.optimizer.state_dict() 304 | checkpoint = { 305 | 'modelstate':model_state, 306 | 'optimstate':optim_state, 307 | 'epoch':epoch, 308 | 'best_result':best_result, 309 | } 310 | torch.save(checkpoint, snapshot_path, pickle_module=pickle) 311 | 312 | #delete old checkpoint 313 | for f in glob.glob(snapshot_prefix + '*'): 314 | if f != snapshot_path: 315 | os.remove(f) 316 | if best_flag: 317 | best_prefix = path.join(self.opt.save_path, 'BEST') 318 | best_path = best_prefix + '_{}.pt'.format(epoch) 319 | torch.save(checkpoint, best_path, pickle_module=pickle) 320 | best_flag = False 321 | for f in glob.glob(best_prefix + '*'): 322 | if f != best_path: 323 | os.remove(f) 324 | print('| finish saving checkpoint {}'.format(epoch)) 325 | 326 | #visualize training and eval process 327 | loss_logger.log((epoch, epoch), (train_loss, val_loss)) 328 | visdom_saver.save() 329 | 330 | #update learning rate 331 | lr_scheduler.step() 332 | 333 | print('*' * 65) 334 | print('Finish train and test all epochs') 335 | 336 | 337 | #------------------------------------------------------------------------------ 338 | #Testing on specific images 339 | #------------------------------------------------------------------------------ 340 | 341 | class Tester(): 342 | """Testing trained model on test data. 343 | """ 344 | @staticmethod 345 | def test(neuralnet, dataloader): 346 | """ 347 | Segment on random image from dataset 348 | Support 2D images only 349 | """ 350 | neuralnet.eval() 351 | batch_transform = data.BatchTransform() 352 | 353 | idx = 0 354 | for iteration, batch in enumerate(dataloader): 355 | with torch.no_grad(): 356 | im = batch[0].requires_grad_(False).to(DEVICE) 357 | keypts = batch[1].requires_grad_(False).to(DEVICE) 358 | 359 | deformed_batch = batch_transform.exe(im, landmarks=keypts) 360 | im, future_im, mask = deformed_batch['image'], deformed_batch['future_image'], deformed_batch['mask'] 361 | 362 | future_im_pred, gauss_mu, _ = neuralnet(im, future_im) 363 | 364 | predict = future_im_pred.data.cpu().numpy().transpose(0, 2, 3, 1) 365 | gauss_mu = gauss_mu.data.cpu().numpy() 366 | # gauss_map = gauss_map.data.cpu().numpy() 367 | future_im = future_im.data.cpu().numpy().transpose(0, 2, 3, 1) 368 | 369 | os.makedirs('testcheck', exist_ok=True) 370 | fig_path = path.join('testcheck', 'fig_{}.png'.format(iteration)) 371 | utils.savegrid(fig_path, future_im, predict, gauss_mu=gauss_mu, name='deform') 372 | 373 | idx += im.shape[0] 374 | 375 | neuralnet.train() 376 | return idx 377 | 378 | @staticmethod 379 | def exe(opt): 380 | #Load trained model 381 | print(opt, '\n') 382 | print('Load checkpoint at {}'.format(opt.trained_model)) 383 | 384 | neuralnet = _make_model(opt) 385 | checkpoint = torch.load(opt.trained_model, \ 386 | map_location=lambda storage, loc: storage, pickle_module=pickle) 387 | model_state = checkpoint['modelstate'] 388 | neuralnet.load_state_dict(model_state) 389 | 390 | #Dataloader 391 | test_loader = _make_data(opt, subset='val', shuffle=False) 392 | 393 | #DataParallel for multiple GPUs: 394 | if len(opt.gpus) > 1: 395 | #dim always is 0 because of input data always is in shape N*W 396 | neuralnet = nn.DataParallel(neuralnet, device_ids=opt.gpus, dim=0) 397 | neuralnet.to(DEVICE) 398 | 399 | print('Start testing on device {}'.format(DEVICE.type)) 400 | start_time = time.time() 401 | total_sample = Tester.test(neuralnet, test_loader) 402 | print('| finish testing on {} samples in {} seconds'.format( 403 | total_sample, time.time() - start_time)) 404 | 405 | if __name__ == "__main__": 406 | if ARGS.trained_model: 407 | Tester.exe(ARGS) 408 | else: 409 | main = Main(ARGS) 410 | main.exe() 411 | -------------------------------------------------------------------------------- /tps_sampler.py: -------------------------------------------------------------------------------- 1 | # ========================================================== 2 | # Author: Ankush Gupta, Tomas Jakab 3 | # ========================================================== 4 | import scipy.spatial.distance as ssd 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import random 10 | 11 | 12 | class TPSRandomSampler(nn.Module): 13 | 14 | def __init__(self, height, width, vertical_points=10, horizontal_points=10, 15 | rotsd=0.0, scalesd=0.0, transsd=0.1, warpsd=(0.001, 0.005), 16 | cache_size=1000, cache_evict_prob=0.01, pad=True): 17 | super(TPSRandomSampler, self).__init__() 18 | 19 | self.input_height = height 20 | self.input_width = width 21 | 22 | self.h_pad = 0 23 | self.w_pad = 0 24 | if pad: 25 | self.h_pad = self.input_height // 2 26 | self.w_pad = self.input_width // 2 27 | 28 | self.height = self.input_height + self.h_pad 29 | self.width = self.input_width + self.w_pad 30 | 31 | self.vertical_points = vertical_points 32 | self.horizontal_points = horizontal_points 33 | 34 | self.rotsd = rotsd 35 | self.scalesd = scalesd 36 | self.transsd = transsd 37 | self.warpsd = warpsd 38 | self.cache_size = cache_size 39 | self.cache_evict_prob = cache_evict_prob 40 | 41 | self.tps = TPSGridGen( 42 | self.height, self.width, vertical_points, horizontal_points) 43 | 44 | self.cache = [None] * self.cache_size 45 | 46 | self.pad = pad 47 | 48 | 49 | def _sample_grid(self): 50 | W = sample_tps_w( 51 | self.vertical_points, self.horizontal_points, self.warpsd, 52 | self.rotsd, self.scalesd, self.transsd) 53 | W = torch.from_numpy(W.astype(np.float32)) 54 | # generate grid 55 | grid = self.tps(W[None]) 56 | return grid 57 | 58 | 59 | def _get_grids(self, batch_size): 60 | grids = [] 61 | for i in range(batch_size): 62 | entry = random.randint(0, self.cache_size - 1) 63 | if self.cache[entry] is None or random.random() < self.cache_evict_prob: 64 | grid = self._sample_grid() 65 | self.cache[entry] = grid 66 | else: 67 | grid = self.cache[entry] 68 | grids.append(grid) 69 | grids = torch.cat(grids) 70 | return grids 71 | 72 | 73 | def forward(self, input): 74 | m, M = input.min(), input.max() 75 | with torch.no_grad(): 76 | # get TPS grids 77 | batch_size = input.shape[0] 78 | grids = self._get_grids(batch_size) 79 | grids = grids.to(input.device) 80 | 81 | input = F.pad(input, (self.h_pad, self.h_pad, \ 82 | self.w_pad, self.w_pad), mode='replicate') 83 | input = F.grid_sample(input, grids) 84 | input = F.pad(input, (-self.h_pad, -self.h_pad, \ 85 | -self.w_pad, -self.w_pad)) 86 | 87 | return torch.clamp(input, m, M) 88 | 89 | def forward_py(self, input): 90 | input = torch.from_numpy(input).float() 91 | input = input.permute([0, 3, 1, 2]) 92 | 93 | input = self.forward(input) 94 | 95 | input = input.permute([0, 2, 3, 1]) 96 | input = input.numpy() 97 | 98 | return input 99 | 100 | 101 | class TPSGridGen(nn.Module): 102 | 103 | def __init__(self, Ho, Wo, Hc, Wc): 104 | """ 105 | Ho,Wo: height/width of the output tensor (grid dimensions). 106 | Hc,Wc: height/width of the control-point grid. 107 | 108 | Assumes for simplicity that the control points lie on a regular grid. 109 | Can be made more general. 110 | """ 111 | super(TPSGridGen, self).__init__() 112 | 113 | self._grid_hw = (Ho, Wo) 114 | self._cp_hw = (Hc, Wc) 115 | 116 | # initialize the grid: 117 | xx, yy = np.meshgrid(np.linspace(-1, 1, Wo), np.linspace(-1, 1, Ho)) 118 | self._grid = np.c_[xx.flatten(), yy.flatten()].astype(np.float32) # Nx2 119 | self._n_grid = self._grid.shape[0] 120 | 121 | # initialize the control points: 122 | xx, yy = np.meshgrid(np.linspace(-1, 1, Wc), np.linspace(-1, 1, Hc)) 123 | self._control_pts = np.c_[xx.flatten(), yy.flatten()].astype(np.float32) # Mx2 124 | self._n_cp = self._control_pts.shape[0] 125 | 126 | # compute the pair-wise distances b/w control-points and grid-points: 127 | Dx = ssd.cdist(self._grid, self._control_pts, metric='sqeuclidean') # NxM 128 | 129 | # create the tps kernel: 130 | # real_min = 100 * np.finfo(np.float32).min 131 | real_min = 1e-8 132 | Dx = np.clip(Dx, real_min, None) # avoid log(0) 133 | Kp = np.log(Dx) * Dx 134 | Os = np.ones((self._grid.shape[0])) 135 | L = np.c_[Kp, np.ones((self._n_grid, 1), dtype=np.float32), 136 | self._grid] # Nx(M+3) 137 | self._L = torch.from_numpy(L.astype(np.float32)) # Nx(M+3) 138 | 139 | 140 | def forward(self, w_tps): 141 | """ 142 | W_TPS: Bx(M+3)x2 sized tensor of tps-transformation params. 143 | here `M` is the number of control-points. 144 | `B` is the batch-size. 145 | 146 | Returns an BxHoxWox2 tensor of grid coordinates. 147 | """ 148 | assert w_tps.shape[1] - 3 == self._n_cp 149 | batch_size = w_tps.shape[0] 150 | tfm_grid = torch.matmul(self._L, w_tps) 151 | tfm_grid = tfm_grid.reshape( 152 | (batch_size, self._grid_hw[0], self._grid_hw[1], 2)) 153 | return tfm_grid 154 | 155 | 156 | def sample_tps_w(Hc, Wc, warpsd, rotsd, scalesd, transsd): 157 | """ 158 | Returns randomly sampled TPS-grid params of size (Hc*Wc+3)x2. 159 | 160 | Params: 161 | WARPSD: 2-tuple 162 | {ROT/SCALE/TRANS}-SD: 1-tuple of standard devs. 163 | """ 164 | Nc = Hc * Wc # no of control-pots 165 | # non-linear component: 166 | mask = (np.random.rand(Nc, 2) > 0.5).astype(np.float32) 167 | W = warpsd[0] * np.random.randn(Nc, 2) + \ 168 | warpsd[1] * (mask * np.random.randn(Nc, 2)) 169 | # affine component: 170 | rnd = np.random.randn 171 | rot = np.deg2rad(rnd() * rotsd) 172 | sc = 1.0 + rnd() * scalesd 173 | aff = [[transsd*rnd(), transsd*rnd()], 174 | [sc * np.cos(rot), sc * -np.sin(rot)], 175 | [sc * np.sin(rot), sc * np.cos(rot)]] 176 | W = np.r_[W, aff] 177 | return W 178 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | def savegrid(fig_path, images, predictions, \ 5 | gauss_mu=None, labels=None, nrow=8, ncol=8, name='image'): 6 | step = 2 7 | ncol = 8 8 | fig_width = 20 9 | if labels is not None: 10 | step = 3 11 | ncol = 12 12 | fig_width = 30 13 | plt.rcParams['figure.figsize'] = (fig_width, 40) 14 | j = 0 15 | for i in range(0, nrow*ncol, step): 16 | if j >= len(images): 17 | break 18 | img = images[j] 19 | plt.subplot(nrow, ncol, i+1) 20 | plt.imshow(img) #,interpolation='none', cmap="nipy_spectral") 21 | if gauss_mu is not None: 22 | for k in range(gauss_mu[j].shape[0]): 23 | y_jk = ((gauss_mu[j, k, 0]+1)*64).astype(np.int) 24 | x_jk = ((gauss_mu[j, k, 1]+1)*64).astype(np.int) 25 | plt.plot(x_jk, y_jk, 'bo') 26 | plt.title('{}_{}'.format(name, j)) 27 | plt.axis('off') 28 | 29 | pred = predictions[j] 30 | plt.subplot(nrow, ncol, i+2) 31 | plt.imshow(pred) 32 | plt.title('predict_{}'.format(j)) 33 | plt.axis('off') 34 | 35 | if labels is not None: 36 | label = labels[j] 37 | plt.subplot(nrow, ncol, i+3) 38 | plt.imshow(label) 39 | plt.title('label_{}'.format(j)) 40 | plt.axis('off') 41 | 42 | j += 1 43 | # plt.show() 44 | plt.savefig(fig_path, bbox_inches="tight", pad_inches=0) 45 | plt.close() 46 | -------------------------------------------------------------------------------- /vgg.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import torch 4 | from torchvision import models 5 | 6 | 7 | class Vgg16(torch.nn.Module): 8 | def __init__(self, requires_grad=False, \ 9 | names=['conv1_2', 'conv2_2', 'conv3_2', 'conv4_2', 'conv5_2']): 10 | super(Vgg16, self).__init__() 11 | self.names = names 12 | vgg_pretrained_features = models.vgg16(pretrained=True).features 13 | self.slice1 = vgg_pretrained_features[:3] #conv1_2 14 | self.slice2 = vgg_pretrained_features[3:8] #conv2_2 15 | self.slice3 = vgg_pretrained_features[8:13] #conv3_2 16 | self.slice4 = vgg_pretrained_features[13:20] #conv4_2 17 | self.slice5 = vgg_pretrained_features[20:27] #conv5_2 18 | if not requires_grad: 19 | for param in self.parameters(): 20 | param.requires_grad = False 21 | 22 | def forward(self, X): 23 | h = self.slice1(X) 24 | h_conv1_2 = h 25 | h = self.slice2(h) 26 | h_conv2_2 = h 27 | h = self.slice3(h) 28 | h_relu3_2 = h 29 | h = self.slice4(h) 30 | h_relu4_2 = h 31 | h = self.slice5(h) 32 | h_relu5_2 = h 33 | vgg_outputs = namedtuple("VggOutputs", self.names) 34 | out = vgg_outputs(h_conv1_2, h_conv2_2, h_relu3_2, h_relu4_2, h_relu5_2) 35 | return out --------------------------------------------------------------------------------