├── flow.JPG ├── cyclegan.png ├── datasets.py ├── LICENSE ├── getDataLoader.py ├── createArchitecture.py ├── testing.py ├── README.md ├── model.py ├── helper.py └── Training.py /flow.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vineet-Pandey/Cycle-GAN-for-Depth-Map-Generation/HEAD/flow.JPG -------------------------------------------------------------------------------- /cyclegan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vineet-Pandey/Cycle-GAN-for-Depth-Map-Generation/HEAD/cyclegan.png -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import random 3 | import os 4 | 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | import torchvision.transforms as transforms 8 | 9 | class ImageDataset(Dataset): 10 | def __init__(self, root, transforms_=None, unaligned=False, mode='train'): 11 | self.transform = transforms.Compose(transforms_) 12 | self.unaligned = unaligned 13 | 14 | self.files_A = sorted(glob.glob(os.path.join(root, '%s/A' % mode) + '/*.*')) 15 | self.files_B = sorted(glob.glob(os.path.join(root, '%s/B' % mode) + '/*.*')) 16 | 17 | def __getitem__(self, index): 18 | item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)])) 19 | 20 | if self.unaligned: 21 | item_B = self.transform(Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)])) 22 | else: 23 | item_B = self.transform(Image.open(self.files_B[index % len(self.files_B)])) 24 | 25 | return {'A': item_A, 'B': item_B} 26 | 27 | def __len__(self): 28 | return max(len(self.files_A), len(self.files_B)) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Vineet Pandey 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /getDataLoader.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[3]: 5 | 6 | 7 | import os 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | from torch.utils.data import DataLoader 11 | import torchvision.datasets as datasets 12 | 13 | 14 | # In[7]: 15 | 16 | 17 | class getDataLoader: 18 | 19 | def __init__(self,img_type, img_dir = 'sunRGBD', 20 | img_size =128,batch_size =16,num_workers=0): 21 | super(getDataLoader, self).__init__() 22 | self.img_type = img_type 23 | self.img_dir = img_dir 24 | self.batch_size = batch_size 25 | self.img_size = img_size 26 | self.num_workers = num_workers 27 | 28 | 29 | def load_data(self): 30 | transform = transforms.Compose([transforms.Resize(self.img_size),transforms.RandomCrop(self.img_size),transforms.ToTensor()]) 31 | image_path = './'+ self.img_dir 32 | train_path = os.path.join(image_path, self.img_type) 33 | test_path = os.path.join(image_path, 'test_{}'.format(self.img_type)) 34 | train_dataset = datasets.ImageFolder(train_path, transform) 35 | test_dataset = datasets.ImageFolder(test_path, transform) 36 | 37 | train_loader = DataLoader(dataset = train_dataset, batch_size=self.batch_size, 38 | shuffle = True, num_workers = self.num_workers) 39 | test_loader = DataLoader(dataset = test_dataset, batch_size = self.batch_size, 40 | shuffle = True, num_workers = self.num_workers) 41 | return train_loader, test_loader 42 | 43 | -------------------------------------------------------------------------------- /createArchitecture.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | # createArchitecture class is used to initialize the cycle-GAN architecture. It takes the number of channels(images channels) from the user. 7 | # Network can be made more deeper by inserting more residual blocks which is 6 by default in code below 8 | # create_model takes the initial values and outputs 4 different networks, 2 generator networks and 2 discriminator networks. 9 | # Most of the code below is self explanatory, but a comment has been provided where ever necessary 10 | 11 | import torch 12 | from model import Generator, Discriminator 13 | from helper import weights_init_normal 14 | 15 | 16 | # In[ ]: 17 | 18 | 19 | class createArchitecture: 20 | 21 | def __init__(self, generator_channels = 3, discriminator_channels = 3, num_res_blocks = 6): 22 | self.generator_channels = generator_channels 23 | self.discriminator_channels = discriminator_channels 24 | self.num_res_blocks = num_res_blocks 25 | 26 | def create_model(self): 27 | generatorNetwork_A2B = Generator(self.generator_channels, num_ResnetBlocks = self.num_res_blocks) 28 | generatorNetwork_B2A = Generator(self.generator_channels, num_ResnetBlocks = self.num_res_blocks) 29 | discriminatorNetwork_A = Discriminator(self.discriminator_channels) 30 | discriminatorNetwork_B = Discriminator(self.discriminator_channels) 31 | 32 | if torch.cuda.is_available(): 33 | generatorNetwork_A2B.cuda() 34 | generatorNetwork_B2A.cuda() 35 | discriminatorNetwork_A.cuda() 36 | discriminatorNetwork_B.cuda() 37 | print("Network moved to GPU") 38 | else: 39 | print("Network moved to CPU as no GPU devices were identified") 40 | # Weights have been initialized by normal guassian values 41 | generatorNetwork_A2B.apply(weights_init_normal) 42 | generatorNetwork_B2A.apply(weights_init_normal) 43 | discriminatorNetwork_A.apply(weights_init_normal) 44 | discriminatorNetwork_B.apply(weights_init_normal) 45 | 46 | return generatorNetwork_A2B,generatorNetwork_B2A,discriminatorNetwork_A,discriminatorNetwork_B 47 | 48 | 49 | -------------------------------------------------------------------------------- /testing.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[37]: 5 | # This python script has been developed to infer the results the cycle-GAN model created in the model.py file. 6 | # This script uses the trained weights of discriminators and the generators stored in checkpoints_cyclegan folder 7 | 8 | from model import Generator 9 | import torch 10 | from torchvision.utils import save_image 11 | from createArchitecture import createArchitecture 12 | from getDataLoader import getDataLoader 13 | 14 | 15 | # In[38]: 16 | 17 | # Model creation using the createArchitecture class. create_model() creates the model. 18 | # More description of the architecture creation has been described in the createArchitecture.py script 19 | torch.set_default_tensor_type('torch.DoubleTensor') 20 | mod = createArchitecture() 21 | # mod = mod.double() 22 | # Since this deals with generation of depth map from the RGB image, discriminators are not considered. However, 23 | # create_model outputs 2-Generators and 2-discriminators 24 | gen_A2B,gen_B2A,_,_ = mod.create_model() 25 | # gen_A2B = gen_A2B.cuda() 26 | # gen_B2A = gen_B2A.cuda() 27 | 28 | 29 | # In[39]: 30 | 31 | # Loading the pre-trained weights 32 | gen_A2B.load_state_dict(torch.load('checkpoints_cyclegan/gen_A2B.pkl')) 33 | gen_B2A.load_state_dict(torch.load('checkpoints_cyclegan/gen_B2A.pkl')) 34 | 35 | 36 | # In[40]: 37 | 38 | # the architecture is evaluated and put on the CUDA Device. However, in the absence of a device, it can 39 | # run on CPU as well. 40 | 41 | gen_A2B.eval() 42 | gen_B2A.eval() 43 | 44 | gen_A2B = gen_A2B.cuda() 45 | gen_B2A = gen_B2A.cuda() 46 | 47 | 48 | # In[41]: 49 | 50 | 51 | f =getDataLoader('rgb', batch_size=1) 52 | _, test_dataloader_A = f.load_data() 53 | test_iter = iter(test_dataloader_A) 54 | 55 | 56 | # In[42]: 57 | #testing the images in the dataset 58 | 59 | cuda0 = torch.device('cuda:0') 60 | test_rgb_image= test_iter.next()[0] 61 | test_rgb_image = torch.tensor(test_rgb_image,dtype=torch.double, device=cuda0) 62 | 63 | save_image(test_rgb_image, 'output/depth/test.png') 64 | 65 | 66 | # In[43]: 67 | 68 | 69 | generated_depth_image=gen_A2B(test_rgb_image).data 70 | generated_depth_image = (generated_depth_image)*-1.5 71 | generated_depth_image.size() 72 | 73 | 74 | # In[44]: 75 | 76 | 77 | save_image(generated_depth_image, 'output/depth/depth.png') 78 | 79 | 80 | # In[45]: 81 | 82 | 83 | torch.cuda.current_device() 84 | 85 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cycle-GAN-for-Depth-Map-Generation 2 | Generative Adversarial Networks have been improved 3 | further to overcome the shortcomings that it initially faced. Cycle is one of the 4 | variations that combines the properties of conditional constraints and cycle consistencies 5 | to effectively use the GAN architecture in Image-to-Image translation tasks. Hence, this 6 | was chosen as the technique to generate the depth images for our project. 7 | Unlike GANs, cycle-GANs can be used to alter a given input image to a distribution of the 8 | target domain. Concretely, an image is taken from input domain D i and then transformed into an image of target domain D t without having a one-to-one mapping between images from the 9 | input to target domain in the training set. Since this architecture has this relaxation of one- 10 | to-one mapping, it becomes quite powerful as the same method could be employed to tackle 11 | variety of problems by varying input and output domain pairs. Since, this method works on 12 | unpaired dataset, it becomes more modular than Pix2Pix architecture. This modularity has been achieved by two step transformations- first by mapping the input image to the target domain and then getting back the original image form 13 | the target domain, Mapping the image to 14 | target domain is done using a generator network and the quality of image is checked by the 15 | discriminator which constantly pushes the generator to perform better. 16 | ![alt text](https://github.com/Vineet-Pandey/Cycle-GAN-for-Depth-Map-Generation/blob/master/cyclegan.png) 17 | ## Dependencies 18 | 1) Pytorch 0.4.1 or newer 19 | 2) Python 3.6 20 | 3) Matplotlib 3.0 21 | 4) OpenCV 3.4 or higher 22 | 23 | ## Directory Structure 24 | #### Dataset directory 25 | ![alt text](https://github.com/Vineet-Pandey/Cycle-GAN-for-Depth-Map-Generation/blob/master/flow.JPG) 26 | 27 | Corresponding images should go in final directory as shown above. 28 | #### checkpoints_cyclegan directory 29 | Create a directory named **checkpoints_cyclegan** to save the .pkl files 30 | #### output directory 31 | Create a directory named **output** to save the inference outputs 32 | #### samples_cyclegan directory 33 | Create a directory named **samples_cyclegan** to save samples while training 34 | 35 | ## Training 36 | `cd` to the folder where the scripts are placed and 37 | run ` python3 Training.py` 38 | 39 | ## Testing 40 | run `python3 testing.py` 41 | ## Pre-Trained weights 42 | Pre-trained weights can be found at [Pre-Trained Weights](https://www.dropbox.com/sh/x8oeg2vvmjlsxzc/AAB_hAvTjzV9fquzu-WA942Ea?dl=0) 43 | 44 | 45 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | # In[2]: 12 | 13 | 14 | class ResnetBlock(nn.Module): 15 | def __init__(self,input_features): 16 | super(ResnetBlock, self).__init__() 17 | Resnet_convolution = [ 18 | nn.ReflectionPad2d(1), 19 | nn.Conv2d(input_features,input_features,3), 20 | nn.InstanceNorm2d(input_features), 21 | nn.LeakyReLU(negative_slope = 0.01), 22 | nn.ReflectionPad2d(1), 23 | nn.Conv2d(input_features, input_features,3), 24 | nn.InstanceNorm2d(input_features) 25 | ] 26 | self.Resnet_convolution = nn.Sequential(*Resnet_convolution) 27 | 28 | def forward(self, x): 29 | return x+self.Resnet_convolution(x) 30 | 31 | 32 | # In[3]: 33 | 34 | 35 | class Generator(nn.Module): 36 | def __init__(self, input_channels, num_ResnetBlocks = 6): 37 | 38 | super(Generator,self).__init__() 39 | 40 | GeneratorArchitecture = [ 41 | nn.ReflectionPad2d(3), 42 | nn.Conv2d(input_channels,64,7), 43 | nn.InstanceNorm2d(64), 44 | nn.LeakyReLU(negative_slope = 0.01, inplace = True) 45 | ] 46 | 47 | input_features = 64 48 | output_features = input_features*2 49 | 50 | for _ in range(2): 51 | GeneratorArchitecture += [ 52 | nn.Conv2d(input_features, output_features,3,stride =2, padding =1), 53 | nn.InstanceNorm2d(output_features), 54 | nn.PReLU(init =0.25) 55 | ] 56 | 57 | input_features = output_features 58 | output_features = input_features*2 59 | 60 | for _ in range(num_ResnetBlocks): 61 | GeneratorArchitecture += [ResnetBlock(input_features)] 62 | 63 | output_features = input_features//2 64 | 65 | for _ in range(2): 66 | GeneratorArchitecture += [ 67 | nn.ConvTranspose2d(input_features, output_features,3, stride =2, padding=1, output_padding =1), 68 | nn.InstanceNorm2d(output_features), 69 | nn.PReLU(init =0.25) 70 | ] 71 | input_features= output_features 72 | output_features = input_features//2 73 | 74 | 75 | GeneratorArchitecture += [nn.ReflectionPad2d(3), 76 | nn.Conv2d(input_features, input_channels,7), 77 | nn.Tanh()] 78 | 79 | self.GeneratorArchitecture = nn.Sequential(*GeneratorArchitecture) 80 | def forward(self, x): 81 | return self.GeneratorArchitecture(x) 82 | 83 | 84 | 85 | # In[4]: 86 | 87 | 88 | class Discriminator(nn.Module): 89 | 90 | def __init__(self, input_channels): 91 | 92 | super(Discriminator,self).__init__() 93 | 94 | DiscriminatorArchitecture = [ 95 | nn.Conv2d(input_channels, 128, 4, stride =2, padding=1), 96 | nn.LeakyReLU(negative_slope = 0.01, inplace = True), 97 | ] 98 | DiscriminatorArchitecture += [ 99 | nn.Conv2d(128, 256,4,stride=2, padding =1), 100 | nn.InstanceNorm2d (256), 101 | nn.LeakyReLU(negative_slope = 0.01, inplace = True) 102 | ] 103 | DiscriminatorArchitecture += [ 104 | nn.Conv2d(256,512,3, padding =1), 105 | nn.InstanceNorm2d (256), 106 | nn.LeakyReLU(negative_slope = 0.01, inplace = True) 107 | ] 108 | DiscriminatorArchitecture += [nn.Conv2d(512, 1, 4, padding =1)] 109 | 110 | self.DiscriminatorArchitecture = nn.Sequential(*DiscriminatorArchitecture) 111 | 112 | 113 | def forward(self, x): 114 | 115 | x = self.DiscriminatorArchitecture(x) 116 | return x 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[6]: 5 | # This file contains various helper functions that will be useful in creating the network architecture and in training of the network 6 | 7 | import torch 8 | import os 9 | import pdb 10 | import pickle 11 | import argparse 12 | import numpy as np 13 | import scipy 14 | import scipy.misc 15 | import warnings 16 | warnings.filterwarnings("ignore") 17 | 18 | 19 | # In[12]: 20 | 21 | # function to return the mean square error of the discriminator output for a real image. MSE was the loss function 22 | # in the cycle-GAN paper and hence the same has been used 23 | def real_mse_loss(D_out): 24 | return torch.mean((D_out-1)**2) 25 | 26 | 27 | # In[ ]: 28 | 29 | # function to return the mean square error of the discriminator output for a generated (fake) image. MSE was the loss function 30 | # in the cycle-GAN paper and hence the same has been used 31 | def fake_mse_loss(D_out): 32 | return torch.mean(D_out**2) 33 | 34 | 35 | # In[ ]: 36 | # This function checks the consistency of the distribution being learnt by the network 37 | # It takes the absolute mean difference of the real image and the reconstructed image 38 | 39 | def cycle_consistency_loss(real_im, reconstructed_im, lambda_weight): 40 | reconstr_loss = torch.mean(torch.abs(real_im - reconstructed_im)) 41 | return lambda_weight*reconstr_loss 42 | 43 | 44 | # In[7]: 45 | # Function to save a .pkl model for every 'iteration' iterations 46 | 47 | def checkpoint(iteration, G_A2B, G_B2A, D_A, disc_B, checkpoint_dir='checkpoints_cyclegan'): 48 | """Saves the parameters of both generators G_YtoX, G_XtoY and discriminators D_X, D_Y. 49 | """ 50 | G_A2B_path = os.path.join(checkpoint_dir, 'gen_A2B.pkl') 51 | G_B2A_path = os.path.join(checkpoint_dir, 'gen_B2A.pkl') 52 | D_A_path = os.path.join(checkpoint_dir, 'disc_A.pkl') 53 | D_B_path = os.path.join(checkpoint_dir, 'disc_B.pkl') 54 | torch.save(G_A2B.state_dict(), G_A2B_path) 55 | torch.save(G_B2A.state_dict(), G_B2A_path) 56 | torch.save(D_A.state_dict(), D_A_path) 57 | torch.save(D_B.state_dict(), D_B_path) 58 | 59 | 60 | 61 | # In[8]: 62 | 63 | 64 | def merge_images(sources, targets, batch_size=16): 65 | """Creates a grid consisting of pairs of columns, where the first column in 66 | each pair contains images source images and the second column in each pair 67 | contains images generated by the CycleGAN from the corresponding images in 68 | the first column. 69 | """ 70 | _, _, h, w = sources.shape 71 | row = int(np.sqrt(batch_size)) 72 | merged = np.zeros([3, row*h, row*w*2]) 73 | for idx, (s, t) in enumerate(zip(sources, targets)): 74 | i = idx // row 75 | j = idx % row 76 | merged[:, i*h:(i+1)*h, (j*2)*h:(j*2+1)*h] = s 77 | merged[:, i*h:(i+1)*h, (j*2+1)*h:(j*2+2)*h] = t 78 | merged = merged.transpose(1, 2, 0) 79 | return merged 80 | 81 | 82 | # In[9]: 83 | # Function Converts the torch tensor to numpy tensor 84 | 85 | def to_data(x): 86 | 87 | if torch.cuda.is_available(): 88 | x = x.cpu() 89 | x = x.data.numpy() 90 | x = ((x +1)*255 / (2)).astype(np.uint8) # rescale to 0-255 91 | return x 92 | 93 | 94 | # In[ ]: 95 | # Scale takes in an image x and returns that image, scaled 96 | # with a feature_range of pixel values from -1 to 1. 97 | # This function assumes that the input x is already scaled from 0-255. 98 | 99 | def scale(x, feature_range=(-1, 1)): 100 | 101 | 102 | # scale from 0-1 to feature_range 103 | min, max = feature_range 104 | x = x * (max - min) + min 105 | return x 106 | 107 | 108 | # In[10]: 109 | # save_sample function save a sample to check the progress of how well the network has been trained 110 | 111 | def save_samples(iteration, fixed_B, fixed_A, G_B2A, G_A2B, batch_size=16, sample_dir='samples_cyclegan'): 112 | """Saves samples from both generators X->Y and Y->X. 113 | """ 114 | # move input data to correct device 115 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 116 | 117 | fake_A = G_B2A(fixed_B.to(device)) 118 | fake_B = G_A2B(fixed_A.to(device)) 119 | 120 | A, fake_A = to_data(fixed_A), to_data(fake_A) 121 | B, fake_B = to_data(fixed_B), to_data(fake_B) 122 | 123 | merged = merge_images(A, fake_B, batch_size) 124 | path = os.path.join(sample_dir, 'sample-{:06d}-A-B.jpg'.format(iteration)) 125 | scipy.misc.imsave(path, merged) 126 | print('Saved {}'.format(path)) 127 | 128 | merged = merge_images(B, fake_A, batch_size) 129 | path = os.path.join(sample_dir, 'sample-{:06d}-B-A.jpg'.format(iteration)) 130 | scipy.misc.imsave(path, merged) 131 | print('Saved {}'.format(path)) 132 | 133 | 134 | # In[ ]: 135 | # A function to randomly initialize the weights of each layer using Gaussian Distribution 136 | 137 | def weights_init_normal(m): 138 | classname = m.__class__.__name__ 139 | 140 | if classname.find('Conv') != -1: 141 | torch.nn.init.normal(m.weight.data, 0.0, 0.02) 142 | elif classname.find('BatchNorm2d') != -1: 143 | torch.nn.init.normal(m.weight.data, 1.0, 0.02) 144 | torch.nn.init.constant(m.bias.data, 0.0) 145 | 146 | 147 | 148 | # In[11]: 149 | 150 | # Function for causing a learning rate decay. This ensures the error does not shoot up when it is about to reach a minima 151 | class LambdaLR: 152 | 153 | def __init__(self, n_epochs, decay_start_epoch): 154 | assert ((n_epochs - decay_start_epoch) > 0), "Decay must start before the training session ends!" 155 | self.n_epochs = n_epochs 156 | self.decay_start_epoch = decay_start_epoch 157 | 158 | 159 | def step(self, epoch): 160 | return 1.0 - max(0, epoch - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch) 161 | 162 | 163 | 164 | # In[ ]: 165 | 166 | 167 | 168 | 169 | -------------------------------------------------------------------------------- /Training.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[ ]: 5 | 6 | 7 | """ This file is for training the Cycle-GAN network developed to transform a given RGB image to the corresponding depth image 8 | The code has been explained in verbose for better understanding of the reader/user. 9 | """ 10 | """ 11 | DataLoader part of the code has been taken from Pytorch website: https://pytorch.org/tutorials/beginner/data_loading_tutorial.html 12 | """ 13 | 14 | 15 | # In[2]: 16 | 17 | 18 | import itertools 19 | import torch 20 | from torch.utils.data import DataLoader 21 | from torch.autograd import Variable 22 | from PIL import Image 23 | import torch.optim as optim 24 | import torchvision.transforms as transforms 25 | import torchvision.datasets as datasets 26 | import os 27 | import numpy as np 28 | 29 | 30 | # In[3]: 31 | 32 | 33 | # The fuctions and classes being imported below are customized to our needs. Their implementations can be found in their corresponding files 34 | # Generator and Discriminators are the classes developed 35 | # LambdaLR is the function developed for using decaying learning rate for the optimizer 36 | # save_samples saves a sample while training for every given number of epochs to check the performance of the network 37 | # weight_init_normal initializes the weights of the filter with numbers generated by normal distribution. Range is given in the file 38 | # scale normalizes the given image to the range 1 to -1. 0 to 1 would also work and you can change the range by while calling the function. 39 | # Look at function definition for more information 40 | # createArchitecture class creates the cycleGAN architecture using the generator and discriminator classes developed 41 | # getDataLoader takes in the path of the dataset and creates an iterable dataset object. Transforms applied on the dataset 42 | # are hardcoded in the getDataLoader class. check the class for more information 43 | 44 | 45 | # In[4]: 46 | 47 | 48 | from model import Generator 49 | from model import Discriminator 50 | from helper import LambdaLR,real_mse_loss,fake_mse_loss,cycle_consistency_loss 51 | from helper import save_samples,checkpoint,weights_init_normal,scale 52 | from getDataLoader import getDataLoader 53 | from createArchitecture import createArchitecture 54 | 55 | 56 | # In[5]: 57 | 58 | 59 | # Since we are using RGB images as our input, input number of channels are three 60 | # We have taken a batch_size of 16. Make sure the batch_size is always a multiple of 16 to utilise all the streaming multiprocessors in the GPU 61 | # Learning rate assumed is 0.001 because it has shown good performance 62 | # decay_epoch is the epoch at which the learning rate starts to decay so that when the error moves towards 63 | # the global minima, it doesn't overshoot 64 | # image size is 128 as it would train faster on our GPU. Size can be changed based on the availability of the GPU 65 | # Number of epochs is high because its a generative model and requires intense training to give good output. 66 | # Make sure the network doesn't overtrain 67 | # beta1 and beta2 are the values used by Adam optimizer. 68 | 69 | 70 | # In[6]: 71 | 72 | 73 | input_numChannels = 3 74 | batch_size = 16 75 | num_epochs =5000 76 | learningRate = 0.001 77 | decay_epoch = int(round(0.6*num_epochs)) 78 | image_size = 128 79 | beta1 = 0.55 80 | beta2 = 0.999 81 | 82 | 83 | # In[7]: 84 | 85 | 86 | # rgb and depth dataset objects are created below. The arguments in the getDataLoader are folder names of the training 87 | # dataset. For more information, check the folder structure described in readme 88 | # .load_data() loads the data from the folder, applies the transforms and returns a training and a testing dataset iterable 89 | 90 | 91 | # In[8]: 92 | 93 | 94 | rgb =getDataLoader('rgb') 95 | depth =getDataLoader('depth') 96 | dataloader_A, test_dataloader_A = rgb.load_data() 97 | dataloader_B, test_dataloader_B = depth.load_data() 98 | 99 | 100 | # In[9]: 101 | 102 | 103 | # createArchitecture creates an instance for cycleGAN architecture as mentioned earlier. .create_model() function creates the cycleGAN model 104 | # and moves it GPU if available 105 | 106 | 107 | # In[10]: 108 | 109 | 110 | mod = createArchitecture() 111 | gen_A2B,gen_B2A,disc_A,disc_B = mod.create_model() 112 | 113 | 114 | # In[12]: 115 | 116 | 117 | print(disc_A) 118 | 119 | 120 | # In[ ]: 121 | 122 | 123 | 124 | 125 | 126 | # In[ ]: 127 | 128 | 129 | # network parameters lists are created because they are easy to operate on 130 | # Adam optimizer has been used, as suggested in the GAN paper, and an optimizer for generators and discriminator_A and discriminator_B are created. 131 | # Separate optimizers are not created for the generator because the parameters are combined in the list. Separate can be created as well. But this cannot 132 | # be the case for discriminators because they have identify two separate o/p 133 | # A learning scheduler has been created for the generators and discrimnators to offer learning rate decay when the 134 | # training approaches global minima 135 | 136 | 137 | # In[6]: 138 | 139 | 140 | generator_params = list(gen_A2B.parameters())+ list(gen_B2A.parameters()) 141 | 142 | 143 | # In[7]: 144 | 145 | 146 | optimizer_gen = optim.Adam(generator_params, learningRate, [beta1, beta2]) 147 | optimizer_disc_A = torch.optim.Adam(disc_A.parameters(),learningRate, [beta1,beta2]) 148 | optimizer_disc_B = torch.optim.Adam(disc_B.parameters(), learningRate, [beta1,beta2]) 149 | 150 | 151 | # In[8]: 152 | 153 | 154 | learningRate_scheduler_gen = torch.optim.lr_scheduler.LambdaLR(optimizer_gen,lr_lambda = LambdaLR(num_epochs,decay_epoch).step) 155 | learningRate_scheduler_disc_A = torch.optim.lr_scheduler.LambdaLR(optimizer_disc_A, lr_lambda = LambdaLR(num_epochs,decay_epoch).step) 156 | learningRate_scheduler_disc_B = torch.optim.lr_scheduler.LambdaLR(optimizer_disc_B, lr_lambda = LambdaLR(num_epochs,decay_epoch).step) 157 | 158 | 159 | # In[ ]: 160 | 161 | 162 | # Training the Cycle-GAN proceeds in the following way: 163 | 164 | # We first train the Discriminator in the following way- 165 | # 1) Check the output of discriminator when real image is fed 166 | # 2) Using the output, get the MSE loss to quantify the error that discriminator produces for real image 167 | # 3) Now get a fake image from the generator [from B-->A] 168 | # 4) Check how the discriminator judges the fake image 169 | # 5) Get the MSE loss produced for the fake image 170 | # 6) Sum both the losses in step (2) and step (5) 171 | # 7) Repeat the same method for the second discriminator as well 172 | 173 | # Training the Generator- 174 | # 1) A sample from domain-A (RGB images in our case) is given to the generator to convert it from Domain A-->B 175 | # 2) The generated output is then fed to the discriminator to check the level of fakeness and the MSE loss is received 176 | # 3) The generated output is the fed to the generator (B-->A) to check if the correct mappings are learnt 177 | # 4) The output of step(3) and the input of step (1) are checked for the consistency because the output of step(3) should be 178 | # input of step(1) (use cycle_consistency_loss) 179 | # 5) steps (3) and (4) are repeated for the second generator 180 | # 6) Both the losses are added and propogated backward 181 | 182 | 183 | # In[9]: 184 | 185 | 186 | def training_loop(dataloader_RGB, dataloader_Depth, test_dataloader_RGB, test_dataloader_Depth, n_epochs=1000): 187 | 188 | print_every = 10 189 | losses = [] 190 | test_iter_RGB = iter(test_dataloader_RGB) 191 | test_iter_Depth = iter(test_dataloader_Depth) 192 | fixed_RGB = test_iter_RGB.next()[0] 193 | fixed_Depth = test_iter_Depth.next()[0] 194 | fixed_RGB = scale(fixed_RGB) 195 | fixed_Depth = scale(fixed_Depth) 196 | iter_RGB = iter(dataloader_RGB) 197 | iter_Depth = iter(dataloader_Depth) 198 | batches_per_epoch = min(len(iter_RGB), len(iter_Depth)) 199 | for epoch in range(1, n_epochs+1): 200 | 201 | 202 | if epoch % batches_per_epoch == 0: 203 | iter_RGB = iter(dataloader_RGB) 204 | iter_Depth = iter(dataloader_Depth) 205 | 206 | images_RGB, _ = iter_RGB.next() 207 | images_RGB = scale(images_RGB) 208 | 209 | images_Depth, _ = iter_Depth.next() 210 | images_Depth = scale(images_Depth) 211 | 212 | 213 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 214 | images_RGB = images_RGB.to(device) 215 | images_Depth = images_Depth.to(device) 216 | #-------------------------------------Discriminator Training-----------------------------------------------------# 217 | 218 | # optimizers are initialized to zero so that they don't carry forward the errors form previous epochs. Basically to 219 | # to avoid any accumulation 220 | optimizer_disc_A.zero_grad() 221 | 222 | # Discriminator-A is initially checked for its performance with real image 223 | out_RGB = disc_A(images_RGB) 224 | disc_A_real_loss = real_mse_loss(out_RGB) 225 | 226 | # A fake image is generated by the generator to check the Discriminator's performance with fake image 227 | fake_RGB = gen_B2A(images_Depth) 228 | 229 | 230 | out_RGB = disc_A(fake_RGB) 231 | disc_A_fake_loss = fake_mse_loss(out_RGB) 232 | 233 | # Losses produced while the above operations are added which is the total loss of the Discriminator-A. It is backpropogated 234 | disc_A_loss = disc_A_real_loss + disc_A_fake_loss 235 | disc_A_loss.backward() 236 | optimizer_disc_A.step() 237 | 238 | # Same operations as mentioned above are now performed on Discriminator-B 239 | optimizer_disc_B.zero_grad() 240 | 241 | out_Depth = disc_B(images_Depth) 242 | disc_B_real_loss = real_mse_loss(out_Depth) 243 | 244 | fake_Depth = gen_A2B(images_RGB) 245 | 246 | out_Depth = disc_B(fake_Depth) 247 | disc_B_fake_loss = fake_mse_loss(out_Depth) 248 | 249 | 250 | disc_B_loss = disc_B_real_loss + disc_B_fake_loss 251 | disc_B_loss.backward() 252 | optimizer_disc_B.step() 253 | 254 | #------------------------------------------Generator Training---------------------------------------------------# 255 | 256 | optimizer_gen.zero_grad() 257 | 258 | # An image is produced from the generator (in this case generator-B2A is being trained) is produced to check how well it can 259 | fake_RGB = gen_B2A(images_Depth) 260 | 261 | # The generated image is then fed to discrminator to check the fakeness and the MSE loss is received 262 | out_RGB = disc_A(fake_RGB) 263 | gen_B2A_loss = real_mse_loss(out_RGB) 264 | 265 | # To check the consistency, the generated image is fed to other generator(A2B in this case) and reconstructed image is received 266 | reconstructed_Depth = gen_A2B(fake_RGB) 267 | # Reconstruction loss is received by comparing two images as mentioned in step(4) of training generators 268 | reconstructed_Depth_loss = cycle_consistency_loss(images_Depth, reconstructed_Depth, lambda_weight=10) 269 | 270 | 271 | # Same operations are done, but this time to train gen_A2B 272 | fake_Depth = gen_A2B(images_RGB) 273 | 274 | 275 | out_Depth = disc_B(fake_Depth) 276 | gen_A2B_loss = real_mse_loss(out_Depth) 277 | 278 | 279 | reconstructed_RGB = gen_B2A(fake_Depth) 280 | reconstructed_RGB_loss = cycle_consistency_loss(images_RGB, reconstructed_RGB, lambda_weight=10) 281 | 282 | # Losses from both the generators are received and propogated backwards 283 | g_total_loss = gen_B2A_loss + gen_A2B_loss + reconstructed_Depth_loss + reconstructed_RGB_loss 284 | g_total_loss.backward() 285 | optimizer_gen.step() 286 | 287 | # Weight decay as mentioned earlier 288 | learningRate_scheduler_gen.step() 289 | learningRate_scheduler_disc_A.step() 290 | learningRate_scheduler_disc_B.step() 291 | 292 | 293 | 294 | if epoch % print_every == 0: 295 | 296 | losses.append((disc_A_loss.item(), disc_B_loss.item(), g_total_loss.item())) 297 | print('Epoch [{:5d}/{:5d}] | d_X_loss: {:6.4f} | d_Y_loss: {:6.4f} | g_total_loss: {:6.4f}'.format( 298 | epoch, n_epochs, disc_A_loss.item(), disc_B_loss.item(), g_total_loss.item())) 299 | 300 | # A training sample is saved every given time to check the performance of our network 301 | sample_every=100 302 | 303 | if epoch % sample_every == 0: 304 | gen_B2A.eval() 305 | gen_A2B.eval() 306 | save_samples(epoch, fixed_Depth, fixed_RGB, gen_B2A, gen_A2B, batch_size=16) 307 | gen_B2A.train() 308 | gen_A2B.train() 309 | 310 | # Models are saved after every given number of epochs 311 | checkpoint_every=1000 312 | # Save the model parameters 313 | if epoch % checkpoint_every == 0: 314 | checkpoint(epoch, gen_A2B, gen_B2A, disc_A, disc_B) 315 | 316 | 317 | 318 | 319 | return losses 320 | 321 | 322 | # In[10]: 323 | 324 | 325 | losses = training_loop(dataloader_A, dataloader_B, test_dataloader_A, test_dataloader_B, n_epochs=num_epochs) 326 | 327 | 328 | # In[11]: 329 | 330 | 331 | # gen_A2B,gen_B2A,disc_A,disc_B 332 | 333 | 334 | # In[12]: 335 | 336 | 337 | # A plot to track the training losses 338 | import matplotlib.pyplot as plt 339 | fig, ax = plt.subplots(figsize=(12,8)) 340 | losses = np.array(losses) 341 | plt.plot(losses.T[0], label='Discriminator, RGB', alpha=0.5) 342 | plt.plot(losses.T[1], label='Discriminator, Depth', alpha=0.5) 343 | plt.plot(losses.T[2], label='Generators', alpha=0.5) 344 | plt.title("Training Losses") 345 | plt.legend() 346 | plt.savefig('Training_Loss.png') 347 | 348 | 349 | # In[13]: 350 | 351 | 352 | plt.show() 353 | 354 | 355 | # In[ ]: 356 | 357 | 358 | 359 | 360 | 361 | 362 | --------------------------------------------------------------------------------