├── EpochVerify.txt ├── CycleGANBD.png ├── results ├── A │ ├── 0032.png │ ├── 0077.png │ ├── 0102.png │ ├── 0110.png │ ├── 0121.png │ ├── 0124.png │ ├── 0126.png │ ├── 0130.png │ ├── 0134.png │ ├── 0135.png │ ├── 0141.png │ ├── 0145.png │ ├── 0166.png │ ├── 0168.png │ ├── 0224.png │ └── 0255.png └── B │ ├── 0016.png │ ├── 0048.png │ ├── 0074.png │ └── 0187.png ├── requirements.txt ├── configure.py ├── LICENSE ├── datasets.py ├── test.py ├── README.md ├── models.py ├── utils.py └── train.py /EpochVerify.txt: -------------------------------------------------------------------------------- 1 | 2 | 0 -------------------------------------------------------------------------------- /CycleGANBD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekyana/CycleGANs-PyTorch/HEAD/CycleGANBD.png -------------------------------------------------------------------------------- /results/A/0032.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekyana/CycleGANs-PyTorch/HEAD/results/A/0032.png -------------------------------------------------------------------------------- /results/A/0077.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekyana/CycleGANs-PyTorch/HEAD/results/A/0077.png -------------------------------------------------------------------------------- /results/A/0102.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekyana/CycleGANs-PyTorch/HEAD/results/A/0102.png -------------------------------------------------------------------------------- /results/A/0110.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekyana/CycleGANs-PyTorch/HEAD/results/A/0110.png -------------------------------------------------------------------------------- /results/A/0121.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekyana/CycleGANs-PyTorch/HEAD/results/A/0121.png -------------------------------------------------------------------------------- /results/A/0124.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekyana/CycleGANs-PyTorch/HEAD/results/A/0124.png -------------------------------------------------------------------------------- /results/A/0126.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekyana/CycleGANs-PyTorch/HEAD/results/A/0126.png -------------------------------------------------------------------------------- /results/A/0130.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekyana/CycleGANs-PyTorch/HEAD/results/A/0130.png -------------------------------------------------------------------------------- /results/A/0134.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekyana/CycleGANs-PyTorch/HEAD/results/A/0134.png -------------------------------------------------------------------------------- /results/A/0135.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekyana/CycleGANs-PyTorch/HEAD/results/A/0135.png -------------------------------------------------------------------------------- /results/A/0141.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekyana/CycleGANs-PyTorch/HEAD/results/A/0141.png -------------------------------------------------------------------------------- /results/A/0145.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekyana/CycleGANs-PyTorch/HEAD/results/A/0145.png -------------------------------------------------------------------------------- /results/A/0166.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekyana/CycleGANs-PyTorch/HEAD/results/A/0166.png -------------------------------------------------------------------------------- /results/A/0168.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekyana/CycleGANs-PyTorch/HEAD/results/A/0168.png -------------------------------------------------------------------------------- /results/A/0224.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekyana/CycleGANs-PyTorch/HEAD/results/A/0224.png -------------------------------------------------------------------------------- /results/A/0255.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekyana/CycleGANs-PyTorch/HEAD/results/A/0255.png -------------------------------------------------------------------------------- /results/B/0016.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekyana/CycleGANs-PyTorch/HEAD/results/B/0016.png -------------------------------------------------------------------------------- /results/B/0048.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekyana/CycleGANs-PyTorch/HEAD/results/B/0048.png -------------------------------------------------------------------------------- /results/B/0074.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekyana/CycleGANs-PyTorch/HEAD/results/B/0074.png -------------------------------------------------------------------------------- /results/B/0187.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekyana/CycleGANs-PyTorch/HEAD/results/B/0187.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.16.2 2 | torch==1.1.0 3 | scipy==1.2.1 4 | torchvision==0.3.0 5 | visdom==0.1.8.8 6 | Pillow==6.1.0 7 | -------------------------------------------------------------------------------- /configure.py: -------------------------------------------------------------------------------- 1 | options = { 2 | "epoch": 0, 3 | "nEpochs": 200, 4 | "batchsize": 2, 5 | "datapath": "./young2old", 6 | "outputpath":'./output/', 7 | "learningrate": 2e-4, 8 | "decayEpoch": 50, 9 | "inResolution":[256, 256, 3], 10 | "outResolution":[256, 256, 3], 11 | "continued": False, 12 | "nThreads":10, 13 | "unpaired":True, 14 | } 15 | 16 | testoptions = { 17 | "batchsize": 1, 18 | "datapath": './young2old', 19 | "outputpath":'./output/', 20 | "inResolution":[256, 256, 3], 21 | "outResolution":[256, 256, 3], 22 | "nThreads":10, 23 | "GEN_AtoB":'./output/GEN_AtoB.pth', 24 | "GEN_BtoA":'./output/GEN_BtoA.pth' 25 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Abhishek 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.vision.transforms as transforms 3 | import os 4 | import glob 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | from scipy.misc import imresize 8 | 9 | 10 | class A2BDataset(Dataset): 11 | def __init__(self, path, mode='train', transforms_arg=None, unpaired=False, size=(256,256,3)): 12 | self.transforms = transforms.Compose(transforms_arg) 13 | self.unpaired = unpaired 14 | self.size = size 15 | self.A_paths = sorted(glob.glob(f'{root}/{mode}/A/*.*')) 16 | self.B_paths = sorted(glob.glob(f'{root}/{mode}/B/*.*')) 17 | self.totA_paths = len(self.A_paths) 18 | self.totB_paths = len(self.B_paths) 19 | self.cache = {'A':{}, 'B':{}} 20 | 21 | def __getitem__(self, idx): 22 | idxa = idx%self.totA_paths 23 | if idxa in self.cache['A']: 24 | image_A = self.cache['A'][idxa] 25 | else: 26 | imageA = Image.open(self.A_paths[idxa]) 27 | imageA = imresize(imageA, (256,256,3)) 28 | imageA = np.array(imageA) 29 | if len(imageA.shape)==2: #For mono channel images 30 | imageA = np.stack((imageA,)*3, axis=-1) 31 | imageA = Image.fromarray(imageA) 32 | image_A = self.transforms(imageA) 33 | self.cache['A'][idxa] = image_A 34 | 35 | if self.unpaired: 36 | idxb = np.random.randint(0, self.totB_paths-1) 37 | else: 38 | idxb = idx%self.totB_paths 39 | 40 | if idxb in self.cache['B']: 41 | image_B = self.cache['B'][idxb] 42 | else: 43 | imageB = Image.open(self.B_paths[idxb]) 44 | imageB = imresize(imageB, (256,256,3)) 45 | imageB = np.array(imageB) 46 | if len(imageB.shape)==2: #For mono channel images 47 | imageB = np.stack((imageB,)*3, axis=-1) 48 | imageB = Image.fromarray(imageB) 49 | image_B = self.transforms(imageB) 50 | self.cache['B'][idxb] = image_B 51 | return {'A':image_A, 'B':image_B} -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from models import Generator 4 | from datasets import AnBDataset 5 | from torch.autograd import Variable 6 | from torch.utils.data import DataLoader 7 | import torchvision.transforms as transforms 8 | from configure import testoptions as options 9 | from torchvision.utils import make_grid, save_image 10 | 11 | 12 | inH, inW, inC = options["inResolution"] 13 | outH, outW, outC = options["outResolution"] 14 | 15 | if torch.cuda.is_available(): 16 | options["cuda"]=True 17 | 18 | def loadModel(options): 19 | GEN_A2B = Generator(inC, outC) 20 | GEN_B2A = Generator(outC, inC) 21 | if options["cuda"]: 22 | GEN_A2B.cuda() 23 | GEN_B2A.cuda() 24 | GEN_A2B.load_state_dict(torch.load(options["GEN_A2B"])) 25 | GEN_B2A.load_state_dict(torch.load(options["GEN_B2A"])) 26 | GEN_A2B.eval() 27 | GEN_B2A.eval() 28 | return GEN_A2B, GEN_B2A 29 | 30 | def loadData(options): 31 | ApplyTransforms = [ transforms.ToTensor(), 32 | transforms.Normalize((1/2,)*3, (1/2,)*3) ] 33 | dataloader = DataLoader(ImageDataset(options["datapath"], transforms_arg=ApplyTransforms, mode='test'), 34 | batch_size=options["batchsize"], shuffle=False, num_workers=options["nThreads"]) 35 | return dataloader 36 | 37 | def tester(options): 38 | GEN_A2B, GEN_B2A = loadModel(options) 39 | dataloader = loadData(options) 40 | 41 | #Create placeholders for the images 42 | if options["cuda"]: 43 | Tensor = torch.cuda.FloatTensor 44 | else: 45 | Tensor = torch.Tensor 46 | imageA = Tensor(batchsize, inC, inH, inW) 47 | imageB = Tensor(batchsize, outC, outH, outW) 48 | 49 | if not os.path.exists(f'{options["outputpath"]}/A'): 50 | os.makedirs(f'{options["outputpath"]}/A') 51 | if not os.path.exists(f'{options["outputpath"]}/B'): 52 | os.makedirs(f'{options["outputpath"]}/B') 53 | 54 | for batch_id, batch_data in enumerate(dataloader): 55 | realA = Variable(imageA.copy_(batch_data['A'])) 56 | realB = Variable(imageB.copy_(batch_data['B'])) 57 | 58 | fakeB = GEN_A2B(realA).data 59 | fakeA = GEN_B2A(realB).data 60 | 61 | IMAGEA = torch.cat([0.5*(realA+1.0), 0.5*(fakeB+1.0)]) 62 | IMAGEB = torch.cat([0.5*(realB+1.0), 0.5*(fakeA+1.0)]) 63 | 64 | save_image(IMAGEA, f'{options["outputpath"]}/A/{batch_id+1}') 65 | save_image(IMAGEB, f'{options["outputpath"]}/B/{batch_id+1}') 66 | print(f"Generated image {batch_id+1} of {len(dataloader)}") 67 | print("DONE!! :D") 68 | 69 | if __name__=="__main__": 70 | tester(options) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CycleGANs-PyTorch applied on Young to Old image converter. 2 | Py-Torch implementation of [CycleGANs Paper](https://arxiv.org/pdf/1703.10593.pdf). 3 | * You can find more about this project in my [blog here](http://blog.abhishekyana.ml/implement-your-own-young-to-old-age-converter-app-in-pytorch-using-cyclegans/). 4 | ### CycleGAN Block Diagram: 5 | ![BD](./CycleGANBD.png) 6 | ### RESULTS FIRST: Young to Old converter 7 | ![img1.jpg](./results/A/0166.png) 8 | ![img2.jpg](./results/A/0168.png) 9 | ![img3.jpg](./results/A/0145.png) 10 | ![img4.jpg](./results/A/0255.png) 11 | ### IF you want to replicate these results may be on different dataset. Read More.. 12 | 1. Clone the repository: 13 | ``` 14 | git clone https://github.com/abhishekyana/CycleGANs-PyTorch.git 15 | cd CycleGANs-PyTorch 16 | # As this is a huge project, I'd suggest to make a conda environment and then run the training and all. 17 | ``` 18 | 1. Install all the requirements from requirements.txt file: 19 | 1. Download the dataset, It can be grabbed from [here](https://www.kaggle.com/abhishekyana/young2old-dataset). 20 | 1. Unzip and Move the dataset folder into this project's root directory. 21 | 1. Adjust the configure.py file according to your flavour, these parameters affect the training. 22 | 1. Run the `python train.py` file and see the training happen for yourself. 23 | * The models will be saved to and loaded from ./outputs as default. 24 | * The model trained for around 4 hours on GTX1080 and i7 system. 25 | 26 | ### If you want to test the mode, then you can download the pretrained model [from here](./). Sorry the link is broken I'll fix it.. 27 | * Download the dataset. 28 | * Download the pretrained model. Only Generator model is enough. 29 | * Copy these folders into appropriate directories as mentioned above. 30 | * Run `python test.py`, After the provess is done, you can see the Juxtaposed results in `./outputs/A` and `./outputs/B`. 31 | * If you want to run this on your own images, Copy your image into a directory in `./directory/A` if you want to make your picture old or into `./directory/B` if you want your picture to be Young. Then edit the `./directory` in testoptions in `configure.py` and run the code again. Now, you can see the your image in the outputs directory. 32 | 33 | ### Please Feel Free to Fork it, Clone it and whatever you want. 34 | * Not only this data, A CycleGAN can map from any unpaired domains, as this application si trending now, I've chosen this to code. 35 | ## With Love on Open Source 36 | ### Thank you 37 | This project is inspired from [Aitor Ruano](https://github.com/aitorzip) and I would like to thank him for providing such a beautiful code which I used to clarify my doubts during the implementation. 38 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | class ResBlock(nn.Module): 5 | def __init__(self, inFeatures): 6 | super(ResBlock, self).__init__() 7 | self.conv = nn.Sequential(nn.ReflectionPad2d(1), 8 | nn.Conv2d(inFeatures, inFeatures, 3), 9 | nn.InstanceNorm2d(inFeatures), 10 | nn.ReLU(inplace=True), 11 | nn.ReflectionPad2d(1), 12 | nn.Conv2d(inFeatures, inFeatures, 3), 13 | nn.InstanceNorm2d(inFeatures)) 14 | def forward(self, X): 15 | out = X + self.conv(X) 16 | return out 17 | 18 | 19 | class Generator(nn.Module): 20 | def __init__(self, inputnc, outputnc, nResBlocks=9): 21 | super(Generator, self).__init__() 22 | 23 | layers = [nn.ReflectionPad2d(3), 24 | nn.Conv2d(inputnc, 64, 7), 25 | nn.InstanceNorm2d(64), 26 | nn.ReLU(inplace=True)] 27 | #To downsample the Image 28 | inFeatures = 64 29 | outFeatures = 2*inFeatures 30 | 31 | for i in range(2): 32 | layers += [nn.Conv2d(inFeatures, outFeatures, 3, stride=2, padding=1), 33 | nn.InstanceNorm2d(outFeatures), 34 | nn.ReLU(inplace=True)] 35 | inFeatures = outFeatures 36 | outFeatures = 2*inFeatures 37 | 38 | for i in range(nResBlocks): 39 | layers += [ResBlock(inFeatures)] 40 | 41 | #To upsample the Image 42 | outFeatures = inFeatures//2 43 | 44 | for i in range(2): 45 | layers += [nn.ConvTranspose2d(inFeatures, outFeatures, 3, stride=2, padding=1, output_padding=1), 46 | nn.InstanceNorm2d(outFeatures), 47 | nn.ReLU(inplace=True)] 48 | inFeatures = outFeatures 49 | outFeatures = inFeatures//2 50 | 51 | layers += [nn.ReflectionPad2d(3), 52 | nn.Conv2d(64, outputnc, 7), 53 | nn.Tanh()] 54 | self.model = nn.Sequential(*layers) 55 | 56 | def forward(self, X): 57 | out=self.model(X) 58 | return out 59 | 60 | class Discriminator(nn.Module): 61 | def __init__(self, inputnc): 62 | super(Discriminator, self).__init__() 63 | layers = [nn.Conv2d(inputnc, 64, 4, stride=2, padding=1), 64 | nn.LeakyReLU(0.2, inplace=True), 65 | nn.Conv2d(64, 128, 4, stride=2, padding=1), 66 | nn.InstanceNorm2d(128), 67 | nn.LeakyReLU(0.2, inplace=True), 68 | nn.Conv2d(128, 256, 4, stride=2, padding=1), 69 | nn.InstanceNorm2d(256), 70 | nn.LeakyReLU(0.2, inplace=True), 71 | nn.Conv2d(256, 512, 4, padding=1), 72 | nn.InstanceNorm2d(512), 73 | nn.LeakyReLU(0.2, inplace=True), 74 | nn.Conv2d(512, 1, 4, padding=1)] 75 | self.model = nn.Sequential(*layers) 76 | def forward(self, X): 77 | out = self.model(X) 78 | out = F.avg_pool2d(out, out.size()[2:]).view(out.size()[0], -1) 79 | return out -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | import datetime 4 | import sys 5 | 6 | from torch.autograd import Variable 7 | import torch 8 | from visdom import Visdom 9 | import numpy as np 10 | 11 | tensor2image = lambda T : (127.5*(T[0].cpu().float().numpy()+ 1.0)).astype(np.uint8) 12 | 13 | class Logger(): 14 | def __init__(self, n_epochs, batches_epoch): 15 | self.viz = Visdom() 16 | self.n_epochs = n_epochs 17 | self.batches_epoch = batches_epoch 18 | self.epoch = 1 19 | self.batch = 1 20 | self.prev_time = time.time() 21 | self.mean_period = 0 22 | self.losses = {} 23 | self.loss_windows = {} 24 | self.image_windows = {} 25 | 26 | 27 | def log(self, losses=None, images=None): 28 | self.mean_period += (time.time() - self.prev_time) 29 | self.prev_time = time.time() 30 | 31 | sys.stdout.write('\rEpoch %03d/%03d [%04d/%04d] -- ' % (self.epoch, self.n_epochs, self.batch, self.batches_epoch)) 32 | 33 | for i, loss_name in enumerate(losses.keys()): 34 | if loss_name not in self.losses: 35 | self.losses[loss_name] = losses[loss_name].item() 36 | else: 37 | self.losses[loss_name] += losses[loss_name].item() 38 | 39 | if (i+1) == len(losses.keys()): 40 | sys.stdout.write('%s: %.4f -- ' % (loss_name, self.losses[loss_name]/self.batch)) 41 | else: 42 | sys.stdout.write('%s: %.4f | ' % (loss_name, self.losses[loss_name]/self.batch)) 43 | 44 | batches_done = self.batches_epoch*(self.epoch - 1) + self.batch 45 | batches_left = self.batches_epoch*(self.n_epochs - self.epoch) + self.batches_epoch - self.batch 46 | sys.stdout.write('ETA: %s' % (datetime.timedelta(seconds=batches_left*self.mean_period/batches_done))) 47 | 48 | # Draw images 49 | for image_name, tensor in images.items(): 50 | if image_name not in self.image_windows: 51 | self.image_windows[image_name] = self.viz.image(tensor2image(tensor.data), opts={'title':image_name}) 52 | else: 53 | self.viz.image(tensor2image(tensor.data), win=self.image_windows[image_name], opts={'title':image_name}) 54 | 55 | # End of epoch 56 | if (self.batch % self.batches_epoch) == 0: 57 | # Plot losses 58 | for loss_name, loss in self.losses.items(): 59 | if loss_name not in self.loss_windows: 60 | self.loss_windows[loss_name] = self.viz.line(X=np.array([self.epoch]), Y=np.array([loss/self.batch]), 61 | opts={'xlabel': 'epochs', 'ylabel': loss_name, 'title': loss_name}) 62 | else: 63 | self.viz.line(X=np.array([self.epoch]), Y=np.array([loss/self.batch]), win=self.loss_windows[loss_name], update='append') 64 | # Reset losses for next epoch 65 | self.losses[loss_name] = 0.0 66 | 67 | self.epoch += 1 68 | self.batch = 1 69 | sys.stdout.write('\n') 70 | else: 71 | self.batch += 1 72 | 73 | 74 | 75 | class ReplayBuffer(): 76 | def __init__(self, max_size=50): 77 | assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.' 78 | self.max_size = max_size 79 | self.data = [] 80 | 81 | def push_and_pop(self, data): 82 | to_return = [] 83 | for element in data.data: 84 | element = torch.unsqueeze(element, 0) 85 | if len(self.data) < self.max_size: 86 | self.data.append(element) 87 | to_return.append(element) 88 | else: 89 | if random.uniform(0,1) > 0.5: 90 | i = random.randint(0, self.max_size-1) 91 | to_return.append(self.data[i].clone()) 92 | self.data[i] = element 93 | else: 94 | to_return.append(element) 95 | return Variable(torch.cat(to_return)) 96 | 97 | def weights_init_normal(m): 98 | classname = m.__class__.__name__ 99 | if classname.find('Conv') != -1: 100 | torch.nn.init.normal(m.weight.data, 0.0, 0.02) 101 | elif classname.find('BatchNorm2d') != -1: 102 | torch.nn.init.normal(m.weight.data, 1.0, 0.02) 103 | torch.nn.init.constant(m.bias.data, 0.0) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import itertools 4 | from PIL import Image 5 | from datasets import AnBDataset 6 | from torch.autograd import Variable 7 | from torch.utils.data import DataLoader 8 | import torchvision.transforms as transforms 9 | from models import Generator, Discriminator 10 | from utils import ReplayBuffer, Logger, weights_init_normal 11 | 12 | from configure import options #Edit this file to control the parameters once, without everytime pasting huge run command for every run. 13 | 14 | 15 | inH, inW, inC = options["inResolution"] 16 | outH, outW, outC = options["outResolution"] 17 | 18 | if torch.cuda.is_available(): 19 | options["cuda"]=True 20 | 21 | def LoadData(path): 22 | ApplyTransforms = [ transforms.Resize((int(inH*1.12), int(inW*1.12)), Image.BICUBIC), 23 | transforms.RandomCrop((inH, inW)), 24 | transforms.RandomHorizontalFlip(), 25 | transforms.ToTensor(), 26 | transforms.Normalize((1/2,)*3,(1/2,)*3)] 27 | Data = AnBDataset(path, transforms_arg=ApplyTransforms, unpaired=options["unpaired"]) 28 | dataloader = DataLoader(Data, batch_size=options["batchsize"], shuffle=True, num_workers=options["nThreads"]) 29 | return dataloader 30 | 31 | 32 | def CycleGANmapper(inC, outC, options, init_weights=True): 33 | GA2B = Generator(inC, outC) 34 | DB = Discriminator(outC) 35 | GB2A = Generator(outC, inC) 36 | DA = Discriminator(inC) 37 | if options["cuda"]: 38 | GA2B.cuda() 39 | DB.cuda() 40 | GB2A.cuda() 41 | DA.cuda() 42 | if init_weights: 43 | GA2B.apply(weights_init_normal) 44 | DB.apply(weights_init_normal) 45 | GB2A.apply(weights_init_normal) 46 | DA.apply(weights_init_normal) 47 | return (GA2B, DB), (GB2A, DA) 48 | if options["continued"]: 49 | GA2B.load_state_dict(torch.load("output/GEN_AtoB.pth")) 50 | DB.load_state_dict(torch.load("output/DIS_B.pth")) 51 | GB2A.load_state_dict(torch.load("output/GEN_BtoA.pth")) 52 | DA.load_state_dict(torch.load("output/DIS_A.pth")) 53 | return (GA2B, DB), (GB2A, DA) 54 | 55 | def SaveModels(): 56 | torch.save(GEN_AtoB.state_dict(), options["outputpath"]+'GEN_AtoB.pth') 57 | torch.save(DIS_B.state_dict(), options["outputpath"]+'DIS_B.pth') 58 | torch.save(GEN_BtoA.state_dict(), options["outputpath"]+'GEN_BtoA.pth') 59 | torch.save(DIS_A.state_dict(), options["outputpath"]+'DIS_A.pth') 60 | print(f"Models Saved at {options['outputpath']}") 61 | 62 | def trainer(options): 63 | startEpoch = options["epoch"] 64 | nEpochs = options["nEpochs"] 65 | decayEpoch = options["decayEpoch"] 66 | assert(nEpochs>decayEpoch), "The decay epoch is larger than total epochs, There will be no decay :P, Sure?" 67 | 68 | (GEN_AtoB, DIS_B), (GEN_BtoA, DIS_A) = CycleGANmapper(inC, outC, options) 69 | 70 | # LOSSES 71 | GANLoss = torch.nn.MSELoss() # ImageA to ImageB distance 72 | CycleLoss = torch.nn.L1Loss() # ImageA->ImageB->ImageA' distance 73 | IdentityLoss = torch.nn.L1Loss() # Absolute loss 74 | 75 | optim_GAN = torch.optim.Adam(list(GEN_AtoB.parameters())+list(GEN_BtoA.parameters()), lr=options["learningrate"], betas=[0.5, 0.999]) 76 | optim_DIS_A = torch.optim.Adam(DIS_A.parameters(), lr=options["learningrate"], betas=[0.5, 0.999]) 77 | optim_DIS_B = torch.optim.Adam(DIS_B.parameters(), lr=options["learningrate"], betas=[0.5, 0.999]) 78 | 79 | # LR Scheduler should be here 80 | lr_scheduler = lambda ep: 1.0 - max(0, ep+startEpoch)/(nEpochs - decayEpoch) 81 | 82 | lr_scheduler_GAN = torch.optim.lr_scheduler.LambdaLR(optim_GAN, lr_lambda=lr_scheduler) 83 | lr_scheduler_Dis_A = torch.optim.lr_scheduler.LambdaLR(optim_DIS_A, lr_lambda=lr_scheduler) 84 | lr_scheduler_Dis_B = torch.optim.lr_scheduler.LambdaLR(optim_DIS_B, lr_lambda=lr_scheduler) 85 | 86 | # Tensors Memory Allocation 87 | batchsize = options["batchsize"] 88 | if options["cuda"]: 89 | Tensor = torch.cuda.FloatTensor 90 | else: 91 | Tensor = torch.Tensor 92 | imageA = Tensor(batchsize, inC, inH, inW) 93 | imageB = Tensor(batchsize, outC, outH, outW) 94 | targetReal = Variable(Tensor(batchsize).fill_(1.0), requires_grad=False) 95 | targetFake = Variable(Tensor(batchsize).fill_(0.0), requires_grad=False) 96 | 97 | FakeAHolder = ReplayBuffer() # Check this Check this 98 | FakeBHolder = ReplayBuffer() 99 | 100 | dataloader = LoadData(options["datapath"]) 101 | logger = Logger(nEpochs, len(dataloader)) 102 | 103 | #Actual Training 104 | for epoch in range(startEpoch, nEpochs): 105 | for batch_id, batch_data in enumerate(dataloader): 106 | realA = Variable(imageA.copy_(batch_data['A'])) 107 | realB = Variable(imageB.copy_(batch_data['B'])) 108 | 109 | # Generator GEN_AtoB mapping from A to B and GEN_BtoA mapping from B to A 110 | optim_GAN.zero_grad() 111 | 112 | #Identity Loss: GEN_AtoB(realB) should generate B 113 | synthB = GEN_AtoB(realB) 114 | lossIden_BtoB = IdentityLoss(synthB, realB)*5.0 115 | 116 | #Identity Loss: GEN_BtoA(realA) should generate A 117 | synthA = GEN_BtoA(realA) 118 | lossIden_AtoA = IdentityLoss(synthA, realA)*5.0 119 | 120 | #GAN Loss: DIS_B(GEN_AtoB(realA)) should be closest to real target. 121 | fakeB = GEN_AtoB(realA) 122 | classB = DIS_B(fakeB) 123 | lossGEN_A2B = GANLoss(classB, targetReal) 124 | 125 | #GAN Loss: DIS_A(GEN_BtoA(realB)) should be closest to real target. 126 | fakeA = GEN_BtoA(realB) 127 | classA = DIS_A(fakeA) 128 | lossGEN_B2A = GANLoss(classA, targetReal) 129 | 130 | #Cycle Recontruction: GEN_BtoA(GEN_AtoB(realA)) -> realA should give realA 131 | reconA = GEN_BtoA(fakeB) 132 | lossCycle_ABA = CycleLoss(reconA, realA)*10.0 133 | 134 | #Cycle Recontruction: GEN_AtoB(GEN_BtoA(realB)) -> realB should give realA 135 | reconB = GEN_BtoA(fakeA) 136 | lossCycle_BAB = CycleLoss(reconB, realB)*10.0 137 | 138 | # Total Loss of the GANSs, Cycle Consistancy Loss 139 | lossTotal = lossCycle_BAB + lossCycle_ABA + lossGEN_B2A + lossGEN_A2B + lossIden_AtoA + lossIden_BtoB 140 | lossTotal.backward() 141 | 142 | optim_GAN.step() 143 | 144 | # Discriminator A Updatation part 145 | optim_DIS_A.zero_grad() 146 | 147 | #Real Loss: When a realA is sent into the Discriminator should predict 1.0 148 | classAreal = DIS_A(realA) 149 | lossDreal = GANLoss(classAreal, targetReal) 150 | 151 | #Fake Loss: When a fakeA is sent into the Discriminator should predict 0.0 152 | fakeA = FakeAHolder.push_and_pop(fakeA) # For logging 153 | classAfake = DIS_A(fakeA.detach()) 154 | lossDfake = GANLoss(classAfake, targetFake) 155 | 156 | # The Discriminator should perfrom equally good at both the tasks 157 | lossDA = (lossDreal + lossDfake)/2 158 | 159 | lossDA.backward() 160 | optim_DIS_A.step() 161 | 162 | # Discriminator B Updatation part 163 | optim_DIS_B.zero_grad() 164 | 165 | #Real Loss: When a realB is sent into the Discriminator should predict 1.0 166 | classBreal = DIS_B(realB) 167 | lossDreal = GANLoss(classBreal, targetReal) 168 | 169 | #Fake Loss: When a fakeB is sent into the Discriminator should predict 0.0 170 | fakeB = FakeBHolder.push_and_pop(fakeB) # For logging 171 | classBfake = DIS_B(fakeB.detach()) 172 | lossDfake = GANLoss(classBfake, targetFake) 173 | 174 | # The Discriminator should perfrom equally good at both the tasks 175 | lossDB = (lossDreal + lossDfake)/2 176 | 177 | lossDB.backward() 178 | optim_DIS_B.step() 179 | #Progress 180 | logger.log({'lossTotal': lossTotal, 'lossIdentity': (lossIden_AtoA + lossIden_BtoB), 'lossGAN': (lossGEN_A2B + lossGEN_B2A), 181 | 'lossCycle': (lossCycle_ABA + lossCycle_BAB), 'lossD': (lossDA + lossDB)}, 182 | images={'realA': realA, 'realB': realB, 'fakeA': fakeA, 'fakeB': fakeB}) 183 | 184 | #Update learning rates 185 | lr_scheduler_GAN.step() 186 | lr_scheduler_Dis_A.step() 187 | lr_scheduler_Dis_B.step() 188 | 189 | # Save Models checkpoints 190 | SaveModels(GEN_AtoB, DIS_B, GEN_BtoA, DIS_A) 191 | with open("EpochVerify.txt",'w') as ff: 192 | ff.write("\n"+str(epoch)) 193 | 194 | 195 | 196 | if __name__=="__main__": 197 | trainer(options) 198 | --------------------------------------------------------------------------------