├── .gitignore ├── NormNet_Paper.png ├── README.md ├── Sample_Restorations.png └── src ├── custom_losses.py ├── data_loader.py ├── helper_funcs.py ├── inpainting.py ├── latent_models.py ├── models.py └── train_LCM.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints -------------------------------------------------------------------------------- /NormNet_Paper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srxdev0619/Latent_Convolutional_Models/7295ecb5c55ee313aaefd6d9cdd4f1b5f70dc5c7/NormNet_Paper.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Latent Convolutional Models 2 | 3 | 4 | ![Img1](Sample_Restorations.png) 5 | *Sample resotrations using a Latent Convolutional Model.* 6 | 7 | 8 | Latent Convolutional Models work by parametrizing the latent space of a generator using convolutional neural networks. A schematic can be found below 9 | 10 | 11 | ![Img2](NormNet_Paper.png) 12 | *The Schematic of a Latent Convolutional Model. The smaller ConvNet **f** (red) is unique to each image is parametrize the latent space of the generator **g_theta** (magenta) which is common to all images. The input **s** is fixed to random noise and is not updated during the training process.* 13 | 14 | 15 | ## Installation Dependencies 16 | 17 | - numpy 1.14.3 18 | - pytorch 0.4.0 19 | - [tensorboard-pytorch](https://github.com/lanpa/tensorboard-pytorch) 20 | - [scikit-image](https://scikit-image.org/) 21 | 22 | # Citation 23 | To cite this work, please use 24 | ``` 25 | @INPROCEEDINGS{LCMAthar2019, 26 | author = {ShahRukh Athar and Evgeny Burnaev and Victor Lempitsky}, 27 | title = {Latent Convolutional Models}, 28 | booktitle = {International Conference on Learning Representations (ICLR)}, 29 | year = {2019} 30 | } 31 | ``` 32 | 33 | # Additional Resources 34 | 35 | - [Webpage](http://shahrukhathar.github.io/2018/06/06/LCM.html) 36 | - [Paper](https://arxiv.org/abs/1806.06284) 37 | - [Supplementary Material](https://drive.google.com/file/d/1K3AceiLhxSPzVdu_-CtZQ7cZnKtxNPCY/view?usp=sharing) 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /Sample_Restorations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srxdev0619/Latent_Convolutional_Models/7295ecb5c55ee313aaefd6d9cdd4f1b5f70dc5c7/Sample_Restorations.png -------------------------------------------------------------------------------- /src/custom_losses.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import numpy as np 6 | import torch 7 | import torch.optim 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torchvision import datasets, transforms 11 | from torch import autograd 12 | import torchvision.utils as vutils 13 | 14 | ''' 15 | Lap Loss adapted from: https://github.com/mtyka/laploss/blob/master/laploss.py 16 | ''' 17 | def gauss_kernel(size=5, sigma=1.0): 18 | grid = np.float32(np.mgrid[0:size,0:size].T) 19 | gaussian = lambda x: np.exp((x - size//2)**2/(-2*sigma**2))**2 20 | kernel = np.sum(gaussian(grid), axis=2) 21 | kernel /= np.sum(kernel) 22 | return kernel 23 | 24 | 25 | def conv_gauss(t_input, stride=1, k_size=5, sigma=1.6, repeats=1): 26 | t_kernel_np = gauss_kernel(size=k_size, sigma=sigma).reshape([1,1, k_size, k_size]) 27 | t_input_device = t_input.device 28 | t_kernel = torch.from_numpy(t_kernel_np).to(t_input_device) 29 | num_channels = t_input.data.shape[1] 30 | t_kernel3 = torch.cat([t_kernel]*num_channels, 0) 31 | t_result = t_input 32 | for r in range(repeats): 33 | t_result = F.conv2d(t_result, t_kernel3, 34 | stride=1, padding=2, groups=num_channels) 35 | return t_result 36 | 37 | def make_laplacian_pyramid(t_img, max_levels): 38 | t_pyr = [] 39 | current = t_img 40 | for level in range(max_levels): 41 | t_gauss = conv_gauss(current, stride=1, k_size=5, sigma=2.0) 42 | t_diff = current - t_gauss 43 | t_pyr.append(t_diff) 44 | current = F.avg_pool2d(t_gauss, 2, 2) 45 | t_pyr.append(current) 46 | return t_pyr 47 | 48 | 49 | def laploss(t_img1, t_img2, max_levels=3): 50 | t_pyr1 = make_laplacian_pyramid(t_img1, max_levels) 51 | t_pyr2 = make_laplacian_pyramid(t_img2, max_levels) 52 | loss = 0.0 53 | for i in range(len(t_pyr1)): 54 | loss += (2**(-2*i))*L1_loss(t_pyr1[i], t_pyr2[i]) 55 | return loss 56 | 57 | 58 | def L1_loss(inputs, targets): 59 | return (inputs - targets).abs().mean() 60 | -------------------------------------------------------------------------------- /src/data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorboardX import SummaryWriter 3 | import scipy.misc 4 | from os import listdir 5 | from os.path import isfile, join 6 | import random 7 | import torch 8 | import latent_models as lm 9 | from skimage.io import imread, imsave 10 | from skimage.transform import resize 11 | import skimage.transform 12 | import time 13 | import os 14 | import warnings 15 | 16 | 17 | 18 | 19 | class DataReader_Disk(): 20 | 21 | """ 22 | This is the class loads a dataset of images and associates with each image a corresponding 23 | latent ConvNet. 24 | """ 25 | 26 | def __init__(self, dataset_folder, 27 | device=torch.device("cuda:0"), 28 | img_size=64, 29 | to_shuffle=False, 30 | model_name=None): 31 | """ 32 | Parameters: 33 | dataset_folder (String): Path to the directory containing images to be trained on. 34 | device (torch.device): Device onto which images and the latent ConvNets are to be loaded 35 | img_size (int, optional): Size to which images are to be resized to. All images would be resized to (img_size, img_size). (Default 64) 36 | to_shuffle (bool, optional): Set to True to shuffle the dataset. (Default: False) 37 | model_name: The model name. This name would be used to create a temporary directory where intermediate latent ConvNets are stored 38 | """ 39 | 40 | self.end = 0 41 | self.start = 0 42 | self.model_name = model_name 43 | self.device=device 44 | self.temp_latent_dir = './latent_temp_dir/' + self.model_name + '_latent_temp/' 45 | if not os.path.exists(self.temp_latent_dir): 46 | os.makedirs(self.temp_latent_dir) 47 | self.source_dir = dataset_folder 48 | self.all_imgs = [join(self.source_dir, f) for f in listdir(self.source_dir) if isfile(join(self.source_dir, f))] 49 | self.all_imgs.sort() 50 | if to_shuffle: 51 | random.seed(42) 52 | random.shuffle(self.all_imgs) 53 | random.seed() 54 | self.img_size = img_size 55 | 56 | def load(self, latent_net_name, 57 | num_epoch=None, 58 | saved_model_name=None, 59 | num_load=None, 60 | same_seed=42, 61 | latent_dir=None): 62 | 63 | """ 64 | This method loads with each image name a corresponding network_id. This network_id is used to load and save the appropriate networks. 65 | 66 | Parameters: 67 | latent_net_name (String): Name of the module to be used as the latent ConvNet. The module must be defined in the 68 | latent_models.py file. 69 | num_epoch (int, optional): Upload the latent nets of a model saved at `num_epoch`. Usually used to continue training (Default: None) 70 | saved_model_name (String, optional): Name to the model whoose latent nets need to be uploaded. (Default: None) 71 | num_load (int, optional): Number of images from `dataset_folder` to be used for training. If None all images are used. (Default: None) 72 | same_seed (int, optional): Seed with which the nets are to be initialized. (Default: 42; Becasue atleast one Deep AI thought this was the answer to 73 | Ultimate Question of Life, the Universe and Everything.) 74 | latent_dir (String, optional): Directory from which to load previously saved latent nets. This must be specified if `num_epoch` is not None. (Default: None) 75 | """ 76 | 77 | self.data_lst = [] 78 | self.data_idx = [] 79 | self.num_samples = 0 80 | self.num_epoch = num_epoch 81 | self.latent_dir = latent_dir 82 | self.latent_net_name = latent_net_name 83 | if num_load is None: 84 | all_imgs_lst = self.all_imgs 85 | else: 86 | all_imgs_lst = self.all_imgs[:num_load] 87 | for img_name in all_imgs_lst: 88 | if same_seed: 89 | torch.manual_seed(same_seed) 90 | latent_net = getattr(lm, self.latent_net_name)() 91 | latent_net = latent_net.to(self.device) 92 | if num_epoch: 93 | if saved_model_name is None: 94 | raise ValueError("\'saved_model_name\' must be specified when loading a saved model") 95 | if latent_dir is None: 96 | raise ValueError("\'latent_dir\' must be specified when loading a saved model") 97 | print("Num prev Loaded: ", self.num_samples) 98 | latent_net.load_state_dict(torch.load(latent_dir + saved_model_name + "_latentnet_" + str(num_epoch) + '_' + str(self.num_samples))) 99 | torch.save(latent_net.state_dict(), self.temp_latent_dir + "temp_latent_net_" + str(self.num_samples)) 100 | self.data_lst.append([img_name, self.num_samples]) 101 | self.data_idx.append(np.random.randint(0, self.img_size - 20, size=(5,2))) 102 | self.num_samples += 1 103 | if num_epoch is None: 104 | print("Networks loaded: ", self.num_samples) 105 | print("Number of samples loaded: ", len(self.data_lst)) 106 | 107 | def get_nets(self, net_ids): 108 | latent_nets = [] 109 | for i in net_ids: 110 | latent_net = getattr(lm, self.latent_net_name)() 111 | latent_net = latent_net.to(self.device) 112 | latent_net.load_state_dict(torch.load(self.temp_latent_dir + "temp_latent_net_" + str(i))) 113 | latent_nets.append(latent_net) 114 | return latent_nets 115 | 116 | def save_nets(self, nets, net_ids): 117 | num_nets = len(net_ids) 118 | for i in range(num_nets): 119 | torch.save(nets[i].state_dict(), self.temp_latent_dir + "temp_latent_net_" + str(net_ids[i])) 120 | 121 | 122 | 123 | def get_imgs(self, img_list): 124 | data_out = None 125 | for img_name in img_list: 126 | inputs = imread(img_name) 127 | inputs = resize(inputs, (self.img_size, self.img_size)) 128 | if len(inputs.shape) != 3: 129 | inputs = np.expand_dims(inputs, -1) 130 | inputs = np.repeat(inputs, 3, -1) 131 | if inputs.shape[-1] != 3: 132 | raise ValueError("Input must have last dimension as 3") 133 | inputs = np.transpose(inputs, [2,0,1]) 134 | inputs = np.expand_dims(inputs, 0) 135 | inputs = (inputs - inputs.min())/(inputs.max() - inputs.min()) 136 | inputs = torch.from_numpy(inputs).float() 137 | if data_out is None: 138 | data_out = inputs 139 | else: 140 | data_out = torch.cat([data_out, inputs]) 141 | return data_out 142 | 143 | def get_batch(self, batch_size=10): 144 | """ 145 | This method gets the next batch of size specified by `batch_size`. If the counter has reached the end of dataset the batch size 146 | returned would be less than or equal to the specified `batch_size`. 147 | Parameters: 148 | batch_size (int, optional): Batch size to be returned. 149 | """ 150 | self.start = self.end 151 | start = self.start 152 | end = min(start + batch_size, self.num_samples) 153 | eff_batch = end - start 154 | if end >= self.num_samples: 155 | end = 0 156 | self.end = end 157 | data_out_lst = [] 158 | latent_net_ids = [] 159 | for i in range(eff_batch): 160 | data_out_lst.append(self.data_lst[start + i][0]) 161 | latent_net_ids.append(self.data_lst[start + i][1]) 162 | 163 | data_out = self.get_imgs(data_out_lst) 164 | latent_nets = self.get_nets(latent_net_ids) 165 | data_out = data_out.to(self.device) 166 | return data_out, latent_nets, latent_net_ids 167 | 168 | 169 | def get_batch_from(self, start, batch_size=10): 170 | """ 171 | This method gets the next batch of size specified by `batch_size` starting from `start`. If the counter has reached the end of dataset the batch size 172 | returned would be less than or equal to the specified `batch_size`. 173 | Parameters: 174 | start (int): The index from which to start the batch. 175 | batch_size (int, optional): Batch size to be returned. 176 | """ 177 | start = start 178 | if start > self.num_samples: 179 | raise ValueError("\'start\' cannot be greater than the number of samples in the dataset") 180 | end = min(start + batch_size, self.num_samples) 181 | eff_batch = end - start 182 | if end >= self.num_samples: 183 | warnings.warn("Warning: End of dataset reached the batch-size will be smaller than what is specified") 184 | end = 0 185 | data_out_lst = [] 186 | latent_net_ids = [] 187 | for i in range(eff_batch): 188 | data_out_lst.append(self.data_lst[start + i][0]) 189 | latent_net_ids.append(self.data_lst[start + i][1]) 190 | 191 | data_out = self.get_imgs(data_out_lst) 192 | latent_nets = self.get_nets(latent_net_ids) 193 | data_out = data_out.to(self.device) 194 | return data_out, latent_nets, latent_net_ids 195 | 196 | def update_state(self, latent_nets, latent_nets_ids): 197 | """ 198 | Updates the latent networks during training. 199 | Parameters: 200 | latent_nets (List): List of latent networks (nn.Modules) to be saved. 201 | latent_nets_ids (List): List of ids of the latent networks to be saved. 202 | """ 203 | self.save_nets(latent_nets, latent_nets_ids) 204 | 205 | def save_latent_net(self, latent_dir ,name): 206 | """ 207 | Saves the latent networks in `latent_dir` to create a model checkpoint. 208 | Parameters: 209 | latent_dir (String): Name of directory where the latent networks are saved. 210 | name (String): Name with which the latent networks must be saved. 211 | """ 212 | if not os.path.exists(latent_dir): 213 | os.makedirs(latent_dir) 214 | for i in range(self.num_samples): 215 | latent_net = self.get_nets([i])[0] 216 | torch.save(latent_net.state_dict(), latent_dir + name + str(i)) 217 | 218 | def get_vec(self, idx): 219 | return self.data_lst[idx][1] 220 | -------------------------------------------------------------------------------- /src/helper_funcs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | import sys 5 | 6 | import numpy as np 7 | 8 | """ 9 | Got from the Deep Image Prior Repo: https://github.com/DmitryUlyanov/deep-image-prior 10 | """ 11 | 12 | 13 | 14 | 15 | def get_noise(input_depth, method, spatial_size, noise_type='u', var=1./10, batch_size=1): 16 | """Returns a tensor of size (1 x `input_depth` x `spatial_size[0]` x `spatial_size[1]`) 17 | initialized in a specific way. 18 | Args: 19 | input_depth: number of channels in the tensor 20 | method: `noise` for fillting tensor with noise; `meshgrid` for np.meshgrid 21 | spatial_size: spatial size of the tensor to initialize 22 | noise_type: 'u' for uniform; 'n' for normal 23 | var: a factor, a noise will be multiplicated by. Basically it is standard deviation scaler. 24 | """ 25 | if isinstance(spatial_size, int): 26 | spatial_size = (spatial_size, spatial_size) 27 | if method == 'noise': 28 | shape = [batch_size, input_depth, spatial_size[0], spatial_size[1]] 29 | net_input = torch.zeros(shape) 30 | 31 | fill_noise(net_input.data, noise_type) 32 | net_input.data *= var 33 | net_input = net_input.data 34 | elif method == 'meshgrid': 35 | assert input_depth == 2 36 | X, Y = np.meshgrid(np.arange(0, spatial_size[1])/float(spatial_size[1]-1), np.arange(0, spatial_size[0])/float(spatial_size[0]-1)) 37 | meshgrid = np.concatenate([X[None,:], Y[None,:]]) 38 | meshgrid = meshgrid[None,:] 39 | if batch_size > 1: 40 | meshgrid = np.repeat(meshgrid, batch_size, 0) 41 | net_input = torch.from_numpy(meshgrid).float() 42 | else: 43 | assert False 44 | 45 | return net_input 46 | 47 | 48 | def fill_noise(x, noise_type): 49 | """Fills tensor `x` with noise of type `noise_type`.""" 50 | torch.manual_seed(42) 51 | if noise_type == 'u': 52 | x.uniform_() 53 | elif noise_type == 'n': 54 | x.normal_() 55 | else: 56 | assert False 57 | -------------------------------------------------------------------------------- /src/inpainting.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorboardX import SummaryWriter 3 | import torch 4 | import torch.optim 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | import torchvision.utils as vutils 9 | from data_loader import * 10 | from models import * 11 | from custom_losses import * 12 | from helper_funcs import * 13 | import itertools 14 | import time 15 | import os 16 | import skimage.io 17 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 18 | 19 | 20 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 21 | 22 | 23 | 24 | RESTRICT = True 25 | RESTRICT_VAL = 0.01 26 | L2_W = 0.0 27 | LATENT_NET_NAME = 'Latent4LSND' 28 | LOG_DIR = '../runs/CelebA50k_Inpaint_Center' 29 | IMG_DIR = "../" + LOG_DIR.split("/")[-1] + "/" 30 | SEED = 42 31 | 32 | MODEL_NAME = 'CelebA150k_LCM' 33 | 34 | MODEL_NUM = 50000 35 | LATENT_DIR = '../' + MODEL_NAME + '_latentDIR/' 36 | 37 | #Defining and loading the model 38 | G = EncDecCelebA(in_channels=64) 39 | G = G.to(device) 40 | 41 | G.load_state_dict(torch.load('../models/CelebA/' + MODEL_NAME + str(MODEL_NUM))) 42 | s_np = np.load(MODEL_NAME + "_commoninput.npy") 43 | s = torch.from_numpy(s_np) 44 | s = s.to(device) 45 | 46 | 47 | img_size = 128 48 | hole_size = 50 49 | 50 | dataset_folder='../celebA_split/test/' 51 | 52 | 53 | data_reader = DataReader_Disk(dataset_folder=dataset_folder, 54 | device=device, 55 | to_shuffle=True, 56 | img_size=img_size, 57 | model_name=MODEL_NAME) 58 | data_reader.load(latent_net_name=LATENT_NET_NAME, num_load=50, same_seed=SEED) 59 | writer = SummaryWriter(log_dir = LOG_DIR) 60 | 61 | #Generate the mask 62 | def make_mask(shape, hole_size, extreme=False, seed=42): 63 | mask = torch.ones(shape) 64 | w_mask = torch.ones(shape) 65 | batch_size = shape[0] 66 | x_max = shape[2] 67 | y_max = shape[3] 68 | np.random.seed(seed) 69 | for i in range(batch_size): 70 | start_x = int((x_max/2) - (hole_size/2.0)) 71 | start_y = int((y_max/2.0) - (hole_size/2.0)) 72 | w_start_x = max(0, start_x - 5) 73 | w_end_x = min(x_max, start_x+hole_size+5) 74 | w_start_y = max(0, start_y - 5) 75 | w_end_y = min(y_max, start_y+hole_size+5) 76 | if extreme: 77 | mask[:,:,start_x:start_x+hole_size, start_y:start_y+hole_size] = 0.0 78 | w_mask[:,:,w_start_x:w_end_x, w_start_y:w_end_y] = 5.0 79 | else: 80 | mask[i,:,start_x:start_x+hole_size, start_y:start_y+hole_size] = 0.0 81 | w_mask[i,:,w_start_x:w_end_x, w_start_y:w_end_y]=5.0 82 | return mask, w_mask 83 | 84 | 85 | def inpaint(s, num_epochs=100): 86 | G.eval() 87 | for p in G.parameters(): 88 | p.requires_grad=False 89 | count = 0 90 | if not os.path.exists(IMG_DIR): 91 | os.makedirs(IMG_DIR) 92 | batch_size = 50 93 | for _ in range(1): 94 | data_in,net_in,_ = data_reader.get_batch_from(100, batch_size=batch_size) 95 | writer.add_image('Org_image', data_in, count) 96 | 97 | for img_num in range(batch_size): 98 | skimage.io.imsave(IMG_DIR + 'Org_image' + str(img_num)+".png", data_in[img_num].numpy().transpose(1,2,0)) 99 | 100 | gen_mask, w_mask = make_mask(data_in.shape, hole_size, extreme=False, seed=SEED) 101 | gen_mask = gen_mask.to(device) 102 | w_mask = w_mask.to(device) 103 | BATCH_SIZE = batch_size 104 | writer.add_image('Hole_Image', gen_mask*data_in, 0) 105 | nets_params = [] 106 | for i in range(BATCH_SIZE): 107 | for p in net_in[i].parameters(): 108 | p.requires_grad=True 109 | nets_params += list(net_in[i].parameters()) 110 | optim_nets = optim.SGD(nets_params, lr=1.0, weight_decay=L2_W) #1.0, 0.5 111 | for ep in range(num_epochs): 112 | optim_nets.zero_grad() 113 | map_out_lst = [] 114 | for i in range(BATCH_SIZE): 115 | m_out = net_in[i](s) 116 | map_out_lst.append(m_out) 117 | map_out = torch.cat(map_out_lst, 0) 118 | g_out = G(map_out) 119 | lap_loss = laploss(gen_mask*g_out, gen_mask*data_in) 120 | mse_loss = F.mse_loss(w_mask*gen_mask*g_out, w_mask*gen_mask*data_in) 121 | loss = mse_loss + lap_loss 122 | loss.backward() 123 | optim_nets.step() 124 | if RESTRICT: 125 | val = RESTRICT_VAL 126 | for i in range(BATCH_SIZE): 127 | net_in[i].restrict(-val, val) 128 | writer.add_scalar('ep_z_loss', loss.data, ep) 129 | writer.add_scalar('ep_z_loss_Lap', lap_loss.data.item(), ep) 130 | writer.add_scalar('ep_z_loss_MSE', mse_loss.data.item(), ep) 131 | if ep%100 == 0: 132 | writer.add_image('Inpainted_Image_ep', g_out, ep) 133 | writer.add_image('latent_Z', map_out[:10,:3,:,:].data.cpu(), ep) 134 | 135 | map_out_lst = [] 136 | for i in range(BATCH_SIZE): 137 | m_out = net_in[i](s) 138 | map_out_lst.append(m_out) 139 | map_out = torch.cat(map_out_lst, 0) 140 | g_pred = G(map_out) 141 | writer.add_scalar('z_loss', loss.data.item(), count) 142 | writer.add_image('Inpainted_Image', g_pred, count) 143 | for img_num in range(batch_size): 144 | skimage.io.imsave(IMG_DIR + 'Hole_Image' + str(img_num) +".png", (gen_mask*data_in)[img_num].data.cpu().numpy().transpose(1,2,0)) 145 | skimage.io.imsave(IMG_DIR + 'Inpainted_Image' + str(img_num) + ".png", g_pred[img_num].data.cpu().numpy().transpose(1,2,0)) 146 | 147 | 148 | 149 | 150 | inpaint(s, 3001) 151 | writer.close() 152 | -------------------------------------------------------------------------------- /src/latent_models.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import torch 6 | import torch.optim 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | 12 | 13 | class Latent4LSND(nn.Module): 14 | 15 | def __init__(self, lr_slope=0.2): 16 | super(Latent4LSND, self).__init__() 17 | self.lr_slope = lr_slope 18 | 19 | 20 | self.conv1 = nn.Conv2d(2, 8, 3, 1, 0, 1, bias=False) #80 21 | self.conv1_bn = nn.BatchNorm2d(8) 22 | self.conv2 = nn.Conv2d(8, 16, 3, 1, 0, 1, bias=False) #76 23 | self.conv2_bn = nn.BatchNorm2d(16) 24 | self.conv3 = nn.Conv2d(16, 32, 3, 1, 0, 1, bias=False) #72 25 | self.conv3_bn = nn.BatchNorm2d(32) 26 | self.conv4 = nn.Conv2d(32, 64, 3, 1, 0, 1, bias=False) #68 27 | 28 | def restrict(self, min_val, max_val): 29 | for m in self.modules(): 30 | if isinstance(m, nn.Conv2d): 31 | for p in m.parameters(): 32 | p.data.clamp_(min_val, max_val) 33 | 34 | def forward(self, x): 35 | 36 | x = F.leaky_relu(self.conv1_bn(self.conv1(x)), self.lr_slope) 37 | x = F.leaky_relu(self.conv2_bn(self.conv2(x)), self.lr_slope) 38 | x = F.leaky_relu(self.conv3_bn(self.conv3(x)), self.lr_slope) 39 | x = F.leaky_relu(self.conv4(x), self.lr_slope) 40 | 41 | return x 42 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | class EncDecCelebA(nn.Module): 17 | 18 | def __init__(self, in_channels=1, lr_slope=0.2, bias=False): 19 | super(EncDecCelebA, self).__init__() 20 | self.lr_slope = lr_slope 21 | 22 | self.conv1 = nn.Conv2d(in_channels, 256, 4, 2, 1, 1, bias=False) #16 23 | self.conv1_bn = nn.BatchNorm2d(256) 24 | self.conv2 = nn.Conv2d(256, 512, 4, 2, 1, 1, bias=False) #8 25 | self.conv2_bn = nn.BatchNorm2d(512) 26 | self.conv3 = nn.Conv2d(512, 1024, 4, 2, 1, 1, bias=False) #4 27 | self.conv3_bn = nn.BatchNorm2d(1024) 28 | 29 | self.conv4 = nn.Conv2d(1024, 1024, 3, 1, 2, 2, groups=512, bias=False) #4 30 | self.conv4_bn = nn.BatchNorm2d(1024) 31 | self.conv5 = nn.Conv2d(1024, 1024, 3, 1, 2, 2, groups=512, bias=False) #4 32 | self.conv5_bn = nn.BatchNorm2d(1024) 33 | self.conv6 = nn.Conv2d(1024, 1024, 3, 1, 2, 2, groups=512, bias=False) #4 34 | self.conv6_bn = nn.BatchNorm2d(1024) 35 | 36 | self.convT1 = nn.ConvTranspose2d(1024 + 1024, 512, 4, 2, 1, bias=bias) 37 | self.convT1_bn = nn.BatchNorm2d(512) #8 38 | self.convT2 = nn.ConvTranspose2d(512+512, 512, 4, 2, 1, bias=bias) 39 | self.convT2_bn = nn.BatchNorm2d(512) #16 40 | self.convT3 = nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=bias) 41 | self.convT3_bn = nn.BatchNorm2d(256) #32 42 | self.convT4 = nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=bias) 43 | self.convT4_bn = nn.BatchNorm2d(128) #64 44 | self.convT5 = nn.Conv2d(128, 64, 3, 1, 1, bias=bias) 45 | self.convT5_bn = nn.BatchNorm2d(64) #128 46 | self.convT6 = nn.Conv2d(64, 32, 3, 1, 1, 1, bias=bias) 47 | self.convT6_bn = nn.BatchNorm2d(32) #128 ## 48 | self.convT7 = nn.Conv2d(32, 3, 3, 1, 1, 1, bias=bias) ## 49 | self.upsamp = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 50 | 51 | 52 | def forward(self, input): 53 | #Encoder 54 | x1 = F.leaky_relu(self.conv1_bn(self.conv1(input)), self.lr_slope) 55 | x2 = F.leaky_relu(self.conv2_bn(self.conv2(x1)), self.lr_slope) 56 | x3 = F.leaky_relu(self.conv3_bn(self.conv3(x2)), self.lr_slope) 57 | 58 | x4 = F.leaky_relu(self.conv4_bn(self.conv4(x3)), self.lr_slope) 59 | x5 = F.leaky_relu(self.conv5_bn(self.conv5(x4)), self.lr_slope) 60 | x6 = F.leaky_relu(self.conv6_bn(self.conv6(x5)), self.lr_slope) 61 | 62 | 63 | #Decoder 64 | x = torch.cat([x6,x3],1) 65 | x = F.leaky_relu(self.convT1_bn(self.convT1(x)), self.lr_slope) #8 66 | x = torch.cat([x,x2], 1) 67 | x = F.leaky_relu(self.convT2_bn(self.convT2(x)), self.lr_slope) #16 68 | x = F.leaky_relu(self.convT3_bn(self.convT3(x)), self.lr_slope) #32 69 | x = F.leaky_relu(self.convT4_bn(self.convT4(x)), self.lr_slope) #64 70 | x = self.upsamp(x) 71 | x = F.leaky_relu(self.convT5_bn(self.convT5(x)), self.lr_slope) #128 72 | x = F.leaky_relu(self.convT6_bn(self.convT6(x)), self.lr_slope) #128 73 | x = F.sigmoid(self.convT7(x)) #128 74 | 75 | return x 76 | -------------------------------------------------------------------------------- /src/train_LCM.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorboardX import SummaryWriter 3 | import torch 4 | import torch.optim 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | import torchvision.utils as vutils 9 | from data_loader import * 10 | from models import * 11 | from custom_losses import * 12 | from helper_funcs import * 13 | import itertools 14 | import time 15 | import os 16 | os.environ["CUDA_VISIBLE_DEVICES"]="1" 17 | 18 | 19 | 20 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 21 | 22 | 23 | IMG_SIZE = 128 24 | S_SIZE = 40 25 | BATCH_SIZE = 10 #was 10 26 | LATENT_NET_NAME = 'Latent4LSND' 27 | RESTRICT = True 28 | L2_W = 0.0 29 | RESTRICT_VAL = 0.01 30 | fname = "train_LCM" 31 | NUM_TRAIN_SAMPLES = 150000 32 | 33 | LOG_DIR = '../runs/LCM_CelebA/' 34 | 35 | MODEL_NAME = 'CelebA150k_LCM' 36 | 37 | LATENT_CHECK_DIR = '../' + MODEL_NAME + '_latentDIR/' 38 | 39 | if not os.path.exists(LATENT_CHECK_DIR): 40 | os.makedirs(LATENT_CHECK_DIR) 41 | 42 | PREV_EPOCH = None 43 | NUM_EPOCHS = 150000 44 | SAVE_EVERY = 15000 45 | 46 | #The Generator Network 47 | G = EncDecCelebA(in_channels=64) 48 | G = G.to(device) 49 | 50 | #Uniform noise 51 | s = get_noise(input_depth=2, method='noise', spatial_size=[S_SIZE, S_SIZE], batch_size=1) 52 | 53 | np.save(MODEL_NAME + '_commoninput', s.data.numpy()) 54 | 55 | 56 | writer = SummaryWriter(log_dir=LOG_DIR) 57 | 58 | dataset_folder = '../celebA_split/train/' 59 | 60 | data_reader = DataReader_Disk(dataset_folder=dataset_folder, 61 | device=device, 62 | to_shuffle=True, 63 | img_size=IMG_SIZE, 64 | model_name=MODEL_NAME) 65 | 66 | 67 | if PREV_EPOCH: 68 | s_np = np.load(MODEL_NAME + "_commoninput.npy") 69 | s.data = torch.from_numpy(s_np) 70 | START_EPOCH = PREV_EPOCH + 1 71 | END_EPOCH = START_EPOCH + NUM_EPOCHS 72 | data_reader.load(latent_net_name=LATENT_NET_NAME, 73 | num_epoch=PREV_EPOCH, 74 | saved_model_name=MODEL_NAME, 75 | num_load=NUM_TRAIN_SAMPLES, 76 | latent_dir=LATENT_CHECK_DIR) 77 | G.load_state_dict(torch.load('../models/CelebA/' + MODEL_NAME + str(PREV_EPOCH))) 78 | else: 79 | START_EPOCH = 0 80 | END_EPOCH = START_EPOCH + NUM_EPOCHS 81 | data_reader.load(latent_net_name=LATENT_NET_NAME, 82 | num_load=NUM_TRAIN_SAMPLES) 83 | s = s.to(device) 84 | G_optimizer = optim.SGD(G.parameters(), lr=1.0) 85 | 86 | 87 | 88 | def trainZG(epoch, data_in, net_in, num_epochs=100): 89 | """ 90 | Jointly trains the latent networks and the generator network. 91 | """ 92 | G.train() 93 | for p in G.parameters(): 94 | p.requires_grad=True 95 | BATCH_SIZE = len(net_in) 96 | nets_params = [] 97 | for i in range(BATCH_SIZE): 98 | for p in net_in[i].parameters(): 99 | p.requires_grad=True 100 | nets_params += list(net_in[i].parameters()) 101 | optim_nets = optim.SGD(nets_params, lr=1.0, weight_decay=L2_W) 102 | for ep in range(num_epochs): 103 | G_optimizer.zero_grad() 104 | optim_nets.zero_grad() 105 | map_out_lst = [] 106 | for i in range(BATCH_SIZE): 107 | m_out = net_in[i](s) 108 | map_out_lst.append(m_out) 109 | map_out = torch.cat(map_out_lst, 0) 110 | g_out = G(map_out) 111 | lap_loss = laploss(g_out, data_in) 112 | mse_loss = F.mse_loss(g_out, data_in) 113 | loss = mse_loss + lap_loss 114 | loss.backward() 115 | optim_nets.step() 116 | G_optimizer.step() 117 | if RESTRICT: 118 | val = RESTRICT_VAL 119 | for i in range(BATCH_SIZE): 120 | net_in[i].restrict(-1.0*val, val) 121 | optim_nets.zero_grad() 122 | G_optimizer.zero_grad() 123 | if epoch%10 == 0: 124 | G.eval() 125 | map_out_lst = [] 126 | for i in range(BATCH_SIZE): 127 | m_out = net_in[i](s) 128 | map_out_lst.append(m_out) 129 | map_out = torch.cat(map_out_lst, 0) 130 | g_out = G(map_out) 131 | writer.add_scalar('Z_loss', loss.data.item(), epoch) 132 | writer.add_scalar('Z_loss_lap', lap_loss.data.item(), epoch) 133 | writer.add_scalar('Z_loss_MSE', mse_loss.data.item(), epoch) 134 | if epoch%100 == 0: 135 | writer.add_image('Real_Images', data_in[:5].data.cpu(), epoch) 136 | writer.add_image('Generated_Images_Z', g_out[:5].data.cpu(), epoch) 137 | writer.add_image('latent_Z', map_out[:10,:3,:,:].data.cpu(), epoch) 138 | 139 | return net_in 140 | 141 | 142 | 143 | for epoch in range(START_EPOCH, END_EPOCH + 1): 144 | #Get a batch of images, their latent networks and corresponding network ids 145 | data_in, latent_nets, latent_net_ids = data_reader.get_batch(batch_size=BATCH_SIZE) 146 | 147 | #train the latent networks and generator 148 | latent_nets = trainZG(epoch, data_in, latent_nets, num_epochs=50) 149 | 150 | #update the latent networks 151 | data_reader.update_state(latent_nets, latent_net_ids) 152 | print(fname + " Epoch: ", epoch) 153 | if epoch%SAVE_EVERY == 0: 154 | if epoch > 0: 155 | data_reader.save_latent_net(name=MODEL_NAME + "_latentnet_" +str(epoch) + "_", latent_dir=LATENT_CHECK_DIR) 156 | torch.save(G.state_dict(), '../models/CelebA/' + MODEL_NAME + str(epoch)) 157 | writer.close() 158 | --------------------------------------------------------------------------------