├── LICENSE
├── README.md
├── arch
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-35.pyc
│ ├── discriminators.cpython-35.pyc
│ ├── generators.cpython-35.pyc
│ └── ops.cpython-35.pyc
├── discriminators.py
├── generators.py
└── ops.py
├── download_dataset.sh
├── images
├── horse_generated.png
├── horse_real.png
├── horse_reconstructed.png
├── zebra_generated.png
├── zebra_real.png
└── zebra_reconstructed.png
├── main.py
├── model.py
├── test.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Arnab Kumar Mondal
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CycleGAN using PyTorch
2 | CycleGAN is one of the most interesting works I have read. Although the idea behind cycleGAN looks quite intuitive after you read the paper: [Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593), the official PyTorch implementation by [junyanz](https://github.com/junyanz) is difficult to understand for beginners (The code is really well written but it's just that it has multiple things implemented together). As I am writing a simpler version of the code for some other work, I thought of making my version of cycleGAN public for those who are looking for an easier implementation of the paper.
3 |
4 | All the credit goes to the authors of the paper.
5 | This code is inspired by the actual implementation by [junyanz](https://github.com/junyanz) which can be found [here](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix).
6 |
7 | ## Requirements
8 | - The code has been written in Python (3.5.2) and PyTorch (0.4.1)
9 |
10 | ## How to run
11 | * To download datasets (eg. horse2zebra)
12 | ```
13 | $ sh ./download_dataset.sh horse2zebra
14 | ```
15 | * To run training
16 | ```
17 | $ python main.py --training True
18 | ```
19 | * To run testing
20 | ```
21 | $ python main.py --testing True
22 | ```
23 | * Try tweaking the arguments to get best performance according to the dataset.
24 |
25 | ## Results
26 |
27 | * For horse to zebra dataset. ( Real - Generated - Reconstructed)
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
--------------------------------------------------------------------------------
/arch/__init__.py:
--------------------------------------------------------------------------------
1 | from .generators import define_Gen
2 | from .discriminators import define_Dis
3 | from .ops import set_grad
--------------------------------------------------------------------------------
/arch/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arnab39/cycleGAN-PyTorch/ecdc9735e426992056c70e78f3143bb6d538e48c/arch/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/arch/__pycache__/discriminators.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arnab39/cycleGAN-PyTorch/ecdc9735e426992056c70e78f3143bb6d538e48c/arch/__pycache__/discriminators.cpython-35.pyc
--------------------------------------------------------------------------------
/arch/__pycache__/generators.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arnab39/cycleGAN-PyTorch/ecdc9735e426992056c70e78f3143bb6d538e48c/arch/__pycache__/generators.cpython-35.pyc
--------------------------------------------------------------------------------
/arch/__pycache__/ops.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arnab39/cycleGAN-PyTorch/ecdc9735e426992056c70e78f3143bb6d538e48c/arch/__pycache__/ops.cpython-35.pyc
--------------------------------------------------------------------------------
/arch/discriminators.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from .ops import conv_norm_lrelu, get_norm_layer, init_network
3 | import torch
4 | from torch.nn import functional as F
5 | import functools
6 |
7 |
8 | class NLayerDiscriminator(nn.Module):
9 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_bias=False):
10 | super(NLayerDiscriminator, self).__init__()
11 | dis_model = [nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1),
12 | nn.LeakyReLU(0.2, True)]
13 | nf_mult = 1
14 | nf_mult_prev = 1
15 | for n in range(1, n_layers):
16 | nf_mult_prev = nf_mult
17 | nf_mult = min(2**n, 8)
18 | dis_model += [conv_norm_lrelu(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=2,
19 | norm_layer= norm_layer, padding=1, bias=use_bias)]
20 | nf_mult_prev = nf_mult
21 | nf_mult = min(2**n_layers, 8)
22 | dis_model += [conv_norm_lrelu(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=1,
23 | norm_layer= norm_layer, padding=1, bias=use_bias)]
24 | dis_model += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=4, stride=1, padding=1)]
25 |
26 | self.dis_model = nn.Sequential(*dis_model)
27 |
28 | def forward(self, input):
29 | return self.dis_model(input)
30 |
31 | class PixelDiscriminator(nn.Module):
32 | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_bias=False):
33 | super(PixelDiscriminator, self).__init__()
34 | dis_model = [
35 | nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
36 | nn.LeakyReLU(0.2, True),
37 | nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
38 | norm_layer(ndf * 2),
39 | nn.LeakyReLU(0.2, True),
40 | nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
41 |
42 | self.dis_model = nn.Sequential(*dis_model)
43 |
44 | def forward(self, input):
45 | return self.dis_model(input)
46 |
47 |
48 |
49 | def define_Dis(input_nc, ndf, netD, n_layers_D=3, norm='batch', gpu_ids=[0]):
50 | dis_net = None
51 | norm_layer = get_norm_layer(norm_type=norm)
52 | if type(norm_layer) == functools.partial:
53 | use_bias = norm_layer.func == nn.InstanceNorm2d
54 | else:
55 | use_bias = norm_layer == nn.InstanceNorm2d
56 |
57 | if netD == 'n_layers':
58 | dis_net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_bias=use_bias)
59 | elif netD == 'pixel':
60 | dis_net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_bias=use_bias)
61 | else:
62 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
63 |
64 | return init_network(dis_net, gpu_ids)
--------------------------------------------------------------------------------
/arch/generators.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import torch
3 | from torch import nn
4 | from .ops import conv_norm_relu, dconv_norm_relu, ResidualBlock, get_norm_layer, init_network
5 |
6 |
7 | class UnetSkipConnectionBlock(nn.Module):
8 | def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None, outermost=False,
9 | innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
10 | super(UnetSkipConnectionBlock, self).__init__()
11 | self.outermost = outermost
12 | if type(norm_layer) == functools.partial:
13 | use_bias = norm_layer.func == nn.InstanceNorm2d
14 | else:
15 | use_bias = norm_layer == nn.InstanceNorm2d
16 | if input_nc is None:
17 | input_nc = outer_nc
18 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
19 |
20 | if outermost:
21 | upconv = nn.ConvTranspose2d(inner_nc*2, outer_nc, kernel_size=4, stride=2, padding=1)
22 | down = [downconv]
23 | up = [nn.ReLU(True), upconv, nn.Tanh()]
24 | model = down + [submodule] + up
25 | elif innermost:
26 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
27 | down = [nn.LeakyReLU(0.2, True), downconv]
28 | up = [nn.ReLU(True), upconv, norm_layer(outer_nc)]
29 | model = down + up
30 | else:
31 | upconv = nn.ConvTranspose2d(inner_nc*2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
32 | down = [nn.LeakyReLU(0.2, True), downconv, norm_layer(inner_nc)]
33 | up = [nn.ReLU(True), upconv, norm_layer(outer_nc)]
34 |
35 | if use_dropout:
36 | model = down + [submodule] + up + [nn.Dropout(0.5)]
37 | else:
38 | model = down + [submodule] + up
39 |
40 | self.model = nn.Sequential(*model)
41 |
42 | def forward(self, x):
43 | if self.outermost:
44 | return self.model(x)
45 | else:
46 | return torch.cat([x, self.model(x)], 1)
47 |
48 | class UnetGenerator(nn.Module):
49 | def __init__(self, input_nc, output_nc, num_downs, ngf=64,
50 | norm_layer=nn.BatchNorm2d, use_dropout=False):
51 | super(UnetGenerator, self).__init__()
52 |
53 | unet_block = UnetSkipConnectionBlock(ngf*8, ngf*8, submodule=None, norm_layer=norm_layer, innermost=True)
54 | for i in range(num_downs - 5):
55 | unet_block = UnetSkipConnectionBlock(ngf*8, ngf*8, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
56 | unet_block = UnetSkipConnectionBlock(ngf*4, ngf*8, submodule=unet_block, norm_layer=norm_layer)
57 | unet_block = UnetSkipConnectionBlock(ngf*2, ngf*4, submodule=unet_block, norm_layer=norm_layer)
58 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, submodule=unet_block, norm_layer=norm_layer)
59 | unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
60 | self.unet_model = unet_block
61 |
62 | def forward(self, input):
63 | return self.unet_model(input)
64 |
65 |
66 |
67 | class ResnetGenerator(nn.Module):
68 | def __init__(self, input_nc=3, output_nc=3, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=True, num_blocks=6):
69 | super(ResnetGenerator, self).__init__()
70 | if type(norm_layer) == functools.partial:
71 | use_bias = norm_layer.func == nn.InstanceNorm2d
72 | else:
73 | use_bias = norm_layer == nn.InstanceNorm2d
74 |
75 | res_model = [nn.ReflectionPad2d(3),
76 | conv_norm_relu(input_nc, ngf * 1, 7, norm_layer=norm_layer, bias=use_bias),
77 | conv_norm_relu(ngf * 1, ngf * 2, 3, 2, 1, norm_layer=norm_layer, bias=use_bias),
78 | conv_norm_relu(ngf * 2, ngf * 4, 3, 2, 1, norm_layer=norm_layer, bias=use_bias)]
79 |
80 | for i in range(num_blocks):
81 | res_model += [ResidualBlock(ngf * 4, norm_layer, use_dropout, use_bias)]
82 |
83 | res_model += [dconv_norm_relu(ngf * 4, ngf * 2, 3, 2, 1, 1, norm_layer=norm_layer, bias=use_bias),
84 | dconv_norm_relu(ngf * 2, ngf * 1, 3, 2, 1, 1, norm_layer=norm_layer, bias=use_bias),
85 | nn.ReflectionPad2d(3),
86 | nn.Conv2d(ngf, output_nc, 7),
87 | nn.Tanh()]
88 | self.res_model = nn.Sequential(*res_model)
89 |
90 | def forward(self, x):
91 | return self.res_model(x)
92 |
93 |
94 |
95 |
96 | def define_Gen(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, gpu_ids=[0]):
97 | gen_net = None
98 | norm_layer = get_norm_layer(norm_type=norm)
99 |
100 | if netG == 'resnet_9blocks':
101 | gen_net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, num_blocks=9)
102 | elif netG == 'resnet_6blocks':
103 | gen_net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, num_blocks=6)
104 | elif netG == 'unet_128':
105 | gen_net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
106 | elif netG == 'unet_256':
107 | gen_net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
108 | else:
109 | raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
110 |
111 | return init_network(gen_net, gpu_ids)
--------------------------------------------------------------------------------
/arch/ops.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from torch.nn import init
3 | import torch.nn as nn
4 | import torch
5 |
6 | def get_norm_layer(norm_type='instance'):
7 | if norm_type == 'batch':
8 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
9 | elif norm_type == 'instance':
10 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
11 | else:
12 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
13 | return norm_layer
14 |
15 | def init_weights(net, init_type='normal', gain=0.02):
16 | def init_func(m):
17 | classname = m.__class__.__name__
18 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
19 | init.normal(m.weight.data, 0.0, gain)
20 | if hasattr(m, 'bias') and m.bias is not None:
21 | init.constant(m.bias.data, 0.0)
22 | elif classname.find('BatchNorm2d') != -1:
23 | init.normal(m.weight.data, 1.0, gain)
24 | init.constant(m.bias.data, 0.0)
25 |
26 | print('Network initialized with weights sampled from N(0,0.02).')
27 | net.apply(init_func)
28 |
29 |
30 | def init_network(net, gpu_ids=[]):
31 | if len(gpu_ids) > 0:
32 | assert(torch.cuda.is_available())
33 | net.cuda(gpu_ids[0])
34 | net = torch.nn.DataParallel(net, gpu_ids)
35 | init_weights(net)
36 | return net
37 |
38 |
39 | def conv_norm_lrelu(in_dim, out_dim, kernel_size, stride = 1, padding=0,
40 | norm_layer = nn.BatchNorm2d, bias = False):
41 | return nn.Sequential(
42 | nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias = bias),
43 | norm_layer(out_dim), nn.LeakyReLU(0.2,True))
44 |
45 | def conv_norm_relu(in_dim, out_dim, kernel_size, stride = 1, padding=0,
46 | norm_layer = nn.BatchNorm2d, bias = False):
47 | return nn.Sequential(
48 | nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias = bias),
49 | norm_layer(out_dim), nn.ReLU(True))
50 |
51 | def dconv_norm_relu(in_dim, out_dim, kernel_size, stride = 1, padding=0, output_padding=0,
52 | norm_layer = nn.BatchNorm2d, bias = False):
53 | return nn.Sequential(
54 | nn.ConvTranspose2d(in_dim, out_dim, kernel_size, stride,
55 | padding, output_padding, bias = bias),
56 | norm_layer(out_dim), nn.ReLU(True))
57 |
58 |
59 | class ResidualBlock(nn.Module):
60 | def __init__(self, dim, norm_layer, use_dropout, use_bias):
61 | super(ResidualBlock, self).__init__()
62 | res_block = [nn.ReflectionPad2d(1),
63 | conv_norm_relu(dim, dim, kernel_size=3,
64 | norm_layer= norm_layer, bias=use_bias)]
65 | if use_dropout:
66 | res_block += [nn.Dropout(0.5)]
67 | res_block += [nn.ReflectionPad2d(1),
68 | nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=use_bias),
69 | norm_layer(dim)]
70 |
71 | self.res_block = nn.Sequential(*res_block)
72 |
73 | def forward(self, x):
74 | return x + self.res_block(x)
75 |
76 |
77 | def set_grad(nets, requires_grad=False):
78 | for net in nets:
79 | for param in net.parameters():
80 | param.requires_grad = requires_grad
--------------------------------------------------------------------------------
/download_dataset.sh:
--------------------------------------------------------------------------------
1 | mkdir datasets
2 | FILE=$1
3 |
4 | if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then
5 | echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos"
6 | exit 1
7 | fi
8 |
9 | URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
10 | ZIP_FILE=./datasets/$FILE.zip
11 | TARGET_DIR=./datasets/$FILE/
12 | wget -N $URL -O $ZIP_FILE
13 | mkdir $TARGET_DIR
14 | unzip $ZIP_FILE -d ./datasets/
15 | rm $ZIP_FILE
16 |
--------------------------------------------------------------------------------
/images/horse_generated.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arnab39/cycleGAN-PyTorch/ecdc9735e426992056c70e78f3143bb6d538e48c/images/horse_generated.png
--------------------------------------------------------------------------------
/images/horse_real.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arnab39/cycleGAN-PyTorch/ecdc9735e426992056c70e78f3143bb6d538e48c/images/horse_real.png
--------------------------------------------------------------------------------
/images/horse_reconstructed.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arnab39/cycleGAN-PyTorch/ecdc9735e426992056c70e78f3143bb6d538e48c/images/horse_reconstructed.png
--------------------------------------------------------------------------------
/images/zebra_generated.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arnab39/cycleGAN-PyTorch/ecdc9735e426992056c70e78f3143bb6d538e48c/images/zebra_generated.png
--------------------------------------------------------------------------------
/images/zebra_real.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arnab39/cycleGAN-PyTorch/ecdc9735e426992056c70e78f3143bb6d538e48c/images/zebra_real.png
--------------------------------------------------------------------------------
/images/zebra_reconstructed.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arnab39/cycleGAN-PyTorch/ecdc9735e426992056c70e78f3143bb6d538e48c/images/zebra_reconstructed.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | from argparse import ArgumentParser
3 | import model as md
4 | from utils import create_link
5 | import test as tst
6 |
7 |
8 | # To get arguments from commandline
9 | def get_args():
10 | parser = ArgumentParser(description='cycleGAN PyTorch')
11 | parser.add_argument('--epochs', type=int, default=200)
12 | parser.add_argument('--decay_epoch', type=int, default=100)
13 | parser.add_argument('--batch_size', type=int, default=1)
14 | parser.add_argument('--lr', type=float, default=.0002)
15 | parser.add_argument('--load_height', type=int, default=286)
16 | parser.add_argument('--load_width', type=int, default=286)
17 | parser.add_argument('--gpu_ids', type=str, default='0')
18 | parser.add_argument('--crop_height', type=int, default=256)
19 | parser.add_argument('--crop_width', type=int, default=256)
20 | parser.add_argument('--lamda', type=int, default=10)
21 | parser.add_argument('--idt_coef', type=float, default=0.5)
22 | parser.add_argument('--training', type=bool, default=False)
23 | parser.add_argument('--testing', type=bool, default=False)
24 | parser.add_argument('--results_dir', type=str, default='./results')
25 | parser.add_argument('--dataset_dir', type=str, default='./datasets/horse2zebra')
26 | parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints/horse2zebra')
27 | parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
28 | parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
29 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
30 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
31 | parser.add_argument('--gen_net', type=str, default='resnet_9blocks')
32 | parser.add_argument('--dis_net', type=str, default='n_layers')
33 | args = parser.parse_args()
34 | return args
35 |
36 |
37 | def main():
38 | args = get_args()
39 |
40 | create_link(args.dataset_dir)
41 |
42 | str_ids = args.gpu_ids.split(',')
43 | args.gpu_ids = []
44 | for str_id in str_ids:
45 | id = int(str_id)
46 | if id >= 0:
47 | args.gpu_ids.append(id)
48 | print(not args.no_dropout)
49 | if args.training:
50 | print("Training")
51 | model = md.cycleGAN(args)
52 | model.train(args)
53 | if args.testing:
54 | print("Testing")
55 | tst.test(args)
56 |
57 |
58 | if __name__ == '__main__':
59 | main()
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import functools
3 |
4 | import os
5 | import torch
6 | from torch import nn
7 | from torch.autograd import Variable
8 | import torchvision.datasets as dsets
9 | import torchvision.transforms as transforms
10 | import utils
11 | from arch import define_Gen, define_Dis, set_grad
12 | from torch.optim import lr_scheduler
13 |
14 | '''
15 | Class for CycleGAN with train() as a member function
16 |
17 | '''
18 | class cycleGAN(object):
19 | def __init__(self,args):
20 |
21 | # Define the network
22 | #####################################################
23 | self.Gab = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG=args.gen_net, norm=args.norm,
24 | use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids)
25 | self.Gba = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG=args.gen_net, norm=args.norm,
26 | use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids)
27 | self.Da = define_Dis(input_nc=3, ndf=args.ndf, netD= args.dis_net, n_layers_D=3, norm=args.norm, gpu_ids=args.gpu_ids)
28 | self.Db = define_Dis(input_nc=3, ndf=args.ndf, netD= args.dis_net, n_layers_D=3, norm=args.norm, gpu_ids=args.gpu_ids)
29 |
30 | utils.print_networks([self.Gab,self.Gba,self.Da,self.Db], ['Gab','Gba','Da','Db'])
31 |
32 | # Define Loss criterias
33 |
34 | self.MSE = nn.MSELoss()
35 | self.L1 = nn.L1Loss()
36 |
37 | # Optimizers
38 | #####################################################
39 | self.g_optimizer = torch.optim.Adam(itertools.chain(self.Gab.parameters(),self.Gba.parameters()), lr=args.lr, betas=(0.5, 0.999))
40 | self.d_optimizer = torch.optim.Adam(itertools.chain(self.Da.parameters(),self.Db.parameters()), lr=args.lr, betas=(0.5, 0.999))
41 |
42 |
43 | self.g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.g_optimizer, lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)
44 | self.d_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.d_optimizer, lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)
45 |
46 | # Try loading checkpoint
47 | #####################################################
48 | if not os.path.isdir(args.checkpoint_dir):
49 | os.makedirs(args.checkpoint_dir)
50 |
51 | try:
52 | ckpt = utils.load_checkpoint('%s/latest.ckpt' % (args.checkpoint_dir))
53 | self.start_epoch = ckpt['epoch']
54 | self.Da.load_state_dict(ckpt['Da'])
55 | self.Db.load_state_dict(ckpt['Db'])
56 | self.Gab.load_state_dict(ckpt['Gab'])
57 | self.Gba.load_state_dict(ckpt['Gba'])
58 | self.d_optimizer.load_state_dict(ckpt['d_optimizer'])
59 | self.g_optimizer.load_state_dict(ckpt['g_optimizer'])
60 | except:
61 | print(' [*] No checkpoint!')
62 | self.start_epoch = 0
63 |
64 |
65 |
66 | def train(self,args):
67 | # For transforming the input image
68 | transform = transforms.Compose(
69 | [transforms.RandomHorizontalFlip(),
70 | transforms.Resize((args.load_height,args.load_width)),
71 | transforms.RandomCrop((args.crop_height,args.crop_width)),
72 | transforms.ToTensor(),
73 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
74 |
75 | dataset_dirs = utils.get_traindata_link(args.dataset_dir)
76 |
77 | # Pytorch dataloader
78 | a_loader = torch.utils.data.DataLoader(dsets.ImageFolder(dataset_dirs['trainA'], transform=transform),
79 | batch_size=args.batch_size, shuffle=True, num_workers=4)
80 | b_loader = torch.utils.data.DataLoader(dsets.ImageFolder(dataset_dirs['trainB'], transform=transform),
81 | batch_size=args.batch_size, shuffle=True, num_workers=4)
82 |
83 | a_fake_sample = utils.Sample_from_Pool()
84 | b_fake_sample = utils.Sample_from_Pool()
85 |
86 | for epoch in range(self.start_epoch, args.epochs):
87 |
88 | lr = self.g_optimizer.param_groups[0]['lr']
89 | print('learning rate = %.7f' % lr)
90 |
91 | for i, (a_real, b_real) in enumerate(zip(a_loader, b_loader)):
92 | # step
93 | step = epoch * min(len(a_loader), len(b_loader)) + i + 1
94 |
95 | # Generator Computations
96 | ##################################################
97 |
98 | set_grad([self.Da, self.Db], False)
99 | self.g_optimizer.zero_grad()
100 |
101 | a_real = Variable(a_real[0])
102 | b_real = Variable(b_real[0])
103 | a_real, b_real = utils.cuda([a_real, b_real])
104 |
105 | # Forward pass through generators
106 | ##################################################
107 | a_fake = self.Gab(b_real)
108 | b_fake = self.Gba(a_real)
109 |
110 | a_recon = self.Gab(b_fake)
111 | b_recon = self.Gba(a_fake)
112 |
113 | a_idt = self.Gab(a_real)
114 | b_idt = self.Gba(b_real)
115 |
116 | # Identity losses
117 | ###################################################
118 | a_idt_loss = self.L1(a_idt, a_real) * args.lamda * args.idt_coef
119 | b_idt_loss = self.L1(b_idt, b_real) * args.lamda * args.idt_coef
120 |
121 | # Adversarial losses
122 | ###################################################
123 | a_fake_dis = self.Da(a_fake)
124 | b_fake_dis = self.Db(b_fake)
125 |
126 | real_label = utils.cuda(Variable(torch.ones(a_fake_dis.size())))
127 |
128 | a_gen_loss = self.MSE(a_fake_dis, real_label)
129 | b_gen_loss = self.MSE(b_fake_dis, real_label)
130 |
131 | # Cycle consistency losses
132 | ###################################################
133 | a_cycle_loss = self.L1(a_recon, a_real) * args.lamda
134 | b_cycle_loss = self.L1(b_recon, b_real) * args.lamda
135 |
136 | # Total generators losses
137 | ###################################################
138 | gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss
139 |
140 | # Update generators
141 | ###################################################
142 | gen_loss.backward()
143 | self.g_optimizer.step()
144 |
145 |
146 | # Discriminator Computations
147 | #################################################
148 |
149 | set_grad([self.Da, self.Db], True)
150 | self.d_optimizer.zero_grad()
151 |
152 | # Sample from history of generated images
153 | #################################################
154 | a_fake = Variable(torch.Tensor(a_fake_sample([a_fake.cpu().data.numpy()])[0]))
155 | b_fake = Variable(torch.Tensor(b_fake_sample([b_fake.cpu().data.numpy()])[0]))
156 | a_fake, b_fake = utils.cuda([a_fake, b_fake])
157 |
158 | # Forward pass through discriminators
159 | #################################################
160 | a_real_dis = self.Da(a_real)
161 | a_fake_dis = self.Da(a_fake)
162 | b_real_dis = self.Db(b_real)
163 | b_fake_dis = self.Db(b_fake)
164 | real_label = utils.cuda(Variable(torch.ones(a_real_dis.size())))
165 | fake_label = utils.cuda(Variable(torch.zeros(a_fake_dis.size())))
166 |
167 | # Discriminator losses
168 | ##################################################
169 | a_dis_real_loss = self.MSE(a_real_dis, real_label)
170 | a_dis_fake_loss = self.MSE(a_fake_dis, fake_label)
171 | b_dis_real_loss = self.MSE(b_real_dis, real_label)
172 | b_dis_fake_loss = self.MSE(b_fake_dis, fake_label)
173 |
174 | # Total discriminators losses
175 | a_dis_loss = (a_dis_real_loss + a_dis_fake_loss)*0.5
176 | b_dis_loss = (b_dis_real_loss + b_dis_fake_loss)*0.5
177 |
178 | # Update discriminators
179 | ##################################################
180 | a_dis_loss.backward()
181 | b_dis_loss.backward()
182 | self.d_optimizer.step()
183 |
184 | print("Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e" %
185 | (epoch, i + 1, min(len(a_loader), len(b_loader)),
186 | gen_loss,a_dis_loss+b_dis_loss))
187 |
188 | # Override the latest checkpoint
189 | #######################################################
190 | utils.save_checkpoint({'epoch': epoch + 1,
191 | 'Da': self.Da.state_dict(),
192 | 'Db': self.Db.state_dict(),
193 | 'Gab': self.Gab.state_dict(),
194 | 'Gba': self.Gba.state_dict(),
195 | 'd_optimizer': self.d_optimizer.state_dict(),
196 | 'g_optimizer': self.g_optimizer.state_dict()},
197 | '%s/latest.ckpt' % (args.checkpoint_dir))
198 |
199 | # Update learning rates
200 | ########################
201 | self.g_lr_scheduler.step()
202 | self.d_lr_scheduler.step()
203 |
204 |
205 |
206 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch import nn
4 | from torch.autograd import Variable
5 | import torchvision
6 | import torchvision.datasets as dsets
7 | import torchvision.transforms as transforms
8 | import utils
9 | from arch import define_Gen, define_Dis
10 |
11 |
12 |
13 |
14 |
15 | def test(args):
16 |
17 | transform = transforms.Compose(
18 | [transforms.Resize((args.crop_height,args.crop_width)),
19 | transforms.ToTensor(),
20 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
21 |
22 | dataset_dirs = utils.get_testdata_link(args.dataset_dir)
23 |
24 | a_test_data = dsets.ImageFolder(dataset_dirs['testA'], transform=transform)
25 | b_test_data = dsets.ImageFolder(dataset_dirs['testB'], transform=transform)
26 |
27 |
28 | a_test_loader = torch.utils.data.DataLoader(a_test_data, batch_size=args.batch_size, shuffle=True, num_workers=4)
29 | b_test_loader = torch.utils.data.DataLoader(b_test_data, batch_size=args.batch_size, shuffle=True, num_workers=4)
30 |
31 | Gab = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG='resnet_9blocks', norm=args.norm,
32 | use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids)
33 | Gba = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG='resnet_9blocks', norm=args.norm,
34 | use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids)
35 |
36 | utils.print_networks([Gab,Gba], ['Gab','Gba'])
37 |
38 | try:
39 | ckpt = utils.load_checkpoint('%s/latest.ckpt' % (args.checkpoint_dir))
40 | Gab.load_state_dict(ckpt['Gab'])
41 | Gba.load_state_dict(ckpt['Gba'])
42 | except:
43 | print(' [*] No checkpoint!')
44 |
45 |
46 | """ run """
47 | a_real_test = Variable(iter(a_test_loader).next()[0], requires_grad=True)
48 | b_real_test = Variable(iter(b_test_loader).next()[0], requires_grad=True)
49 | a_real_test, b_real_test = utils.cuda([a_real_test, b_real_test])
50 |
51 |
52 | Gab.eval()
53 | Gba.eval()
54 |
55 | with torch.no_grad():
56 | a_fake_test = Gab(b_real_test)
57 | b_fake_test = Gba(a_real_test)
58 | a_recon_test = Gab(b_fake_test)
59 | b_recon_test = Gba(a_fake_test)
60 |
61 | pic = (torch.cat([a_real_test, b_fake_test, a_recon_test, b_real_test, a_fake_test, b_recon_test], dim=0).data + 1) / 2.0
62 |
63 | if not os.path.isdir(args.results_dir):
64 | os.makedirs(args.results_dir)
65 |
66 | torchvision.utils.save_image(pic, args.results_dir+'/sample.jpg', nrow=3)
67 |
68 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os
3 | import shutil
4 |
5 | import numpy as np
6 | import torch
7 |
8 | # To make directories
9 | def mkdir(paths):
10 | for path in paths:
11 | if not os.path.isdir(path):
12 | os.makedirs(path)
13 |
14 | # To make cuda tensor
15 | def cuda(xs):
16 | if torch.cuda.is_available():
17 | if not isinstance(xs, (list, tuple)):
18 | return xs.cuda()
19 | else:
20 | return [x.cuda() for x in xs]
21 |
22 | # For Pytorch data loader
23 | def create_link(dataset_dir):
24 | dirs = {}
25 | dirs['trainA'] = os.path.join(dataset_dir, 'ltrainA')
26 | dirs['trainB'] = os.path.join(dataset_dir, 'ltrainB')
27 | dirs['testA'] = os.path.join(dataset_dir, 'ltestA')
28 | dirs['testB'] = os.path.join(dataset_dir, 'ltestB')
29 | mkdir(dirs.values())
30 |
31 | for key in dirs:
32 | try:
33 | os.remove(os.path.join(dirs[key], 'Link'))
34 | except:
35 | pass
36 | os.symlink(os.path.abspath(os.path.join(dataset_dir, key)),
37 | os.path.join(dirs[key], 'Link'))
38 |
39 | return dirs
40 |
41 | def get_traindata_link(dataset_dir):
42 | dirs = {}
43 | dirs['trainA'] = os.path.join(dataset_dir, 'ltrainA')
44 | dirs['trainB'] = os.path.join(dataset_dir, 'ltrainB')
45 | return dirs
46 |
47 | def get_testdata_link(dataset_dir):
48 | dirs = {}
49 | dirs['testA'] = os.path.join(dataset_dir, 'ltestA')
50 | dirs['testB'] = os.path.join(dataset_dir, 'ltestB')
51 | return dirs
52 |
53 |
54 | # To save the checkpoint
55 | def save_checkpoint(state, save_path):
56 | torch.save(state, save_path)
57 |
58 |
59 | # To load the checkpoint
60 | def load_checkpoint(ckpt_path, map_location=None):
61 | ckpt = torch.load(ckpt_path, map_location=map_location)
62 | print(' [*] Loading checkpoint from %s succeed!' % ckpt_path)
63 | return ckpt
64 |
65 |
66 | # To store 50 generated image in a pool and sample from it when it is full
67 | # Shrivastava et al’s strategy
68 | class Sample_from_Pool(object):
69 | def __init__(self, max_elements=50):
70 | self.max_elements = max_elements
71 | self.cur_elements = 0
72 | self.items = []
73 |
74 | def __call__(self, in_items):
75 | return_items = []
76 | for in_item in in_items:
77 | if self.cur_elements < self.max_elements:
78 | self.items.append(in_item)
79 | self.cur_elements = self.cur_elements + 1
80 | return_items.append(in_item)
81 | else:
82 | if np.random.ranf() > 0.5:
83 | idx = np.random.randint(0, self.max_elements)
84 | tmp = copy.copy(self.items[idx])
85 | self.items[idx] = in_item
86 | return_items.append(tmp)
87 | else:
88 | return_items.append(in_item)
89 | return return_items
90 |
91 | class LambdaLR():
92 | def __init__(self, epochs, offset, decay_epoch):
93 | self.epochs = epochs
94 | self.offset = offset
95 | self.decay_epoch = decay_epoch
96 |
97 | def step(self, epoch):
98 | return 1.0 - max(0, epoch + self.offset - self.decay_epoch)/(self.epochs - self.decay_epoch)
99 |
100 | def print_networks(nets, names):
101 | print('------------Number of Parameters---------------')
102 | i=0
103 | for net in nets:
104 | num_params = 0
105 | for param in net.parameters():
106 | num_params += param.numel()
107 | print('[Network %s] Total number of parameters : %.3f M' % (names[i], num_params / 1e6))
108 | i=i+1
109 | print('-----------------------------------------------')
--------------------------------------------------------------------------------