├── .gitignore ├── README.md ├── download_dataset.sh ├── imgs ├── AtoB_n02381460_4530.jpg ├── AtoB_n02381460_4660.jpg ├── AtoB_n02381460_510.jpg ├── AtoB_n02381460_8980.jpg ├── BtoA_n02391049_1760.jpg ├── BtoA_n02391049_3070.jpg ├── BtoA_n02391049_5100.jpg ├── BtoA_n02391049_7150.jpg ├── n02381460_4530.jpg ├── n02381460_4660.jpg ├── n02381460_510.jpg ├── n02381460_8980.jpg ├── n02391049_1760.jpg ├── n02391049_3070.jpg ├── n02391049_5100.jpg ├── n02391049_7150.jpg └── teaser.jpg ├── main.py ├── model.py ├── module.py ├── ops.py ├── requirements.txt └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .idea/* 3 | logs/* 4 | checkpoint/* 5 | datasets/* 6 | test/* 7 | sample/* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 5 | # CycleGAN 6 | 7 | Tensorflow implementation for learning an image-to-image translation **without** input-output pairs. 8 | The method is proposed by [Jun-Yan Zhu](https://people.eecs.berkeley.edu/~junyanz/) in 9 | [Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networkssee](https://arxiv.org/pdf/1703.10593.pdf). 10 | For example in paper: 11 | 12 | 13 | 14 | 32 | 33 | ## Update Results 34 | The results of this implementation: 35 | 36 | - Horses -> Zebras
37 | 38 | 39 | - Zebras -> Horses
40 | 41 | 42 | You can download the pretrained model from [this url](https://1drv.ms/u/s!AroAdu0uts_gj5tA93GnwyfRpvBIDA) 43 | and extract the rar file to `./checkpoint/`. 44 | 45 | 46 | ## Prerequisites 47 | - tensorflow r1.1 48 | - numpy 1.11.0 49 | - scipy 0.17.0 50 | - pillow 3.3.0 51 | 52 | ## Getting Started 53 | ### Installation 54 | - Install tensorflow from https://github.com/tensorflow/tensorflow 55 | - Clone this repo: 56 | ```bash 57 | git clone https://github.com/xhujoy/CycleGAN-tensorflow 58 | cd CycleGAN-tensorflow 59 | ``` 60 | 61 | ### Train 62 | - Download a dataset (e.g. zebra and horse images from ImageNet): 63 | ```bash 64 | bash ./download_dataset.sh horse2zebra 65 | ``` 66 | - Train a model: 67 | ```bash 68 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset_dir=horse2zebra 69 | ``` 70 | - Use tensorboard to visualize the training details: 71 | ```bash 72 | tensorboard --logdir=./logs 73 | ``` 74 | 75 | ### Test 76 | - Finally, test the model: 77 | ```bash 78 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset_dir=horse2zebra --phase=test --which_direction=AtoB 79 | ``` 80 | 81 | ## Training and Test Details 82 | To train a model, 83 | ```bash 84 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset_dir=/path/to/data/ 85 | ``` 86 | Models are saved to `./checkpoints/` (can be changed by passing `--checkpoint_dir=your_dir`). 87 | 88 | To test the model, 89 | ```bash 90 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset_dir=/path/to/data/ --phase=test --which_direction=AtoB/BtoA 91 | ``` 92 | 93 | ## Datasets 94 | Download the datasets using the following script: 95 | ```bash 96 | bash ./download_dataset.sh dataset_name 97 | ``` 98 | - `facades`: 400 images from the [CMP Facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade/). 99 | - `cityscapes`: 2975 images from the [Cityscapes training set](https://www.cityscapes-dataset.com/). 100 | - `maps`: 1096 training images scraped from Google Maps. 101 | - `horse2zebra`: 939 horse images and 1177 zebra images downloaded from [ImageNet](http://www.image-net.org/) using keywords `wild horse` and `zebra`. 102 | - `apple2orange`: 996 apple images and 1020 orange images downloaded from [ImageNet](http://www.image-net.org/) using keywords `apple` and `navel orange`. 103 | - `summer2winter_yosemite`: 1273 summer Yosemite images and 854 winter Yosemite images were downloaded using Flickr API. See more details in our paper. 104 | - `monet2photo`, `vangogh2photo`, `ukiyoe2photo`, `cezanne2photo`: The art images were downloaded from [Wikiart](https://www.wikiart.org/). The real photos are downloaded from Flickr using combination of tags *landscape* and *landscapephotography*. The training set size of each class is Monet:1074, Cezanne:584, Van Gogh:401, Ukiyo-e:1433, Photographs:6853. 105 | - `iphone2dslr_flower`: both classe of images were downlaoded from Flickr. The training set size of each class is iPhone:1813, DSLR:3316. See more details in our paper. 106 | 107 | 108 | ## Reference 109 | - The torch implementation of CycleGAN, https://github.com/junyanz/CycleGAN 110 | - The tensorflow implementation of pix2pix, https://github.com/yenchenlin/pix2pix-tensorflow 111 | -------------------------------------------------------------------------------- /download_dataset.sh: -------------------------------------------------------------------------------- 1 | mkdir datasets 2 | FILE=$1 3 | 4 | if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then 5 | echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos" 6 | exit 1 7 | fi 8 | 9 | URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip 10 | ZIP_FILE=./datasets/$FILE.zip 11 | TARGET_DIR=./datasets/$FILE/ 12 | wget -N $URL -O $ZIP_FILE 13 | mkdir $TARGET_DIR 14 | unzip $ZIP_FILE -d ./datasets/ 15 | rm $ZIP_FILE 16 | -------------------------------------------------------------------------------- /imgs/AtoB_n02381460_4530.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/AtoB_n02381460_4530.jpg -------------------------------------------------------------------------------- /imgs/AtoB_n02381460_4660.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/AtoB_n02381460_4660.jpg -------------------------------------------------------------------------------- /imgs/AtoB_n02381460_510.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/AtoB_n02381460_510.jpg -------------------------------------------------------------------------------- /imgs/AtoB_n02381460_8980.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/AtoB_n02381460_8980.jpg -------------------------------------------------------------------------------- /imgs/BtoA_n02391049_1760.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/BtoA_n02391049_1760.jpg -------------------------------------------------------------------------------- /imgs/BtoA_n02391049_3070.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/BtoA_n02391049_3070.jpg -------------------------------------------------------------------------------- /imgs/BtoA_n02391049_5100.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/BtoA_n02391049_5100.jpg -------------------------------------------------------------------------------- /imgs/BtoA_n02391049_7150.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/BtoA_n02391049_7150.jpg -------------------------------------------------------------------------------- /imgs/n02381460_4530.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/n02381460_4530.jpg -------------------------------------------------------------------------------- /imgs/n02381460_4660.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/n02381460_4660.jpg -------------------------------------------------------------------------------- /imgs/n02381460_510.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/n02381460_510.jpg -------------------------------------------------------------------------------- /imgs/n02381460_8980.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/n02381460_8980.jpg -------------------------------------------------------------------------------- /imgs/n02391049_1760.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/n02391049_1760.jpg -------------------------------------------------------------------------------- /imgs/n02391049_3070.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/n02391049_3070.jpg -------------------------------------------------------------------------------- /imgs/n02391049_5100.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/n02391049_5100.jpg -------------------------------------------------------------------------------- /imgs/n02391049_7150.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/n02391049_7150.jpg -------------------------------------------------------------------------------- /imgs/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/teaser.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import tensorflow as tf 4 | tf.set_random_seed(19) 5 | from model import cyclegan 6 | 7 | parser = argparse.ArgumentParser(description='') 8 | parser.add_argument('--dataset_dir', dest='dataset_dir', default='horse2zebra', help='path of the dataset') 9 | parser.add_argument('--epoch', dest='epoch', type=int, default=200, help='# of epoch') 10 | parser.add_argument('--epoch_step', dest='epoch_step', type=int, default=100, help='# of epoch to decay lr') 11 | parser.add_argument('--batch_size', dest='batch_size', type=int, default=1, help='# images in batch') 12 | parser.add_argument('--train_size', dest='train_size', type=int, default=1e8, help='# images used to train') 13 | parser.add_argument('--load_size', dest='load_size', type=int, default=286, help='scale images to this size') 14 | parser.add_argument('--fine_size', dest='fine_size', type=int, default=256, help='then crop to this size') 15 | parser.add_argument('--ngf', dest='ngf', type=int, default=64, help='# of gen filters in first conv layer') 16 | parser.add_argument('--ndf', dest='ndf', type=int, default=64, help='# of discri filters in first conv layer') 17 | parser.add_argument('--input_nc', dest='input_nc', type=int, default=3, help='# of input image channels') 18 | parser.add_argument('--output_nc', dest='output_nc', type=int, default=3, help='# of output image channels') 19 | parser.add_argument('--lr', dest='lr', type=float, default=0.0002, help='initial learning rate for adam') 20 | parser.add_argument('--beta1', dest='beta1', type=float, default=0.5, help='momentum term of adam') 21 | parser.add_argument('--which_direction', dest='which_direction', default='AtoB', help='AtoB or BtoA') 22 | parser.add_argument('--phase', dest='phase', default='train', help='train, test') 23 | parser.add_argument('--save_freq', dest='save_freq', type=int, default=1000, help='save a model every save_freq iterations') 24 | parser.add_argument('--print_freq', dest='print_freq', type=int, default=100, help='print the debug information every print_freq iterations') 25 | parser.add_argument('--continue_train', dest='continue_train', type=bool, default=False, help='if continue training, load the latest model: 1: true, 0: false') 26 | parser.add_argument('--checkpoint_dir', dest='checkpoint_dir', default='./checkpoint', help='models are saved here') 27 | parser.add_argument('--sample_dir', dest='sample_dir', default='./sample', help='sample are saved here') 28 | parser.add_argument('--test_dir', dest='test_dir', default='./test', help='test sample are saved here') 29 | parser.add_argument('--L1_lambda', dest='L1_lambda', type=float, default=10.0, help='weight on L1 term in objective') 30 | parser.add_argument('--use_resnet', dest='use_resnet', type=bool, default=True, help='generation network using reidule block') 31 | parser.add_argument('--use_lsgan', dest='use_lsgan', type=bool, default=True, help='gan loss defined in lsgan') 32 | parser.add_argument('--max_size', dest='max_size', type=int, default=50, help='max size of image pool, 0 means do not use image pool') 33 | 34 | args = parser.parse_args() 35 | 36 | 37 | def main(_): 38 | if not os.path.exists(args.checkpoint_dir): 39 | os.makedirs(args.checkpoint_dir) 40 | if not os.path.exists(args.sample_dir): 41 | os.makedirs(args.sample_dir) 42 | if not os.path.exists(args.test_dir): 43 | os.makedirs(args.test_dir) 44 | 45 | tfconfig = tf.ConfigProto(allow_soft_placement=True) 46 | tfconfig.gpu_options.allow_growth = True 47 | with tf.Session(config=tfconfig) as sess: 48 | model = cyclegan(sess, args) 49 | model.train(args) if args.phase == 'train' \ 50 | else model.test(args) 51 | 52 | if __name__ == '__main__': 53 | tf.app.run() 54 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import time 4 | from glob import glob 5 | import tensorflow as tf 6 | import numpy as np 7 | from collections import namedtuple 8 | 9 | from module import * 10 | from utils import * 11 | 12 | 13 | class cyclegan(object): 14 | def __init__(self, sess, args): 15 | self.sess = sess 16 | self.batch_size = args.batch_size 17 | self.image_size = args.fine_size 18 | self.input_c_dim = args.input_nc 19 | self.output_c_dim = args.output_nc 20 | self.L1_lambda = args.L1_lambda 21 | self.dataset_dir = args.dataset_dir 22 | 23 | self.discriminator = discriminator 24 | if args.use_resnet: 25 | self.generator = generator_resnet 26 | else: 27 | self.generator = generator_unet 28 | if args.use_lsgan: 29 | self.criterionGAN = mae_criterion 30 | else: 31 | self.criterionGAN = sce_criterion 32 | 33 | OPTIONS = namedtuple('OPTIONS', 'batch_size image_size \ 34 | gf_dim df_dim output_c_dim is_training') 35 | self.options = OPTIONS._make((args.batch_size, args.fine_size, 36 | args.ngf, args.ndf, args.output_nc, 37 | args.phase == 'train')) 38 | 39 | self._build_model() 40 | self.saver = tf.train.Saver() 41 | self.pool = ImagePool(args.max_size) 42 | 43 | def _build_model(self): 44 | self.real_data = tf.placeholder(tf.float32, 45 | [None, self.image_size, self.image_size, 46 | self.input_c_dim + self.output_c_dim], 47 | name='real_A_and_B_images') 48 | 49 | self.real_A = self.real_data[:, :, :, :self.input_c_dim] 50 | self.real_B = self.real_data[:, :, :, self.input_c_dim:self.input_c_dim + self.output_c_dim] 51 | 52 | self.fake_B = self.generator(self.real_A, self.options, False, name="generatorA2B") 53 | self.fake_A_ = self.generator(self.fake_B, self.options, False, name="generatorB2A") 54 | self.fake_A = self.generator(self.real_B, self.options, True, name="generatorB2A") 55 | self.fake_B_ = self.generator(self.fake_A, self.options, True, name="generatorA2B") 56 | 57 | self.DB_fake = self.discriminator(self.fake_B, self.options, reuse=False, name="discriminatorB") 58 | self.DA_fake = self.discriminator(self.fake_A, self.options, reuse=False, name="discriminatorA") 59 | self.g_loss_a2b = self.criterionGAN(self.DB_fake, tf.ones_like(self.DB_fake)) \ 60 | + self.L1_lambda * abs_criterion(self.real_A, self.fake_A_) \ 61 | + self.L1_lambda * abs_criterion(self.real_B, self.fake_B_) 62 | self.g_loss_b2a = self.criterionGAN(self.DA_fake, tf.ones_like(self.DA_fake)) \ 63 | + self.L1_lambda * abs_criterion(self.real_A, self.fake_A_) \ 64 | + self.L1_lambda * abs_criterion(self.real_B, self.fake_B_) 65 | self.g_loss = self.criterionGAN(self.DA_fake, tf.ones_like(self.DA_fake)) \ 66 | + self.criterionGAN(self.DB_fake, tf.ones_like(self.DB_fake)) \ 67 | + self.L1_lambda * abs_criterion(self.real_A, self.fake_A_) \ 68 | + self.L1_lambda * abs_criterion(self.real_B, self.fake_B_) 69 | 70 | self.fake_A_sample = tf.placeholder(tf.float32, 71 | [None, self.image_size, self.image_size, 72 | self.input_c_dim], name='fake_A_sample') 73 | self.fake_B_sample = tf.placeholder(tf.float32, 74 | [None, self.image_size, self.image_size, 75 | self.output_c_dim], name='fake_B_sample') 76 | self.DB_real = self.discriminator(self.real_B, self.options, reuse=True, name="discriminatorB") 77 | self.DA_real = self.discriminator(self.real_A, self.options, reuse=True, name="discriminatorA") 78 | self.DB_fake_sample = self.discriminator(self.fake_B_sample, self.options, reuse=True, name="discriminatorB") 79 | self.DA_fake_sample = self.discriminator(self.fake_A_sample, self.options, reuse=True, name="discriminatorA") 80 | 81 | self.db_loss_real = self.criterionGAN(self.DB_real, tf.ones_like(self.DB_real)) 82 | self.db_loss_fake = self.criterionGAN(self.DB_fake_sample, tf.zeros_like(self.DB_fake_sample)) 83 | self.db_loss = (self.db_loss_real + self.db_loss_fake) / 2 84 | self.da_loss_real = self.criterionGAN(self.DA_real, tf.ones_like(self.DA_real)) 85 | self.da_loss_fake = self.criterionGAN(self.DA_fake_sample, tf.zeros_like(self.DA_fake_sample)) 86 | self.da_loss = (self.da_loss_real + self.da_loss_fake) / 2 87 | self.d_loss = self.da_loss + self.db_loss 88 | 89 | self.g_loss_a2b_sum = tf.summary.scalar("g_loss_a2b", self.g_loss_a2b) 90 | self.g_loss_b2a_sum = tf.summary.scalar("g_loss_b2a", self.g_loss_b2a) 91 | self.g_loss_sum = tf.summary.scalar("g_loss", self.g_loss) 92 | self.g_sum = tf.summary.merge([self.g_loss_a2b_sum, self.g_loss_b2a_sum, self.g_loss_sum]) 93 | self.db_loss_sum = tf.summary.scalar("db_loss", self.db_loss) 94 | self.da_loss_sum = tf.summary.scalar("da_loss", self.da_loss) 95 | self.d_loss_sum = tf.summary.scalar("d_loss", self.d_loss) 96 | self.db_loss_real_sum = tf.summary.scalar("db_loss_real", self.db_loss_real) 97 | self.db_loss_fake_sum = tf.summary.scalar("db_loss_fake", self.db_loss_fake) 98 | self.da_loss_real_sum = tf.summary.scalar("da_loss_real", self.da_loss_real) 99 | self.da_loss_fake_sum = tf.summary.scalar("da_loss_fake", self.da_loss_fake) 100 | self.d_sum = tf.summary.merge( 101 | [self.da_loss_sum, self.da_loss_real_sum, self.da_loss_fake_sum, 102 | self.db_loss_sum, self.db_loss_real_sum, self.db_loss_fake_sum, 103 | self.d_loss_sum] 104 | ) 105 | 106 | self.test_A = tf.placeholder(tf.float32, 107 | [None, self.image_size, self.image_size, 108 | self.input_c_dim], name='test_A') 109 | self.test_B = tf.placeholder(tf.float32, 110 | [None, self.image_size, self.image_size, 111 | self.output_c_dim], name='test_B') 112 | self.testB = self.generator(self.test_A, self.options, True, name="generatorA2B") 113 | self.testA = self.generator(self.test_B, self.options, True, name="generatorB2A") 114 | 115 | t_vars = tf.trainable_variables() 116 | self.d_vars = [var for var in t_vars if 'discriminator' in var.name] 117 | self.g_vars = [var for var in t_vars if 'generator' in var.name] 118 | for var in t_vars: print(var.name) 119 | 120 | def train(self, args): 121 | """Train cyclegan""" 122 | self.lr = tf.placeholder(tf.float32, None, name='learning_rate') 123 | self.d_optim = tf.train.AdamOptimizer(self.lr, beta1=args.beta1) \ 124 | .minimize(self.d_loss, var_list=self.d_vars) 125 | self.g_optim = tf.train.AdamOptimizer(self.lr, beta1=args.beta1) \ 126 | .minimize(self.g_loss, var_list=self.g_vars) 127 | 128 | init_op = tf.global_variables_initializer() 129 | self.sess.run(init_op) 130 | self.writer = tf.summary.FileWriter("./logs", self.sess.graph) 131 | 132 | counter = 1 133 | start_time = time.time() 134 | 135 | if args.continue_train: 136 | if self.load(args.checkpoint_dir): 137 | print(" [*] Load SUCCESS") 138 | else: 139 | print(" [!] Load failed...") 140 | 141 | for epoch in range(args.epoch): 142 | dataA = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/trainA')) 143 | dataB = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/trainB')) 144 | np.random.shuffle(dataA) 145 | np.random.shuffle(dataB) 146 | batch_idxs = min(min(len(dataA), len(dataB)), args.train_size) // self.batch_size 147 | lr = args.lr if epoch < args.epoch_step else args.lr*(args.epoch-epoch)/(args.epoch-args.epoch_step) 148 | 149 | for idx in range(0, batch_idxs): 150 | batch_files = list(zip(dataA[idx * self.batch_size:(idx + 1) * self.batch_size], 151 | dataB[idx * self.batch_size:(idx + 1) * self.batch_size])) 152 | batch_images = [load_train_data(batch_file, args.load_size, args.fine_size) for batch_file in batch_files] 153 | batch_images = np.array(batch_images).astype(np.float32) 154 | 155 | # Update G network and record fake outputs 156 | fake_A, fake_B, _, summary_str = self.sess.run( 157 | [self.fake_A, self.fake_B, self.g_optim, self.g_sum], 158 | feed_dict={self.real_data: batch_images, self.lr: lr}) 159 | self.writer.add_summary(summary_str, counter) 160 | [fake_A, fake_B] = self.pool([fake_A, fake_B]) 161 | 162 | # Update D network 163 | _, summary_str = self.sess.run( 164 | [self.d_optim, self.d_sum], 165 | feed_dict={self.real_data: batch_images, 166 | self.fake_A_sample: fake_A, 167 | self.fake_B_sample: fake_B, 168 | self.lr: lr}) 169 | self.writer.add_summary(summary_str, counter) 170 | 171 | counter += 1 172 | print(("Epoch: [%2d] [%4d/%4d] time: %4.4f" % ( 173 | epoch, idx, batch_idxs, time.time() - start_time))) 174 | 175 | if np.mod(counter, args.print_freq) == 1: 176 | self.sample_model(args.sample_dir, epoch, idx) 177 | 178 | if np.mod(counter, args.save_freq) == 2: 179 | self.save(args.checkpoint_dir, counter) 180 | 181 | def save(self, checkpoint_dir, step): 182 | model_name = "cyclegan.model" 183 | model_dir = "%s_%s" % (self.dataset_dir, self.image_size) 184 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir) 185 | 186 | if not os.path.exists(checkpoint_dir): 187 | os.makedirs(checkpoint_dir) 188 | 189 | self.saver.save(self.sess, 190 | os.path.join(checkpoint_dir, model_name), 191 | global_step=step) 192 | 193 | def load(self, checkpoint_dir): 194 | print(" [*] Reading checkpoint...") 195 | 196 | model_dir = "%s_%s" % (self.dataset_dir, self.image_size) 197 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir) 198 | 199 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 200 | if ckpt and ckpt.model_checkpoint_path: 201 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 202 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 203 | return True 204 | else: 205 | return False 206 | 207 | def sample_model(self, sample_dir, epoch, idx): 208 | dataA = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/testA')) 209 | dataB = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/testB')) 210 | np.random.shuffle(dataA) 211 | np.random.shuffle(dataB) 212 | batch_files = list(zip(dataA[:self.batch_size], dataB[:self.batch_size])) 213 | sample_images = [load_train_data(batch_file, is_testing=True) for batch_file in batch_files] 214 | sample_images = np.array(sample_images).astype(np.float32) 215 | 216 | fake_A, fake_B = self.sess.run( 217 | [self.fake_A, self.fake_B], 218 | feed_dict={self.real_data: sample_images} 219 | ) 220 | save_images(fake_A, [self.batch_size, 1], 221 | './{}/A_{:02d}_{:04d}.jpg'.format(sample_dir, epoch, idx)) 222 | save_images(fake_B, [self.batch_size, 1], 223 | './{}/B_{:02d}_{:04d}.jpg'.format(sample_dir, epoch, idx)) 224 | 225 | def test(self, args): 226 | """Test cyclegan""" 227 | init_op = tf.global_variables_initializer() 228 | self.sess.run(init_op) 229 | if args.which_direction == 'AtoB': 230 | sample_files = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/testA')) 231 | elif args.which_direction == 'BtoA': 232 | sample_files = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/testB')) 233 | else: 234 | raise Exception('--which_direction must be AtoB or BtoA') 235 | 236 | if self.load(args.checkpoint_dir): 237 | print(" [*] Load SUCCESS") 238 | else: 239 | print(" [!] Load failed...") 240 | 241 | # write html for visual comparison 242 | index_path = os.path.join(args.test_dir, '{0}_index.html'.format(args.which_direction)) 243 | index = open(index_path, "w") 244 | index.write("") 245 | index.write("") 246 | 247 | out_var, in_var = (self.testB, self.test_A) if args.which_direction == 'AtoB' else ( 248 | self.testA, self.test_B) 249 | 250 | for sample_file in sample_files: 251 | print('Processing image: ' + sample_file) 252 | sample_image = [load_test_data(sample_file, args.fine_size)] 253 | sample_image = np.array(sample_image).astype(np.float32) 254 | image_path = os.path.join(args.test_dir, 255 | '{0}_{1}'.format(args.which_direction, os.path.basename(sample_file))) 256 | fake_img = self.sess.run(out_var, feed_dict={in_var: sample_image}) 257 | save_images(fake_img, [1, 1], image_path) 258 | index.write("" % os.path.basename(image_path)) 259 | index.write("" % (sample_file if os.path.isabs(sample_file) else ( 260 | '..' + os.path.sep + sample_file))) 261 | index.write("" % (image_path if os.path.isabs(image_path) else ( 262 | '..' + os.path.sep + image_path))) 263 | index.write("") 264 | index.close() 265 | -------------------------------------------------------------------------------- /module.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import tensorflow as tf 3 | from ops import * 4 | from utils import * 5 | 6 | 7 | def discriminator(image, options, reuse=False, name="discriminator"): 8 | 9 | with tf.variable_scope(name): 10 | # image is 256 x 256 x input_c_dim 11 | if reuse: 12 | tf.get_variable_scope().reuse_variables() 13 | else: 14 | assert tf.get_variable_scope().reuse is False 15 | 16 | h0 = lrelu(conv2d(image, options.df_dim, name='d_h0_conv')) 17 | # h0 is (128 x 128 x self.df_dim) 18 | h1 = lrelu(instance_norm(conv2d(h0, options.df_dim*2, name='d_h1_conv'), 'd_bn1')) 19 | # h1 is (64 x 64 x self.df_dim*2) 20 | h2 = lrelu(instance_norm(conv2d(h1, options.df_dim*4, name='d_h2_conv'), 'd_bn2')) 21 | # h2 is (32x 32 x self.df_dim*4) 22 | h3 = lrelu(instance_norm(conv2d(h2, options.df_dim*8, s=1, name='d_h3_conv'), 'd_bn3')) 23 | # h3 is (32 x 32 x self.df_dim*8) 24 | h4 = conv2d(h3, 1, s=1, name='d_h3_pred') 25 | # h4 is (32 x 32 x 1) 26 | return h4 27 | 28 | 29 | def generator_unet(image, options, reuse=False, name="generator"): 30 | 31 | dropout_rate = 0.5 if options.is_training else 1.0 32 | with tf.variable_scope(name): 33 | # image is 256 x 256 x input_c_dim 34 | if reuse: 35 | tf.get_variable_scope().reuse_variables() 36 | else: 37 | assert tf.get_variable_scope().reuse is False 38 | 39 | # image is (256 x 256 x input_c_dim) 40 | e1 = instance_norm(conv2d(image, options.gf_dim, name='g_e1_conv')) 41 | # e1 is (128 x 128 x self.gf_dim) 42 | e2 = instance_norm(conv2d(lrelu(e1), options.gf_dim*2, name='g_e2_conv'), 'g_bn_e2') 43 | # e2 is (64 x 64 x self.gf_dim*2) 44 | e3 = instance_norm(conv2d(lrelu(e2), options.gf_dim*4, name='g_e3_conv'), 'g_bn_e3') 45 | # e3 is (32 x 32 x self.gf_dim*4) 46 | e4 = instance_norm(conv2d(lrelu(e3), options.gf_dim*8, name='g_e4_conv'), 'g_bn_e4') 47 | # e4 is (16 x 16 x self.gf_dim*8) 48 | e5 = instance_norm(conv2d(lrelu(e4), options.gf_dim*8, name='g_e5_conv'), 'g_bn_e5') 49 | # e5 is (8 x 8 x self.gf_dim*8) 50 | e6 = instance_norm(conv2d(lrelu(e5), options.gf_dim*8, name='g_e6_conv'), 'g_bn_e6') 51 | # e6 is (4 x 4 x self.gf_dim*8) 52 | e7 = instance_norm(conv2d(lrelu(e6), options.gf_dim*8, name='g_e7_conv'), 'g_bn_e7') 53 | # e7 is (2 x 2 x self.gf_dim*8) 54 | e8 = instance_norm(conv2d(lrelu(e7), options.gf_dim*8, name='g_e8_conv'), 'g_bn_e8') 55 | # e8 is (1 x 1 x self.gf_dim*8) 56 | 57 | d1 = deconv2d(tf.nn.relu(e8), options.gf_dim*8, name='g_d1') 58 | d1 = tf.nn.dropout(d1, dropout_rate) 59 | d1 = tf.concat([instance_norm(d1, 'g_bn_d1'), e7], 3) 60 | # d1 is (2 x 2 x self.gf_dim*8*2) 61 | 62 | d2 = deconv2d(tf.nn.relu(d1), options.gf_dim*8, name='g_d2') 63 | d2 = tf.nn.dropout(d2, dropout_rate) 64 | d2 = tf.concat([instance_norm(d2, 'g_bn_d2'), e6], 3) 65 | # d2 is (4 x 4 x self.gf_dim*8*2) 66 | 67 | d3 = deconv2d(tf.nn.relu(d2), options.gf_dim*8, name='g_d3') 68 | d3 = tf.nn.dropout(d3, dropout_rate) 69 | d3 = tf.concat([instance_norm(d3, 'g_bn_d3'), e5], 3) 70 | # d3 is (8 x 8 x self.gf_dim*8*2) 71 | 72 | d4 = deconv2d(tf.nn.relu(d3), options.gf_dim*8, name='g_d4') 73 | d4 = tf.concat([instance_norm(d4, 'g_bn_d4'), e4], 3) 74 | # d4 is (16 x 16 x self.gf_dim*8*2) 75 | 76 | d5 = deconv2d(tf.nn.relu(d4), options.gf_dim*4, name='g_d5') 77 | d5 = tf.concat([instance_norm(d5, 'g_bn_d5'), e3], 3) 78 | # d5 is (32 x 32 x self.gf_dim*4*2) 79 | 80 | d6 = deconv2d(tf.nn.relu(d5), options.gf_dim*2, name='g_d6') 81 | d6 = tf.concat([instance_norm(d6, 'g_bn_d6'), e2], 3) 82 | # d6 is (64 x 64 x self.gf_dim*2*2) 83 | 84 | d7 = deconv2d(tf.nn.relu(d6), options.gf_dim, name='g_d7') 85 | d7 = tf.concat([instance_norm(d7, 'g_bn_d7'), e1], 3) 86 | # d7 is (128 x 128 x self.gf_dim*1*2) 87 | 88 | d8 = deconv2d(tf.nn.relu(d7), options.output_c_dim, name='g_d8') 89 | # d8 is (256 x 256 x output_c_dim) 90 | 91 | return tf.nn.tanh(d8) 92 | 93 | 94 | def generator_resnet(image, options, reuse=False, name="generator"): 95 | 96 | with tf.variable_scope(name): 97 | # image is 256 x 256 x input_c_dim 98 | if reuse: 99 | tf.get_variable_scope().reuse_variables() 100 | else: 101 | assert tf.get_variable_scope().reuse is False 102 | 103 | def residule_block(x, dim, ks=3, s=1, name='res'): 104 | p = int((ks - 1) / 2) 105 | y = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT") 106 | y = instance_norm(conv2d(y, dim, ks, s, padding='VALID', name=name+'_c1'), name+'_bn1') 107 | y = tf.pad(tf.nn.relu(y), [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT") 108 | y = instance_norm(conv2d(y, dim, ks, s, padding='VALID', name=name+'_c2'), name+'_bn2') 109 | return y + x 110 | 111 | # Justin Johnson's model from https://github.com/jcjohnson/fast-neural-style/ 112 | # The network with 9 blocks consists of: c7s1-32, d64, d128, R128, R128, R128, 113 | # R128, R128, R128, R128, R128, R128, u64, u32, c7s1-3 114 | c0 = tf.pad(image, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT") 115 | c1 = tf.nn.relu(instance_norm(conv2d(c0, options.gf_dim, 7, 1, padding='VALID', name='g_e1_c'), 'g_e1_bn')) 116 | c2 = tf.nn.relu(instance_norm(conv2d(c1, options.gf_dim*2, 3, 2, name='g_e2_c'), 'g_e2_bn')) 117 | c3 = tf.nn.relu(instance_norm(conv2d(c2, options.gf_dim*4, 3, 2, name='g_e3_c'), 'g_e3_bn')) 118 | # define G network with 9 resnet blocks 119 | r1 = residule_block(c3, options.gf_dim*4, name='g_r1') 120 | r2 = residule_block(r1, options.gf_dim*4, name='g_r2') 121 | r3 = residule_block(r2, options.gf_dim*4, name='g_r3') 122 | r4 = residule_block(r3, options.gf_dim*4, name='g_r4') 123 | r5 = residule_block(r4, options.gf_dim*4, name='g_r5') 124 | r6 = residule_block(r5, options.gf_dim*4, name='g_r6') 125 | r7 = residule_block(r6, options.gf_dim*4, name='g_r7') 126 | r8 = residule_block(r7, options.gf_dim*4, name='g_r8') 127 | r9 = residule_block(r8, options.gf_dim*4, name='g_r9') 128 | 129 | d1 = deconv2d(r9, options.gf_dim*2, 3, 2, name='g_d1_dc') 130 | d1 = tf.nn.relu(instance_norm(d1, 'g_d1_bn')) 131 | d2 = deconv2d(d1, options.gf_dim, 3, 2, name='g_d2_dc') 132 | d2 = tf.nn.relu(instance_norm(d2, 'g_d2_bn')) 133 | d2 = tf.pad(d2, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT") 134 | pred = tf.nn.tanh(conv2d(d2, options.output_c_dim, 7, 1, padding='VALID', name='g_pred_c')) 135 | 136 | return pred 137 | 138 | 139 | def abs_criterion(in_, target): 140 | return tf.reduce_mean(tf.abs(in_ - target)) 141 | 142 | 143 | def mae_criterion(in_, target): 144 | return tf.reduce_mean((in_-target)**2) 145 | 146 | 147 | def sce_criterion(logits, labels): 148 | return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels)) 149 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import tensorflow as tf 4 | import tensorflow.contrib.slim as slim 5 | from tensorflow.python.framework import ops 6 | 7 | from utils import * 8 | 9 | def batch_norm(x, name="batch_norm"): 10 | return tf.contrib.layers.batch_norm(x, decay=0.9, updates_collections=None, epsilon=1e-5, scale=True, scope=name) 11 | 12 | def instance_norm(input, name="instance_norm"): 13 | with tf.variable_scope(name): 14 | depth = input.get_shape()[3] 15 | scale = tf.get_variable("scale", [depth], initializer=tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32)) 16 | offset = tf.get_variable("offset", [depth], initializer=tf.constant_initializer(0.0)) 17 | mean, variance = tf.nn.moments(input, axes=[1,2], keep_dims=True) 18 | epsilon = 1e-5 19 | inv = tf.rsqrt(variance + epsilon) 20 | normalized = (input-mean)*inv 21 | return scale*normalized + offset 22 | 23 | def conv2d(input_, output_dim, ks=4, s=2, stddev=0.02, padding='SAME', name="conv2d"): 24 | with tf.variable_scope(name): 25 | return slim.conv2d(input_, output_dim, ks, s, padding=padding, activation_fn=None, 26 | weights_initializer=tf.truncated_normal_initializer(stddev=stddev), 27 | biases_initializer=None) 28 | 29 | def deconv2d(input_, output_dim, ks=4, s=2, stddev=0.02, name="deconv2d"): 30 | with tf.variable_scope(name): 31 | return slim.conv2d_transpose(input_, output_dim, ks, s, padding='SAME', activation_fn=None, 32 | weights_initializer=tf.truncated_normal_initializer(stddev=stddev), 33 | biases_initializer=None) 34 | 35 | def lrelu(x, leak=0.2, name="lrelu"): 36 | return tf.maximum(x, leak*x) 37 | 38 | def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False): 39 | 40 | with tf.variable_scope(scope or "Linear"): 41 | matrix = tf.get_variable("Matrix", [input_.get_shape()[-1], output_size], tf.float32, 42 | tf.random_normal_initializer(stddev=stddev)) 43 | bias = tf.get_variable("bias", [output_size], 44 | initializer=tf.constant_initializer(bias_start)) 45 | if with_w: 46 | return tf.matmul(input_, matrix) + bias, matrix, bias 47 | else: 48 | return tf.matmul(input_, matrix) + bias 49 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu 2 | numpy 3 | scipy 4 | pillow 5 | imageio 6 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some codes from https://github.com/Newmu/dcgan_code 3 | """ 4 | from __future__ import division 5 | import math 6 | import pprint 7 | import scipy.misc 8 | import numpy as np 9 | import copy 10 | try: 11 | _imread = scipy.misc.imread 12 | except AttributeError: 13 | from imageio import imread as _imread 14 | 15 | pp = pprint.PrettyPrinter() 16 | 17 | get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1]) 18 | 19 | # ----------------------------- 20 | # new added functions for cyclegan 21 | class ImagePool(object): 22 | def __init__(self, maxsize=50): 23 | self.maxsize = maxsize 24 | self.num_img = 0 25 | self.images = [] 26 | 27 | def __call__(self, image): 28 | if self.maxsize <= 0: 29 | return image 30 | if self.num_img < self.maxsize: 31 | self.images.append(image) 32 | self.num_img += 1 33 | return image 34 | if np.random.rand() > 0.5: 35 | idx = int(np.random.rand()*self.maxsize) 36 | tmp1 = copy.copy(self.images[idx])[0] 37 | self.images[idx][0] = image[0] 38 | idx = int(np.random.rand()*self.maxsize) 39 | tmp2 = copy.copy(self.images[idx])[1] 40 | self.images[idx][1] = image[1] 41 | return [tmp1, tmp2] 42 | else: 43 | return image 44 | 45 | def load_test_data(image_path, fine_size=256): 46 | img = imread(image_path) 47 | img = scipy.misc.imresize(img, [fine_size, fine_size]) 48 | img = img/127.5 - 1 49 | return img 50 | 51 | def load_train_data(image_path, load_size=286, fine_size=256, is_testing=False): 52 | img_A = imread(image_path[0]) 53 | img_B = imread(image_path[1]) 54 | if not is_testing: 55 | img_A = scipy.misc.imresize(img_A, [load_size, load_size]) 56 | img_B = scipy.misc.imresize(img_B, [load_size, load_size]) 57 | h1 = int(np.ceil(np.random.uniform(1e-2, load_size-fine_size))) 58 | w1 = int(np.ceil(np.random.uniform(1e-2, load_size-fine_size))) 59 | img_A = img_A[h1:h1+fine_size, w1:w1+fine_size] 60 | img_B = img_B[h1:h1+fine_size, w1:w1+fine_size] 61 | 62 | if np.random.random() > 0.5: 63 | img_A = np.fliplr(img_A) 64 | img_B = np.fliplr(img_B) 65 | else: 66 | img_A = scipy.misc.imresize(img_A, [fine_size, fine_size]) 67 | img_B = scipy.misc.imresize(img_B, [fine_size, fine_size]) 68 | 69 | img_A = img_A/127.5 - 1. 70 | img_B = img_B/127.5 - 1. 71 | 72 | img_AB = np.concatenate((img_A, img_B), axis=2) 73 | # img_AB shape: (fine_size, fine_size, input_c_dim + output_c_dim) 74 | return img_AB 75 | 76 | # ----------------------------- 77 | 78 | def get_image(image_path, image_size, is_crop=True, resize_w=64, is_grayscale = False): 79 | return transform(imread(image_path, is_grayscale), image_size, is_crop, resize_w) 80 | 81 | def save_images(images, size, image_path): 82 | return imsave(inverse_transform(images), size, image_path) 83 | 84 | def imread(path, is_grayscale = False): 85 | if (is_grayscale): 86 | return _imread(path, flatten=True).astype(np.float) 87 | else: 88 | return _imread(path, mode='RGB').astype(np.float) 89 | 90 | def merge_images(images, size): 91 | return inverse_transform(images) 92 | 93 | def merge(images, size): 94 | h, w = images.shape[1], images.shape[2] 95 | img = np.zeros((h * size[0], w * size[1], 3)) 96 | for idx, image in enumerate(images): 97 | i = idx % size[1] 98 | j = idx // size[1] 99 | img[j*h:j*h+h, i*w:i*w+w, :] = image 100 | 101 | return img 102 | 103 | def imsave(images, size, path): 104 | return scipy.misc.imsave(path, merge(images, size)) 105 | 106 | def center_crop(x, crop_h, crop_w, 107 | resize_h=64, resize_w=64): 108 | if crop_w is None: 109 | crop_w = crop_h 110 | h, w = x.shape[:2] 111 | j = int(round((h - crop_h)/2.)) 112 | i = int(round((w - crop_w)/2.)) 113 | return scipy.misc.imresize( 114 | x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w]) 115 | 116 | def transform(image, npx=64, is_crop=True, resize_w=64): 117 | # npx : # of pixels width/height of image 118 | if is_crop: 119 | cropped_image = center_crop(image, npx, resize_w=resize_w) 120 | else: 121 | cropped_image = image 122 | return np.array(cropped_image)/127.5 - 1. 123 | 124 | def inverse_transform(images): 125 | return (images+1.)/2. 126 | --------------------------------------------------------------------------------
nameinputoutput
%s