├── LICENSE ├── README.md ├── arch ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── discriminators.cpython-35.pyc │ ├── generators.cpython-35.pyc │ └── ops.cpython-35.pyc ├── discriminators.py ├── generators.py └── ops.py ├── download_dataset.sh ├── images ├── horse_generated.png ├── horse_real.png ├── horse_reconstructed.png ├── zebra_generated.png ├── zebra_real.png └── zebra_reconstructed.png ├── main.py ├── model.py ├── test.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Arnab Kumar Mondal 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CycleGAN using PyTorch 2 | CycleGAN is one of the most interesting works I have read. Although the idea behind cycleGAN looks quite intuitive after you read the paper: [Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593), the official PyTorch implementation by [junyanz](https://github.com/junyanz) is difficult to understand for beginners (The code is really well written but it's just that it has multiple things implemented together). As I am writing a simpler version of the code for some other work, I thought of making my version of cycleGAN public for those who are looking for an easier implementation of the paper. 3 | 4 | All the credit goes to the authors of the paper. 5 | This code is inspired by the actual implementation by [junyanz](https://github.com/junyanz) which can be found [here](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). 6 | 7 | ## Requirements 8 | - The code has been written in Python (3.5.2) and PyTorch (0.4.1) 9 | 10 | ## How to run 11 | * To download datasets (eg. horse2zebra) 12 | ``` 13 | $ sh ./download_dataset.sh horse2zebra 14 | ``` 15 | * To run training 16 | ``` 17 | $ python main.py --training True 18 | ``` 19 | * To run testing 20 | ``` 21 | $ python main.py --testing True 22 | ``` 23 | * Try tweaking the arguments to get best performance according to the dataset. 24 | 25 | ## Results 26 | 27 | * For horse to zebra dataset. ( Real - Generated - Reconstructed) 28 | 29 |

30 | 31 | 32 | 33 |

34 | 35 |

36 | 37 | 38 | 39 |

40 | -------------------------------------------------------------------------------- /arch/__init__.py: -------------------------------------------------------------------------------- 1 | from .generators import define_Gen 2 | from .discriminators import define_Dis 3 | from .ops import set_grad -------------------------------------------------------------------------------- /arch/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arnab39/cycleGAN-PyTorch/ecdc9735e426992056c70e78f3143bb6d538e48c/arch/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /arch/__pycache__/discriminators.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arnab39/cycleGAN-PyTorch/ecdc9735e426992056c70e78f3143bb6d538e48c/arch/__pycache__/discriminators.cpython-35.pyc -------------------------------------------------------------------------------- /arch/__pycache__/generators.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arnab39/cycleGAN-PyTorch/ecdc9735e426992056c70e78f3143bb6d538e48c/arch/__pycache__/generators.cpython-35.pyc -------------------------------------------------------------------------------- /arch/__pycache__/ops.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arnab39/cycleGAN-PyTorch/ecdc9735e426992056c70e78f3143bb6d538e48c/arch/__pycache__/ops.cpython-35.pyc -------------------------------------------------------------------------------- /arch/discriminators.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from .ops import conv_norm_lrelu, get_norm_layer, init_network 3 | import torch 4 | from torch.nn import functional as F 5 | import functools 6 | 7 | 8 | class NLayerDiscriminator(nn.Module): 9 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_bias=False): 10 | super(NLayerDiscriminator, self).__init__() 11 | dis_model = [nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1), 12 | nn.LeakyReLU(0.2, True)] 13 | nf_mult = 1 14 | nf_mult_prev = 1 15 | for n in range(1, n_layers): 16 | nf_mult_prev = nf_mult 17 | nf_mult = min(2**n, 8) 18 | dis_model += [conv_norm_lrelu(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=2, 19 | norm_layer= norm_layer, padding=1, bias=use_bias)] 20 | nf_mult_prev = nf_mult 21 | nf_mult = min(2**n_layers, 8) 22 | dis_model += [conv_norm_lrelu(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=1, 23 | norm_layer= norm_layer, padding=1, bias=use_bias)] 24 | dis_model += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=4, stride=1, padding=1)] 25 | 26 | self.dis_model = nn.Sequential(*dis_model) 27 | 28 | def forward(self, input): 29 | return self.dis_model(input) 30 | 31 | class PixelDiscriminator(nn.Module): 32 | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_bias=False): 33 | super(PixelDiscriminator, self).__init__() 34 | dis_model = [ 35 | nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), 36 | nn.LeakyReLU(0.2, True), 37 | nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), 38 | norm_layer(ndf * 2), 39 | nn.LeakyReLU(0.2, True), 40 | nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] 41 | 42 | self.dis_model = nn.Sequential(*dis_model) 43 | 44 | def forward(self, input): 45 | return self.dis_model(input) 46 | 47 | 48 | 49 | def define_Dis(input_nc, ndf, netD, n_layers_D=3, norm='batch', gpu_ids=[0]): 50 | dis_net = None 51 | norm_layer = get_norm_layer(norm_type=norm) 52 | if type(norm_layer) == functools.partial: 53 | use_bias = norm_layer.func == nn.InstanceNorm2d 54 | else: 55 | use_bias = norm_layer == nn.InstanceNorm2d 56 | 57 | if netD == 'n_layers': 58 | dis_net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_bias=use_bias) 59 | elif netD == 'pixel': 60 | dis_net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_bias=use_bias) 61 | else: 62 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD) 63 | 64 | return init_network(dis_net, gpu_ids) -------------------------------------------------------------------------------- /arch/generators.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | from torch import nn 4 | from .ops import conv_norm_relu, dconv_norm_relu, ResidualBlock, get_norm_layer, init_network 5 | 6 | 7 | class UnetSkipConnectionBlock(nn.Module): 8 | def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None, outermost=False, 9 | innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): 10 | super(UnetSkipConnectionBlock, self).__init__() 11 | self.outermost = outermost 12 | if type(norm_layer) == functools.partial: 13 | use_bias = norm_layer.func == nn.InstanceNorm2d 14 | else: 15 | use_bias = norm_layer == nn.InstanceNorm2d 16 | if input_nc is None: 17 | input_nc = outer_nc 18 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) 19 | 20 | if outermost: 21 | upconv = nn.ConvTranspose2d(inner_nc*2, outer_nc, kernel_size=4, stride=2, padding=1) 22 | down = [downconv] 23 | up = [nn.ReLU(True), upconv, nn.Tanh()] 24 | model = down + [submodule] + up 25 | elif innermost: 26 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) 27 | down = [nn.LeakyReLU(0.2, True), downconv] 28 | up = [nn.ReLU(True), upconv, norm_layer(outer_nc)] 29 | model = down + up 30 | else: 31 | upconv = nn.ConvTranspose2d(inner_nc*2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) 32 | down = [nn.LeakyReLU(0.2, True), downconv, norm_layer(inner_nc)] 33 | up = [nn.ReLU(True), upconv, norm_layer(outer_nc)] 34 | 35 | if use_dropout: 36 | model = down + [submodule] + up + [nn.Dropout(0.5)] 37 | else: 38 | model = down + [submodule] + up 39 | 40 | self.model = nn.Sequential(*model) 41 | 42 | def forward(self, x): 43 | if self.outermost: 44 | return self.model(x) 45 | else: 46 | return torch.cat([x, self.model(x)], 1) 47 | 48 | class UnetGenerator(nn.Module): 49 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, 50 | norm_layer=nn.BatchNorm2d, use_dropout=False): 51 | super(UnetGenerator, self).__init__() 52 | 53 | unet_block = UnetSkipConnectionBlock(ngf*8, ngf*8, submodule=None, norm_layer=norm_layer, innermost=True) 54 | for i in range(num_downs - 5): 55 | unet_block = UnetSkipConnectionBlock(ngf*8, ngf*8, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) 56 | unet_block = UnetSkipConnectionBlock(ngf*4, ngf*8, submodule=unet_block, norm_layer=norm_layer) 57 | unet_block = UnetSkipConnectionBlock(ngf*2, ngf*4, submodule=unet_block, norm_layer=norm_layer) 58 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, submodule=unet_block, norm_layer=norm_layer) 59 | unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) 60 | self.unet_model = unet_block 61 | 62 | def forward(self, input): 63 | return self.unet_model(input) 64 | 65 | 66 | 67 | class ResnetGenerator(nn.Module): 68 | def __init__(self, input_nc=3, output_nc=3, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=True, num_blocks=6): 69 | super(ResnetGenerator, self).__init__() 70 | if type(norm_layer) == functools.partial: 71 | use_bias = norm_layer.func == nn.InstanceNorm2d 72 | else: 73 | use_bias = norm_layer == nn.InstanceNorm2d 74 | 75 | res_model = [nn.ReflectionPad2d(3), 76 | conv_norm_relu(input_nc, ngf * 1, 7, norm_layer=norm_layer, bias=use_bias), 77 | conv_norm_relu(ngf * 1, ngf * 2, 3, 2, 1, norm_layer=norm_layer, bias=use_bias), 78 | conv_norm_relu(ngf * 2, ngf * 4, 3, 2, 1, norm_layer=norm_layer, bias=use_bias)] 79 | 80 | for i in range(num_blocks): 81 | res_model += [ResidualBlock(ngf * 4, norm_layer, use_dropout, use_bias)] 82 | 83 | res_model += [dconv_norm_relu(ngf * 4, ngf * 2, 3, 2, 1, 1, norm_layer=norm_layer, bias=use_bias), 84 | dconv_norm_relu(ngf * 2, ngf * 1, 3, 2, 1, 1, norm_layer=norm_layer, bias=use_bias), 85 | nn.ReflectionPad2d(3), 86 | nn.Conv2d(ngf, output_nc, 7), 87 | nn.Tanh()] 88 | self.res_model = nn.Sequential(*res_model) 89 | 90 | def forward(self, x): 91 | return self.res_model(x) 92 | 93 | 94 | 95 | 96 | def define_Gen(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, gpu_ids=[0]): 97 | gen_net = None 98 | norm_layer = get_norm_layer(norm_type=norm) 99 | 100 | if netG == 'resnet_9blocks': 101 | gen_net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, num_blocks=9) 102 | elif netG == 'resnet_6blocks': 103 | gen_net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, num_blocks=6) 104 | elif netG == 'unet_128': 105 | gen_net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 106 | elif netG == 'unet_256': 107 | gen_net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 108 | else: 109 | raise NotImplementedError('Generator model name [%s] is not recognized' % netG) 110 | 111 | return init_network(gen_net, gpu_ids) -------------------------------------------------------------------------------- /arch/ops.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from torch.nn import init 3 | import torch.nn as nn 4 | import torch 5 | 6 | def get_norm_layer(norm_type='instance'): 7 | if norm_type == 'batch': 8 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 9 | elif norm_type == 'instance': 10 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 11 | else: 12 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 13 | return norm_layer 14 | 15 | def init_weights(net, init_type='normal', gain=0.02): 16 | def init_func(m): 17 | classname = m.__class__.__name__ 18 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 19 | init.normal(m.weight.data, 0.0, gain) 20 | if hasattr(m, 'bias') and m.bias is not None: 21 | init.constant(m.bias.data, 0.0) 22 | elif classname.find('BatchNorm2d') != -1: 23 | init.normal(m.weight.data, 1.0, gain) 24 | init.constant(m.bias.data, 0.0) 25 | 26 | print('Network initialized with weights sampled from N(0,0.02).') 27 | net.apply(init_func) 28 | 29 | 30 | def init_network(net, gpu_ids=[]): 31 | if len(gpu_ids) > 0: 32 | assert(torch.cuda.is_available()) 33 | net.cuda(gpu_ids[0]) 34 | net = torch.nn.DataParallel(net, gpu_ids) 35 | init_weights(net) 36 | return net 37 | 38 | 39 | def conv_norm_lrelu(in_dim, out_dim, kernel_size, stride = 1, padding=0, 40 | norm_layer = nn.BatchNorm2d, bias = False): 41 | return nn.Sequential( 42 | nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias = bias), 43 | norm_layer(out_dim), nn.LeakyReLU(0.2,True)) 44 | 45 | def conv_norm_relu(in_dim, out_dim, kernel_size, stride = 1, padding=0, 46 | norm_layer = nn.BatchNorm2d, bias = False): 47 | return nn.Sequential( 48 | nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias = bias), 49 | norm_layer(out_dim), nn.ReLU(True)) 50 | 51 | def dconv_norm_relu(in_dim, out_dim, kernel_size, stride = 1, padding=0, output_padding=0, 52 | norm_layer = nn.BatchNorm2d, bias = False): 53 | return nn.Sequential( 54 | nn.ConvTranspose2d(in_dim, out_dim, kernel_size, stride, 55 | padding, output_padding, bias = bias), 56 | norm_layer(out_dim), nn.ReLU(True)) 57 | 58 | 59 | class ResidualBlock(nn.Module): 60 | def __init__(self, dim, norm_layer, use_dropout, use_bias): 61 | super(ResidualBlock, self).__init__() 62 | res_block = [nn.ReflectionPad2d(1), 63 | conv_norm_relu(dim, dim, kernel_size=3, 64 | norm_layer= norm_layer, bias=use_bias)] 65 | if use_dropout: 66 | res_block += [nn.Dropout(0.5)] 67 | res_block += [nn.ReflectionPad2d(1), 68 | nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=use_bias), 69 | norm_layer(dim)] 70 | 71 | self.res_block = nn.Sequential(*res_block) 72 | 73 | def forward(self, x): 74 | return x + self.res_block(x) 75 | 76 | 77 | def set_grad(nets, requires_grad=False): 78 | for net in nets: 79 | for param in net.parameters(): 80 | param.requires_grad = requires_grad -------------------------------------------------------------------------------- /download_dataset.sh: -------------------------------------------------------------------------------- 1 | mkdir datasets 2 | FILE=$1 3 | 4 | if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then 5 | echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos" 6 | exit 1 7 | fi 8 | 9 | URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip 10 | ZIP_FILE=./datasets/$FILE.zip 11 | TARGET_DIR=./datasets/$FILE/ 12 | wget -N $URL -O $ZIP_FILE 13 | mkdir $TARGET_DIR 14 | unzip $ZIP_FILE -d ./datasets/ 15 | rm $ZIP_FILE 16 | -------------------------------------------------------------------------------- /images/horse_generated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arnab39/cycleGAN-PyTorch/ecdc9735e426992056c70e78f3143bb6d538e48c/images/horse_generated.png -------------------------------------------------------------------------------- /images/horse_real.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arnab39/cycleGAN-PyTorch/ecdc9735e426992056c70e78f3143bb6d538e48c/images/horse_real.png -------------------------------------------------------------------------------- /images/horse_reconstructed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arnab39/cycleGAN-PyTorch/ecdc9735e426992056c70e78f3143bb6d538e48c/images/horse_reconstructed.png -------------------------------------------------------------------------------- /images/zebra_generated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arnab39/cycleGAN-PyTorch/ecdc9735e426992056c70e78f3143bb6d538e48c/images/zebra_generated.png -------------------------------------------------------------------------------- /images/zebra_real.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arnab39/cycleGAN-PyTorch/ecdc9735e426992056c70e78f3143bb6d538e48c/images/zebra_real.png -------------------------------------------------------------------------------- /images/zebra_reconstructed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arnab39/cycleGAN-PyTorch/ecdc9735e426992056c70e78f3143bb6d538e48c/images/zebra_reconstructed.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | import model as md 4 | from utils import create_link 5 | import test as tst 6 | 7 | 8 | # To get arguments from commandline 9 | def get_args(): 10 | parser = ArgumentParser(description='cycleGAN PyTorch') 11 | parser.add_argument('--epochs', type=int, default=200) 12 | parser.add_argument('--decay_epoch', type=int, default=100) 13 | parser.add_argument('--batch_size', type=int, default=1) 14 | parser.add_argument('--lr', type=float, default=.0002) 15 | parser.add_argument('--load_height', type=int, default=286) 16 | parser.add_argument('--load_width', type=int, default=286) 17 | parser.add_argument('--gpu_ids', type=str, default='0') 18 | parser.add_argument('--crop_height', type=int, default=256) 19 | parser.add_argument('--crop_width', type=int, default=256) 20 | parser.add_argument('--lamda', type=int, default=10) 21 | parser.add_argument('--idt_coef', type=float, default=0.5) 22 | parser.add_argument('--training', type=bool, default=False) 23 | parser.add_argument('--testing', type=bool, default=False) 24 | parser.add_argument('--results_dir', type=str, default='./results') 25 | parser.add_argument('--dataset_dir', type=str, default='./datasets/horse2zebra') 26 | parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints/horse2zebra') 27 | parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') 28 | parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') 29 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') 30 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') 31 | parser.add_argument('--gen_net', type=str, default='resnet_9blocks') 32 | parser.add_argument('--dis_net', type=str, default='n_layers') 33 | args = parser.parse_args() 34 | return args 35 | 36 | 37 | def main(): 38 | args = get_args() 39 | 40 | create_link(args.dataset_dir) 41 | 42 | str_ids = args.gpu_ids.split(',') 43 | args.gpu_ids = [] 44 | for str_id in str_ids: 45 | id = int(str_id) 46 | if id >= 0: 47 | args.gpu_ids.append(id) 48 | print(not args.no_dropout) 49 | if args.training: 50 | print("Training") 51 | model = md.cycleGAN(args) 52 | model.train(args) 53 | if args.testing: 54 | print("Testing") 55 | tst.test(args) 56 | 57 | 58 | if __name__ == '__main__': 59 | main() -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import functools 3 | 4 | import os 5 | import torch 6 | from torch import nn 7 | from torch.autograd import Variable 8 | import torchvision.datasets as dsets 9 | import torchvision.transforms as transforms 10 | import utils 11 | from arch import define_Gen, define_Dis, set_grad 12 | from torch.optim import lr_scheduler 13 | 14 | ''' 15 | Class for CycleGAN with train() as a member function 16 | 17 | ''' 18 | class cycleGAN(object): 19 | def __init__(self,args): 20 | 21 | # Define the network 22 | ##################################################### 23 | self.Gab = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG=args.gen_net, norm=args.norm, 24 | use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids) 25 | self.Gba = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG=args.gen_net, norm=args.norm, 26 | use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids) 27 | self.Da = define_Dis(input_nc=3, ndf=args.ndf, netD= args.dis_net, n_layers_D=3, norm=args.norm, gpu_ids=args.gpu_ids) 28 | self.Db = define_Dis(input_nc=3, ndf=args.ndf, netD= args.dis_net, n_layers_D=3, norm=args.norm, gpu_ids=args.gpu_ids) 29 | 30 | utils.print_networks([self.Gab,self.Gba,self.Da,self.Db], ['Gab','Gba','Da','Db']) 31 | 32 | # Define Loss criterias 33 | 34 | self.MSE = nn.MSELoss() 35 | self.L1 = nn.L1Loss() 36 | 37 | # Optimizers 38 | ##################################################### 39 | self.g_optimizer = torch.optim.Adam(itertools.chain(self.Gab.parameters(),self.Gba.parameters()), lr=args.lr, betas=(0.5, 0.999)) 40 | self.d_optimizer = torch.optim.Adam(itertools.chain(self.Da.parameters(),self.Db.parameters()), lr=args.lr, betas=(0.5, 0.999)) 41 | 42 | 43 | self.g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.g_optimizer, lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step) 44 | self.d_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.d_optimizer, lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step) 45 | 46 | # Try loading checkpoint 47 | ##################################################### 48 | if not os.path.isdir(args.checkpoint_dir): 49 | os.makedirs(args.checkpoint_dir) 50 | 51 | try: 52 | ckpt = utils.load_checkpoint('%s/latest.ckpt' % (args.checkpoint_dir)) 53 | self.start_epoch = ckpt['epoch'] 54 | self.Da.load_state_dict(ckpt['Da']) 55 | self.Db.load_state_dict(ckpt['Db']) 56 | self.Gab.load_state_dict(ckpt['Gab']) 57 | self.Gba.load_state_dict(ckpt['Gba']) 58 | self.d_optimizer.load_state_dict(ckpt['d_optimizer']) 59 | self.g_optimizer.load_state_dict(ckpt['g_optimizer']) 60 | except: 61 | print(' [*] No checkpoint!') 62 | self.start_epoch = 0 63 | 64 | 65 | 66 | def train(self,args): 67 | # For transforming the input image 68 | transform = transforms.Compose( 69 | [transforms.RandomHorizontalFlip(), 70 | transforms.Resize((args.load_height,args.load_width)), 71 | transforms.RandomCrop((args.crop_height,args.crop_width)), 72 | transforms.ToTensor(), 73 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) 74 | 75 | dataset_dirs = utils.get_traindata_link(args.dataset_dir) 76 | 77 | # Pytorch dataloader 78 | a_loader = torch.utils.data.DataLoader(dsets.ImageFolder(dataset_dirs['trainA'], transform=transform), 79 | batch_size=args.batch_size, shuffle=True, num_workers=4) 80 | b_loader = torch.utils.data.DataLoader(dsets.ImageFolder(dataset_dirs['trainB'], transform=transform), 81 | batch_size=args.batch_size, shuffle=True, num_workers=4) 82 | 83 | a_fake_sample = utils.Sample_from_Pool() 84 | b_fake_sample = utils.Sample_from_Pool() 85 | 86 | for epoch in range(self.start_epoch, args.epochs): 87 | 88 | lr = self.g_optimizer.param_groups[0]['lr'] 89 | print('learning rate = %.7f' % lr) 90 | 91 | for i, (a_real, b_real) in enumerate(zip(a_loader, b_loader)): 92 | # step 93 | step = epoch * min(len(a_loader), len(b_loader)) + i + 1 94 | 95 | # Generator Computations 96 | ################################################## 97 | 98 | set_grad([self.Da, self.Db], False) 99 | self.g_optimizer.zero_grad() 100 | 101 | a_real = Variable(a_real[0]) 102 | b_real = Variable(b_real[0]) 103 | a_real, b_real = utils.cuda([a_real, b_real]) 104 | 105 | # Forward pass through generators 106 | ################################################## 107 | a_fake = self.Gab(b_real) 108 | b_fake = self.Gba(a_real) 109 | 110 | a_recon = self.Gab(b_fake) 111 | b_recon = self.Gba(a_fake) 112 | 113 | a_idt = self.Gab(a_real) 114 | b_idt = self.Gba(b_real) 115 | 116 | # Identity losses 117 | ################################################### 118 | a_idt_loss = self.L1(a_idt, a_real) * args.lamda * args.idt_coef 119 | b_idt_loss = self.L1(b_idt, b_real) * args.lamda * args.idt_coef 120 | 121 | # Adversarial losses 122 | ################################################### 123 | a_fake_dis = self.Da(a_fake) 124 | b_fake_dis = self.Db(b_fake) 125 | 126 | real_label = utils.cuda(Variable(torch.ones(a_fake_dis.size()))) 127 | 128 | a_gen_loss = self.MSE(a_fake_dis, real_label) 129 | b_gen_loss = self.MSE(b_fake_dis, real_label) 130 | 131 | # Cycle consistency losses 132 | ################################################### 133 | a_cycle_loss = self.L1(a_recon, a_real) * args.lamda 134 | b_cycle_loss = self.L1(b_recon, b_real) * args.lamda 135 | 136 | # Total generators losses 137 | ################################################### 138 | gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss 139 | 140 | # Update generators 141 | ################################################### 142 | gen_loss.backward() 143 | self.g_optimizer.step() 144 | 145 | 146 | # Discriminator Computations 147 | ################################################# 148 | 149 | set_grad([self.Da, self.Db], True) 150 | self.d_optimizer.zero_grad() 151 | 152 | # Sample from history of generated images 153 | ################################################# 154 | a_fake = Variable(torch.Tensor(a_fake_sample([a_fake.cpu().data.numpy()])[0])) 155 | b_fake = Variable(torch.Tensor(b_fake_sample([b_fake.cpu().data.numpy()])[0])) 156 | a_fake, b_fake = utils.cuda([a_fake, b_fake]) 157 | 158 | # Forward pass through discriminators 159 | ################################################# 160 | a_real_dis = self.Da(a_real) 161 | a_fake_dis = self.Da(a_fake) 162 | b_real_dis = self.Db(b_real) 163 | b_fake_dis = self.Db(b_fake) 164 | real_label = utils.cuda(Variable(torch.ones(a_real_dis.size()))) 165 | fake_label = utils.cuda(Variable(torch.zeros(a_fake_dis.size()))) 166 | 167 | # Discriminator losses 168 | ################################################## 169 | a_dis_real_loss = self.MSE(a_real_dis, real_label) 170 | a_dis_fake_loss = self.MSE(a_fake_dis, fake_label) 171 | b_dis_real_loss = self.MSE(b_real_dis, real_label) 172 | b_dis_fake_loss = self.MSE(b_fake_dis, fake_label) 173 | 174 | # Total discriminators losses 175 | a_dis_loss = (a_dis_real_loss + a_dis_fake_loss)*0.5 176 | b_dis_loss = (b_dis_real_loss + b_dis_fake_loss)*0.5 177 | 178 | # Update discriminators 179 | ################################################## 180 | a_dis_loss.backward() 181 | b_dis_loss.backward() 182 | self.d_optimizer.step() 183 | 184 | print("Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e" % 185 | (epoch, i + 1, min(len(a_loader), len(b_loader)), 186 | gen_loss,a_dis_loss+b_dis_loss)) 187 | 188 | # Override the latest checkpoint 189 | ####################################################### 190 | utils.save_checkpoint({'epoch': epoch + 1, 191 | 'Da': self.Da.state_dict(), 192 | 'Db': self.Db.state_dict(), 193 | 'Gab': self.Gab.state_dict(), 194 | 'Gba': self.Gba.state_dict(), 195 | 'd_optimizer': self.d_optimizer.state_dict(), 196 | 'g_optimizer': self.g_optimizer.state_dict()}, 197 | '%s/latest.ckpt' % (args.checkpoint_dir)) 198 | 199 | # Update learning rates 200 | ######################## 201 | self.g_lr_scheduler.step() 202 | self.d_lr_scheduler.step() 203 | 204 | 205 | 206 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | from torch.autograd import Variable 5 | import torchvision 6 | import torchvision.datasets as dsets 7 | import torchvision.transforms as transforms 8 | import utils 9 | from arch import define_Gen, define_Dis 10 | 11 | 12 | 13 | 14 | 15 | def test(args): 16 | 17 | transform = transforms.Compose( 18 | [transforms.Resize((args.crop_height,args.crop_width)), 19 | transforms.ToTensor(), 20 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) 21 | 22 | dataset_dirs = utils.get_testdata_link(args.dataset_dir) 23 | 24 | a_test_data = dsets.ImageFolder(dataset_dirs['testA'], transform=transform) 25 | b_test_data = dsets.ImageFolder(dataset_dirs['testB'], transform=transform) 26 | 27 | 28 | a_test_loader = torch.utils.data.DataLoader(a_test_data, batch_size=args.batch_size, shuffle=True, num_workers=4) 29 | b_test_loader = torch.utils.data.DataLoader(b_test_data, batch_size=args.batch_size, shuffle=True, num_workers=4) 30 | 31 | Gab = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG='resnet_9blocks', norm=args.norm, 32 | use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids) 33 | Gba = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG='resnet_9blocks', norm=args.norm, 34 | use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids) 35 | 36 | utils.print_networks([Gab,Gba], ['Gab','Gba']) 37 | 38 | try: 39 | ckpt = utils.load_checkpoint('%s/latest.ckpt' % (args.checkpoint_dir)) 40 | Gab.load_state_dict(ckpt['Gab']) 41 | Gba.load_state_dict(ckpt['Gba']) 42 | except: 43 | print(' [*] No checkpoint!') 44 | 45 | 46 | """ run """ 47 | a_real_test = Variable(iter(a_test_loader).next()[0], requires_grad=True) 48 | b_real_test = Variable(iter(b_test_loader).next()[0], requires_grad=True) 49 | a_real_test, b_real_test = utils.cuda([a_real_test, b_real_test]) 50 | 51 | 52 | Gab.eval() 53 | Gba.eval() 54 | 55 | with torch.no_grad(): 56 | a_fake_test = Gab(b_real_test) 57 | b_fake_test = Gba(a_real_test) 58 | a_recon_test = Gab(b_fake_test) 59 | b_recon_test = Gba(a_fake_test) 60 | 61 | pic = (torch.cat([a_real_test, b_fake_test, a_recon_test, b_real_test, a_fake_test, b_recon_test], dim=0).data + 1) / 2.0 62 | 63 | if not os.path.isdir(args.results_dir): 64 | os.makedirs(args.results_dir) 65 | 66 | torchvision.utils.save_image(pic, args.results_dir+'/sample.jpg', nrow=3) 67 | 68 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import shutil 4 | 5 | import numpy as np 6 | import torch 7 | 8 | # To make directories 9 | def mkdir(paths): 10 | for path in paths: 11 | if not os.path.isdir(path): 12 | os.makedirs(path) 13 | 14 | # To make cuda tensor 15 | def cuda(xs): 16 | if torch.cuda.is_available(): 17 | if not isinstance(xs, (list, tuple)): 18 | return xs.cuda() 19 | else: 20 | return [x.cuda() for x in xs] 21 | 22 | # For Pytorch data loader 23 | def create_link(dataset_dir): 24 | dirs = {} 25 | dirs['trainA'] = os.path.join(dataset_dir, 'ltrainA') 26 | dirs['trainB'] = os.path.join(dataset_dir, 'ltrainB') 27 | dirs['testA'] = os.path.join(dataset_dir, 'ltestA') 28 | dirs['testB'] = os.path.join(dataset_dir, 'ltestB') 29 | mkdir(dirs.values()) 30 | 31 | for key in dirs: 32 | try: 33 | os.remove(os.path.join(dirs[key], 'Link')) 34 | except: 35 | pass 36 | os.symlink(os.path.abspath(os.path.join(dataset_dir, key)), 37 | os.path.join(dirs[key], 'Link')) 38 | 39 | return dirs 40 | 41 | def get_traindata_link(dataset_dir): 42 | dirs = {} 43 | dirs['trainA'] = os.path.join(dataset_dir, 'ltrainA') 44 | dirs['trainB'] = os.path.join(dataset_dir, 'ltrainB') 45 | return dirs 46 | 47 | def get_testdata_link(dataset_dir): 48 | dirs = {} 49 | dirs['testA'] = os.path.join(dataset_dir, 'ltestA') 50 | dirs['testB'] = os.path.join(dataset_dir, 'ltestB') 51 | return dirs 52 | 53 | 54 | # To save the checkpoint 55 | def save_checkpoint(state, save_path): 56 | torch.save(state, save_path) 57 | 58 | 59 | # To load the checkpoint 60 | def load_checkpoint(ckpt_path, map_location=None): 61 | ckpt = torch.load(ckpt_path, map_location=map_location) 62 | print(' [*] Loading checkpoint from %s succeed!' % ckpt_path) 63 | return ckpt 64 | 65 | 66 | # To store 50 generated image in a pool and sample from it when it is full 67 | # Shrivastava et al’s strategy 68 | class Sample_from_Pool(object): 69 | def __init__(self, max_elements=50): 70 | self.max_elements = max_elements 71 | self.cur_elements = 0 72 | self.items = [] 73 | 74 | def __call__(self, in_items): 75 | return_items = [] 76 | for in_item in in_items: 77 | if self.cur_elements < self.max_elements: 78 | self.items.append(in_item) 79 | self.cur_elements = self.cur_elements + 1 80 | return_items.append(in_item) 81 | else: 82 | if np.random.ranf() > 0.5: 83 | idx = np.random.randint(0, self.max_elements) 84 | tmp = copy.copy(self.items[idx]) 85 | self.items[idx] = in_item 86 | return_items.append(tmp) 87 | else: 88 | return_items.append(in_item) 89 | return return_items 90 | 91 | class LambdaLR(): 92 | def __init__(self, epochs, offset, decay_epoch): 93 | self.epochs = epochs 94 | self.offset = offset 95 | self.decay_epoch = decay_epoch 96 | 97 | def step(self, epoch): 98 | return 1.0 - max(0, epoch + self.offset - self.decay_epoch)/(self.epochs - self.decay_epoch) 99 | 100 | def print_networks(nets, names): 101 | print('------------Number of Parameters---------------') 102 | i=0 103 | for net in nets: 104 | num_params = 0 105 | for param in net.parameters(): 106 | num_params += param.numel() 107 | print('[Network %s] Total number of parameters : %.3f M' % (names[i], num_params / 1e6)) 108 | i=i+1 109 | print('-----------------------------------------------') --------------------------------------------------------------------------------