├── .gitignore ├── requirements.txt ├── cyclegan ├── utils │ ├── ops.py │ ├── tester.py │ ├── logger.py │ └── data.py ├── config.py ├── network.py └── model.py ├── train.py ├── transfer.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | __pycache__ 3 | .idea 4 | outputs 5 | datasets 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Pillow==4.3.0 2 | scipy==1.0.0 3 | torch==0.3.0.post4 4 | torchvision==0.2.0 5 | -------------------------------------------------------------------------------- /cyclegan/utils/ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for tensor operations 3 | """ 4 | import torch 5 | 6 | 7 | def optimize(loss, optimizer): 8 | optimizer.zero_grad() 9 | loss.backward() 10 | optimizer.step() 11 | 12 | 13 | def variable(data): 14 | if torch.cuda.is_available(): 15 | return torch.autograd.Variable(data.cuda()) 16 | else: 17 | return torch.autograd.Variable(data) 18 | 19 | 20 | def normal_weights_initializer(m, mean, std): 21 | if isinstance(m, torch.nn.ConvTranspose2d) or isinstance(m, torch.nn.Conv2d): 22 | m.weight.data.normal_(mean, std) 23 | m.bias.data.zero_() 24 | 25 | 26 | class Activation: 27 | ReLU = 1 28 | LeakyReLU = 2 29 | Tanh = 3 30 | 31 | def __init__(self, activation_fn=None, batch_norm=True, output_size=0): 32 | self.activation_fn = self._init_activation_fn(activation_fn) 33 | self.batch_norm = self._init_batch_norm(batch_norm, output_size) 34 | 35 | def _init_activation_fn(self, fn): 36 | if fn == self.ReLU: 37 | return torch.nn.ReLU(True) 38 | elif fn == self.LeakyReLU: 39 | return torch.nn.LeakyReLU(0.2, True) 40 | elif fn == self.Tanh: 41 | return torch.nn.Tanh() 42 | else: 43 | return None 44 | 45 | @staticmethod 46 | def _init_batch_norm(batch_norm, output_size): 47 | if batch_norm: 48 | if torch.cuda.is_available(): 49 | return torch.nn.InstanceNorm2d(output_size).cuda() 50 | else: 51 | return torch.nn.InstanceNorm2d(output_size) 52 | else: 53 | return None 54 | 55 | def apply(self, output): 56 | if self.batch_norm: 57 | output = self.batch_norm(output) 58 | 59 | if self.activation_fn: 60 | output = self.activation_fn(output) 61 | 62 | return output 63 | -------------------------------------------------------------------------------- /cyclegan/utils/tester.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions especially for testing for now 3 | """ 4 | import os 5 | 6 | import scipy.misc as misc 7 | 8 | from cyclegan.utils.data import DataLoader 9 | 10 | 11 | def get_test_output_file_path(prefix, subdir): 12 | path_real = os.path.join(subdir, '{0}_real.png'.format(prefix)) 13 | path_fake = os.path.join(subdir, '{0}_fake.png'.format(prefix)) 14 | path_recon = os.path.join(subdir, '{0}_recon.png'.format(prefix)) 15 | 16 | return path_real, path_fake, path_recon 17 | 18 | 19 | def prepare_test_output_dir_by_epoch(epoch, base_dir): 20 | subdir_A_to_B = os.path.join(base_dir, str(epoch), 'A_to_B') 21 | subdir_B_to_A = os.path.join(base_dir, str(epoch), 'B_to_A') 22 | 23 | os.makedirs(subdir_A_to_B, exist_ok=True) 24 | os.makedirs(subdir_B_to_A, exist_ok=True) 25 | 26 | return subdir_A_to_B, subdir_B_to_A 27 | 28 | 29 | def generate_one(generator, real_data, output_dir, prefix): 30 | path_real, path_fake, path_recon = get_test_output_file_path(prefix, output_dir) 31 | 32 | fake_data, recon_data = generator(real_data, recon=True) 33 | 34 | misc.imsave(path_real, real_data[0].numpy().transpose(1, 2, 0)) 35 | misc.imsave(path_fake, fake_data[0].cpu().data.numpy().transpose(1, 2, 0)) 36 | misc.imsave(path_recon, recon_data[0].cpu().data.numpy().transpose(1, 2, 0)) 37 | 38 | 39 | def generate_testset(epoch, model, args, max=10, shuffle=True): 40 | subdir_A_to_B, subdir_B_to_A = prepare_test_output_dir_by_epoch(epoch, args.test_output_dir) 41 | 42 | test_data_loader_A = DataLoader(args.test_data_A_dir, args.input_size, args.batch_size, shuffle=shuffle) 43 | test_data_loader_B = DataLoader(args.test_data_B_dir, args.input_size, args.batch_size, shuffle=shuffle) 44 | 45 | for i, (real_A, real_B) in enumerate(zip(test_data_loader_A, test_data_loader_B)): 46 | generate_one(model.generate_A_to_B, real_A, subdir_A_to_B, i) 47 | generate_one(model.generate_B_to_A, real_B, subdir_B_to_A, i) 48 | 49 | if max and i >= max - 1: 50 | break 51 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | A script for CycleGAN training 3 | """ 4 | import os 5 | 6 | import cyclegan.config as config 7 | from cyclegan.utils import tester 8 | from cyclegan.utils.logger import Logger 9 | from cyclegan.model import CycleGAN 10 | from cyclegan.utils.data import DataLoader 11 | 12 | 13 | class Trainer: 14 | def __init__(self, args): 15 | self.args = args 16 | 17 | self.data_loader_A = DataLoader(args.data_A_dir, args.input_size, args.batch_size, shuffle=True) 18 | self.data_loader_B = DataLoader(args.data_B_dir, args.input_size, args.batch_size, shuffle=True) 19 | 20 | self.cycleGAN = CycleGAN(args.ngf, args.ndf, args.num_resnet, 21 | args.lrG, args.lrD, args.beta1, args.beta2, 22 | args.lambdaA, args.lambdaB, args.num_pool) 23 | 24 | self.logger = Logger() 25 | 26 | def run_epoch(self): 27 | for i, (real_A, real_B) in enumerate(zip(self.data_loader_A, self.data_loader_B)): 28 | losses = self.cycleGAN.train(real_A, real_B) 29 | 30 | self.logger.add_losses(losses) 31 | self.logger.print_last_loss() 32 | 33 | def run(self): 34 | for epoch in range(self.args.num_epochs): 35 | if (epoch + 1) > self.args.decay_epoch: 36 | self.cycleGAN.decay_optimizer(self.args.num_epochs, self.args.decay_epoch) 37 | 38 | self.run_epoch() 39 | 40 | self.logger.next_epoch() 41 | self.logger.print_avg_loss() 42 | 43 | if self.args.test_data_A_dir and self.args.test_data_B_dir: 44 | tester.generate_testset(epoch, self.cycleGAN, self.args) 45 | 46 | self.cycleGAN.save(args.model_dir) 47 | 48 | self.logger.print_status() 49 | 50 | 51 | def prepare_output_dir(args): 52 | args.log_dir = os.path.join(args.output_dir, 'log') 53 | args.model_dir = os.path.join(args.output_dir, 'model') 54 | args.test_output_dir = os.path.join(args.output_dir, 'test') 55 | 56 | os.makedirs(args.log_dir, exist_ok=True) 57 | os.makedirs(args.model_dir, exist_ok=True) 58 | os.makedirs(args.test_output_dir, exist_ok=True) 59 | 60 | 61 | if __name__ == '__main__': 62 | args = config.parse_args(is_training=True) 63 | 64 | prepare_output_dir(args) 65 | 66 | trainer = Trainer(args) 67 | 68 | trainer.run() 69 | -------------------------------------------------------------------------------- /transfer.py: -------------------------------------------------------------------------------- 1 | """ 2 | A script for transferring 3 | """ 4 | import os 5 | 6 | import scipy.misc as misc 7 | 8 | import cyclegan.config as config 9 | from cyclegan.model import CycleGAN 10 | from cyclegan.utils.data import load_file 11 | 12 | 13 | class Transfer: 14 | def __init__(self, args): 15 | self.args = args 16 | 17 | self.cycleGAN = CycleGAN(args.ngf, args.ndf, args.num_resnet, 18 | args.lrG, args.lrD, args.beta1, args.beta2, 19 | args.lambdaA, args.lambdaB, args.num_pool) 20 | 21 | self.cycleGAN.load(args.model, args.type) 22 | 23 | def generate(self, data_src): 24 | if self.args.type is 'BtoA': 25 | generator = self.cycleGAN.generate_B_to_A 26 | else: 27 | generator = self.cycleGAN.generate_A_to_B 28 | 29 | generated, _ = generator(data_src, recon=False) 30 | 31 | generated = generated[0].cpu().data.numpy().transpose(1, 2, 0) 32 | 33 | return generated 34 | 35 | def run_dir(self, src_dir, out_dir): 36 | for filename in os.listdir(src_dir): 37 | src_path = os.path.join(src_dir, filename) 38 | out_path = os.path.join(out_dir, filename) 39 | 40 | self.run_file(src_path, out_path) 41 | 42 | def run_file(self, src_path, out_path): 43 | print('{0} -> {1}'.format(src_path, out_path)) 44 | 45 | data_src = load_file(src_path, self.args.input_size) 46 | 47 | data_out = self.generate(data_src) 48 | 49 | misc.imsave(out_path, data_out) 50 | 51 | 52 | def prepare_output_dir(args): 53 | args.log_dir = os.path.join(args.output_dir, 'log') 54 | args.model_dir = os.path.join(args.output_dir, 'model') 55 | args.test_output_dir = os.path.join(args.output_dir, 'test') 56 | 57 | os.makedirs(args.log_dir, exist_ok=True) 58 | os.makedirs(args.model_dir, exist_ok=True) 59 | os.makedirs(args.test_output_dir, exist_ok=True) 60 | 61 | 62 | if __name__ == '__main__': 63 | args = config.parse_args() 64 | 65 | transfer = Transfer(args) 66 | 67 | print('Running {0} Transfer'.format(args.type)) 68 | 69 | if args.src_dir and args.out_dir: 70 | os.makedirs(args.out_dir, exist_ok=False) 71 | transfer.run_dir(args.src_dir, args.out_dir) 72 | elif args.src and args.out: 73 | transfer.run_file(args.src, args.out) 74 | else: 75 | config.print_help() 76 | -------------------------------------------------------------------------------- /cyclegan/utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for logging especially for loss 3 | """ 4 | import datetime 5 | 6 | import torch 7 | 8 | from cyclegan.model import LossType 9 | 10 | 11 | class Logger: 12 | def __init__(self): 13 | self.epoch = 0 14 | self.step = 0 15 | self.losses_avg = self.init_losses() 16 | self.losses_in_epoch = self.init_losses() 17 | self.start_time = datetime.datetime.now() 18 | 19 | @staticmethod 20 | def init_losses(): 21 | losses = { 22 | LossType.D_A: [], 23 | LossType.D_B: [], 24 | LossType.G_AB: [], 25 | LossType.G_BA: [], 26 | LossType.cycle_A: [], 27 | LossType.cycle_B: [] 28 | } 29 | 30 | return losses 31 | 32 | def add_losses(self, losses): 33 | """ 34 | Add loss for logging 35 | :param losses: dictionary with LosType keys 36 | :return: 37 | """ 38 | for key, loss in losses.items(): 39 | if key in self.losses_in_epoch: 40 | self.losses_in_epoch[key].append(loss.data[0]) 41 | 42 | self.step += 1 43 | 44 | def next_epoch(self): 45 | for key, losses in self.losses_in_epoch.items(): 46 | avg_loss = torch.mean(torch.FloatTensor(losses)) 47 | self.losses_avg[key].append(avg_loss) 48 | 49 | self.losses_in_epoch = self.init_losses() 50 | 51 | self.epoch += 1 52 | 53 | def print_last_loss(self): 54 | print('Epoch [%d], Step [%d] : D_A_loss = %.4f, D_B_loss = %.4f, G_AB_loss = %.4f, G_BA_loss = %.4f' 55 | % (self.epoch, self.step, 56 | self.losses_in_epoch[LossType.D_A][-1], 57 | self.losses_in_epoch[LossType.D_B][-1], 58 | self.losses_in_epoch[LossType.G_AB][-1], 59 | self.losses_in_epoch[LossType.G_BA][-1])) 60 | 61 | def print_avg_loss(self): 62 | print('Avg. Epoch [%d], Step [%d] : D_A_loss = %.4f, D_B_loss = %.4f, G_AB_loss = %.4f, G_BA_loss = %.4f' 63 | % (self.epoch, self.step, 64 | self.losses_avg[LossType.D_A][-1], 65 | self.losses_avg[LossType.D_B][-1], 66 | self.losses_avg[LossType.G_AB][-1], 67 | self.losses_avg[LossType.G_BA][-1])) 68 | 69 | def print_status(self): 70 | print('Epoch [%d], Step [%d], Total time : %s' 71 | % (self.epoch, self.step, datetime.datetime.now() - self.start_time)) 72 | -------------------------------------------------------------------------------- /cyclegan/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training Options 3 | """ 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | 8 | parser.add_argument('--batch_size', type=int, default=1, help='batch size') 9 | parser.add_argument('--input_size', type=int, default=256, help='input size') 10 | parser.add_argument('--num_pool', type=int, default=50, help='number of image pool') 11 | parser.add_argument('--ngf', type=int, default=32, help='number of generator\'s filter') 12 | parser.add_argument('--ndf', type=int, default=64, help='number of discriminator\'s filter') 13 | parser.add_argument('--num_resnet', type=int, default=6, help='number of resnet blocks') 14 | parser.add_argument('--num_epochs', type=int, default=200, help='number of train epochs') 15 | parser.add_argument('--decay_epoch', type=int, default=100, help='learning rate decay') 16 | parser.add_argument('--lrG', type=float, default=0.0002, help='learning rate for generator') 17 | parser.add_argument('--lrD', type=float, default=0.0002, help='learning rate for discriminator') 18 | parser.add_argument('--lambdaA', type=float, default=10, help='lambdaA for cycle loss') 19 | parser.add_argument('--lambdaB', type=float, default=10, help='lambdaB for cycle loss') 20 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer') 21 | parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer') 22 | 23 | 24 | def parse_args(is_training=False): 25 | if is_training: 26 | parser.add_argument('--data_A_dir', type=str, required=True, help='A data dir for training') 27 | parser.add_argument('--data_B_dir', type=str, required=True, help='B data dir for training') 28 | parser.add_argument('--test_data_A_dir', type=str, required=False, help='A data dir for testing') 29 | parser.add_argument('--test_data_B_dir', type=str, required=False, help='B data dir for testing') 30 | parser.add_argument('--output_dir', type=str, required=True, help='output dir') 31 | else: 32 | parser.add_argument('--model', type=str, required=True, help='Pre-trained model dir') 33 | parser.add_argument('--type', type=str, default='AtoB', help='Generator type: AtoB, BtoA') 34 | parser.add_argument('--src', type=str, required=False, help='File path of an source image') 35 | parser.add_argument('--out', type=str, required=False, help='File path of an output image file path which is transferred') 36 | parser.add_argument('--src_dir', type=str, required=False, help='A data dir you want to transfer') 37 | parser.add_argument('--out_dir', type=str, required=False, help='output dir ') 38 | 39 | return parser.parse_args() 40 | 41 | 42 | def print_help(): 43 | parser.print_help() 44 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cycle GAN 2 | 3 | Yet another Cycle GAN implementation in PyTorch. 4 | 5 | The purpose of this implementation is Well-structured, reusable and easily understandable. 6 | 7 | - [CycleGAN Paper](https://arxiv.org/pdf/1703.10593.pdf) 8 | - [Download datasets](https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/) 9 | 10 | ## Prerequisites 11 | 12 | - System 13 | - Linux or macOS 14 | - CPU or (NVIDIA GPU + CUDA CuDNN) 15 | - It can run on Single CPU/GPU or Multi GPUs. 16 | - Python 3 17 | 18 | - Libraries 19 | - PyTorch >= 0.3.0 20 | - Torchvision >= 0.2.0 21 | - scipy >= 1.0.0 22 | - Pillow >= 0.2.0 23 | 24 | ## Training 25 | 26 | ```bash 27 | python train.py \ 28 | --data_A_dir=./datasets/apple2orange/trainA \ 29 | --data_B_dir=./datasets/apple2orange/trainB \ 30 | --output_dir=./outputs 31 | ``` 32 | 33 | If you set `test_data_A_dir` and `test_data_B_dir` then generate A->B and B->A when end of every epoch. 34 | 35 | ```bash 36 | python train.py \ 37 | --data_A_dir=./datasets/apple2orange/trainA \ 38 | --data_B_dir=./datasets/apple2orange/trainB \ 39 | --test_data_A_dir=./datasets/apple2orange/testA \ 40 | --test_data_B_dir=./datasets/apple2orange/testB \ 41 | --output_dir=./outputs 42 | ``` 43 | 44 | Use `python train.py --help` to see more options. 45 | 46 | ## Transferring 47 | 48 | For single file 49 | 50 | ```bash 51 | python transfer.py \ 52 | --model=./outputs/model \ 53 | --src=./datasets/apple2orange/testA/n07740461_41.jpg \ 54 | --out=./outputs/apple2orange.png 55 | ``` 56 | 57 | For directory 58 | 59 | ```bash 60 | python transfer.py \ 61 | --src_dir=./datasets/apple2orange/testA \ 62 | --out_dir=./outputs/testA 63 | ``` 64 | 65 | Use `python transfer.py --help` to see more options. 66 | 67 | ## File structures 68 | 69 | `network.py` and `model.py` is main implementations. 70 | 71 | - cyclegan 72 | - `config.py` : Training options 73 | - `network.py` : The neural network architecture of Cycle GAN 74 | - `model.py` : Calculate loss and optimizing 75 | - utils 76 | - `data.py` : Utilities for loading data 77 | - `logger.py` : Utilities for logging 78 | - `ops.py` : Utilities for tensor operations 79 | - `tester.py` : Utility functions especially for testing 80 | - `train.py` : A script for CycleGAN training 81 | - `transfer.py` : A script for transferring with pre-trained model 82 | 83 | # TODO 84 | 85 | - [ ] Visualizing training progress with Visdom 86 | - [ ] Add some nice generated images and videos :-) 87 | 88 | ## References 89 | 90 | - https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix 91 | - https://hardikbansal.github.io/CycleGANBlog 92 | - https://github.com/togheppi/CycleGAN 93 | - https://github.com/znxlwm/pytorch-CycleGAN 94 | -------------------------------------------------------------------------------- /cyclegan/utils/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for loading training data 3 | """ 4 | import os 5 | import random 6 | 7 | from PIL import Image 8 | 9 | import torch 10 | import torch.utils.data as data 11 | from torchvision import transforms 12 | 13 | import cyclegan.utils.ops as ops 14 | 15 | 16 | class Dataset(data.Dataset): 17 | def __init__(self, data_dir, transform=None): 18 | super(Dataset, self).__init__() 19 | 20 | self.root_path = data_dir 21 | self.filenames = [x for x in sorted(os.listdir(data_dir))] 22 | 23 | self.transform = transform 24 | 25 | def __getitem__(self, index): 26 | filepath = os.path.join(self.root_path, self.filenames[index]) 27 | img = Image.open(filepath).convert('RGB') 28 | 29 | if self.transform is not None: 30 | img = self.transform(img) 31 | 32 | return img 33 | 34 | def __len__(self): 35 | return len(self.filenames) 36 | 37 | 38 | class DataLoader(data.DataLoader): 39 | def __init__(self, data_dir, input_size, batch_size, shuffle=False): 40 | transform = transforms.Compose([ 41 | transforms.Resize(input_size), 42 | transforms.ToTensor(), 43 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 44 | ]) 45 | 46 | dataset = Dataset(data_dir, transform) 47 | 48 | super(DataLoader, self).__init__(dataset, batch_size, shuffle) 49 | 50 | 51 | class ImagePool: 52 | def __init__(self, pool_size): 53 | self.pool_size = pool_size 54 | self.images = [] 55 | 56 | def random_swap(self, image): 57 | swap = image 58 | 59 | if random.uniform(0, 1) > 0.5: 60 | random_id = random.randint(0, self.pool_size - 1) 61 | old = self.images[random_id].clone() 62 | self.images[random_id] = image 63 | swap = old 64 | 65 | return swap 66 | 67 | def query(self, images): 68 | return_images = [] 69 | 70 | for image in images.data: 71 | image = torch.unsqueeze(image, 0) 72 | 73 | if len(self.images) < self.pool_size: 74 | self.images.append(image) 75 | return_images.append(image) 76 | else: 77 | return_image = self.random_swap(image) 78 | return_images.append(return_image) 79 | 80 | return ops.variable(torch.cat(return_images, 0)) 81 | 82 | 83 | def load_file(filepath, input_size): 84 | transform = transforms.Compose([ 85 | transforms.Resize(input_size), 86 | transforms.ToTensor(), 87 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 88 | ]) 89 | 90 | img = Image.open(filepath).convert('RGB') 91 | 92 | img = transform(img) 93 | img = torch.unsqueeze(img, 0) 94 | 95 | return img 96 | -------------------------------------------------------------------------------- /cyclegan/network.py: -------------------------------------------------------------------------------- 1 | """ 2 | The neural network architecture of Cycle GAN 3 | """ 4 | import torch 5 | import torch.nn.modules 6 | 7 | from cyclegan.utils.ops import Activation, normal_weights_initializer 8 | 9 | 10 | class ConvLayer(torch.nn.Module): 11 | def __init__(self, input_size, output_size, 12 | kernel_size=3, stride=2, padding=1, 13 | activation_fn=None, batch_norm=True): 14 | super(ConvLayer, self).__init__() 15 | 16 | self.conv = torch.nn.Conv2d(input_size, output_size, kernel_size, stride, padding) 17 | 18 | self.activation = Activation(activation_fn, batch_norm, output_size) 19 | 20 | def forward(self, x): 21 | output = self.conv(x) 22 | output = self.activation.apply(output) 23 | 24 | return output 25 | 26 | 27 | class DeconvLayer(torch.nn.Module): 28 | def __init__(self, input_size, output_size, 29 | kernel_size=3, stride=2, padding=1, output_padding=1, 30 | activation_fn=None, batch_norm=True): 31 | super(DeconvLayer, self).__init__() 32 | 33 | self.deconv = torch.nn.ConvTranspose2d(input_size, output_size, kernel_size, stride, padding, output_padding) 34 | 35 | self.activation = Activation(activation_fn, batch_norm, output_size) 36 | 37 | def forward(self, x): 38 | output = self.deconv(x) 39 | output = self.activation.apply(output) 40 | 41 | return output 42 | 43 | 44 | class ResnetBlock(torch.nn.Module): 45 | def __init__(self, num_filter, kernel_size=3, stride=1, padding=0): 46 | super(ResnetBlock, self).__init__() 47 | 48 | self.pad = torch.nn.ReflectionPad2d(padding) 49 | 50 | self.conv1 = ConvLayer(num_filter, num_filter, 51 | kernel_size=kernel_size, stride=stride, padding=0, 52 | activation_fn=Activation.ReLU, batch_norm=True) 53 | 54 | self.conv2 = ConvLayer(num_filter, num_filter, 55 | kernel_size=kernel_size, stride=stride, padding=0, 56 | activation_fn=None, batch_norm=True) 57 | 58 | def forward(self, x): 59 | output = self.pad(x) 60 | output = self.conv1(output) 61 | output = self.pad(output) 62 | output = self.conv2(output) 63 | 64 | return x + output 65 | 66 | 67 | class Generator(torch.nn.Module): 68 | def __init__(self, in_channels=3, out_channels=3, num_filter=32, num_resnet=6): 69 | super(Generator, self).__init__() 70 | 71 | self.encoder = self._init_encoder_layer(in_channels, num_filter) 72 | 73 | self.transformer = self._init_reset_block(num_filter, num_resnet) 74 | 75 | self.decoder = self._init_decoder_layer(out_channels, num_filter) 76 | 77 | @staticmethod 78 | def _init_encoder_layer(in_channels, num_filter): 79 | pad = torch.nn.ReflectionPad2d(3) 80 | 81 | conv1 = ConvLayer(in_channels, num_filter, 82 | kernel_size=7, stride=1, padding=0, 83 | activation_fn=Activation.ReLU, batch_norm=True) 84 | 85 | conv2 = ConvLayer(num_filter, num_filter * 2, 86 | kernel_size=3, stride=2, padding=1, 87 | activation_fn=Activation.ReLU, batch_norm=True) 88 | 89 | conv3 = ConvLayer(num_filter * 2, num_filter * 4, 90 | kernel_size=3, stride=2, padding=1, 91 | activation_fn=Activation.ReLU, batch_norm=True) 92 | 93 | return torch.nn.Sequential(pad, conv1, conv2, conv3) 94 | 95 | @staticmethod 96 | def _init_reset_block(num_filter, num_resnet): 97 | resnet_blocks = [] 98 | 99 | for i in range(num_resnet): 100 | resnet = ResnetBlock(num_filter * 4, 3, 1, 1) 101 | resnet_blocks.append(resnet) 102 | 103 | return torch.nn.Sequential(*resnet_blocks) 104 | 105 | @staticmethod 106 | def _init_decoder_layer(out_channels, num_filter): 107 | deconv1 = DeconvLayer(num_filter * 4, num_filter * 2, 108 | kernel_size=3, stride=2, padding=1, output_padding=1, 109 | activation_fn=Activation.ReLU, batch_norm=True) 110 | 111 | deconv2 = DeconvLayer(num_filter * 2, num_filter, 112 | kernel_size=3, stride=2, padding=1, output_padding=1, 113 | activation_fn=Activation.ReLU, batch_norm=True) 114 | 115 | pad = torch.nn.ReflectionPad2d(3) 116 | 117 | conv1 = ConvLayer(num_filter, out_channels, 118 | kernel_size=7, stride=1, padding=0, 119 | activation_fn=Activation.Tanh, batch_norm=False) 120 | 121 | return torch.nn.Sequential(deconv1, deconv2, pad, conv1) 122 | 123 | def init_weight(self, mean, std): 124 | for m in self._modules: 125 | normal_weights_initializer(self._modules[m], mean, std) 126 | 127 | def forward(self, x): 128 | output = self.encoder(x) 129 | output = self.transformer(output) 130 | output = self.decoder(output) 131 | 132 | return output 133 | 134 | 135 | class Discriminator(torch.nn.Module): 136 | def __init__(self, in_channels, out_channels, num_filter=64): 137 | super(Discriminator, self).__init__() 138 | 139 | self.discriminator = self._init_layer(in_channels, out_channels, num_filter) 140 | 141 | @staticmethod 142 | def _init_layer(in_channels, out_channels, num_filter): 143 | conv1 = ConvLayer(in_channels, num_filter, 144 | kernel_size=4, stride=2, padding=1, 145 | activation_fn=Activation.LeakyReLU, batch_norm=False) 146 | 147 | conv2 = ConvLayer(num_filter, num_filter * 2, 148 | kernel_size=4, stride=2, padding=1, 149 | activation_fn=Activation.LeakyReLU, batch_norm=True) 150 | 151 | conv3 = ConvLayer(num_filter * 2, num_filter * 4, 152 | kernel_size=4, stride=2, padding=1, 153 | activation_fn=Activation.LeakyReLU, batch_norm=True) 154 | 155 | conv4 = ConvLayer(num_filter * 4, num_filter * 8, 156 | kernel_size=4, stride=1, padding=1, 157 | activation_fn=Activation.LeakyReLU, batch_norm=True) 158 | 159 | conv5 = ConvLayer(num_filter * 8, out_channels, 160 | kernel_size=4, stride=1, padding=1, 161 | activation_fn=None, batch_norm=False) 162 | 163 | return torch.nn.Sequential(conv1, conv2, conv3, conv4, conv5) 164 | 165 | def init_weight(self, mean, std): 166 | for m in self._modules: 167 | normal_weights_initializer(self._modules[m], mean, std) 168 | 169 | def forward(self, x): 170 | output = self.discriminator(x) 171 | 172 | return output 173 | 174 | -------------------------------------------------------------------------------- /cyclegan/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main model of CycleGAN 3 | Calculate loss and optimizing 4 | """ 5 | import os 6 | import itertools 7 | from enum import Enum 8 | 9 | import torch 10 | 11 | import cyclegan.utils.ops as ops 12 | from cyclegan.utils.data import ImagePool 13 | from cyclegan.network import Generator, Discriminator 14 | 15 | 16 | class LossType(Enum): 17 | G = 1 18 | D_A = 2 19 | D_B = 3 20 | G_AB = 4 21 | G_BA = 5 22 | cycle_A = 6 23 | cycle_B = 7 24 | 25 | 26 | class CycleGAN: 27 | def __init__(self, ngf=32, ndf=64, num_resnet=6, 28 | lrG=0.0002, lrD=0.0002, beta1=0.5, beta2=0.999, 29 | lambdaA=10, lambdaB=10, num_pool=50): 30 | 31 | self.lrG, self.lrD = lrG, lrD 32 | self.beta1, self.beta2 = beta1, beta2 33 | self.lambdaA, self.lambdaB = lambdaA, lambdaB 34 | 35 | self.G_AB, self.G_BA, self.D_A, self.D_B =\ 36 | self._init_network(ngf, ndf, num_resnet) 37 | 38 | self.MSE_loss, self.L1_loss = self._init_loss() 39 | 40 | self.G_optimizer, self.D_A_optimizer, self.D_B_optimizer =\ 41 | self._init_optimizer(lrG, lrD, beta1, beta2) 42 | 43 | self.fake_A_pool = ImagePool(num_pool) 44 | self.fake_B_pool = ImagePool(num_pool) 45 | 46 | self._prepare_for_gpu() 47 | 48 | @staticmethod 49 | def _init_network(ngf, ndf, num_resnet): 50 | G_AB = Generator(3, 3, ngf, num_resnet) 51 | G_BA = Generator(3, 3, ngf, num_resnet) 52 | D_A = Discriminator(3, 1, ndf) 53 | D_B = Discriminator(3, 1, ndf) 54 | 55 | G_AB.init_weight(mean=0.0, std=0.02) 56 | G_BA.init_weight(mean=0.0, std=0.02) 57 | D_A.init_weight(mean=0.0, std=0.02) 58 | D_B.init_weight(mean=0.0, std=0.02) 59 | 60 | return G_AB, G_BA, D_A, D_B 61 | 62 | def _prepare_for_gpu(self): 63 | if torch.cuda.device_count() > 1: 64 | print("{0} GPUs are detected.".format(torch.cuda.device_count())) 65 | 66 | self.G_AB = torch.nn.DataParallel(self.G_AB) 67 | self.G_BA = torch.nn.DataParallel(self.G_BA) 68 | self.D_A = torch.nn.DataParallel(self.D_A) 69 | self.D_B = torch.nn.DataParallel(self.D_B) 70 | 71 | if torch.cuda.is_available(): 72 | self.G_AB.cuda() 73 | self.G_BA.cuda() 74 | self.D_A.cuda() 75 | self.D_B.cuda() 76 | 77 | @staticmethod 78 | def _init_loss(): 79 | if torch.cuda.is_available(): 80 | MSE_loss = torch.nn.MSELoss().cuda() 81 | L1_loss = torch.nn.L1Loss().cuda() 82 | else: 83 | MSE_loss = torch.nn.MSELoss() 84 | L1_loss = torch.nn.L1Loss() 85 | 86 | return MSE_loss, L1_loss 87 | 88 | def _init_optimizer(self, lrG, lrD, beta1, beta2): 89 | G_optimizer = torch.optim.Adam(itertools.chain(self.G_AB.parameters(), self.G_BA.parameters()), 90 | lr=lrG, betas=(beta1, beta2)) 91 | D_A_optimizer = torch.optim.Adam(self.D_A.parameters(), lr=lrD, betas=(beta1, beta2)) 92 | D_B_optimizer = torch.optim.Adam(self.D_B.parameters(), lr=lrD, betas=(beta1, beta2)) 93 | 94 | return G_optimizer, D_A_optimizer, D_B_optimizer 95 | 96 | def decay_optimizer(self, num_epochs, decay_epoch): 97 | self.G_optimizer.param_groups[0]['lr'] -= self.lrG / (num_epochs - decay_epoch) 98 | self.D_A_optimizer.param_groups[0]['lr'] -= self.lrD / (num_epochs - decay_epoch) 99 | self.D_B_optimizer.param_groups[0]['lr'] -= self.lrD / (num_epochs - decay_epoch) 100 | 101 | def train_generator(self, real_A, real_B): 102 | # Train G : A -> B 103 | fake_B = self.G_AB(real_A) 104 | D_B_result = self.D_B(fake_B) 105 | G_AB_loss = self.MSE_loss(D_B_result, ops.variable(torch.ones(D_B_result.size()))) 106 | 107 | # Reconstruct and calculate cycle loss : fake B -> A 108 | recon_A = self.G_BA(fake_B) 109 | cycle_A_loss = self.L1_loss(recon_A, real_A) * self.lambdaA 110 | 111 | # Train G : B -> A 112 | fake_A = self.G_BA(real_B) 113 | D_A_result = self.D_A(fake_A) 114 | G_BA_loss = self.MSE_loss(D_A_result, ops.variable(torch.ones(D_A_result.size()))) 115 | 116 | # Reconstruct and calculate cycle loss : fake A -> B 117 | recon_B = self.G_AB(fake_A) 118 | cycle_B_loss = self.L1_loss(recon_B, real_B) * self.lambdaB 119 | 120 | # Optimize G for training generator 121 | G_loss = G_AB_loss + G_BA_loss + cycle_A_loss + cycle_B_loss 122 | ops.optimize(G_loss, self.G_optimizer) 123 | 124 | losses = { 125 | LossType.G: G_loss, 126 | LossType.G_AB: G_AB_loss, 127 | LossType.G_BA: G_BA_loss, 128 | LossType.cycle_A: cycle_A_loss, 129 | LossType.cycle_B: cycle_B_loss 130 | } 131 | 132 | return losses, fake_A, fake_B 133 | 134 | def train_discriminator(self, real_A, real_B, fake_A, fake_B): 135 | # Train D for A 136 | fake_A = self.fake_A_pool.query(fake_A) 137 | D_A_real_result = self.D_A(real_A) 138 | D_A_fake_result = self.D_A(fake_A) 139 | D_A_real_loss = self.MSE_loss(D_A_real_result, ops.variable(torch.ones(D_A_real_result.size()))) 140 | D_A_fake_loss = self.MSE_loss(D_A_fake_result, ops.variable(torch.zeros(D_A_fake_result.size()))) 141 | 142 | # Optimize D_A for training discriminator for A 143 | D_A_loss = (D_A_real_loss + D_A_fake_loss) * 0.5 144 | ops.optimize(D_A_loss, self.D_A_optimizer) 145 | 146 | # Train D for B 147 | fake_B = self.fake_B_pool.query(fake_B) 148 | D_B_real_result = self.D_B(real_B) 149 | D_B_fake_result = self.D_B(fake_B) 150 | D_B_real_loss = self.MSE_loss(D_B_real_result, ops.variable(torch.ones(D_B_real_result.size()))) 151 | D_B_fake_loss = self.MSE_loss(D_B_fake_result, ops.variable(torch.zeros(D_B_fake_result.size()))) 152 | 153 | # Optimize D_A for training discriminator for A 154 | D_B_loss = (D_B_real_loss + D_B_fake_loss) * 0.5 155 | ops.optimize(D_B_loss, self.D_B_optimizer) 156 | 157 | losses = { 158 | LossType.D_A: D_A_loss, 159 | LossType.D_B: D_B_loss 160 | } 161 | 162 | return losses 163 | 164 | def train(self, real_A, real_B): 165 | real_A = ops.variable(real_A) 166 | real_B = ops.variable(real_B) 167 | 168 | losses_G, fake_A, fake_B = self.train_generator(real_A, real_B) 169 | losses_D = self.train_discriminator(real_A, real_B, fake_A, fake_B) 170 | 171 | losses = {**losses_G, **losses_D} 172 | 173 | return losses 174 | 175 | def generate_A_to_B(self, real_A, recon=False): 176 | fake_B = self.G_AB(ops.variable(real_A)) 177 | recon_A = None 178 | 179 | if recon: 180 | recon_A = self.G_BA(fake_B) 181 | 182 | return fake_B, recon_A 183 | 184 | def generate_B_to_A(self, real_B, recon=False): 185 | fake_A = self.G_BA(ops.variable(real_B)) 186 | recon_B = None 187 | 188 | if recon: 189 | recon_B = self.G_AB(fake_A) 190 | 191 | return fake_A, recon_B 192 | 193 | def load(self, model_dir, generator='ALL', discriminator=False): 194 | """ 195 | Load pre-trained model 196 | :param model_dir: 197 | :param generator: 'ALL', 'AtoB', 'BtoA' 198 | :param discriminator: True/False 199 | :return: 200 | """ 201 | print("Loading model from {0}".format(model_dir)) 202 | 203 | if generator is ['AtoB', 'ALL']: 204 | self.G_AB.load_state_dict(torch.load(os.path.join(model_dir, 'generator_AB_param.pkl'))) 205 | 206 | if generator in ['BtoA', 'ALL']: 207 | self.G_BA.load_state_dict(torch.load(os.path.join(model_dir, 'generator_BA_param.pkl'))) 208 | 209 | if discriminator: 210 | self.D_A.load_state_dict(torch.load(os.path.join(model_dir, 'discriminator_A_param.pkl'))) 211 | self.D_B.load_state_dict(torch.load(os.path.join(model_dir, 'discriminator_B_param.pkl'))) 212 | 213 | def save(self, model_dir): 214 | print("Saving model into {0}".format(model_dir)) 215 | 216 | torch.save(self.G_AB.state_dict(), os.path.join(model_dir, 'generator_AB_param.pkl')) 217 | torch.save(self.G_BA.state_dict(), os.path.join(model_dir, 'generator_BA_param.pkl')) 218 | torch.save(self.D_A.state_dict(), os.path.join(model_dir, 'discriminator_A_param.pkl')) 219 | torch.save(self.D_B.state_dict(), os.path.join(model_dir, 'discriminator_B_param.pkl')) 220 | --------------------------------------------------------------------------------