├── .gitignore ├── 3rdparty └── replicate │ ├── cog.yaml │ └── predict.py ├── LICENSE ├── README.md ├── configs.py ├── dataset └── .gitignore ├── discriminator ├── discriminator_patch.py └── discriminator_spatch.py ├── docs └── tips.md ├── frozen_model ├── image_translator.py └── test_frozen_model.py ├── gan ├── gan.py └── spatchgan.py ├── generator └── generator_basic_res.py ├── imagedata.py ├── images ├── SPatchGAN_D_20210317_3x.jpg └── s2a_cmp_github_downsized.jpg ├── main.py ├── ops.py ├── output └── .gitignore ├── requirements.txt └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | bash/ 3 | __pycache__/ 4 | -------------------------------------------------------------------------------- /3rdparty/replicate/cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | gpu: true 3 | python_version: "3.8" 4 | system_packages: 5 | - "libgl1-mesa-glx" 6 | - "libglib2.0-0" 7 | python_packages: 8 | - "tensorflow-gpu==2.6.0" 9 | - "numpy==1.19.4" 10 | - "opencv-python==4.5.3.56" 11 | - "gast==0.4.0" 12 | - "ipython==7.19.0" 13 | 14 | predict: "predict.py:Predictor" 15 | -------------------------------------------------------------------------------- /3rdparty/replicate/predict.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../..') 3 | import tempfile 4 | from pathlib import Path 5 | import argparse 6 | import cv2 7 | import numpy as np 8 | import tensorflow as tf 9 | from tensorflow.python.platform import gfile 10 | import cog 11 | from frozen_model.image_translator import ImageTranslator 12 | 13 | 14 | class Predictor(cog.Predictor): 15 | def setup(self): 16 | args = parse_arguments() 17 | args.model = 'SPatchGAN_selfie2anime_scale3_cyc20_20210831.pb' 18 | size = 256 19 | self.translator = ImageTranslator(model_path=args.model, 20 | size=size, 21 | n_threads_intra=args.n_threads_intra, 22 | n_threads_inter=args.n_threads_inter) 23 | 24 | @cog.input( 25 | "image", 26 | type=Path, 27 | help="input image, model will generate female anime, support .png, .jpg and .jpeg", 28 | ) 29 | def predict(self, image): 30 | out_path = Path(tempfile.mkdtemp()) / "out.png" 31 | img = cv2.imread(str(image), cv2.IMREAD_COLOR) 32 | output = self.translator.translate(img) 33 | cv2.imwrite(str(out_path), output) 34 | return out_path 35 | 36 | 37 | def parse_arguments(): 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('--image', type=str, help='Image file path or directory.') 40 | parser.add_argument('--model', type=str, help='.pb model path.') 41 | parser.add_argument('--n_threads_inter', type=int, default=1, help='Number of inter op threads.') 42 | parser.add_argument('--n_threads_intra', type=int, default=1, help='Number of intra op threads.') 43 | parser.add_argument('--n_iters', type=int, default=1) 44 | return parser.parse_args('') 45 | 46 | 47 | class ImageTranslator: 48 | def __init__(self, model_path, size, n_threads_intra=1, n_threads_inter=1): 49 | config = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=n_threads_intra, 50 | inter_op_parallelism_threads=n_threads_inter) 51 | config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1 52 | config.gpu_options.allow_growth = True 53 | self._size = size 54 | self._graph = tf.Graph() 55 | self._sess = tf.compat.v1.Session(config=config, graph=self._graph) 56 | 57 | self._pb_file_path = model_path 58 | self._restore_from_pb() 59 | self._input_op = self._graph.get_tensor_by_name('test_domain_A:0') 60 | self._output_op = self._graph.get_tensor_by_name('test_fake_B:0') 61 | 62 | def _restore_from_pb(self): 63 | with self._graph.as_default(): 64 | with gfile.FastGFile(self._pb_file_path, 'rb') as f: 65 | graph_def = tf.compat.v1.GraphDef() 66 | graph_def.ParseFromString(f.read()) 67 | tf.import_graph_def(graph_def, name='') 68 | 69 | def _input_transform(self, img): 70 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 71 | img = cv2.resize(img, dsize=(self._size, self._size)) 72 | img = np.expand_dims(img, axis=0) 73 | image_input = img / 127.5 - 1 74 | return image_input 75 | 76 | @staticmethod 77 | def _output_transform(output): 78 | output = ((output + 1.) / 2) * 255.0 79 | image_output = cv2.cvtColor(output.astype('uint8'), cv2.COLOR_RGB2BGR) 80 | return image_output 81 | 82 | def translate(self, image): 83 | """ Translate an image from the source domain to the target domain""" 84 | image_input = self._input_transform(image) 85 | output = self._sess.run(self._output_op, feed_dict={self._input_op: image_input})[0] 86 | return self._output_transform(output) 87 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, NetEase Games AI Lab. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## SPatchGAN: Official TensorFlow Implementation 2 | 3 | ### Paper 4 | "SPatchGAN: A Statistical Feature Based Discriminator for Unsupervised Image-to-Image Translation" (ICCV 2021) 5 | 6 | [](https://arxiv.org/abs/2103.16219) 7 | [](https://www.youtube.com/watch?v=JY7eq6q5qpk) 8 | [](https://drive.google.com/file/d/1r62hIUkQrQGE_6PI_ovS4sdfHPjfSfIK/view?usp=sharing) 9 | 10 | 11 | 12 |

13 | 14 | 15 | 16 | ### Web Demos 17 | by [CJWBW](https://github.com/CJWBW) 18 | 19 | 20 | ### Environment 21 | - CUDA 10.0 22 | - Python 3.6 23 | - ``pip install -r requirements.txt`` 24 | 25 | ### Dataset 26 | 27 | - Dataset structure (dataset_struct='plain') 28 | ``` 29 | - dataset 30 | - 31 | - trainA 32 | - 1.jpg 33 | - 2.jpg 34 | - ... 35 | - trainB 36 | - 3.jpg 37 | - 4.jpg 38 | - ... 39 | - testA 40 | - 5.jpg 41 | - 6.jpg 42 | - ... 43 | - testB 44 | - 7.jpg 45 | - 8.jpg 46 | - ... 47 | ``` 48 | 49 | - Supported extensions: jpg, jpeg, png 50 | - An additional level of subdirectories is also supported by setting dataset_struct to 'tree', e.g., 51 | ``` 52 | - trainA 53 | - subdir1 54 | - 1.jpg 55 | - 2.jpg 56 | - ... 57 | - subdir2 58 | - ... 59 | ``` 60 | 61 | - Selfie-to-anime: 62 | - The dataset can be downloaded from [U-GAT-IT](https://github.com/taki0112/UGATIT). 63 | 64 | - Male-to-female and glasses removal: 65 | - The datasets can be downloaded from [Council-GAN](https://github.com/Onr/Council-GAN). 66 | - The images must be center cropped from 218x178 to 178x178 before training or testing. 67 | - For glasses removal, only the male images are used in the experiments in our paper. Note that the dataset from Council-GAN has already been split into two subdirectories, "1" for male and "2" for female. 68 | 69 | ### Training 70 | 71 | - Set the suffix to anything descriptive, e.g., the date. 72 | - Selfie-to-Anime 73 | ```bash 74 | python main.py --dataset selfie2anime --augment_type resize_crop --n_scales_dis 3 --suffix scale3_cyc20_20210831 --phase train 75 | ``` 76 | 77 | - Male-to-Female 78 | ```bash 79 | python main.py --dataset male2female --cyc_weight 10 --suffix cyc10_20210831 --phase train 80 | ``` 81 | 82 | - Glasses Removal 83 | ```bash 84 | python main.py --dataset glasses-male --cyc_weight 30 --suffix cyc30_20210831 --phase train 85 | ``` 86 | - Find the output in ``./output/SPatchGAN__`` 87 | - The same command can be used to continue training based on the latest checkpoint. 88 | - For a new task, we recommend to use the default setting as the starting point, and adjust the hyperparameters according to the [tips](docs/tips.md). 89 | - Check [configs.py](configs.py) for all the hyperparameters. 90 | 91 | ### Testing with the latest checkpoint 92 | - Replace ``--phase train`` with ``--phase test`` 93 | 94 | ### Save a frozen model (.pb) 95 | - Replace ``--phase train`` with ``--phase freeze_graph`` 96 | - Find the saved frozen model in ``./output/SPatchGAN__/checkpoint/pb`` 97 | 98 | ### Testing with the frozon model 99 | ```bash 100 | cd frozen_model 101 | python test_frozen_model.py --image --output_dir --model 102 | ``` 103 | 104 | ### Pretrained Models 105 | - Download the pretrained models from [huggingface](https://huggingface.co/necoarc/spatchgan-model), and put them in the output directory. 106 | - You can test the checkpoints (in ./checkpoint) or the frozen models (in ./checkpoint/pb). Either way produces the same results. 107 | - The results generated by the pretrained models are slightly different from those in the paper, since we have rerun the training after code refactoring. 108 | - We set ``n_scales_dis`` to 3 for the pretrained selfie2anime model to further improve the performance. It was 4 in the paper. See more details in the [tips](docs/tips.md). 109 | - We also provide the generated results of the last 100 test images (in ./gen, sorted by name, no cherry-picking) for the calibration purpose. 110 | 111 | ### Other Implementations 112 | - We provide a PyTorch implementation of the SPatchGAN discriminator in [spatchgan_discriminator_pytorch.py](https://gist.github.com/NetEase-GameAI/6b93a3fa4c8ab7a59a75eeacca33712f). 113 | 114 | ### Citation 115 | ``` 116 | @InProceedings{Shao_2021_ICCV, 117 | author = {Shao, Xuning and Zhang, Weidong}, 118 | title = {SPatchGAN: A Statistical Feature Based Discriminator for Unsupervised Image-to-Image Translation}, 119 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 120 | month = {October}, 121 | year = {2021}, 122 | pages = {6546-6555} 123 | } 124 | ``` 125 | 126 | ### Acknowledgement 127 | - Our code is partially based on [U-GAT-IT](https://github.com/taki0112/UGATIT). 128 | -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | def _str2bool(x): 6 | return x.lower() == 'true' 7 | 8 | 9 | def _none_or_str(x): 10 | if x == 'None': 11 | return None 12 | else: 13 | return x 14 | 15 | 16 | def parse_args(): 17 | """Configurations.""" 18 | desc = "TensorFlow implementation of SPatchGAN." 19 | parser = argparse.ArgumentParser(description=desc) 20 | 21 | # General configs 22 | parser.add_argument('--network', type=str, default='spatchgan', help='Network type: [spatchgan].') 23 | parser.add_argument('--phase', type=str, default='train', 24 | help='Phase: [train / test / freeze_graph].') 25 | parser.add_argument('--dataset', type=str, required=True, help='Name of the training dataset.') 26 | parser.add_argument('--test_dataset', type=str, default=None, 27 | help='Name of the testing dataset. Same as the training dataset by default.') 28 | parser.add_argument('--dataset_struct', type=str, default='plain', help='Dataset type: [plain / tree].') 29 | parser.add_argument('--suffix', type=str, default=None, help='suffix for the model name.') 30 | 31 | # Training configs 32 | parser.add_argument('--n_steps', type=int, default=50, help='Number of training steps.') 33 | parser.add_argument('--n_iters_per_step', type=int, default=10000, help='Number of iterations per step') 34 | parser.add_argument('--batch_size', type=int, default=4, help='Batch size.') 35 | parser.add_argument('--img_save_freq', type=int, default=1000, help='Image saving frequency in iteration.') 36 | parser.add_argument('--ckpt_save_freq', type=int, default=1000, help='Checkpoint saving frequency in iteration.') 37 | parser.add_argument('--summary_freq', type=int, default=100, help='TensorFlow summary frequency.') 38 | parser.add_argument('--decay_step', type=int, default=10, help='Starting point for learning rate decay.') 39 | parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate.') 40 | parser.add_argument('--adv_weight', type=float, default=4.0, help='Adversarial loss weight.') 41 | parser.add_argument('--reg_weight', type=float, default=1.0, help='Regularization weight.') 42 | parser.add_argument('--cyc_weight', type=float, default=20.0, help='Weak cycle loss weight.') 43 | parser.add_argument('--id_weight', type=float, default=10.0, help='Identity loss weight.') 44 | parser.add_argument('--gan_type', type=str, default='lsgan', help='GAN loss type: [lsgan].') 45 | 46 | # Input configs 47 | parser.add_argument('--img_size', type=int, default=256, help='The size of input images.') 48 | parser.add_argument('--augment_type', type=_none_or_str, default='pad_crop', 49 | help='Augmentation method: [pad_crop / resize_crop / None].') 50 | 51 | # Discriminator configs 52 | parser.add_argument('--dis_type', type=str, default='spatch', help='D type: [spatch / patch].') 53 | parser.add_argument('--logits_type_dis', type=str, default='stats', help='D logits calculation method: [stats].') 54 | parser.add_argument('--ch_dis', type=int, default=256, help='Base channel number of D.') 55 | parser.add_argument('--n_downsample_init_dis', type=int, default=2, 56 | help='Number of downsampling layers in the initial feature extraction block.') 57 | parser.add_argument('--n_scales_dis', type=int, default=4, help='Number of scales in D.') 58 | parser.add_argument('--sn_dis', type=_none_or_str, default='fast', help='Spectral norm type: [fast / full / None]') 59 | parser.add_argument('--n_adapt_dis', type=int, default=2, help='Number of layers in each adaptation block.') 60 | parser.add_argument('--n_mix_dis', type=int, default=2, help='Number of mixing layers in each MLP.') 61 | parser.add_argument('--mean_dis', type=_str2bool, default=True, help='Use the gap output in D.') 62 | parser.add_argument('--max_dis', type=_str2bool, default=True, help='Use the gmp output in D.') 63 | parser.add_argument('--stddev_dis', type=_str2bool, default=True, help='Use the stddev output in D.') 64 | 65 | # Generator configs 66 | parser.add_argument('--gen_type', type=str, default='basic_res', help='G type: [basic_res].') 67 | parser.add_argument('--block_type_gen', type=str, default='v1', help='G residual block type: [v1].') 68 | parser.add_argument('--ch_gen', type=int, default=128, help='Base channel number of forward G.') 69 | parser.add_argument('--ch_gen_bw', type=int, default=512, help='Base channel number of backward G.') 70 | parser.add_argument('--upsample_type_gen', type=str, default='nearest', 71 | help='Upsampling method: [nearest / bilinear].') 72 | parser.add_argument('--n_updownsample_gen', type=int, default=3, 73 | help='Number of up/downsampling layers in forward G.') 74 | parser.add_argument('--n_updownsample_gen_bw', type=int, default=0, 75 | help='Number of up/downsampling layers in backward G.') 76 | parser.add_argument('--n_res_gen', type=int, default=8, help='Number of residual blocks in forward G.') 77 | parser.add_argument('--n_res_gen_bw', type=int, default=8, help='Number of residual blocks in backward G.') 78 | parser.add_argument('--n_enhanced_upsample_gen', type=int, default=1, 79 | help='Number of enhanced upsampling blocks that include multiple mixing layers.') 80 | parser.add_argument('--n_mix_upsample_gen', type=int, default=2, 81 | help='Number of mixing layers in an enhanced upsampling block.') 82 | parser.add_argument('--resize_factor_gen_bw', type=int, default=8, 83 | help='The resizing factor of input images for backward G.') 84 | 85 | # Directory names 86 | parser.add_argument('--dataset_dir', type=str, 87 | default=os.path.join(os.path.dirname(__file__), 'dataset'), 88 | help='Directory for all datasets') 89 | parser.add_argument('--output_dir', type=str, 90 | default=os.path.join(os.path.dirname(__file__), 'output'), 91 | help='Directory for all outputs') 92 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 93 | help='Directory for the checkpoints') 94 | parser.add_argument('--result_dir', type=str, default='gen', 95 | help='Directory for the generated images.') 96 | parser.add_argument('--log_dir', type=str, default='logs', 97 | help='Directory for the training logs.') 98 | parser.add_argument('--sample_dir', type=str, default='samples', 99 | help='Directory for the training sample images.') 100 | 101 | return _check_args(parser.parse_args()) 102 | 103 | 104 | def _check_args(args): 105 | if args is None: 106 | raise RuntimeError('Invalid arguments!') 107 | return args 108 | -------------------------------------------------------------------------------- /dataset/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /discriminator/discriminator_patch.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from ops import conv, lrelu 3 | 4 | 5 | class DiscriminatorPatch: 6 | """Multiscale PatchGan discriminator for MUNIT / Council-GAN / ACL-GAN.""" 7 | def __init__(self, ch, n_downsample_init, n_scales, sn): 8 | self._ch = ch # 64 in MUNIT 9 | self._n_downsample_init = n_downsample_init # 4 in MUNIT 10 | self._n_scales = n_scales # 3 in MUNIT 11 | self._sn = sn 12 | 13 | def discriminate(self, x, reuse=False, scope='dis'): 14 | """Calculate the patch based logits.""" 15 | with tf.variable_scope(scope, reuse=reuse): 16 | logits = [] 17 | for i in range(self._n_scales): 18 | logits_patch = self._discriminator_per_scale(x, reuse=reuse, scope='scale_{}'.format(i)) 19 | logits.append(logits_patch) 20 | x = tf.layers.average_pooling2d(x, pool_size=3, strides=2, padding='SAME') 21 | return logits 22 | 23 | def _discriminator_per_scale(self, x, reuse=False, scope='scale_0'): 24 | with tf.variable_scope(scope, reuse=reuse): 25 | channel = self._ch 26 | for i in range(self._n_downsample_init): 27 | with tf.variable_scope('down_{}'.format(i)): 28 | x = conv(x, channel, kernel=4, stride=2, pad=1, sn=self._sn) 29 | x = lrelu(x, 0.2) 30 | channel *= 2 31 | x = conv(x, channels=1, kernel=1, stride=1, sn=self._sn, scope='logits') 32 | x = tf.identity(x, 'D_logits_patch') 33 | return x 34 | -------------------------------------------------------------------------------- /discriminator/discriminator_spatch.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from ops import conv, lrelu, global_avg_pooling, global_max_pooling, fully_connected 3 | 4 | 5 | class DiscriminatorSPatch: 6 | """SPatchGAN discriminator.""" 7 | def __init__(self, ch, n_downsample_init, n_scales, n_adapt, n_mix, 8 | logits_type: str, stats: list, sn): 9 | self._ch = ch 10 | self._n_downsample_init = n_downsample_init 11 | self._n_scales = n_scales 12 | self._n_adapt = n_adapt 13 | self._n_mix = n_mix 14 | self._logits_type = logits_type 15 | self._stats = stats 16 | self._sn = sn 17 | 18 | def discriminate(self, x, reuse=False, scope='dis'): 19 | """Calculate the statistical feature based logits.""" 20 | with tf.variable_scope(scope, reuse=reuse): 21 | channel = self._ch 22 | logits_list = [] 23 | 24 | for i in range(self._n_downsample_init): 25 | with tf.variable_scope('down_{}'.format(i)): 26 | # (256, 256, 3) -> (128, 128, 256) -> (64, 64, 512) 27 | x = conv(x, channel, kernel=4, stride=2, pad=1, sn=self._sn) 28 | x = lrelu(x) 29 | channel *= 2 30 | 31 | for i in range(self._n_scales): 32 | with tf.variable_scope('scale_{}'.format(i)): 33 | # (64, 64, 512) -> (32, 32, 1024) -> (16, 16, 1024) -> (8, 8, 1024) -> (4, 4, 1024) 34 | x = conv(x, channel, kernel=4, stride=2, pad=1, sn=self._sn, scope='conv_k4') 35 | x = lrelu(x) 36 | logits = self._dis_logits(x) 37 | logits_list.extend(logits) 38 | 39 | return logits_list 40 | 41 | def _dis_logits(self, x, scope='dis_logits'): 42 | if self._logits_type == 'stats': 43 | return self._dis_logits_stats(x, scope=scope) 44 | else: 45 | raise ValueError('Invalid logits_type_dis!') 46 | 47 | def _dis_logits_stats(self, x, scope='dis_logits'): 48 | with tf.variable_scope(scope): 49 | logits_list = [] 50 | channel = x.shape[-1].value 51 | 52 | for i in range(self._n_adapt): 53 | with tf.variable_scope('premix_{}'.format(i)): 54 | x = conv(x, channel, sn=self._sn) 55 | x = lrelu(x) 56 | 57 | if 'mean' in self._stats: 58 | with tf.variable_scope('gap'): 59 | x_gap = global_avg_pooling(x) 60 | x_gap_logits = self._mlp_logits(x_gap) 61 | x_gap_logits = tf.identity(x_gap_logits, 'D_logits_gap') 62 | logits_list.append(x_gap_logits) 63 | 64 | if 'max' in self._stats: 65 | with tf.variable_scope('gmp'): 66 | x_gmp = global_max_pooling(x) 67 | x_gmp_logits = self._mlp_logits(x_gmp) 68 | x_gmp_logits = tf.identity(x_gmp_logits, 'D_logits_gmp') 69 | logits_list.append(x_gmp_logits) 70 | 71 | if 'stddev' in self._stats: 72 | with tf.variable_scope('stddev'): 73 | # Calculate the channel-wise uncorrected standard deviation 74 | x_diff_square = tf.square(x - tf.reduce_mean(x, axis=[1, 2], keepdims=True)) 75 | x_stddev = tf.sqrt(global_avg_pooling(x_diff_square)) 76 | x_stddev_logits = self._mlp_logits(x_stddev) 77 | x_stddev_logits = tf.identity(x_stddev_logits, 'D_logits_stddev') 78 | logits_list.append(x_stddev_logits) 79 | 80 | return logits_list 81 | 82 | def _mlp_logits(self, x, n_ch=None, scope='dis_logits_mix'): 83 | with tf.variable_scope(scope): 84 | shape = x.shape.as_list() 85 | channel = n_ch or shape[-1] 86 | if len(shape) == 2: 87 | for i in range(self._n_mix): 88 | x = fully_connected(x, units=channel, sn=self._sn, scope='mix_'+str(i)) 89 | x = lrelu(x) 90 | x = fully_connected(x, units=1, sn=self._sn, scope='logits') 91 | elif len(shape) == 4: 92 | for i in range(self._n_mix): 93 | x = conv(x, channels=channel, kernel=1, stride=1, sn=self._sn, scope='mix_'+str(i)) 94 | x = lrelu(x) 95 | x = conv(x, channels=1, kernel=1, stride=1, sn=self._sn, scope='logits') 96 | return x 97 | 98 | -------------------------------------------------------------------------------- /docs/tips.md: -------------------------------------------------------------------------------- 1 | ## Hyperparameter Tuning 2 | 3 | - Discriminator 4 | - We set the number of scales (``n_scales_dis``) to be a constant (4) in the paper. This is to verify that a good performance can be achieved for different tasks with a fixed network structure. We found in practice that a reduced number of scales (3) is often discriminative enough and more stabilized, especially for the tasks which require a significant shape deformation. In such case, the 4-th scale, which is the most discriminative one in the default setting, tends to be an overkill. 5 | - Reducing the number of base channels (``ch_dis``) is an effective way to accelerate the training process. 6 | 7 | - Generator 8 | - To improve the inference speed, you may want to reduce the number of base channels (``ch_gen``) or the number of enhanced upsampling layers (``n_enhanced_upsample_gen``). 9 | 10 | - Weak cycle 11 | - The weight for the weak cycle constraint (``cyc_weight``) can be adjusted on a per-task basis. A large value is generally more prohibitive for the shape deformation. On the other hand, it helps to keep the generated image correlated with the source image. 12 | - The input of the backward generator is resized by 1/(``resize_factor_gen_bw``). The number of downsampling and upsampling layers of the backward generator is set by ``n_updownsample_gen_bw``. These two parameters can be adjusted to change the weak cycle to a full resolution forward cycle, or something in between. The effect is somewhat similar to increasing the cycle weight. 13 | 14 | - Training with a higher resolution 15 | - You may want to adjust several parameters if the input / output resolution is higher than 256x256. Take 512x512 as an example. A good starting point will be setting ``img_size`` to 512, increasing ``n_downsample_init_dis`` from 2 to 3, and reducing ``ch_dis``, ``ch_gen`` and ``ch_gen_bw`` by a factor of two. -------------------------------------------------------------------------------- /frozen_model/image_translator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import warnings 5 | warnings.filterwarnings('ignore', category=FutureWarning) 6 | import tensorflow as tf 7 | from tensorflow.python.platform import gfile 8 | 9 | curPath = os.path.abspath(os.path.dirname(__file__)) 10 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) 11 | 12 | 13 | class ImageTranslator: 14 | def __init__(self, model_path, size, n_threads_intra=1, n_threads_inter=1): 15 | config = tf.ConfigProto(intra_op_parallelism_threads=n_threads_intra, 16 | inter_op_parallelism_threads=n_threads_inter) 17 | config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 18 | config.gpu_options.allow_growth = True 19 | self._size = size 20 | self._graph = tf.Graph() 21 | self._sess = tf.Session(config=config, graph=self._graph) 22 | 23 | self._pb_file_path = model_path 24 | self._restore_from_pb() 25 | self._input_op = self._graph.get_tensor_by_name('test_domain_A:0') 26 | self._output_op = self._graph.get_tensor_by_name('test_fake_B:0') 27 | 28 | def _restore_from_pb(self): 29 | with self._graph.as_default(): 30 | with gfile.FastGFile(self._pb_file_path, 'rb') as f: 31 | graph_def = tf.GraphDef() 32 | graph_def.ParseFromString(f.read()) 33 | tf.import_graph_def(graph_def, name='') 34 | 35 | def _input_transform(self, img): 36 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 37 | img = cv2.resize(img, dsize=(self._size, self._size)) 38 | img = np.expand_dims(img, axis=0) 39 | image_input = img / 127.5 - 1 40 | return image_input 41 | 42 | @staticmethod 43 | def _output_transform(output): 44 | output = ((output + 1.) / 2) * 255.0 45 | image_output = cv2.cvtColor(output.astype('uint8'), cv2.COLOR_RGB2BGR) 46 | return image_output 47 | 48 | def translate(self, image): 49 | """ Translate an image from the source domain to the target domain""" 50 | image_input = self._input_transform(image) 51 | output = self._sess.run(self._output_op, feed_dict={self._input_op: image_input})[0] 52 | return self._output_transform(output) 53 | -------------------------------------------------------------------------------- /frozen_model/test_frozen_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import time 6 | 7 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 8 | from utils import get_img_paths 9 | from frozen_model.image_translator import ImageTranslator 10 | 11 | cv2.setNumThreads(1) 12 | 13 | 14 | def _get_images(img_loc: str) -> list: 15 | """ Find the image(s) in a given location. 16 | :param img_loc: either an image file path, or a directory containing images. 17 | :return: a list of image paths. 18 | """ 19 | image_list = [] 20 | if os.path.isfile(img_loc): 21 | image_list.append(img_loc) 22 | elif os.path.isdir(img_loc): 23 | image_list.extend(get_img_paths(img_loc)) 24 | return image_list 25 | 26 | 27 | def main(args): 28 | """ Translate all images in a given location with a frozen model (.pb).""" 29 | images = _get_images(args.image) 30 | if len(images) == 0: 31 | raise RuntimeError('No image in {}!'.format(args.image)) 32 | 33 | size = 256 34 | translator = ImageTranslator(model_path=args.model, 35 | size=size, 36 | n_threads_intra=args.n_threads_intra, 37 | n_threads_inter=args.n_threads_inter) 38 | os.makedirs(args.output_dir, exist_ok=True) 39 | 40 | st = time.time() 41 | for i, image in enumerate(images): 42 | print('{}: {}'.format(i, image)) 43 | img = cv2.imread(image, cv2.IMREAD_COLOR) 44 | if img is None: 45 | print('Invalid image {}'.format(image)) 46 | continue 47 | else: 48 | output = None 49 | for i_iter in range(args.n_iters): 50 | print('Iter: {}'.format(i_iter)) 51 | output = translator.translate(img) 52 | save_path = os.path.join(args.output_dir, os.path.basename(image)) 53 | cv2.imwrite(save_path, output) 54 | 55 | time_total = time.time() - st 56 | time_cost_per_img_ms = int((time_total / len(images) / args.n_iters) * 1000) 57 | print('Time cost per image: {} ms'.format(time_cost_per_img_ms)) 58 | 59 | 60 | def _parse_arguments(): 61 | parser = argparse.ArgumentParser() 62 | parser.add_argument('--image', type=str, required=True, help='Image file path or directory.') 63 | parser.add_argument('--output_dir', type=str, required=True, help='Output directory.') 64 | parser.add_argument('--model', type=str, required=True, help='.pb model path.') 65 | parser.add_argument('--n_threads_inter', type=int, default=1, help='Number of inter op threads.') 66 | parser.add_argument('--n_threads_intra', type=int, default=1, help='Number of intra op threads.') 67 | parser.add_argument('--n_iters', type=int, default=1) 68 | return parser.parse_args() 69 | 70 | 71 | if __name__ == '__main__': 72 | main(_parse_arguments()) -------------------------------------------------------------------------------- /gan/gan.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | from datetime import datetime 4 | import time 5 | from utils import get_img_paths, save_images, load_test_data 6 | 7 | 8 | class GAN: 9 | """Base class for GANs.""" 10 | def __init__(self, model_name, sess, args): 11 | # General 12 | self._model_name = model_name 13 | self._sess = sess 14 | self._saver = None 15 | self._dataset_name = args.dataset 16 | self._test_dataset_name = args.test_dataset or args.dataset 17 | self._dataset_struct = args.dataset_struct 18 | self._suffix = args.suffix 19 | 20 | # Directories 21 | self._dataset_dir = args.dataset_dir 22 | model_dir = "{}_{}_{}".format(self._model_name, self._dataset_name, self._suffix) 23 | self._checkpoint_dir = os.path.join(args.output_dir, model_dir, args.checkpoint_dir) 24 | self._sample_dir = os.path.join(args.output_dir, model_dir, args.sample_dir) 25 | self._log_dir = os.path.join(args.output_dir, model_dir, args.log_dir) 26 | self._result_dir = os.path.join(args.output_dir, model_dir, args.result_dir) 27 | for dir_ in [self._checkpoint_dir, self._sample_dir, self._log_dir, self._result_dir]: 28 | os.makedirs(dir_, exist_ok=True) 29 | 30 | # Input 31 | self._img_size = args.img_size 32 | train_a_dir = os.path.join(self._dataset_dir, self._dataset_name, 'trainA') 33 | train_b_dir = os.path.join(self._dataset_dir, self._dataset_name, 'trainB') 34 | self._train_a_dataset = get_img_paths(train_a_dir, self._dataset_struct) 35 | self._train_b_dataset = get_img_paths(train_b_dir, self._dataset_struct) 36 | self._dataset_num = max(len(self._train_a_dataset), len(self._train_b_dataset)) 37 | 38 | # Generator 39 | self._gen = None 40 | 41 | print() 42 | print('##### Information #####') 43 | print('Number of trainA/B images: {}/{}'.format(len(self._train_a_dataset), len(self._train_b_dataset)) ) 44 | print() 45 | 46 | def build_model_train(self): 47 | """To be implemented by the subclass.""" 48 | pass 49 | 50 | def build_model_test(self): 51 | """Build the graph for testing.""" 52 | self._test_domain_a = tf.placeholder(tf.float32, [1, self._img_size, self._img_size, 3], 53 | name='test_domain_A') 54 | test_fake_b = self._gen.translate(self._test_domain_a, scope='gen_a2b') 55 | self._test_fake_b = tf.identity(test_fake_b, 'test_fake_B') 56 | 57 | def train(self): 58 | """To be implemented by the subclass.""" 59 | pass 60 | 61 | def test(self): 62 | """Translate test images.""" 63 | tes_a_dir = os.path.join(self._dataset_dir, self._test_dataset_name, 'testA') 64 | test_a_files = get_img_paths(tes_a_dir, self._dataset_struct) 65 | 66 | if self._saver is None: 67 | self._saver = tf.train.Saver() 68 | could_load, checkpoint_counter = self._load_ckpt(self._checkpoint_dir) 69 | if could_load: 70 | print(" [*] Load SUCCESS") 71 | else : 72 | print(" [!] Load failed...") 73 | raise RuntimeError("Failed to load the checkpoint") 74 | 75 | dataset_tag = '' if self._test_dataset_name == self._dataset_name else self._test_dataset_name + '_' 76 | result_dir = os.path.join(self._result_dir, dataset_tag + str(checkpoint_counter)) 77 | os.makedirs(result_dir, exist_ok=True) 78 | 79 | st = time.time() 80 | for sample_file in test_a_files: # A -> B 81 | print('Processing source image: ' + sample_file) 82 | src = load_test_data(sample_file, size=self._img_size) 83 | fake_img = self._sess.run(self._test_fake_b, feed_dict={self._test_domain_a: src}) 84 | 85 | if self._dataset_struct == 'plain': 86 | dst_dir = result_dir 87 | elif self._dataset_struct == 'tree': 88 | src_dir = os.path.dirname(sample_file) 89 | dirname_level1 = os.path.basename(src_dir) 90 | dirname_level2 = os.path.basename(os.path.dirname(src_dir)) 91 | dst_dir = os.path.join(result_dir, dirname_level2, dirname_level1) 92 | os.makedirs(dst_dir, exist_ok=True) 93 | else: 94 | raise RuntimeError('Invalid dataset_type!') 95 | image_path = os.path.join(dst_dir, os.path.basename(sample_file)) 96 | save_images(fake_img[[0],:], [1, 1], image_path) 97 | 98 | time_cost = time.time() - st 99 | time_cost_per_img_ms = round(time_cost * 1000 / len(test_a_files)) 100 | print('Time cost per image: {} ms'.format(time_cost_per_img_ms)) 101 | 102 | def freeze_graph(self): 103 | """Generate the .pb model.""" 104 | self._saver = tf.train.Saver() 105 | could_load, checkpoint_counter = self._load_ckpt(self._checkpoint_dir) 106 | 107 | if could_load: 108 | print(" [*] Load SUCCESS") 109 | else: 110 | print(" [!] Load failed...") 111 | raise RuntimeError("Failed to load the checkpoint") 112 | 113 | output_dir = os.path.join(self._checkpoint_dir, 'pb') 114 | os.makedirs(output_dir, exist_ok=True) 115 | time_stamp = datetime.now().strftime('%Y%m%d-%H%M%S') 116 | output_file = os.path.join(output_dir, 'output_graph_' + time_stamp + '.pb') 117 | 118 | frozen_graph_def = tf.graph_util.convert_variables_to_constants( 119 | sess=self._sess, 120 | input_graph_def=self._sess.graph_def, 121 | output_node_names=['test_fake_B']) 122 | 123 | # Save the frozen graph 124 | with open(output_file, 'wb') as f: 125 | f.write(frozen_graph_def.SerializeToString()) 126 | 127 | def _save_ckpt(self, checkpoint_dir, step): 128 | os.makedirs(checkpoint_dir, exist_ok=True) 129 | self._saver.save(self._sess, os.path.join(checkpoint_dir, self._model_name + '.model'), global_step=step) 130 | 131 | def _load_ckpt(self, checkpoint_dir): 132 | print(" [*] Reading checkpoints...") 133 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 134 | if ckpt and ckpt.model_checkpoint_path: 135 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 136 | self._saver.restore(self._sess, os.path.join(checkpoint_dir, ckpt_name)) 137 | counter = int(ckpt_name.split('-')[-1]) 138 | print(" [*] Success to read {}".format(ckpt_name)) 139 | return True, counter 140 | else: 141 | print(" [*] Failed to find a checkpoint") 142 | return False, 0 143 | 144 | -------------------------------------------------------------------------------- /gan/spatchgan.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import tensorflow as tf 4 | import numpy as np 5 | from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch 6 | from gan.gan import GAN 7 | from utils import summary_by_keywords, batch_resize, save_images 8 | from ops import l1_loss, adv_loss, regularization_loss 9 | from imagedata import ImageData 10 | from discriminator.discriminator_spatch import DiscriminatorSPatch 11 | from discriminator.discriminator_patch import DiscriminatorPatch 12 | from generator.generator_basic_res import GeneratorBasicRes 13 | 14 | 15 | class SPatchGAN(GAN): 16 | """SPatchGAN framework.""" 17 | def __init__(self, model_name, sess, args): 18 | super().__init__(model_name, sess, args) 19 | 20 | # Training 21 | self._n_steps = args.n_steps 22 | self._n_iters_per_step = args.n_iters_per_step 23 | self._batch_size = args.batch_size 24 | self._img_save_freq = args.img_save_freq 25 | self._ckpt_save_freq = args.ckpt_save_freq 26 | self._summary_freq = args.summary_freq 27 | self._decay_step = args.decay_step 28 | self._init_lr = args.lr 29 | self._adv_weight = args.adv_weight 30 | self._reg_weight = args.reg_weight 31 | self._cyc_weight = args.cyc_weight 32 | self._id_weight = args.id_weight 33 | self._gan_type = args.gan_type 34 | 35 | # Input 36 | self._augment_type = args.augment_type 37 | 38 | # Discriminator 39 | self._dis = self._create_dis(args) 40 | 41 | # Generator 42 | self._gen = self._create_gen(args) 43 | self._gen_bw = self._create_gen_bw(args) 44 | self._resolution_bw = self._img_size // args.resize_factor_gen_bw 45 | 46 | @staticmethod 47 | def _create_dis(args): 48 | if args.dis_type == 'spatch': 49 | stats = [] 50 | if args.mean_dis: 51 | stats.append('mean') 52 | if args.max_dis: 53 | stats.append('max') 54 | if args.mean_dis: 55 | stats.append('stddev') 56 | return DiscriminatorSPatch(ch=args.ch_dis, 57 | n_downsample_init=args.n_downsample_init_dis, 58 | n_scales=args.n_scales_dis, 59 | n_adapt=args.n_adapt_dis, 60 | n_mix=args.n_mix_dis, 61 | logits_type=args.logits_type_dis, 62 | stats=stats, 63 | sn=args.sn_dis) 64 | elif args.dis_type == 'patch': 65 | return DiscriminatorPatch(ch=args.ch_dis, 66 | n_downsample_init=args.n_downsample_init_dis, 67 | n_scales=args.n_scales_dis, 68 | sn=args.sn_dis) 69 | 70 | else: 71 | raise ValueError('Invalid dis_type!') 72 | 73 | @staticmethod 74 | def _create_gen(args): 75 | if args.gen_type == 'basic_res': 76 | return GeneratorBasicRes(ch=args.ch_gen, 77 | n_updownsample=args.n_updownsample_gen, 78 | n_res=args.n_res_gen, 79 | n_enhanced_upsample=args.n_enhanced_upsample_gen, 80 | n_mix_upsample=args.n_mix_upsample_gen, 81 | block_type=args.block_type_gen, 82 | upsample_type=args.upsample_type_gen) 83 | else: 84 | raise ValueError('Invalid gen_type!') 85 | 86 | @staticmethod 87 | def _create_gen_bw(args): 88 | if args.gen_type == 'basic_res': 89 | return GeneratorBasicRes(ch=args.ch_gen_bw, 90 | n_updownsample=args.n_updownsample_gen_bw, 91 | n_res=args.n_res_gen_bw, 92 | n_enhanced_upsample=args.n_enhanced_upsample_gen, 93 | n_mix_upsample=args.n_mix_upsample_gen, 94 | block_type=args.block_type_gen, 95 | upsample_type=args.upsample_type_gen) 96 | else: 97 | raise ValueError('Invalid gen_type!') 98 | 99 | def _fetch_data(self, dataset): 100 | gpu_device = '/gpu:0' 101 | imgdata = ImageData(self._img_size, self._augment_type) 102 | train_dataset = tf.data.Dataset.from_tensor_slices(dataset) 103 | train_dataset = train_dataset.apply(shuffle_and_repeat(self._dataset_num)) \ 104 | .apply(map_and_batch(imgdata.image_processing, self._batch_size, 105 | num_parallel_batches=16, drop_remainder=True)) \ 106 | .apply(prefetch_to_device(gpu_device, None)) 107 | train_iterator = train_dataset.make_one_shot_iterator() 108 | return train_iterator.get_next() 109 | 110 | def build_model_train(self): 111 | """Build the graph for training.""" 112 | self._lr = tf.placeholder(tf.float32, name='learning_rate') 113 | 114 | # Input images 115 | self._domain_a = self._fetch_data(self._train_a_dataset) 116 | self._domain_b = self._fetch_data(self._train_b_dataset) 117 | 118 | # Forward generation 119 | self._x_ab = self._gen.translate(self._domain_a, scope='gen_a2b') 120 | 121 | # Backward generation 122 | if self._cyc_weight > 0.0: 123 | self._a_lowres = tf.image.resize_images(self._domain_a, [self._resolution_bw, self._resolution_bw]) 124 | self._ab_lowres = tf.image.resize_images(self._x_ab, [self._resolution_bw, self._resolution_bw]) 125 | self._aba_lowres = self._gen_bw.translate(self._ab_lowres, scope='gen_b2a') 126 | else: 127 | self._aba_lowres = tf.zeros([self._batch_size, self._resolution_bw, self._resolution_bw, 3]) 128 | 129 | # Identity mapping 130 | self._x_bb = self._gen.translate(self._domain_b, reuse=True, scope='gen_a2b') \ 131 | if self._id_weight > 0.0 else tf.zeros_like(self._domain_b) 132 | 133 | # Discriminator 134 | b_logits = self._dis.discriminate(self._domain_b, scope='dis_b') 135 | ab_logits = self._dis.discriminate(self._x_ab, reuse=True, scope='dis_b') 136 | 137 | # Adversarial loss for G 138 | adv_loss_gen_ab = self._adv_weight * adv_loss(ab_logits, self._gan_type, target='real') 139 | 140 | # Adversarial loss for D 141 | adv_loss_dis_b = self._adv_weight * adv_loss(b_logits, self._gan_type, target='real') 142 | adv_loss_dis_b += self._adv_weight * adv_loss(ab_logits, self._gan_type, target='fake') 143 | 144 | # Identity loss 145 | id_loss_bb = self._id_weight * l1_loss(self._domain_b, self._x_bb) \ 146 | if self._id_weight > 0.0 else 0.0 147 | cyc_loss_aba = self._cyc_weight * l1_loss(self._a_lowres, self._aba_lowres) \ 148 | if self._cyc_weight > 0.0 else 0.0 149 | 150 | # Weight decay 151 | reg_loss_gen = self._reg_weight * regularization_loss('gen_') 152 | reg_loss_dis = self._reg_weight * regularization_loss('dis_') 153 | 154 | # Overall loss 155 | self._gen_loss_all = adv_loss_gen_ab \ 156 | + id_loss_bb \ 157 | + cyc_loss_aba \ 158 | + reg_loss_gen 159 | 160 | self._dis_loss_all = adv_loss_dis_b \ 161 | + reg_loss_dis 162 | 163 | 164 | """ Training """ 165 | t_vars = tf.trainable_variables() 166 | vars_gen = [var for var in t_vars if 'gen_' in var.name] 167 | vars_dis = [var for var in t_vars if 'dis_' in var.name] 168 | 169 | self._optim_gen = tf.train.AdamOptimizer(self._lr, beta1=0.5, beta2=0.999)\ 170 | .minimize(self._gen_loss_all, var_list=vars_gen) 171 | self._optim_dis = tf.train.AdamOptimizer(self._lr, beta1=0.5, beta2=0.999)\ 172 | .minimize(self._dis_loss_all, var_list=vars_dis) 173 | 174 | """" Summary """ 175 | # Record the IN scaling factor for each residual block. 176 | summary_scale_res = summary_by_keywords(['gamma', 'resblock', 'res2'], node_type='variable') 177 | summary_logits_gen = summary_by_keywords('pre_tanh', node_type='tensor') 178 | summary_logits_dis = summary_by_keywords(['D_logits_'], node_type='tensor') 179 | 180 | summary_list_gen = [] 181 | summary_list_gen.append(tf.summary.scalar("gen_loss_all", self._gen_loss_all)) 182 | summary_list_gen.append(tf.summary.scalar("adv_loss_gen_ab", adv_loss_gen_ab)) 183 | summary_list_gen.append(tf.summary.scalar("id_loss_bb", id_loss_bb)) 184 | summary_list_gen.append(tf.summary.scalar("cyc_loss_aba", cyc_loss_aba)) 185 | summary_list_gen.append(tf.summary.scalar("reg_loss_gen", reg_loss_gen)) 186 | summary_list_gen.extend(summary_scale_res) 187 | summary_list_gen.extend(summary_logits_gen) 188 | self._summary_gen = tf.summary.merge(summary_list_gen) 189 | 190 | summary_list_dis = [] 191 | summary_list_dis.append(tf.summary.scalar("dis_loss_all", self._dis_loss_all)) 192 | summary_list_dis.append(tf.summary.scalar("adv_loss_dis_b", adv_loss_dis_b)) 193 | summary_list_dis.append(tf.summary.scalar("reg_loss_dis", reg_loss_dis)) 194 | summary_list_dis.extend(summary_logits_dis) 195 | self._summary_dis = tf.summary.merge(summary_list_dis) 196 | 197 | def train(self): 198 | """Run training iterations.""" 199 | tf.global_variables_initializer().run() 200 | self._saver = tf.train.Saver() 201 | writer = tf.summary.FileWriter(self._log_dir, self._sess.graph) 202 | 203 | # restore the checkpoint if it exits 204 | could_load, checkpoint_counter = self._load_ckpt(self._checkpoint_dir) 205 | if could_load: 206 | counter = checkpoint_counter + 1 207 | start_step = counter // self._n_iters_per_step 208 | start_batch_id = counter - start_step * self._n_iters_per_step 209 | print(" [*] Load SUCCESS") 210 | else: 211 | counter = 0 212 | start_step = 0 213 | start_batch_id = 0 214 | print(" [!] Load failed...") 215 | 216 | # Looping over steps 217 | start_time = time.time() 218 | for step in range(start_step, self._n_steps): 219 | lr = self._init_lr if step < self._decay_step else \ 220 | self._init_lr * (self._n_steps - step) / (self._n_steps - self._decay_step) 221 | for batch_id in range(start_batch_id, self._n_iters_per_step): 222 | train_feed_dict = { 223 | self._lr: lr 224 | } 225 | 226 | # Update D 227 | loss_dis, summary_str_dis, _ = self._sess.run([self._dis_loss_all, self._summary_dis, self._optim_dis], 228 | feed_dict=train_feed_dict) 229 | 230 | # Update G 231 | batch_a_images, batch_b_images, fake_b, identity_b, aba_lowres, loss_gen, summary_str_gen, _ = \ 232 | self._sess.run([self._domain_a, self._domain_b, 233 | self._x_ab, self._x_bb, self._aba_lowres, 234 | self._gen_loss_all, self._summary_gen, self._optim_gen], 235 | feed_dict=train_feed_dict) 236 | 237 | # display training status 238 | print("Step: [%2d] [%5d/%5d] time: %4.4f D_loss: %.8f, G_loss: %.8f" 239 | % (step, batch_id, self._n_iters_per_step, time.time() - start_time, loss_dis, loss_gen)) 240 | 241 | if (counter+1) % self._summary_freq == 0: 242 | writer.add_summary(summary_str_dis, counter) 243 | writer.add_summary(summary_str_gen, counter) 244 | 245 | if (counter+1) % self._img_save_freq == 0: 246 | aba_lowres_resize = batch_resize(aba_lowres, self._img_size) 247 | merged = np.vstack([batch_a_images, fake_b, aba_lowres_resize, batch_b_images, identity_b]) 248 | save_images(merged, [5, self._batch_size], 249 | os.path.join(self._sample_dir, 'sample_{:03d}_{:05d}.jpg'.format(step, batch_id))) 250 | 251 | if (counter+1) % self._ckpt_save_freq == 0: 252 | self._save_ckpt(self._checkpoint_dir, counter) 253 | 254 | counter += 1 255 | 256 | # After each step, start_batch_id is set to zero. 257 | # Non-zero value is only for the first step after loading a pre-trained model. 258 | start_batch_id = 0 259 | 260 | # Save the final model. 261 | self._save_ckpt(self._checkpoint_dir, counter-1) 262 | -------------------------------------------------------------------------------- /generator/generator_basic_res.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from ops import conv, instance_norm, layer_norm, relu, tanh, resblock_v1, nearest_up, bilinear_up 3 | 4 | 5 | class GeneratorBasicRes: 6 | """Basic residual block based generator for SPatchGAN.""" 7 | def __init__(self, ch, n_updownsample, n_res, n_enhanced_upsample, n_mix_upsample, block_type, upsample_type): 8 | self._ch = ch 9 | self._n_updownsample = n_updownsample 10 | self._n_res = n_res 11 | self._n_enhanced_upsample = n_enhanced_upsample 12 | self._n_mix_upsample = n_mix_upsample 13 | self._block_type = block_type 14 | self._upsample_type = upsample_type 15 | 16 | def translate(self, x, reuse=False, scope='gen'): 17 | """Build the generator graph.""" 18 | with tf.variable_scope(scope, reuse=reuse) : 19 | channel = self._ch 20 | 21 | # Downsampling 22 | for i in range(self._n_updownsample): 23 | with tf.variable_scope('down_{}'.format(i)): 24 | # (256, 256, 3) -> (128, 128, 128) -> (64, 64, 256) -> (32, 32, 512) 25 | x = conv(x, channel, kernel=3, stride=2, pad=1) 26 | x = instance_norm(x) 27 | if i < self._n_updownsample - 1: 28 | x = relu(x) 29 | channel *= 2 30 | 31 | if self._n_updownsample == 0: 32 | with tf.variable_scope('mix_init'): 33 | x = conv(x, channel, kernel=3, pad=1) 34 | x = instance_norm(x) 35 | 36 | for i in range(self._n_res): 37 | with tf.variable_scope('res_{}'.format(i)): 38 | x = self._conv_block(x, block_type=self._block_type, channel=channel) 39 | 40 | for i in range(self._n_updownsample): 41 | with tf.variable_scope('up_{}'.format(i)): 42 | # (32, 32, 512) -> (64, 64, 512) -> (128, 128, 256) -> (256, 256, 128) 43 | x = self._upsample(x, method=self._upsample_type) 44 | channel = channel if i < self._n_enhanced_upsample else channel // 2 45 | n_mix_upsample = self._n_mix_upsample if i < self._n_enhanced_upsample else 1 46 | for j in range(n_mix_upsample): 47 | with tf.variable_scope('mix_{}'.format(j)): 48 | x = conv(x, channel, kernel=3, stride=1, pad=1) 49 | x = layer_norm(x) 50 | x = relu(x) 51 | 52 | if self._n_updownsample == 0: 53 | with tf.variable_scope('mix_end'): 54 | x = conv(x, channel, kernel=3, stride=1, pad=1) 55 | x = layer_norm(x) 56 | x = relu(x) 57 | 58 | with tf.variable_scope('logits'): 59 | # (256, 256, 128) -> (256, 256, 3) 60 | x = conv(x, channels=3, kernel=3, pad=1, scope='G_logit') 61 | x = tf.identity(x, 'pre_tanh') 62 | x = tanh(x) 63 | 64 | return x 65 | 66 | @staticmethod 67 | def _conv_block(x, block_type, channel, scope='resblock_0'): 68 | if block_type == 'v1': 69 | x = resblock_v1(x, channel=channel, scope=scope) 70 | else: 71 | raise ValueError('Wrong block_type!') 72 | return x 73 | 74 | @staticmethod 75 | def _upsample(x, method: str = 'nearest'): 76 | if method == 'nearest': 77 | x = nearest_up(x) 78 | elif method == 'bilinear': 79 | x = bilinear_up(x) 80 | else: 81 | raise ValueError('Invalid upsampling method!') 82 | return x 83 | -------------------------------------------------------------------------------- /imagedata.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class ImageData: 5 | """Input image processing and augmentation.""" 6 | def __init__(self, load_size, augment_type): 7 | self._load_size = load_size 8 | self._augment_type = augment_type 9 | 10 | def image_processing(self, filename): 11 | x = tf.read_file(filename) 12 | x_decode = tf.image.decode_jpeg(x, channels=3) 13 | 14 | if self._augment_type is None: 15 | img = tf.image.resize_images(x_decode, [self._load_size, self._load_size]) 16 | elif self._augment_type == 'pad_crop': 17 | img = _augmentation_pad_crop(x_decode, self._load_size) 18 | elif self._augment_type == 'resize_crop': 19 | img = _augmentation_resize_crop(x_decode, self._load_size) 20 | else: 21 | raise ValueError('Invalid augment_type!') 22 | 23 | img = tf.cast(img, tf.float32) / 127.5 - 1 24 | return img 25 | 26 | 27 | def _augmentation_pad_crop(image, size_out): 28 | image = tf.image.resize(image, [size_out, size_out]) 29 | image = tf.cast(image, tf.uint8) 30 | # The shape info will be lost after random jpeg quality. 31 | image = tf.image.random_jpeg_quality(image, min_jpeg_quality=50, max_jpeg_quality=100) 32 | image = tf.reshape(image, [size_out, size_out, 3]) 33 | 34 | pad_size = round(size_out * 0.05) 35 | # White padding 36 | image = tf.pad(image, paddings=[[pad_size, pad_size], [pad_size, pad_size], [0, 0]], constant_values=255) 37 | image = tf.random_crop(image, [size_out, size_out, 3]) 38 | image = _augmentation_general(image) 39 | return image 40 | 41 | 42 | def _augmentation_resize_crop(image, size_out): 43 | aug_rand = tf.random_uniform([]) 44 | image = tf.cond(aug_rand < 0.5, 45 | lambda: _ugatit_resize_crop(image, size_out), 46 | lambda: tf.image.resize(image, [size_out, size_out])) 47 | image = tf.cast(image, tf.uint8) 48 | # The shape info will be lost after random jpeg quality. 49 | image = tf.image.random_jpeg_quality(image, min_jpeg_quality=50, max_jpeg_quality=100) 50 | image = tf.reshape(image, [size_out, size_out, 3]) 51 | image = _augmentation_general(image) 52 | return image 53 | 54 | 55 | def _augmentation_general(image): 56 | # Operations that preserve the shape and are safe for most images. 57 | # These color changes should be done after padding to apply the changes on the paddings. 58 | image = tf.image.random_flip_left_right(image) 59 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 60 | image = tf.image.random_saturation(image, lower=0.8, upper=1.2) 61 | image = tf.image.random_hue(image, max_delta=0.02) 62 | image = tf.image.random_contrast(image, lower=0.8, upper=1.2) 63 | return image 64 | 65 | 66 | def _ugatit_resize_crop(image, size_out): 67 | augment_size = size_out 68 | if size_out == 256: 69 | augment_size += 30 70 | elif size_out == 512: 71 | augment_size += 60 72 | else: 73 | # Generalize the augmentation strategy in U-GAT-IT 74 | augment_size += round(augment_size * 0.1) 75 | image = tf.image.resize_images(image, [augment_size, augment_size]) 76 | image = tf.random_crop(image, [size_out, size_out, 3]) 77 | return image 78 | -------------------------------------------------------------------------------- /images/SPatchGAN_D_20210317_3x.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NetEase-GameAI/SPatchGAN/64ac3547545dcdafee9cf7ccf1956410a2793ad5/images/SPatchGAN_D_20210317_3x.jpg -------------------------------------------------------------------------------- /images/s2a_cmp_github_downsized.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NetEase-GameAI/SPatchGAN/64ac3547545dcdafee9cf7ccf1956410a2793ad5/images/s2a_cmp_github_downsized.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | from configs import parse_args 4 | from utils import show_all_variables 5 | from gan.spatchgan import SPatchGAN 6 | 7 | 8 | def main(): 9 | """General entry point for running GANs.""" 10 | args = parse_args() 11 | os.makedirs(args.output_dir, exist_ok=True) 12 | 13 | gpu_options = tf.GPUOptions(allow_growth=True) 14 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)) as sess: 15 | 16 | if args.network == 'spatchgan': 17 | gan = SPatchGAN('SPatchGAN', sess, args) 18 | else: 19 | raise RuntimeError('Invalid network!') 20 | 21 | if args.phase == 'train': 22 | gan.build_model_train() 23 | show_all_variables() 24 | gan.train() 25 | print(" [*] Training finished!") 26 | elif args.phase == 'test': 27 | gan.build_model_test() 28 | show_all_variables() 29 | gan.test() 30 | print(" [*] Test finished!") 31 | elif args.phase == 'freeze_graph': 32 | gan.build_model_test() 33 | show_all_variables() 34 | gan.freeze_graph() 35 | print(" [*] Graph frozen!") 36 | else: 37 | raise RuntimeError('Invalid phase!') 38 | 39 | 40 | if __name__ == '__main__': 41 | main() 42 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib as tf_contrib 3 | 4 | weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02) 5 | weight_regularizer = tf_contrib.layers.l2_regularizer(scale=0.0001) 6 | 7 | 8 | def conv(x, channels, kernel=1, stride=1, pad=0, pad_type: str = 'zero', use_bias=True, 9 | sn: str = None, scope: str = 'conv_0'): 10 | """Convolution layer.""" 11 | with tf.variable_scope(scope): 12 | if pad > 0 : 13 | if (kernel - stride) % 2 == 0: 14 | pad_top = pad 15 | pad_bottom = pad 16 | pad_left = pad 17 | pad_right = pad 18 | 19 | else: 20 | # For kernel = 3, stride = 2, pad=1: 21 | # pad_top = pad_left = 1 22 | # pad_bottom = pad_right = 0 23 | # h_out = (h_in + 1 - 3) / 2 + 1 = h_in / 2 24 | pad_top = pad 25 | pad_bottom = kernel - stride - pad_top 26 | pad_left = pad 27 | pad_right = kernel - stride - pad_left 28 | 29 | if pad_type == 'zero': 30 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]) 31 | if pad_type == 'reflect': 32 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT') 33 | 34 | if sn is not None: 35 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init, 36 | regularizer=weight_regularizer) 37 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w, method=sn), 38 | strides=[1, stride, stride, 1], padding='VALID') 39 | if use_bias: 40 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0)) 41 | x = tf.nn.bias_add(x, bias) 42 | 43 | else: 44 | x = tf.layers.conv2d(inputs=x, filters=channels, 45 | kernel_size=kernel, kernel_initializer=weight_init, 46 | kernel_regularizer=weight_regularizer, 47 | strides=stride, use_bias=use_bias) 48 | return x 49 | 50 | 51 | def fully_connected(x, units, use_bias=True, sn: str = None, scope='linear'): 52 | """Fully connected layer.""" 53 | with tf.variable_scope(scope): 54 | x = tf.layers.flatten(x) 55 | shape = x.get_shape().as_list() 56 | channels = shape[-1] 57 | 58 | if sn is not None: 59 | w = tf.get_variable("kernel", [channels, units], tf.float32, 60 | initializer=weight_init, regularizer=weight_regularizer) 61 | if use_bias: 62 | bias = tf.get_variable("bias", [units], 63 | initializer=tf.constant_initializer(0.0)) 64 | 65 | x = tf.matmul(x, spectral_norm(w, method=sn)) + bias 66 | else: 67 | x = tf.matmul(x, spectral_norm(w, method=sn)) 68 | 69 | else: 70 | x = tf.layers.dense(x, units=units, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer, use_bias=use_bias) 71 | 72 | return x 73 | 74 | 75 | def instance_norm(x, scope='instance_norm'): 76 | """Instance normalization layer.""" 77 | return tf_contrib.layers.instance_norm(x, 78 | epsilon=1e-05, 79 | center=True, scale=True, 80 | scope=scope) 81 | 82 | 83 | def layer_norm(x, scope='layer_norm'): 84 | """Layer normalization layer.""" 85 | return tf_contrib.layers.layer_norm(x, 86 | center=True, scale=True, 87 | scope=scope) 88 | 89 | 90 | def spectral_norm(w, n_iters=1, method: str = 'fast'): 91 | """Spectral normalization layer.""" 92 | w_shape = w.shape.as_list() 93 | w = tf.reshape(w, [-1, w_shape[-1]]) 94 | 95 | u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False) 96 | 97 | u_hat = tf.stop_gradient(u) if method == 'full' else u 98 | v_hat = None 99 | for i in range(n_iters): 100 | """ 101 | power iteration 102 | Usually iteration = 1 will be enough 103 | """ 104 | v_ = tf.matmul(u_hat, tf.transpose(w)) 105 | v_hat = tf.nn.l2_normalize(v_) 106 | 107 | u_ = tf.matmul(v_hat, w) 108 | u_hat = tf.nn.l2_normalize(u_) 109 | 110 | if method == 'fast': 111 | u_hat = tf.stop_gradient(u_hat) 112 | v_hat = tf.stop_gradient(v_hat) 113 | elif method == 'full': 114 | pass 115 | else: 116 | raise RuntimeError('Invalid sn method!') 117 | 118 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat)) 119 | 120 | with tf.control_dependencies([u.assign(u_hat)]): 121 | w_norm = w / sigma 122 | w_norm = tf.reshape(w_norm, w_shape) 123 | 124 | return w_norm 125 | 126 | 127 | def lrelu(x, alpha=0.2): 128 | """Leaky ReLU.""" 129 | return tf.nn.leaky_relu(x, alpha) 130 | 131 | 132 | def relu(x): 133 | """ReLU.""" 134 | return tf.nn.relu(x) 135 | 136 | 137 | def tanh(x): 138 | """Tanh.""" 139 | return tf.tanh(x) 140 | 141 | 142 | def global_avg_pooling(x): 143 | """Global average pooling for the NHWC data.""" 144 | gap = tf.reduce_mean(x, axis=[1, 2]) 145 | return gap 146 | 147 | 148 | def global_max_pooling(x): 149 | """Global max pooling for the NHWC data.""" 150 | gmp = tf.reduce_max(x, axis=[1, 2]) 151 | return gmp 152 | 153 | 154 | def nearest_up(x, scale_factor=2): 155 | """Nearest neighbor upsampling.""" 156 | _, h, w, _ = x.get_shape().as_list() 157 | new_size = [h * scale_factor, w * scale_factor] 158 | return tf.image.resize_nearest_neighbor(x, size=new_size) 159 | 160 | 161 | def bilinear_up(x, scale_factor=2): 162 | """Bilinear upsampling.""" 163 | _, h, w, _ = x.get_shape().as_list() 164 | new_size = [h * scale_factor, w * scale_factor] 165 | return tf.image.resize_images(x, size=new_size) 166 | 167 | 168 | def resblock_v1(x_init, channel, pad_type: str = 'zero', use_bias=True, is_res=True, scope='resblock_0'): 169 | """Residual block.""" 170 | with tf.variable_scope(scope): 171 | with tf.variable_scope('res1'): 172 | x = conv(x_init, channel, kernel=3, pad=1, pad_type=pad_type, use_bias=use_bias) 173 | x = instance_norm(x) 174 | x = relu(x) 175 | 176 | with tf.variable_scope('res2'): 177 | x = conv(x, channel, kernel=3, pad=1, pad_type=pad_type, use_bias=use_bias) 178 | x = instance_norm(x) 179 | 180 | x_ret = x + x_init if is_res else x 181 | return x_ret 182 | 183 | 184 | def l1_loss(x, y): 185 | """Calculate the L1 loss.""" 186 | loss = tf.reduce_mean(tf.abs(x - y)) 187 | return loss 188 | 189 | 190 | def regularization_loss(scope_name: str): 191 | """Collect the regularization loss.""" 192 | collection_regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) 193 | 194 | loss = [] 195 | for item in collection_regularization : 196 | if scope_name in item.name : 197 | loss.append(item) 198 | 199 | return tf.reduce_sum(loss) 200 | 201 | 202 | def adv_loss(x, loss_func : str, target : str): 203 | """Calculate the adversarial loss.""" 204 | loss_list = [] 205 | logits_list = x if isinstance(x, list) else [x] 206 | for i, logits in enumerate(logits_list): 207 | if loss_func == 'lsgan': 208 | if target == 'real': 209 | target_val = 1.0 210 | elif target == 'fake': 211 | target_val = 0.0 212 | else: 213 | raise ValueError('Invalid target {} for adv_loss'.format(target)) 214 | loss = tf.squared_difference(logits, target_val) 215 | else: 216 | raise ValueError('Invalid loss_func {} for adv_loss'.format(loss_func)) 217 | loss = tf.reduce_mean(loss) / len(logits_list) 218 | loss_list.append(loss) 219 | 220 | return sum(loss_list) 221 | -------------------------------------------------------------------------------- /output/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.18.1 2 | opencv-python==3.4.3.18 3 | tensorflow-gpu==1.14.0 4 | gast==0.2.2 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import tensorflow as tf 3 | import numpy as np 4 | from glob import glob 5 | from tensorflow.contrib import slim 6 | 7 | 8 | def show_all_variables(): 9 | """Show all variables in the graph.""" 10 | model_vars = tf.trainable_variables() 11 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 12 | 13 | 14 | def get_img_paths(input_dir: str, dataset_struct: str = 'plain'): 15 | """Get all images in an input directory.""" 16 | exts = ['jpg', 'jpeg', 'png'] 17 | imgs = [] 18 | for ext in exts: 19 | if dataset_struct == 'plain': 20 | pattern = input_dir + '/*.{}'.format(ext) 21 | elif dataset_struct == 'tree': 22 | pattern = input_dir + '/*/*.{}'.format(ext) 23 | else: 24 | raise ValueError('Invalid dataset_struct!') 25 | imgs.extend(glob(pattern)) 26 | return imgs 27 | 28 | 29 | def get_img_paths_auto(input_dir: str): 30 | """Auto detect the directory structure and get all images.""" 31 | dataset = get_img_paths(input_dir) 32 | if len(dataset) == 0: 33 | dataset = get_img_paths(input_dir, dataset_struct='tree') 34 | return dataset 35 | 36 | 37 | def summary_by_keywords(keywords, node_type='tensor'): 38 | """Generate summary for the tf nodes whose names match the keywords.""" 39 | summary_list = [] 40 | if node_type == 'tensor': 41 | all_nodes = [tensor for op in tf.get_default_graph().get_operations() for tensor in op.values()] 42 | elif node_type == 'variable': 43 | all_nodes = tf.trainable_variables() 44 | else: 45 | raise RuntimeError('Invalid target!') 46 | 47 | keyword_list = keywords if isinstance(keywords, list) else [keywords] 48 | 49 | # Include a node if its name contains all keywords. 50 | nodes = [node for node in all_nodes if all(keyword in node.name for keyword in keyword_list)] 51 | 52 | for node in nodes: 53 | summary_list.append(tf.summary.scalar(node.name + "_min", tf.reduce_min(node))) 54 | summary_list.append(tf.summary.scalar(node.name + "_max", tf.reduce_max(node))) 55 | node_mean = tf.reduce_mean(node) 56 | summary_list.append(tf.summary.scalar(node.name + "_mean", node_mean)) 57 | # Calculate the uncorrected standard deviation 58 | node_stddev = tf.sqrt(tf.reduce_mean(tf.square(node - node_mean))) 59 | summary_list.append(tf.summary.scalar(node.name + "_stddev", node_stddev)) 60 | 61 | return summary_list 62 | 63 | 64 | def batch_resize(x, img_size): 65 | """Resize the NHWC Numpy images.""" 66 | x_up = np.zeros((x.shape[0], img_size, img_size, 3)) 67 | for i in range(x.shape[0]): 68 | x_up[i, :, :, :] = cv2.resize(x[i, :, :, :], dsize=(img_size, img_size)) 69 | return x_up 70 | 71 | 72 | def save_images(images, size, image_path): 73 | """Save a grid of images.""" 74 | return _imsave(_inverse_transform(images), size, image_path) 75 | 76 | 77 | def _inverse_transform(images): 78 | return ((images+1.) / 2) * 255.0 79 | 80 | 81 | def _imsave(images, size, path): 82 | images = _merge(images, size) 83 | images = cv2.cvtColor(images.astype('uint8'), cv2.COLOR_RGB2BGR) 84 | 85 | return cv2.imwrite(path, images) 86 | 87 | 88 | def _merge(images, size): 89 | h, w = images.shape[1], images.shape[2] 90 | img = np.zeros((h * size[0], w * size[1], 3)) 91 | for idx, image in enumerate(images): 92 | i = idx % size[1] 93 | j = idx // size[1] 94 | img[h*j:h*(j+1), w*i:w*(i+1), :] = image 95 | 96 | return img 97 | 98 | 99 | def load_test_data(image_path, size=256): 100 | """Load test images with OpenCV.""" 101 | img = cv2.imread(image_path, flags=cv2.IMREAD_COLOR) 102 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 103 | img = cv2.resize(img, dsize=(size, size)) 104 | img = np.expand_dims(img, axis=0) 105 | img = img / 127.5 - 1 106 | 107 | return img 108 | --------------------------------------------------------------------------------