├── assets ├── apple2orange.png ├── paper-figure.png ├── vangogh2photo.png ├── training-apple2orange.png ├── training-vangogh2photo.png └── wrong-initialization.png ├── utils.py ├── download_cyclegan_dataset.sh ├── LICENSE ├── discriminator.py ├── .gitignore ├── generator.py ├── data_loader.py ├── README.md ├── cycle-gan.py ├── ops.py └── model.py /assets/apple2orange.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/CycleGAN-Tensorflow/HEAD/assets/apple2orange.png -------------------------------------------------------------------------------- /assets/paper-figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/CycleGAN-Tensorflow/HEAD/assets/paper-figure.png -------------------------------------------------------------------------------- /assets/vangogh2photo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/CycleGAN-Tensorflow/HEAD/assets/vangogh2photo.png -------------------------------------------------------------------------------- /assets/training-apple2orange.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/CycleGAN-Tensorflow/HEAD/assets/training-apple2orange.png -------------------------------------------------------------------------------- /assets/training-vangogh2photo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/CycleGAN-Tensorflow/HEAD/assets/training-vangogh2photo.png -------------------------------------------------------------------------------- /assets/wrong-initialization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/CycleGAN-Tensorflow/HEAD/assets/wrong-initialization.png -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | 5 | # start logging 6 | logging.info("Start CycleGAN") 7 | logger = logging.getLogger('cycle-gan') 8 | logger.setLevel(logging.INFO) 9 | 10 | def makedirs(path): 11 | if not os.path.exists(path): 12 | os.makedirs(path) 13 | -------------------------------------------------------------------------------- /download_cyclegan_dataset.sh: -------------------------------------------------------------------------------- 1 | FILE=$1 2 | 3 | if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then 4 | echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos" 5 | exit 1 6 | fi 7 | 8 | URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip 9 | ZIP_FILE=./datasets/$FILE.zip 10 | TARGET_DIR=./datasets/$FILE/ 11 | wget -N $URL -O $ZIP_FILE 12 | mkdir $TARGET_DIR 13 | unzip $ZIP_FILE -d ./datasets/ 14 | rm $ZIP_FILE 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Youngwoon Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /discriminator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from utils import logger 3 | import ops 4 | 5 | 6 | class Discriminator(object): 7 | def __init__(self, name, is_train, norm='instance', activation='leaky'): 8 | logger.info('Init Discriminator %s', name) 9 | self.name = name 10 | self._is_train = is_train 11 | self._norm = norm 12 | self._activation = activation 13 | self._reuse = False 14 | 15 | def __call__(self, input): 16 | with tf.variable_scope(self.name, reuse=self._reuse): 17 | D = ops.conv_block(input, 64, 'C64', 4, 2, self._is_train, 18 | self._reuse, norm=None, activation=self._activation) 19 | D = ops.conv_block(D, 128, 'C128', 4, 2, self._is_train, 20 | self._reuse, self._norm, self._activation) 21 | D = ops.conv_block(D, 256, 'C256', 4, 2, self._is_train, 22 | self._reuse, self._norm, self._activation) 23 | D = ops.conv_block(D, 512, 'C512', 4, 2, self._is_train, 24 | self._reuse, self._norm, self._activation) 25 | D = ops.conv_block(D, 1, 'C1', 4, 1, self._is_train, 26 | self._reuse, norm=None, activation=None, bias=True) 27 | D = tf.reduce_mean(D, axis=[1,2,3]) 28 | 29 | self._reuse = True 30 | self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name) 31 | return D 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Tensorflow logs / datasets / results 2 | logs/ 3 | datasets/ 4 | results/ 5 | 6 | # Temporary files 7 | *.zip 8 | *.swp 9 | *~ 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | env/ 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | .hypothesis/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # celery beat schedule file 87 | celerybeat-schedule 88 | 89 | # SageMath parsed files 90 | *.sage.py 91 | 92 | # dotenv 93 | .env 94 | 95 | # virtualenv 96 | .venv 97 | venv/ 98 | ENV/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | -------------------------------------------------------------------------------- /generator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from utils import logger 3 | import ops 4 | 5 | 6 | class Generator(object): 7 | def __init__(self, name, is_train, norm='instance', activation='relu', 8 | image_size=128): 9 | logger.info('Init Generator %s', name) 10 | self.name = name 11 | self._is_train = is_train 12 | self._norm = norm 13 | self._activation = activation 14 | self._num_res_block = 6 if image_size == 128 else 9 15 | self._reuse = False 16 | 17 | def __call__(self, input): 18 | with tf.variable_scope(self.name, reuse=self._reuse): 19 | G = ops.conv_block(input, 32, 'c7s1-32', 7, 1, self._is_train, 20 | self._reuse, self._norm, self._activation, pad='REFLECT') 21 | G = ops.conv_block(G, 64, 'd64', 3, 2, self._is_train, 22 | self._reuse, self._norm, self._activation) 23 | G = ops.conv_block(G, 128, 'd128', 3, 2, self._is_train, 24 | self._reuse, self._norm, self._activation) 25 | for i in range(self._num_res_block): 26 | G = ops.residual(G, 128, 'R128_{}'.format(i), self._is_train, 27 | self._reuse, self._norm) 28 | G = ops.deconv_block(G, 64, 'u64', 3, 2, self._is_train, 29 | self._reuse, self._norm, self._activation) 30 | G = ops.deconv_block(G, 32, 'u32', 3, 2, self._is_train, 31 | self._reuse, self._norm, self._activation) 32 | G = ops.conv_block(G, 3, 'c7s1-3', 7, 1, self._is_train, 33 | self._reuse, norm=None, activation='tanh', pad='REFLECT') 34 | 35 | self._reuse = True 36 | self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name) 37 | return G 38 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | 4 | from scipy.misc import imread, imresize 5 | import numpy as np 6 | from tqdm import tqdm 7 | import h5py 8 | 9 | datasets = ['ae_photos', 'apple2orange', 'summer2winter_yosemite', 'horse2zebra', 10 | 'monet2photo', 'cezanne2photo', 'ukiyoe2photo', 'vangogh2photo', 11 | 'maps', 'cityscapes', 'facades', 'iphone2dslr_flower'] 12 | 13 | def read_image(path): 14 | image = imread(path) 15 | if len(image.shape) != 3 or image.shape[2] != 3: 16 | print('Wrong image {} with shape {}'.format(path, image.shape)) 17 | return None 18 | 19 | # range of pixel values = [-1.0, 1.0] 20 | image = image.astype(np.float32) / 255.0 21 | image = image * 2.0 - 1.0 22 | return image 23 | 24 | def read_images(base_dir): 25 | ret = [] 26 | for dir_name in ['trainA', 'trainB', 'testA', 'testB']: 27 | data_dir = os.path.join(base_dir, dir_name) 28 | paths = glob(os.path.join(data_dir, '*.jpg')) 29 | print('# images in {}: {}'.format(data_dir, len(paths))) 30 | 31 | images = [] 32 | for path in tqdm(paths): 33 | image = read_image(path) 34 | if image is not None: 35 | images.append(image) 36 | ret.append((dir_name, images)) 37 | return ret 38 | 39 | def store_h5py(base_dir, dir_name, images, image_size): 40 | f = h5py.File(os.path.join(base_dir, '{}_{}.hy'.format(dir_name, image_size)), 'w') 41 | for i in range(len(images)): 42 | grp = f.create_group(str(i)) 43 | if images[i].shape[0] != image_size: 44 | image = imresize(images[i], (image_size, image_size, 3)) 45 | # range of pixel values = [-1.0, 1.0] 46 | image = image.astype(np.float32) / 255.0 47 | image = image * 2.0 - 1.0 48 | grp['image'] = image 49 | else: 50 | grp['image'] = images[i] 51 | f.close() 52 | 53 | def convert_h5py(task_name): 54 | print('Generating h5py file') 55 | base_dir = os.path.join('datasets', task_name) 56 | data = read_images(base_dir) 57 | for dir_name, images in data: 58 | if images[0].shape[0] == 256: 59 | store_h5py(base_dir, dir_name, images, 256) 60 | store_h5py(base_dir, dir_name, images, 128) 61 | 62 | def read_h5py(task_name, image_size): 63 | base_dir = 'datasets/' + task_name 64 | paths = glob(os.path.join(base_dir, '*_{}.hy'.format(image_size))) 65 | if len(paths) != 4: 66 | convert_h5py(task_name) 67 | ret = [] 68 | for dir_name in ['trainA', 'trainB', 'testA', 'testB']: 69 | try: 70 | dataset = h5py.File(os.path.join(base_dir, '{}_{}.hy'.format(dir_name, image_size)), 'r') 71 | except: 72 | raise IOError('Dataset is not available. Please try it again') 73 | 74 | images = [] 75 | for id in dataset: 76 | images.append(dataset[id]['image'].value.astype(np.float32)) 77 | ret.append(images) 78 | return ret 79 | 80 | def download_dataset(task_name): 81 | print('Download data %s' % task_name) 82 | cmd = './download_cyclegan_dataset.sh ' + task_name 83 | os.system(cmd) 84 | 85 | def get_data(task_name, image_size): 86 | assert task_name in datasets, 'Dataset {}_{} is not available'.format( 87 | task_name, image_size) 88 | 89 | if not os.path.exists('datasets'): 90 | os.makedirs('datasets') 91 | 92 | base_dir = os.path.join('datasets', task_name) 93 | print('Check data %s' % base_dir) 94 | if not os.path.exists(base_dir): 95 | print('Dataset not found. Start downloading...') 96 | download_dataset(task_name) 97 | convert_h5py(task_name) 98 | 99 | print('Load data %s' % task_name) 100 | train_A, train_B, test_A, test_B = \ 101 | read_h5py(task_name, image_size) 102 | return train_A, train_B, test_A, test_B 103 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CycleGAN implementation in Tensorflow 2 | 3 | As part of the implementation series of [Joseph Lim's group at USC](http://csail.mit.edu/~lim), our motivation is to accelerate (or sometimes delay) research in the AI community by promoting open-source projects. To this end, we implement state-of-the-art research papers, and publicly share them with concise reports. Please visit our [group github site](https://github.com/gitlimlab) for other projects. 4 | 5 | This project is implemented by [Youngwoon Lee](https://github.com/youngwoon) and the codes have been reviewed by [Honghua Dong](https://github.com/dhh1995) before being published. 6 | 7 | ## Description 8 | 9 | This repo is a [Tensorflow](https://www.tensorflow.org/) implementation of CycleGAN on Pix2Pix datasets: [Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593). 10 | 11 | This paper presents a framework addressing the **image-to-image translation** task, where we are interested in converting an image from one domain (e.g., zebra) to another domain (e.g., horse). It transforms a given image by finding an one-to-one mapping between unpaired data from two domains. 12 | 13 | The framework consists of two generators and two discriminators. Generator *G_ab* aims to translate an image in domain *a* (zebra) to its domain *b* version (horse); while generator *G_ba* aims to translate an image in domain *b* to its domain *a* version. On the other hand, discriminator *D_a* verifies whether given images are in domain *a* or not; so does discriminator *D_b*. 14 | 15 | Therefore, the entire frameowrk consists of two loops of GANs which are trained to perform image-to-image translation *a*->*b*->*a* and *b*->*a*->*b*. When training these GANs, a **cycle-consistent loss**, which is a sum of reconstruction errors (*a*->*b*->*a* and *b*->*a*->*b*), is added to the adversarial loss. Without one-to-one mapping between two domains *a* and *b*, the framework cannot reconstruct original image and it leads to the large cycle-consistent loss. Therefore, the cycle-consistent loss alleviates the issue of mode collapse by imposing one-to-one mapping between two domains. 16 | 17 | ![paper-figure](assets/paper-figure.png) 18 | 19 | ## Dependencies 20 | 21 | - Ubuntu 16.04 22 | - Python 2.7 23 | - [Tensorflow 1.1.0](https://www.tensorflow.org/) 24 | - [NumPy](https://pypi.python.org/pypi/numpy) 25 | - [SciPy](https://pypi.python.org/pypi/scipy) 26 | - [Pillow](https://pillow.readthedocs.io/en/4.0.x/) 27 | - [tqdm](https://github.com/tqdm/tqdm) 28 | - [h5py](http://docs.h5py.org/en/latest/) 29 | 30 | ## Usage 31 | 32 | - Execute the following command to download the specified dataset as well as train a model: 33 | 34 | ``` 35 | $ python cycle-gan.py --task apple2orange --image_size 256 36 | ``` 37 | 38 | - To reconstruct 256x256 images, set `--image_size` to 256; otherwise it will resize to and generate images in 128x128. 39 | Once training is ended, testing images will be converted to the target domain and the results will be saved to `./results/apple2orange_2017-07-07_07-07-07/`. 40 | - Available datasets: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos 41 | 42 | 43 | - Check the training status on Tensorboard: 44 | 45 | ``` 46 | $ tensorboard --logdir=./logs 47 | ``` 48 | 49 | > **Carefully check Tensorboard for the first 1000 iterations. You need to run the experiment again if dark and bright regions are reversed like the exmaple below. This GAN implementation is sensitive to the initialization.** 50 | 51 | ![wrong-example](assets/wrong-initialization.png) 52 | 53 | ## Results 54 | 55 | ### apple2orange 56 | 57 | ![apple2orange](assets/apple2orange.png) 58 | 59 | ![training-apple2orange.png](assets/training-apple2orange.png) 60 | 61 | ### vangogh2photo 62 | 63 | ![vangogh2photo](assets/vangogh2photo.png) 64 | 65 | ![training-vangogh2photo](assets/training-vangogh2photo.png) 66 | 67 | ## References 68 | 69 | - [Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593) 70 | - [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022) 71 | - The official implementation in Torch: https://github.com/junyanz/CycleGAN 72 | - The data downloading script is from the author's code. 73 | 74 | -------------------------------------------------------------------------------- /cycle-gan.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import signal 4 | import os 5 | from datetime import datetime 6 | 7 | import tensorflow as tf 8 | 9 | from data_loader import get_data 10 | from model import CycleGAN 11 | from utils import logger, makedirs 12 | 13 | 14 | # parsing cmd arguments 15 | parser = argparse.ArgumentParser(description="Run commands") 16 | parser.add_argument('-t', '--train', default=True, type=bool, 17 | help="Training mode") 18 | parser.add_argument('--task', type=str, default='apple2orange', 19 | help='Task name') 20 | parser.add_argument('--cycle_loss_coeff', type=float, default=10, 21 | help='Cycle Consistency Loss coefficient') 22 | parser.add_argument('--instance_normalization', default=True, type=bool, 23 | help="Use instance norm instead of batch norm") 24 | parser.add_argument('--log_step', default=100, type=int, 25 | help="Tensorboard log frequency") 26 | parser.add_argument('--batch_size', default=1, type=int, 27 | help="Batch size") 28 | parser.add_argument('--image_size', default=128, type=int, 29 | help="Image size") 30 | parser.add_argument('--load_model', default='', 31 | help='Model path to load (e.g., train_2017-07-07_01-23-45)') 32 | 33 | 34 | class FastSaver(tf.train.Saver): 35 | def save(self, sess, save_path, global_step=None, latest_filename=None, 36 | meta_graph_suffix="meta", write_meta_graph=True): 37 | super(FastSaver, self).save(sess, save_path, global_step, latest_filename, 38 | meta_graph_suffix, False) 39 | 40 | 41 | def run(args): 42 | logger.info('Read data:') 43 | train_A, train_B, test_A, test_B = get_data(args.task, args.image_size) 44 | 45 | logger.info('Build graph:') 46 | model = CycleGAN(args) 47 | 48 | variables_to_save = tf.global_variables() 49 | init_op = tf.variables_initializer(variables_to_save) 50 | init_all_op = tf.global_variables_initializer() 51 | saver = FastSaver(variables_to_save) 52 | 53 | logger.info('Trainable vars:') 54 | var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 55 | tf.get_variable_scope().name) 56 | for v in var_list: 57 | logger.info(' %s %s', v.name, v.get_shape()) 58 | 59 | if args.load_model != '': 60 | model_name = args.load_model 61 | else: 62 | model_name = '{}_{}'.format(args.task, datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) 63 | logdir = './logs' 64 | makedirs(logdir) 65 | logdir = os.path.join(logdir, model_name) 66 | logger.info('Events directory: %s', logdir) 67 | summary_writer = tf.summary.FileWriter(logdir) 68 | 69 | def init_fn(sess): 70 | logger.info('Initializing all parameters.') 71 | sess.run(init_all_op) 72 | 73 | sv = tf.train.Supervisor(is_chief=True, 74 | logdir=logdir, 75 | saver=saver, 76 | summary_op=None, 77 | init_op=init_op, 78 | init_fn=init_fn, 79 | summary_writer=summary_writer, 80 | ready_op=tf.report_uninitialized_variables(variables_to_save), 81 | global_step=model.global_step, 82 | save_model_secs=300, 83 | save_summaries_secs=30) 84 | 85 | if args.train: 86 | logger.info("Starting training session.") 87 | with sv.managed_session() as sess: 88 | model.train(sess, summary_writer, train_A, train_B) 89 | 90 | logger.info("Starting testing session.") 91 | with sv.managed_session() as sess: 92 | base_dir = os.path.join('results', model_name) 93 | makedirs(base_dir) 94 | model.test(sess, test_A, test_B, base_dir) 95 | 96 | def main(): 97 | args, unparsed = parser.parse_known_args() 98 | 99 | def shutdown(signal, frame): 100 | tf.logging.warn('Received signal %s: exiting', signal) 101 | sys.exit(128+signal) 102 | signal.signal(signal.SIGHUP, shutdown) 103 | signal.signal(signal.SIGINT, shutdown) 104 | signal.signal(signal.SIGTERM, shutdown) 105 | 106 | run(args) 107 | 108 | if __name__ == "__main__": 109 | main() 110 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def _norm(input, is_train, reuse=True, norm=None): 5 | assert norm in ['instance', 'batch', None] 6 | if norm == 'instance': 7 | with tf.variable_scope('instance_norm', reuse=reuse): 8 | eps = 1e-5 9 | mean, sigma = tf.nn.moments(input, [1, 2], keep_dims=True) 10 | normalized = (input - mean) / (tf.sqrt(sigma) + eps) 11 | out = normalized 12 | # Apply momentum (not mendatory) 13 | #c = input.get_shape()[-1] 14 | #shift = tf.get_variable('shift', shape=[c], 15 | # initializer=tf.zeros_initializer()) 16 | #scale = tf.get_variable('scale', shape=[c], 17 | # initializer=tf.random_normal_initializer(1.0, 0.02)) 18 | #out = scale * normalized + shift 19 | elif norm == 'batch': 20 | with tf.variable_scope('batch_norm', reuse=reuse): 21 | out = tf.contrib.layers.batch_norm(input, 22 | decay=0.99, center=True, 23 | scale=True, is_training=is_train, 24 | updates_collections=None) 25 | else: 26 | out = input 27 | 28 | return out 29 | 30 | def _activation(input, activation=None): 31 | assert activation in ['relu', 'leaky', 'tanh', 'sigmoid', None] 32 | if activation == 'relu': 33 | return tf.nn.relu(input) 34 | elif activation == 'leaky': 35 | return tf.contrib.keras.layers.LeakyReLU(0.2)(input) 36 | elif activation == 'tanh': 37 | return tf.tanh(input) 38 | elif activation == 'sigmoid': 39 | return tf.sigmoid(input) 40 | else: 41 | return input 42 | 43 | def conv2d(input, num_filters, filter_size, stride, reuse=False, 44 | pad='SAME', dtype=tf.float32, bias=False): 45 | stride_shape = [1, stride, stride, 1] 46 | filter_shape = [filter_size, filter_size, input.get_shape()[3], num_filters] 47 | 48 | w = tf.get_variable('w', filter_shape, dtype, tf.random_normal_initializer(0.0, 0.02)) 49 | if pad == 'REFLECT': 50 | p = (filter_size - 1) // 2 51 | x = tf.pad(input, [[0,0],[p,p],[p,p],[0,0]], 'REFLECT') 52 | conv = tf.nn.conv2d(x, w, stride_shape, padding='VALID') 53 | else: 54 | assert pad in ['SAME', 'VALID'] 55 | conv = tf.nn.conv2d(input, w, stride_shape, padding=pad) 56 | 57 | if bias: 58 | b = tf.get_variable('b', [1,1,1,num_filters], initializer=tf.constant_initializer(0.0)) 59 | conv = conv + b 60 | return conv 61 | 62 | def conv2d_transpose(input, num_filters, filter_size, stride, reuse, 63 | pad='SAME', dtype=tf.float32): 64 | assert pad == 'SAME' 65 | n, h, w, c = input.get_shape().as_list() 66 | stride_shape = [1, stride, stride, 1] 67 | filter_shape = [filter_size, filter_size, num_filters, c] 68 | output_shape = [n, h * stride, w * stride, num_filters] 69 | 70 | w = tf.get_variable('w', filter_shape, dtype, tf.random_normal_initializer(0.0, 0.02)) 71 | deconv = tf.nn.conv2d_transpose(input, w, output_shape, stride_shape, pad) 72 | return deconv 73 | 74 | def conv_block(input, num_filters, name, k_size, stride, is_train, reuse, norm, 75 | activation, pad='SAME', bias=False): 76 | with tf.variable_scope(name, reuse=reuse): 77 | out = conv2d(input, num_filters, k_size, stride, reuse, pad, bias=bias) 78 | out = _norm(out, is_train, reuse, norm) 79 | out = _activation(out, activation) 80 | return out 81 | 82 | def residual(input, num_filters, name, is_train, reuse, norm, pad='REFLECT'): 83 | with tf.variable_scope(name, reuse=reuse): 84 | with tf.variable_scope('res1', reuse=reuse): 85 | out = conv2d(input, num_filters, 3, 1, reuse, pad) 86 | out = _norm(out, is_train, reuse, norm) 87 | out = tf.nn.relu(out) 88 | 89 | with tf.variable_scope('res2', reuse=reuse): 90 | out = conv2d(out, num_filters, 3, 1, reuse, pad) 91 | out = _norm(out, is_train, reuse, norm) 92 | 93 | return tf.nn.relu(input + out) 94 | 95 | def deconv_block(input, num_filters, name, k_size, stride, is_train, reuse, norm, activation): 96 | with tf.variable_scope(name, reuse=reuse): 97 | out = conv2d_transpose(input, num_filters, k_size, stride, reuse) 98 | out = _norm(out, is_train, reuse, norm) 99 | out = _activation(out, activation) 100 | return out 101 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | from tqdm import trange 5 | from scipy.misc import imsave 6 | import tensorflow as tf 7 | import numpy as np 8 | 9 | from generator import Generator 10 | from discriminator import Discriminator 11 | from utils import logger 12 | 13 | 14 | class HistoryQueue(object): 15 | def __init__(self, shape=[128,128,3], size=50): 16 | self._size = size 17 | self._shape = shape 18 | self._count = 0 19 | self._queue = [] 20 | 21 | def query(self, image): 22 | if len(image.shape) == 3: 23 | image = np.expand_dims(image, axis=0) 24 | if self._size == 0: 25 | return image 26 | if self._count < self._size: 27 | self._count += 1 28 | self._queue.append(image) 29 | return image 30 | 31 | p = random.random() 32 | if p > 0.5: 33 | idx = random.randrange(0, self._size) 34 | ret = self._queue[idx] 35 | self._queue[idx] = image 36 | return ret 37 | else: 38 | return image 39 | 40 | 41 | class CycleGAN(object): 42 | def __init__(self, args): 43 | self._log_step = args.log_step 44 | self._batch_size = args.batch_size 45 | self._image_size = args.image_size 46 | self._cycle_loss_coeff = args.cycle_loss_coeff 47 | 48 | self._augment_size = self._image_size + (30 if self._image_size == 256 else 15) 49 | self._image_shape = [self._image_size, self._image_size, 3] 50 | 51 | self.is_train = tf.placeholder(tf.bool, name='is_train') 52 | self.lr = tf.placeholder(tf.float32, name='lr') 53 | self.global_step = tf.contrib.framework.get_or_create_global_step(graph=None) 54 | 55 | image_a = self.image_a = \ 56 | tf.placeholder(tf.float32, [self._batch_size] + self._image_shape, name='image_a') 57 | image_b = self.image_b = \ 58 | tf.placeholder(tf.float32, [self._batch_size] + self._image_shape, name='image_b') 59 | history_fake_a = self.history_fake_a = \ 60 | tf.placeholder(tf.float32, [None] + self._image_shape, name='history_fake_a') 61 | history_fake_b = self.history_fake_b = \ 62 | tf.placeholder(tf.float32, [None] + self._image_shape, name='history_fake_b') 63 | 64 | # Data augmentation 65 | def augment_image(image): 66 | image = tf.image.resize_images(image, [self._augment_size, self._augment_size]) 67 | image = tf.random_crop(image, [self._batch_size] + self._image_shape) 68 | image = tf.map_fn(tf.image.random_flip_left_right, image) 69 | return image 70 | 71 | image_a = tf.cond(self.is_train, 72 | lambda: augment_image(image_a), 73 | lambda: image_a) 74 | image_b = tf.cond(self.is_train, 75 | lambda: augment_image(image_b), 76 | lambda: image_b) 77 | 78 | # Generator 79 | G_ab = Generator('G_ab', is_train=self.is_train, 80 | norm='instance', activation='relu', image_size=self._image_size) 81 | G_ba = Generator('G_ba', is_train=self.is_train, 82 | norm='instance', activation='relu', image_size=self._image_size) 83 | 84 | # Discriminator 85 | D_a = Discriminator('D_a', is_train=self.is_train, 86 | norm='instance', activation='leaky') 87 | D_b = Discriminator('D_b', is_train=self.is_train, 88 | norm='instance', activation='leaky') 89 | 90 | # Generate images (a->b->a and b->a->b) 91 | image_ab = self.image_ab = G_ab(image_a) 92 | image_aba = self.image_aba = G_ba(image_ab) 93 | image_ba = self.image_ba = G_ba(image_b) 94 | image_bab = self.image_bab = G_ab(image_ba) 95 | 96 | # Discriminate real/fake images 97 | D_real_a = D_a(image_a) 98 | D_fake_a = D_a(image_ba) 99 | D_real_b = D_b(image_b) 100 | D_fake_b = D_b(image_ab) 101 | D_history_fake_a = D_a(history_fake_a) 102 | D_history_fake_b = D_b(history_fake_b) 103 | 104 | # Least squre loss for GAN discriminator 105 | loss_D_a = (tf.reduce_mean(tf.squared_difference(D_real_a, 0.9)) + 106 | tf.reduce_mean(tf.square(D_history_fake_a))) * 0.5 107 | loss_D_b = (tf.reduce_mean(tf.squared_difference(D_real_b, 0.9)) + 108 | tf.reduce_mean(tf.square(D_history_fake_b))) * 0.5 109 | 110 | # Least squre loss for GAN generator 111 | loss_G_ab = tf.reduce_mean(tf.squared_difference(D_fake_b, 0.9)) 112 | loss_G_ba = tf.reduce_mean(tf.squared_difference(D_fake_a, 0.9)) 113 | 114 | # L1 norm for reconstruction error 115 | loss_rec_aba = tf.reduce_mean(tf.abs(image_a - image_aba)) 116 | loss_rec_bab = tf.reduce_mean(tf.abs(image_b - image_bab)) 117 | loss_cycle = self._cycle_loss_coeff * (loss_rec_aba + loss_rec_bab) 118 | 119 | loss_G_ab_final = loss_G_ab + loss_cycle 120 | loss_G_ba_final = loss_G_ba + loss_cycle 121 | 122 | # Optimizer 123 | self.optimizer_D_a = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \ 124 | .minimize(loss_D_a, var_list=D_a.var_list, global_step=self.global_step) 125 | self.optimizer_D_b = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \ 126 | .minimize(loss_D_b, var_list=D_b.var_list) 127 | self.optimizer_G_ab = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \ 128 | .minimize(loss_G_ab_final, var_list=G_ab.var_list) 129 | self.optimizer_G_ba = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \ 130 | .minimize(loss_G_ba_final, var_list=G_ba.var_list) 131 | 132 | # Summaries 133 | self.loss_D_a = loss_D_a 134 | self.loss_D_b = loss_D_b 135 | self.loss_G_ab = loss_G_ab 136 | self.loss_G_ba = loss_G_ba 137 | self.loss_cycle = loss_cycle 138 | 139 | tf.summary.scalar('loss/dis_A', loss_D_a) 140 | tf.summary.scalar('loss/dis_B', loss_D_b) 141 | tf.summary.scalar('loss/gen_AB', loss_G_ab) 142 | tf.summary.scalar('loss/gen_BA', loss_G_ba) 143 | tf.summary.scalar('loss/cycle', loss_cycle) 144 | tf.summary.scalar('model/D_a_real', tf.reduce_mean(D_real_a)) 145 | tf.summary.scalar('model/D_a_fake', tf.reduce_mean(D_fake_a)) 146 | tf.summary.scalar('model/D_b_real', tf.reduce_mean(D_real_b)) 147 | tf.summary.scalar('model/D_b_fake', tf.reduce_mean(D_fake_b)) 148 | tf.summary.scalar('model/lr', self.lr) 149 | tf.summary.image('A/A', image_a[0:1]) 150 | tf.summary.image('A/A-B', image_ab[0:1]) 151 | tf.summary.image('A/A-B-A', image_aba[0:1]) 152 | tf.summary.image('B/B', image_b[0:1]) 153 | tf.summary.image('B/B-A', image_ba[0:1]) 154 | tf.summary.image('B/B-A-B', image_bab[0:1]) 155 | self.summary_op = tf.summary.merge_all() 156 | 157 | def train(self, sess, summary_writer, data_A, data_B): 158 | logger.info('Start training.') 159 | logger.info(' {} images from A'.format(len(data_A))) 160 | logger.info(' {} images from B'.format(len(data_B))) 161 | 162 | data_size = min(len(data_A), len(data_B)) 163 | num_batch = data_size // self._batch_size 164 | epoch_length = num_batch * self._batch_size 165 | 166 | num_initial_iter = 100 167 | num_decay_iter = 100 168 | lr = lr_initial = 0.0002 169 | lr_decay = lr_initial / num_decay_iter 170 | 171 | history_a = HistoryQueue(shape=self._image_shape, size=50) 172 | history_b = HistoryQueue(shape=self._image_shape, size=50) 173 | 174 | initial_step = sess.run(self.global_step) 175 | num_global_step = (num_initial_iter + num_decay_iter) * epoch_length 176 | t = trange(initial_step, num_global_step, 177 | total=num_global_step, initial=initial_step) 178 | 179 | for step in t: 180 | #TODO: resume training with global_step 181 | epoch = step // epoch_length 182 | iter = step % epoch_length 183 | 184 | if epoch > num_initial_iter: 185 | lr = max(0.0, lr_initial - (epoch - num_initial_iter) * lr_decay) 186 | 187 | if iter == 0: 188 | random.shuffle(data_A) 189 | random.shuffle(data_B) 190 | 191 | image_a = np.stack(data_A[iter*self._batch_size:(iter+1)*self._batch_size]) 192 | image_b = np.stack(data_B[iter*self._batch_size:(iter+1)*self._batch_size]) 193 | fake_a, fake_b = sess.run([self.image_ba, self.image_ab], 194 | feed_dict={self.image_a: image_a, 195 | self.image_b: image_b, 196 | self.is_train: True}) 197 | fake_a = history_a.query(fake_a) 198 | fake_b = history_b.query(fake_b) 199 | 200 | fetches = [self.loss_D_a, self.loss_D_b, self.loss_G_ab, 201 | self.loss_G_ba, self.loss_cycle, 202 | self.optimizer_D_a, self.optimizer_D_b, 203 | self.optimizer_G_ab, self.optimizer_G_ba] 204 | if step % self._log_step == 0: 205 | fetches += [self.summary_op] 206 | 207 | fetched = sess.run(fetches, feed_dict={self.image_a: image_a, 208 | self.image_b: image_b, 209 | self.is_train: True, 210 | self.lr: lr, 211 | self.history_fake_a: fake_a, 212 | self.history_fake_b: fake_b}) 213 | 214 | if step % self._log_step == 0: 215 | summary_writer.add_summary(fetched[-1], step) 216 | summary_writer.flush() 217 | t.set_description( 218 | 'Loss: D_a({:.3f}) D_b({:.3f}) G_ab({:.3f}) G_ba({:.3f}) cycle({:.3f})'.format( 219 | fetched[0], fetched[1], fetched[2], fetched[3], fetched[4])) 220 | 221 | 222 | def test(self, sess, data_A, data_B, base_dir): 223 | step = 0 224 | for data in data_A: 225 | step += 1 226 | fetches = [self.image_ab, self.image_aba] 227 | image_a = np.expand_dims(data, axis=0) 228 | image_ab, image_aba = sess.run(fetches, feed_dict={self.image_a: image_a, 229 | self.is_train: False}) 230 | images = np.concatenate((image_a, image_ab, image_aba), axis=2) 231 | images = np.squeeze(images, axis=0) 232 | imsave(os.path.join(base_dir, 'a_to_b_{}.jpg'.format(step)), images) 233 | 234 | step = 0 235 | for data in data_B: 236 | step += 1 237 | fetches = [self.image_ba, self.image_bab] 238 | image_b = np.expand_dims(data, axis=0) 239 | image_ba, image_bab = sess.run(fetches, feed_dict={self.image_b: image_b, 240 | self.is_train: False}) 241 | images = np.concatenate((image_b, image_ba, image_bab), axis=2) 242 | images = np.squeeze(images, axis=0) 243 | imsave(os.path.join(base_dir, 'b_to_a_{}.jpg'.format(step)), images) 244 | --------------------------------------------------------------------------------