├── .gitignore
├── 3rdparty
└── replicate
│ ├── cog.yaml
│ └── predict.py
├── LICENSE
├── README.md
├── configs.py
├── dataset
└── .gitignore
├── discriminator
├── discriminator_patch.py
└── discriminator_spatch.py
├── docs
└── tips.md
├── frozen_model
├── image_translator.py
└── test_frozen_model.py
├── gan
├── gan.py
└── spatchgan.py
├── generator
└── generator_basic_res.py
├── imagedata.py
├── images
├── SPatchGAN_D_20210317_3x.jpg
└── s2a_cmp_github_downsized.jpg
├── main.py
├── ops.py
├── output
└── .gitignore
├── requirements.txt
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | bash/
3 | __pycache__/
4 |
--------------------------------------------------------------------------------
/3rdparty/replicate/cog.yaml:
--------------------------------------------------------------------------------
1 | build:
2 | gpu: true
3 | python_version: "3.8"
4 | system_packages:
5 | - "libgl1-mesa-glx"
6 | - "libglib2.0-0"
7 | python_packages:
8 | - "tensorflow-gpu==2.6.0"
9 | - "numpy==1.19.4"
10 | - "opencv-python==4.5.3.56"
11 | - "gast==0.4.0"
12 | - "ipython==7.19.0"
13 |
14 | predict: "predict.py:Predictor"
15 |
--------------------------------------------------------------------------------
/3rdparty/replicate/predict.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../..')
3 | import tempfile
4 | from pathlib import Path
5 | import argparse
6 | import cv2
7 | import numpy as np
8 | import tensorflow as tf
9 | from tensorflow.python.platform import gfile
10 | import cog
11 | from frozen_model.image_translator import ImageTranslator
12 |
13 |
14 | class Predictor(cog.Predictor):
15 | def setup(self):
16 | args = parse_arguments()
17 | args.model = 'SPatchGAN_selfie2anime_scale3_cyc20_20210831.pb'
18 | size = 256
19 | self.translator = ImageTranslator(model_path=args.model,
20 | size=size,
21 | n_threads_intra=args.n_threads_intra,
22 | n_threads_inter=args.n_threads_inter)
23 |
24 | @cog.input(
25 | "image",
26 | type=Path,
27 | help="input image, model will generate female anime, support .png, .jpg and .jpeg",
28 | )
29 | def predict(self, image):
30 | out_path = Path(tempfile.mkdtemp()) / "out.png"
31 | img = cv2.imread(str(image), cv2.IMREAD_COLOR)
32 | output = self.translator.translate(img)
33 | cv2.imwrite(str(out_path), output)
34 | return out_path
35 |
36 |
37 | def parse_arguments():
38 | parser = argparse.ArgumentParser()
39 | parser.add_argument('--image', type=str, help='Image file path or directory.')
40 | parser.add_argument('--model', type=str, help='.pb model path.')
41 | parser.add_argument('--n_threads_inter', type=int, default=1, help='Number of inter op threads.')
42 | parser.add_argument('--n_threads_intra', type=int, default=1, help='Number of intra op threads.')
43 | parser.add_argument('--n_iters', type=int, default=1)
44 | return parser.parse_args('')
45 |
46 |
47 | class ImageTranslator:
48 | def __init__(self, model_path, size, n_threads_intra=1, n_threads_inter=1):
49 | config = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=n_threads_intra,
50 | inter_op_parallelism_threads=n_threads_inter)
51 | config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1
52 | config.gpu_options.allow_growth = True
53 | self._size = size
54 | self._graph = tf.Graph()
55 | self._sess = tf.compat.v1.Session(config=config, graph=self._graph)
56 |
57 | self._pb_file_path = model_path
58 | self._restore_from_pb()
59 | self._input_op = self._graph.get_tensor_by_name('test_domain_A:0')
60 | self._output_op = self._graph.get_tensor_by_name('test_fake_B:0')
61 |
62 | def _restore_from_pb(self):
63 | with self._graph.as_default():
64 | with gfile.FastGFile(self._pb_file_path, 'rb') as f:
65 | graph_def = tf.compat.v1.GraphDef()
66 | graph_def.ParseFromString(f.read())
67 | tf.import_graph_def(graph_def, name='')
68 |
69 | def _input_transform(self, img):
70 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
71 | img = cv2.resize(img, dsize=(self._size, self._size))
72 | img = np.expand_dims(img, axis=0)
73 | image_input = img / 127.5 - 1
74 | return image_input
75 |
76 | @staticmethod
77 | def _output_transform(output):
78 | output = ((output + 1.) / 2) * 255.0
79 | image_output = cv2.cvtColor(output.astype('uint8'), cv2.COLOR_RGB2BGR)
80 | return image_output
81 |
82 | def translate(self, image):
83 | """ Translate an image from the source domain to the target domain"""
84 | image_input = self._input_transform(image)
85 | output = self._sess.run(self._output_op, feed_dict={self._input_op: image_input})[0]
86 | return self._output_transform(output)
87 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2021, NetEase Games AI Lab.
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## SPatchGAN: Official TensorFlow Implementation
2 |
3 | ### Paper
4 | "SPatchGAN: A Statistical Feature Based Discriminator for Unsupervised Image-to-Image Translation" (ICCV 2021)
5 |
6 | [
](https://arxiv.org/abs/2103.16219)
7 | [
](https://www.youtube.com/watch?v=JY7eq6q5qpk)
8 | [
](https://drive.google.com/file/d/1r62hIUkQrQGE_6PI_ovS4sdfHPjfSfIK/view?usp=sharing)
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 | ### Web Demos
17 |
by [CJWBW](https://github.com/CJWBW)
18 |
19 |
20 | ### Environment
21 | - CUDA 10.0
22 | - Python 3.6
23 | - ``pip install -r requirements.txt``
24 |
25 | ### Dataset
26 |
27 | - Dataset structure (dataset_struct='plain')
28 | ```
29 | - dataset
30 | -
31 | - trainA
32 | - 1.jpg
33 | - 2.jpg
34 | - ...
35 | - trainB
36 | - 3.jpg
37 | - 4.jpg
38 | - ...
39 | - testA
40 | - 5.jpg
41 | - 6.jpg
42 | - ...
43 | - testB
44 | - 7.jpg
45 | - 8.jpg
46 | - ...
47 | ```
48 |
49 | - Supported extensions: jpg, jpeg, png
50 | - An additional level of subdirectories is also supported by setting dataset_struct to 'tree', e.g.,
51 | ```
52 | - trainA
53 | - subdir1
54 | - 1.jpg
55 | - 2.jpg
56 | - ...
57 | - subdir2
58 | - ...
59 | ```
60 |
61 | - Selfie-to-anime:
62 | - The dataset can be downloaded from [U-GAT-IT](https://github.com/taki0112/UGATIT).
63 |
64 | - Male-to-female and glasses removal:
65 | - The datasets can be downloaded from [Council-GAN](https://github.com/Onr/Council-GAN).
66 | - The images must be center cropped from 218x178 to 178x178 before training or testing.
67 | - For glasses removal, only the male images are used in the experiments in our paper. Note that the dataset from Council-GAN has already been split into two subdirectories, "1" for male and "2" for female.
68 |
69 | ### Training
70 |
71 | - Set the suffix to anything descriptive, e.g., the date.
72 | - Selfie-to-Anime
73 | ```bash
74 | python main.py --dataset selfie2anime --augment_type resize_crop --n_scales_dis 3 --suffix scale3_cyc20_20210831 --phase train
75 | ```
76 |
77 | - Male-to-Female
78 | ```bash
79 | python main.py --dataset male2female --cyc_weight 10 --suffix cyc10_20210831 --phase train
80 | ```
81 |
82 | - Glasses Removal
83 | ```bash
84 | python main.py --dataset glasses-male --cyc_weight 30 --suffix cyc30_20210831 --phase train
85 | ```
86 | - Find the output in ``./output/SPatchGAN__``
87 | - The same command can be used to continue training based on the latest checkpoint.
88 | - For a new task, we recommend to use the default setting as the starting point, and adjust the hyperparameters according to the [tips](docs/tips.md).
89 | - Check [configs.py](configs.py) for all the hyperparameters.
90 |
91 | ### Testing with the latest checkpoint
92 | - Replace ``--phase train`` with ``--phase test``
93 |
94 | ### Save a frozen model (.pb)
95 | - Replace ``--phase train`` with ``--phase freeze_graph``
96 | - Find the saved frozen model in ``./output/SPatchGAN__/checkpoint/pb``
97 |
98 | ### Testing with the frozon model
99 | ```bash
100 | cd frozen_model
101 | python test_frozen_model.py --image --output_dir --model
102 | ```
103 |
104 | ### Pretrained Models
105 | - Download the pretrained models from [huggingface](https://huggingface.co/necoarc/spatchgan-model), and put them in the output directory.
106 | - You can test the checkpoints (in ./checkpoint) or the frozen models (in ./checkpoint/pb). Either way produces the same results.
107 | - The results generated by the pretrained models are slightly different from those in the paper, since we have rerun the training after code refactoring.
108 | - We set ``n_scales_dis`` to 3 for the pretrained selfie2anime model to further improve the performance. It was 4 in the paper. See more details in the [tips](docs/tips.md).
109 | - We also provide the generated results of the last 100 test images (in ./gen, sorted by name, no cherry-picking) for the calibration purpose.
110 |
111 | ### Other Implementations
112 | - We provide a PyTorch implementation of the SPatchGAN discriminator in [spatchgan_discriminator_pytorch.py](https://gist.github.com/NetEase-GameAI/6b93a3fa4c8ab7a59a75eeacca33712f).
113 |
114 | ### Citation
115 | ```
116 | @InProceedings{Shao_2021_ICCV,
117 | author = {Shao, Xuning and Zhang, Weidong},
118 | title = {SPatchGAN: A Statistical Feature Based Discriminator for Unsupervised Image-to-Image Translation},
119 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
120 | month = {October},
121 | year = {2021},
122 | pages = {6546-6555}
123 | }
124 | ```
125 |
126 | ### Acknowledgement
127 | - Our code is partially based on [U-GAT-IT](https://github.com/taki0112/UGATIT).
128 |
--------------------------------------------------------------------------------
/configs.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 |
5 | def _str2bool(x):
6 | return x.lower() == 'true'
7 |
8 |
9 | def _none_or_str(x):
10 | if x == 'None':
11 | return None
12 | else:
13 | return x
14 |
15 |
16 | def parse_args():
17 | """Configurations."""
18 | desc = "TensorFlow implementation of SPatchGAN."
19 | parser = argparse.ArgumentParser(description=desc)
20 |
21 | # General configs
22 | parser.add_argument('--network', type=str, default='spatchgan', help='Network type: [spatchgan].')
23 | parser.add_argument('--phase', type=str, default='train',
24 | help='Phase: [train / test / freeze_graph].')
25 | parser.add_argument('--dataset', type=str, required=True, help='Name of the training dataset.')
26 | parser.add_argument('--test_dataset', type=str, default=None,
27 | help='Name of the testing dataset. Same as the training dataset by default.')
28 | parser.add_argument('--dataset_struct', type=str, default='plain', help='Dataset type: [plain / tree].')
29 | parser.add_argument('--suffix', type=str, default=None, help='suffix for the model name.')
30 |
31 | # Training configs
32 | parser.add_argument('--n_steps', type=int, default=50, help='Number of training steps.')
33 | parser.add_argument('--n_iters_per_step', type=int, default=10000, help='Number of iterations per step')
34 | parser.add_argument('--batch_size', type=int, default=4, help='Batch size.')
35 | parser.add_argument('--img_save_freq', type=int, default=1000, help='Image saving frequency in iteration.')
36 | parser.add_argument('--ckpt_save_freq', type=int, default=1000, help='Checkpoint saving frequency in iteration.')
37 | parser.add_argument('--summary_freq', type=int, default=100, help='TensorFlow summary frequency.')
38 | parser.add_argument('--decay_step', type=int, default=10, help='Starting point for learning rate decay.')
39 | parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate.')
40 | parser.add_argument('--adv_weight', type=float, default=4.0, help='Adversarial loss weight.')
41 | parser.add_argument('--reg_weight', type=float, default=1.0, help='Regularization weight.')
42 | parser.add_argument('--cyc_weight', type=float, default=20.0, help='Weak cycle loss weight.')
43 | parser.add_argument('--id_weight', type=float, default=10.0, help='Identity loss weight.')
44 | parser.add_argument('--gan_type', type=str, default='lsgan', help='GAN loss type: [lsgan].')
45 |
46 | # Input configs
47 | parser.add_argument('--img_size', type=int, default=256, help='The size of input images.')
48 | parser.add_argument('--augment_type', type=_none_or_str, default='pad_crop',
49 | help='Augmentation method: [pad_crop / resize_crop / None].')
50 |
51 | # Discriminator configs
52 | parser.add_argument('--dis_type', type=str, default='spatch', help='D type: [spatch / patch].')
53 | parser.add_argument('--logits_type_dis', type=str, default='stats', help='D logits calculation method: [stats].')
54 | parser.add_argument('--ch_dis', type=int, default=256, help='Base channel number of D.')
55 | parser.add_argument('--n_downsample_init_dis', type=int, default=2,
56 | help='Number of downsampling layers in the initial feature extraction block.')
57 | parser.add_argument('--n_scales_dis', type=int, default=4, help='Number of scales in D.')
58 | parser.add_argument('--sn_dis', type=_none_or_str, default='fast', help='Spectral norm type: [fast / full / None]')
59 | parser.add_argument('--n_adapt_dis', type=int, default=2, help='Number of layers in each adaptation block.')
60 | parser.add_argument('--n_mix_dis', type=int, default=2, help='Number of mixing layers in each MLP.')
61 | parser.add_argument('--mean_dis', type=_str2bool, default=True, help='Use the gap output in D.')
62 | parser.add_argument('--max_dis', type=_str2bool, default=True, help='Use the gmp output in D.')
63 | parser.add_argument('--stddev_dis', type=_str2bool, default=True, help='Use the stddev output in D.')
64 |
65 | # Generator configs
66 | parser.add_argument('--gen_type', type=str, default='basic_res', help='G type: [basic_res].')
67 | parser.add_argument('--block_type_gen', type=str, default='v1', help='G residual block type: [v1].')
68 | parser.add_argument('--ch_gen', type=int, default=128, help='Base channel number of forward G.')
69 | parser.add_argument('--ch_gen_bw', type=int, default=512, help='Base channel number of backward G.')
70 | parser.add_argument('--upsample_type_gen', type=str, default='nearest',
71 | help='Upsampling method: [nearest / bilinear].')
72 | parser.add_argument('--n_updownsample_gen', type=int, default=3,
73 | help='Number of up/downsampling layers in forward G.')
74 | parser.add_argument('--n_updownsample_gen_bw', type=int, default=0,
75 | help='Number of up/downsampling layers in backward G.')
76 | parser.add_argument('--n_res_gen', type=int, default=8, help='Number of residual blocks in forward G.')
77 | parser.add_argument('--n_res_gen_bw', type=int, default=8, help='Number of residual blocks in backward G.')
78 | parser.add_argument('--n_enhanced_upsample_gen', type=int, default=1,
79 | help='Number of enhanced upsampling blocks that include multiple mixing layers.')
80 | parser.add_argument('--n_mix_upsample_gen', type=int, default=2,
81 | help='Number of mixing layers in an enhanced upsampling block.')
82 | parser.add_argument('--resize_factor_gen_bw', type=int, default=8,
83 | help='The resizing factor of input images for backward G.')
84 |
85 | # Directory names
86 | parser.add_argument('--dataset_dir', type=str,
87 | default=os.path.join(os.path.dirname(__file__), 'dataset'),
88 | help='Directory for all datasets')
89 | parser.add_argument('--output_dir', type=str,
90 | default=os.path.join(os.path.dirname(__file__), 'output'),
91 | help='Directory for all outputs')
92 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
93 | help='Directory for the checkpoints')
94 | parser.add_argument('--result_dir', type=str, default='gen',
95 | help='Directory for the generated images.')
96 | parser.add_argument('--log_dir', type=str, default='logs',
97 | help='Directory for the training logs.')
98 | parser.add_argument('--sample_dir', type=str, default='samples',
99 | help='Directory for the training sample images.')
100 |
101 | return _check_args(parser.parse_args())
102 |
103 |
104 | def _check_args(args):
105 | if args is None:
106 | raise RuntimeError('Invalid arguments!')
107 | return args
108 |
--------------------------------------------------------------------------------
/dataset/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file
4 | !.gitignore
--------------------------------------------------------------------------------
/discriminator/discriminator_patch.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from ops import conv, lrelu
3 |
4 |
5 | class DiscriminatorPatch:
6 | """Multiscale PatchGan discriminator for MUNIT / Council-GAN / ACL-GAN."""
7 | def __init__(self, ch, n_downsample_init, n_scales, sn):
8 | self._ch = ch # 64 in MUNIT
9 | self._n_downsample_init = n_downsample_init # 4 in MUNIT
10 | self._n_scales = n_scales # 3 in MUNIT
11 | self._sn = sn
12 |
13 | def discriminate(self, x, reuse=False, scope='dis'):
14 | """Calculate the patch based logits."""
15 | with tf.variable_scope(scope, reuse=reuse):
16 | logits = []
17 | for i in range(self._n_scales):
18 | logits_patch = self._discriminator_per_scale(x, reuse=reuse, scope='scale_{}'.format(i))
19 | logits.append(logits_patch)
20 | x = tf.layers.average_pooling2d(x, pool_size=3, strides=2, padding='SAME')
21 | return logits
22 |
23 | def _discriminator_per_scale(self, x, reuse=False, scope='scale_0'):
24 | with tf.variable_scope(scope, reuse=reuse):
25 | channel = self._ch
26 | for i in range(self._n_downsample_init):
27 | with tf.variable_scope('down_{}'.format(i)):
28 | x = conv(x, channel, kernel=4, stride=2, pad=1, sn=self._sn)
29 | x = lrelu(x, 0.2)
30 | channel *= 2
31 | x = conv(x, channels=1, kernel=1, stride=1, sn=self._sn, scope='logits')
32 | x = tf.identity(x, 'D_logits_patch')
33 | return x
34 |
--------------------------------------------------------------------------------
/discriminator/discriminator_spatch.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from ops import conv, lrelu, global_avg_pooling, global_max_pooling, fully_connected
3 |
4 |
5 | class DiscriminatorSPatch:
6 | """SPatchGAN discriminator."""
7 | def __init__(self, ch, n_downsample_init, n_scales, n_adapt, n_mix,
8 | logits_type: str, stats: list, sn):
9 | self._ch = ch
10 | self._n_downsample_init = n_downsample_init
11 | self._n_scales = n_scales
12 | self._n_adapt = n_adapt
13 | self._n_mix = n_mix
14 | self._logits_type = logits_type
15 | self._stats = stats
16 | self._sn = sn
17 |
18 | def discriminate(self, x, reuse=False, scope='dis'):
19 | """Calculate the statistical feature based logits."""
20 | with tf.variable_scope(scope, reuse=reuse):
21 | channel = self._ch
22 | logits_list = []
23 |
24 | for i in range(self._n_downsample_init):
25 | with tf.variable_scope('down_{}'.format(i)):
26 | # (256, 256, 3) -> (128, 128, 256) -> (64, 64, 512)
27 | x = conv(x, channel, kernel=4, stride=2, pad=1, sn=self._sn)
28 | x = lrelu(x)
29 | channel *= 2
30 |
31 | for i in range(self._n_scales):
32 | with tf.variable_scope('scale_{}'.format(i)):
33 | # (64, 64, 512) -> (32, 32, 1024) -> (16, 16, 1024) -> (8, 8, 1024) -> (4, 4, 1024)
34 | x = conv(x, channel, kernel=4, stride=2, pad=1, sn=self._sn, scope='conv_k4')
35 | x = lrelu(x)
36 | logits = self._dis_logits(x)
37 | logits_list.extend(logits)
38 |
39 | return logits_list
40 |
41 | def _dis_logits(self, x, scope='dis_logits'):
42 | if self._logits_type == 'stats':
43 | return self._dis_logits_stats(x, scope=scope)
44 | else:
45 | raise ValueError('Invalid logits_type_dis!')
46 |
47 | def _dis_logits_stats(self, x, scope='dis_logits'):
48 | with tf.variable_scope(scope):
49 | logits_list = []
50 | channel = x.shape[-1].value
51 |
52 | for i in range(self._n_adapt):
53 | with tf.variable_scope('premix_{}'.format(i)):
54 | x = conv(x, channel, sn=self._sn)
55 | x = lrelu(x)
56 |
57 | if 'mean' in self._stats:
58 | with tf.variable_scope('gap'):
59 | x_gap = global_avg_pooling(x)
60 | x_gap_logits = self._mlp_logits(x_gap)
61 | x_gap_logits = tf.identity(x_gap_logits, 'D_logits_gap')
62 | logits_list.append(x_gap_logits)
63 |
64 | if 'max' in self._stats:
65 | with tf.variable_scope('gmp'):
66 | x_gmp = global_max_pooling(x)
67 | x_gmp_logits = self._mlp_logits(x_gmp)
68 | x_gmp_logits = tf.identity(x_gmp_logits, 'D_logits_gmp')
69 | logits_list.append(x_gmp_logits)
70 |
71 | if 'stddev' in self._stats:
72 | with tf.variable_scope('stddev'):
73 | # Calculate the channel-wise uncorrected standard deviation
74 | x_diff_square = tf.square(x - tf.reduce_mean(x, axis=[1, 2], keepdims=True))
75 | x_stddev = tf.sqrt(global_avg_pooling(x_diff_square))
76 | x_stddev_logits = self._mlp_logits(x_stddev)
77 | x_stddev_logits = tf.identity(x_stddev_logits, 'D_logits_stddev')
78 | logits_list.append(x_stddev_logits)
79 |
80 | return logits_list
81 |
82 | def _mlp_logits(self, x, n_ch=None, scope='dis_logits_mix'):
83 | with tf.variable_scope(scope):
84 | shape = x.shape.as_list()
85 | channel = n_ch or shape[-1]
86 | if len(shape) == 2:
87 | for i in range(self._n_mix):
88 | x = fully_connected(x, units=channel, sn=self._sn, scope='mix_'+str(i))
89 | x = lrelu(x)
90 | x = fully_connected(x, units=1, sn=self._sn, scope='logits')
91 | elif len(shape) == 4:
92 | for i in range(self._n_mix):
93 | x = conv(x, channels=channel, kernel=1, stride=1, sn=self._sn, scope='mix_'+str(i))
94 | x = lrelu(x)
95 | x = conv(x, channels=1, kernel=1, stride=1, sn=self._sn, scope='logits')
96 | return x
97 |
98 |
--------------------------------------------------------------------------------
/docs/tips.md:
--------------------------------------------------------------------------------
1 | ## Hyperparameter Tuning
2 |
3 | - Discriminator
4 | - We set the number of scales (``n_scales_dis``) to be a constant (4) in the paper. This is to verify that a good performance can be achieved for different tasks with a fixed network structure. We found in practice that a reduced number of scales (3) is often discriminative enough and more stabilized, especially for the tasks which require a significant shape deformation. In such case, the 4-th scale, which is the most discriminative one in the default setting, tends to be an overkill.
5 | - Reducing the number of base channels (``ch_dis``) is an effective way to accelerate the training process.
6 |
7 | - Generator
8 | - To improve the inference speed, you may want to reduce the number of base channels (``ch_gen``) or the number of enhanced upsampling layers (``n_enhanced_upsample_gen``).
9 |
10 | - Weak cycle
11 | - The weight for the weak cycle constraint (``cyc_weight``) can be adjusted on a per-task basis. A large value is generally more prohibitive for the shape deformation. On the other hand, it helps to keep the generated image correlated with the source image.
12 | - The input of the backward generator is resized by 1/(``resize_factor_gen_bw``). The number of downsampling and upsampling layers of the backward generator is set by ``n_updownsample_gen_bw``. These two parameters can be adjusted to change the weak cycle to a full resolution forward cycle, or something in between. The effect is somewhat similar to increasing the cycle weight.
13 |
14 | - Training with a higher resolution
15 | - You may want to adjust several parameters if the input / output resolution is higher than 256x256. Take 512x512 as an example. A good starting point will be setting ``img_size`` to 512, increasing ``n_downsample_init_dis`` from 2 to 3, and reducing ``ch_dis``, ``ch_gen`` and ``ch_gen_bw`` by a factor of two.
--------------------------------------------------------------------------------
/frozen_model/image_translator.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | import warnings
5 | warnings.filterwarnings('ignore', category=FutureWarning)
6 | import tensorflow as tf
7 | from tensorflow.python.platform import gfile
8 |
9 | curPath = os.path.abspath(os.path.dirname(__file__))
10 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
11 |
12 |
13 | class ImageTranslator:
14 | def __init__(self, model_path, size, n_threads_intra=1, n_threads_inter=1):
15 | config = tf.ConfigProto(intra_op_parallelism_threads=n_threads_intra,
16 | inter_op_parallelism_threads=n_threads_inter)
17 | config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
18 | config.gpu_options.allow_growth = True
19 | self._size = size
20 | self._graph = tf.Graph()
21 | self._sess = tf.Session(config=config, graph=self._graph)
22 |
23 | self._pb_file_path = model_path
24 | self._restore_from_pb()
25 | self._input_op = self._graph.get_tensor_by_name('test_domain_A:0')
26 | self._output_op = self._graph.get_tensor_by_name('test_fake_B:0')
27 |
28 | def _restore_from_pb(self):
29 | with self._graph.as_default():
30 | with gfile.FastGFile(self._pb_file_path, 'rb') as f:
31 | graph_def = tf.GraphDef()
32 | graph_def.ParseFromString(f.read())
33 | tf.import_graph_def(graph_def, name='')
34 |
35 | def _input_transform(self, img):
36 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
37 | img = cv2.resize(img, dsize=(self._size, self._size))
38 | img = np.expand_dims(img, axis=0)
39 | image_input = img / 127.5 - 1
40 | return image_input
41 |
42 | @staticmethod
43 | def _output_transform(output):
44 | output = ((output + 1.) / 2) * 255.0
45 | image_output = cv2.cvtColor(output.astype('uint8'), cv2.COLOR_RGB2BGR)
46 | return image_output
47 |
48 | def translate(self, image):
49 | """ Translate an image from the source domain to the target domain"""
50 | image_input = self._input_transform(image)
51 | output = self._sess.run(self._output_op, feed_dict={self._input_op: image_input})[0]
52 | return self._output_transform(output)
53 |
--------------------------------------------------------------------------------
/frozen_model/test_frozen_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import argparse
5 | import time
6 |
7 | sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
8 | from utils import get_img_paths
9 | from frozen_model.image_translator import ImageTranslator
10 |
11 | cv2.setNumThreads(1)
12 |
13 |
14 | def _get_images(img_loc: str) -> list:
15 | """ Find the image(s) in a given location.
16 | :param img_loc: either an image file path, or a directory containing images.
17 | :return: a list of image paths.
18 | """
19 | image_list = []
20 | if os.path.isfile(img_loc):
21 | image_list.append(img_loc)
22 | elif os.path.isdir(img_loc):
23 | image_list.extend(get_img_paths(img_loc))
24 | return image_list
25 |
26 |
27 | def main(args):
28 | """ Translate all images in a given location with a frozen model (.pb)."""
29 | images = _get_images(args.image)
30 | if len(images) == 0:
31 | raise RuntimeError('No image in {}!'.format(args.image))
32 |
33 | size = 256
34 | translator = ImageTranslator(model_path=args.model,
35 | size=size,
36 | n_threads_intra=args.n_threads_intra,
37 | n_threads_inter=args.n_threads_inter)
38 | os.makedirs(args.output_dir, exist_ok=True)
39 |
40 | st = time.time()
41 | for i, image in enumerate(images):
42 | print('{}: {}'.format(i, image))
43 | img = cv2.imread(image, cv2.IMREAD_COLOR)
44 | if img is None:
45 | print('Invalid image {}'.format(image))
46 | continue
47 | else:
48 | output = None
49 | for i_iter in range(args.n_iters):
50 | print('Iter: {}'.format(i_iter))
51 | output = translator.translate(img)
52 | save_path = os.path.join(args.output_dir, os.path.basename(image))
53 | cv2.imwrite(save_path, output)
54 |
55 | time_total = time.time() - st
56 | time_cost_per_img_ms = int((time_total / len(images) / args.n_iters) * 1000)
57 | print('Time cost per image: {} ms'.format(time_cost_per_img_ms))
58 |
59 |
60 | def _parse_arguments():
61 | parser = argparse.ArgumentParser()
62 | parser.add_argument('--image', type=str, required=True, help='Image file path or directory.')
63 | parser.add_argument('--output_dir', type=str, required=True, help='Output directory.')
64 | parser.add_argument('--model', type=str, required=True, help='.pb model path.')
65 | parser.add_argument('--n_threads_inter', type=int, default=1, help='Number of inter op threads.')
66 | parser.add_argument('--n_threads_intra', type=int, default=1, help='Number of intra op threads.')
67 | parser.add_argument('--n_iters', type=int, default=1)
68 | return parser.parse_args()
69 |
70 |
71 | if __name__ == '__main__':
72 | main(_parse_arguments())
--------------------------------------------------------------------------------
/gan/gan.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tensorflow as tf
3 | from datetime import datetime
4 | import time
5 | from utils import get_img_paths, save_images, load_test_data
6 |
7 |
8 | class GAN:
9 | """Base class for GANs."""
10 | def __init__(self, model_name, sess, args):
11 | # General
12 | self._model_name = model_name
13 | self._sess = sess
14 | self._saver = None
15 | self._dataset_name = args.dataset
16 | self._test_dataset_name = args.test_dataset or args.dataset
17 | self._dataset_struct = args.dataset_struct
18 | self._suffix = args.suffix
19 |
20 | # Directories
21 | self._dataset_dir = args.dataset_dir
22 | model_dir = "{}_{}_{}".format(self._model_name, self._dataset_name, self._suffix)
23 | self._checkpoint_dir = os.path.join(args.output_dir, model_dir, args.checkpoint_dir)
24 | self._sample_dir = os.path.join(args.output_dir, model_dir, args.sample_dir)
25 | self._log_dir = os.path.join(args.output_dir, model_dir, args.log_dir)
26 | self._result_dir = os.path.join(args.output_dir, model_dir, args.result_dir)
27 | for dir_ in [self._checkpoint_dir, self._sample_dir, self._log_dir, self._result_dir]:
28 | os.makedirs(dir_, exist_ok=True)
29 |
30 | # Input
31 | self._img_size = args.img_size
32 | train_a_dir = os.path.join(self._dataset_dir, self._dataset_name, 'trainA')
33 | train_b_dir = os.path.join(self._dataset_dir, self._dataset_name, 'trainB')
34 | self._train_a_dataset = get_img_paths(train_a_dir, self._dataset_struct)
35 | self._train_b_dataset = get_img_paths(train_b_dir, self._dataset_struct)
36 | self._dataset_num = max(len(self._train_a_dataset), len(self._train_b_dataset))
37 |
38 | # Generator
39 | self._gen = None
40 |
41 | print()
42 | print('##### Information #####')
43 | print('Number of trainA/B images: {}/{}'.format(len(self._train_a_dataset), len(self._train_b_dataset)) )
44 | print()
45 |
46 | def build_model_train(self):
47 | """To be implemented by the subclass."""
48 | pass
49 |
50 | def build_model_test(self):
51 | """Build the graph for testing."""
52 | self._test_domain_a = tf.placeholder(tf.float32, [1, self._img_size, self._img_size, 3],
53 | name='test_domain_A')
54 | test_fake_b = self._gen.translate(self._test_domain_a, scope='gen_a2b')
55 | self._test_fake_b = tf.identity(test_fake_b, 'test_fake_B')
56 |
57 | def train(self):
58 | """To be implemented by the subclass."""
59 | pass
60 |
61 | def test(self):
62 | """Translate test images."""
63 | tes_a_dir = os.path.join(self._dataset_dir, self._test_dataset_name, 'testA')
64 | test_a_files = get_img_paths(tes_a_dir, self._dataset_struct)
65 |
66 | if self._saver is None:
67 | self._saver = tf.train.Saver()
68 | could_load, checkpoint_counter = self._load_ckpt(self._checkpoint_dir)
69 | if could_load:
70 | print(" [*] Load SUCCESS")
71 | else :
72 | print(" [!] Load failed...")
73 | raise RuntimeError("Failed to load the checkpoint")
74 |
75 | dataset_tag = '' if self._test_dataset_name == self._dataset_name else self._test_dataset_name + '_'
76 | result_dir = os.path.join(self._result_dir, dataset_tag + str(checkpoint_counter))
77 | os.makedirs(result_dir, exist_ok=True)
78 |
79 | st = time.time()
80 | for sample_file in test_a_files: # A -> B
81 | print('Processing source image: ' + sample_file)
82 | src = load_test_data(sample_file, size=self._img_size)
83 | fake_img = self._sess.run(self._test_fake_b, feed_dict={self._test_domain_a: src})
84 |
85 | if self._dataset_struct == 'plain':
86 | dst_dir = result_dir
87 | elif self._dataset_struct == 'tree':
88 | src_dir = os.path.dirname(sample_file)
89 | dirname_level1 = os.path.basename(src_dir)
90 | dirname_level2 = os.path.basename(os.path.dirname(src_dir))
91 | dst_dir = os.path.join(result_dir, dirname_level2, dirname_level1)
92 | os.makedirs(dst_dir, exist_ok=True)
93 | else:
94 | raise RuntimeError('Invalid dataset_type!')
95 | image_path = os.path.join(dst_dir, os.path.basename(sample_file))
96 | save_images(fake_img[[0],:], [1, 1], image_path)
97 |
98 | time_cost = time.time() - st
99 | time_cost_per_img_ms = round(time_cost * 1000 / len(test_a_files))
100 | print('Time cost per image: {} ms'.format(time_cost_per_img_ms))
101 |
102 | def freeze_graph(self):
103 | """Generate the .pb model."""
104 | self._saver = tf.train.Saver()
105 | could_load, checkpoint_counter = self._load_ckpt(self._checkpoint_dir)
106 |
107 | if could_load:
108 | print(" [*] Load SUCCESS")
109 | else:
110 | print(" [!] Load failed...")
111 | raise RuntimeError("Failed to load the checkpoint")
112 |
113 | output_dir = os.path.join(self._checkpoint_dir, 'pb')
114 | os.makedirs(output_dir, exist_ok=True)
115 | time_stamp = datetime.now().strftime('%Y%m%d-%H%M%S')
116 | output_file = os.path.join(output_dir, 'output_graph_' + time_stamp + '.pb')
117 |
118 | frozen_graph_def = tf.graph_util.convert_variables_to_constants(
119 | sess=self._sess,
120 | input_graph_def=self._sess.graph_def,
121 | output_node_names=['test_fake_B'])
122 |
123 | # Save the frozen graph
124 | with open(output_file, 'wb') as f:
125 | f.write(frozen_graph_def.SerializeToString())
126 |
127 | def _save_ckpt(self, checkpoint_dir, step):
128 | os.makedirs(checkpoint_dir, exist_ok=True)
129 | self._saver.save(self._sess, os.path.join(checkpoint_dir, self._model_name + '.model'), global_step=step)
130 |
131 | def _load_ckpt(self, checkpoint_dir):
132 | print(" [*] Reading checkpoints...")
133 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
134 | if ckpt and ckpt.model_checkpoint_path:
135 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
136 | self._saver.restore(self._sess, os.path.join(checkpoint_dir, ckpt_name))
137 | counter = int(ckpt_name.split('-')[-1])
138 | print(" [*] Success to read {}".format(ckpt_name))
139 | return True, counter
140 | else:
141 | print(" [*] Failed to find a checkpoint")
142 | return False, 0
143 |
144 |
--------------------------------------------------------------------------------
/gan/spatchgan.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import tensorflow as tf
4 | import numpy as np
5 | from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch
6 | from gan.gan import GAN
7 | from utils import summary_by_keywords, batch_resize, save_images
8 | from ops import l1_loss, adv_loss, regularization_loss
9 | from imagedata import ImageData
10 | from discriminator.discriminator_spatch import DiscriminatorSPatch
11 | from discriminator.discriminator_patch import DiscriminatorPatch
12 | from generator.generator_basic_res import GeneratorBasicRes
13 |
14 |
15 | class SPatchGAN(GAN):
16 | """SPatchGAN framework."""
17 | def __init__(self, model_name, sess, args):
18 | super().__init__(model_name, sess, args)
19 |
20 | # Training
21 | self._n_steps = args.n_steps
22 | self._n_iters_per_step = args.n_iters_per_step
23 | self._batch_size = args.batch_size
24 | self._img_save_freq = args.img_save_freq
25 | self._ckpt_save_freq = args.ckpt_save_freq
26 | self._summary_freq = args.summary_freq
27 | self._decay_step = args.decay_step
28 | self._init_lr = args.lr
29 | self._adv_weight = args.adv_weight
30 | self._reg_weight = args.reg_weight
31 | self._cyc_weight = args.cyc_weight
32 | self._id_weight = args.id_weight
33 | self._gan_type = args.gan_type
34 |
35 | # Input
36 | self._augment_type = args.augment_type
37 |
38 | # Discriminator
39 | self._dis = self._create_dis(args)
40 |
41 | # Generator
42 | self._gen = self._create_gen(args)
43 | self._gen_bw = self._create_gen_bw(args)
44 | self._resolution_bw = self._img_size // args.resize_factor_gen_bw
45 |
46 | @staticmethod
47 | def _create_dis(args):
48 | if args.dis_type == 'spatch':
49 | stats = []
50 | if args.mean_dis:
51 | stats.append('mean')
52 | if args.max_dis:
53 | stats.append('max')
54 | if args.mean_dis:
55 | stats.append('stddev')
56 | return DiscriminatorSPatch(ch=args.ch_dis,
57 | n_downsample_init=args.n_downsample_init_dis,
58 | n_scales=args.n_scales_dis,
59 | n_adapt=args.n_adapt_dis,
60 | n_mix=args.n_mix_dis,
61 | logits_type=args.logits_type_dis,
62 | stats=stats,
63 | sn=args.sn_dis)
64 | elif args.dis_type == 'patch':
65 | return DiscriminatorPatch(ch=args.ch_dis,
66 | n_downsample_init=args.n_downsample_init_dis,
67 | n_scales=args.n_scales_dis,
68 | sn=args.sn_dis)
69 |
70 | else:
71 | raise ValueError('Invalid dis_type!')
72 |
73 | @staticmethod
74 | def _create_gen(args):
75 | if args.gen_type == 'basic_res':
76 | return GeneratorBasicRes(ch=args.ch_gen,
77 | n_updownsample=args.n_updownsample_gen,
78 | n_res=args.n_res_gen,
79 | n_enhanced_upsample=args.n_enhanced_upsample_gen,
80 | n_mix_upsample=args.n_mix_upsample_gen,
81 | block_type=args.block_type_gen,
82 | upsample_type=args.upsample_type_gen)
83 | else:
84 | raise ValueError('Invalid gen_type!')
85 |
86 | @staticmethod
87 | def _create_gen_bw(args):
88 | if args.gen_type == 'basic_res':
89 | return GeneratorBasicRes(ch=args.ch_gen_bw,
90 | n_updownsample=args.n_updownsample_gen_bw,
91 | n_res=args.n_res_gen_bw,
92 | n_enhanced_upsample=args.n_enhanced_upsample_gen,
93 | n_mix_upsample=args.n_mix_upsample_gen,
94 | block_type=args.block_type_gen,
95 | upsample_type=args.upsample_type_gen)
96 | else:
97 | raise ValueError('Invalid gen_type!')
98 |
99 | def _fetch_data(self, dataset):
100 | gpu_device = '/gpu:0'
101 | imgdata = ImageData(self._img_size, self._augment_type)
102 | train_dataset = tf.data.Dataset.from_tensor_slices(dataset)
103 | train_dataset = train_dataset.apply(shuffle_and_repeat(self._dataset_num)) \
104 | .apply(map_and_batch(imgdata.image_processing, self._batch_size,
105 | num_parallel_batches=16, drop_remainder=True)) \
106 | .apply(prefetch_to_device(gpu_device, None))
107 | train_iterator = train_dataset.make_one_shot_iterator()
108 | return train_iterator.get_next()
109 |
110 | def build_model_train(self):
111 | """Build the graph for training."""
112 | self._lr = tf.placeholder(tf.float32, name='learning_rate')
113 |
114 | # Input images
115 | self._domain_a = self._fetch_data(self._train_a_dataset)
116 | self._domain_b = self._fetch_data(self._train_b_dataset)
117 |
118 | # Forward generation
119 | self._x_ab = self._gen.translate(self._domain_a, scope='gen_a2b')
120 |
121 | # Backward generation
122 | if self._cyc_weight > 0.0:
123 | self._a_lowres = tf.image.resize_images(self._domain_a, [self._resolution_bw, self._resolution_bw])
124 | self._ab_lowres = tf.image.resize_images(self._x_ab, [self._resolution_bw, self._resolution_bw])
125 | self._aba_lowres = self._gen_bw.translate(self._ab_lowres, scope='gen_b2a')
126 | else:
127 | self._aba_lowres = tf.zeros([self._batch_size, self._resolution_bw, self._resolution_bw, 3])
128 |
129 | # Identity mapping
130 | self._x_bb = self._gen.translate(self._domain_b, reuse=True, scope='gen_a2b') \
131 | if self._id_weight > 0.0 else tf.zeros_like(self._domain_b)
132 |
133 | # Discriminator
134 | b_logits = self._dis.discriminate(self._domain_b, scope='dis_b')
135 | ab_logits = self._dis.discriminate(self._x_ab, reuse=True, scope='dis_b')
136 |
137 | # Adversarial loss for G
138 | adv_loss_gen_ab = self._adv_weight * adv_loss(ab_logits, self._gan_type, target='real')
139 |
140 | # Adversarial loss for D
141 | adv_loss_dis_b = self._adv_weight * adv_loss(b_logits, self._gan_type, target='real')
142 | adv_loss_dis_b += self._adv_weight * adv_loss(ab_logits, self._gan_type, target='fake')
143 |
144 | # Identity loss
145 | id_loss_bb = self._id_weight * l1_loss(self._domain_b, self._x_bb) \
146 | if self._id_weight > 0.0 else 0.0
147 | cyc_loss_aba = self._cyc_weight * l1_loss(self._a_lowres, self._aba_lowres) \
148 | if self._cyc_weight > 0.0 else 0.0
149 |
150 | # Weight decay
151 | reg_loss_gen = self._reg_weight * regularization_loss('gen_')
152 | reg_loss_dis = self._reg_weight * regularization_loss('dis_')
153 |
154 | # Overall loss
155 | self._gen_loss_all = adv_loss_gen_ab \
156 | + id_loss_bb \
157 | + cyc_loss_aba \
158 | + reg_loss_gen
159 |
160 | self._dis_loss_all = adv_loss_dis_b \
161 | + reg_loss_dis
162 |
163 |
164 | """ Training """
165 | t_vars = tf.trainable_variables()
166 | vars_gen = [var for var in t_vars if 'gen_' in var.name]
167 | vars_dis = [var for var in t_vars if 'dis_' in var.name]
168 |
169 | self._optim_gen = tf.train.AdamOptimizer(self._lr, beta1=0.5, beta2=0.999)\
170 | .minimize(self._gen_loss_all, var_list=vars_gen)
171 | self._optim_dis = tf.train.AdamOptimizer(self._lr, beta1=0.5, beta2=0.999)\
172 | .minimize(self._dis_loss_all, var_list=vars_dis)
173 |
174 | """" Summary """
175 | # Record the IN scaling factor for each residual block.
176 | summary_scale_res = summary_by_keywords(['gamma', 'resblock', 'res2'], node_type='variable')
177 | summary_logits_gen = summary_by_keywords('pre_tanh', node_type='tensor')
178 | summary_logits_dis = summary_by_keywords(['D_logits_'], node_type='tensor')
179 |
180 | summary_list_gen = []
181 | summary_list_gen.append(tf.summary.scalar("gen_loss_all", self._gen_loss_all))
182 | summary_list_gen.append(tf.summary.scalar("adv_loss_gen_ab", adv_loss_gen_ab))
183 | summary_list_gen.append(tf.summary.scalar("id_loss_bb", id_loss_bb))
184 | summary_list_gen.append(tf.summary.scalar("cyc_loss_aba", cyc_loss_aba))
185 | summary_list_gen.append(tf.summary.scalar("reg_loss_gen", reg_loss_gen))
186 | summary_list_gen.extend(summary_scale_res)
187 | summary_list_gen.extend(summary_logits_gen)
188 | self._summary_gen = tf.summary.merge(summary_list_gen)
189 |
190 | summary_list_dis = []
191 | summary_list_dis.append(tf.summary.scalar("dis_loss_all", self._dis_loss_all))
192 | summary_list_dis.append(tf.summary.scalar("adv_loss_dis_b", adv_loss_dis_b))
193 | summary_list_dis.append(tf.summary.scalar("reg_loss_dis", reg_loss_dis))
194 | summary_list_dis.extend(summary_logits_dis)
195 | self._summary_dis = tf.summary.merge(summary_list_dis)
196 |
197 | def train(self):
198 | """Run training iterations."""
199 | tf.global_variables_initializer().run()
200 | self._saver = tf.train.Saver()
201 | writer = tf.summary.FileWriter(self._log_dir, self._sess.graph)
202 |
203 | # restore the checkpoint if it exits
204 | could_load, checkpoint_counter = self._load_ckpt(self._checkpoint_dir)
205 | if could_load:
206 | counter = checkpoint_counter + 1
207 | start_step = counter // self._n_iters_per_step
208 | start_batch_id = counter - start_step * self._n_iters_per_step
209 | print(" [*] Load SUCCESS")
210 | else:
211 | counter = 0
212 | start_step = 0
213 | start_batch_id = 0
214 | print(" [!] Load failed...")
215 |
216 | # Looping over steps
217 | start_time = time.time()
218 | for step in range(start_step, self._n_steps):
219 | lr = self._init_lr if step < self._decay_step else \
220 | self._init_lr * (self._n_steps - step) / (self._n_steps - self._decay_step)
221 | for batch_id in range(start_batch_id, self._n_iters_per_step):
222 | train_feed_dict = {
223 | self._lr: lr
224 | }
225 |
226 | # Update D
227 | loss_dis, summary_str_dis, _ = self._sess.run([self._dis_loss_all, self._summary_dis, self._optim_dis],
228 | feed_dict=train_feed_dict)
229 |
230 | # Update G
231 | batch_a_images, batch_b_images, fake_b, identity_b, aba_lowres, loss_gen, summary_str_gen, _ = \
232 | self._sess.run([self._domain_a, self._domain_b,
233 | self._x_ab, self._x_bb, self._aba_lowres,
234 | self._gen_loss_all, self._summary_gen, self._optim_gen],
235 | feed_dict=train_feed_dict)
236 |
237 | # display training status
238 | print("Step: [%2d] [%5d/%5d] time: %4.4f D_loss: %.8f, G_loss: %.8f"
239 | % (step, batch_id, self._n_iters_per_step, time.time() - start_time, loss_dis, loss_gen))
240 |
241 | if (counter+1) % self._summary_freq == 0:
242 | writer.add_summary(summary_str_dis, counter)
243 | writer.add_summary(summary_str_gen, counter)
244 |
245 | if (counter+1) % self._img_save_freq == 0:
246 | aba_lowres_resize = batch_resize(aba_lowres, self._img_size)
247 | merged = np.vstack([batch_a_images, fake_b, aba_lowres_resize, batch_b_images, identity_b])
248 | save_images(merged, [5, self._batch_size],
249 | os.path.join(self._sample_dir, 'sample_{:03d}_{:05d}.jpg'.format(step, batch_id)))
250 |
251 | if (counter+1) % self._ckpt_save_freq == 0:
252 | self._save_ckpt(self._checkpoint_dir, counter)
253 |
254 | counter += 1
255 |
256 | # After each step, start_batch_id is set to zero.
257 | # Non-zero value is only for the first step after loading a pre-trained model.
258 | start_batch_id = 0
259 |
260 | # Save the final model.
261 | self._save_ckpt(self._checkpoint_dir, counter-1)
262 |
--------------------------------------------------------------------------------
/generator/generator_basic_res.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from ops import conv, instance_norm, layer_norm, relu, tanh, resblock_v1, nearest_up, bilinear_up
3 |
4 |
5 | class GeneratorBasicRes:
6 | """Basic residual block based generator for SPatchGAN."""
7 | def __init__(self, ch, n_updownsample, n_res, n_enhanced_upsample, n_mix_upsample, block_type, upsample_type):
8 | self._ch = ch
9 | self._n_updownsample = n_updownsample
10 | self._n_res = n_res
11 | self._n_enhanced_upsample = n_enhanced_upsample
12 | self._n_mix_upsample = n_mix_upsample
13 | self._block_type = block_type
14 | self._upsample_type = upsample_type
15 |
16 | def translate(self, x, reuse=False, scope='gen'):
17 | """Build the generator graph."""
18 | with tf.variable_scope(scope, reuse=reuse) :
19 | channel = self._ch
20 |
21 | # Downsampling
22 | for i in range(self._n_updownsample):
23 | with tf.variable_scope('down_{}'.format(i)):
24 | # (256, 256, 3) -> (128, 128, 128) -> (64, 64, 256) -> (32, 32, 512)
25 | x = conv(x, channel, kernel=3, stride=2, pad=1)
26 | x = instance_norm(x)
27 | if i < self._n_updownsample - 1:
28 | x = relu(x)
29 | channel *= 2
30 |
31 | if self._n_updownsample == 0:
32 | with tf.variable_scope('mix_init'):
33 | x = conv(x, channel, kernel=3, pad=1)
34 | x = instance_norm(x)
35 |
36 | for i in range(self._n_res):
37 | with tf.variable_scope('res_{}'.format(i)):
38 | x = self._conv_block(x, block_type=self._block_type, channel=channel)
39 |
40 | for i in range(self._n_updownsample):
41 | with tf.variable_scope('up_{}'.format(i)):
42 | # (32, 32, 512) -> (64, 64, 512) -> (128, 128, 256) -> (256, 256, 128)
43 | x = self._upsample(x, method=self._upsample_type)
44 | channel = channel if i < self._n_enhanced_upsample else channel // 2
45 | n_mix_upsample = self._n_mix_upsample if i < self._n_enhanced_upsample else 1
46 | for j in range(n_mix_upsample):
47 | with tf.variable_scope('mix_{}'.format(j)):
48 | x = conv(x, channel, kernel=3, stride=1, pad=1)
49 | x = layer_norm(x)
50 | x = relu(x)
51 |
52 | if self._n_updownsample == 0:
53 | with tf.variable_scope('mix_end'):
54 | x = conv(x, channel, kernel=3, stride=1, pad=1)
55 | x = layer_norm(x)
56 | x = relu(x)
57 |
58 | with tf.variable_scope('logits'):
59 | # (256, 256, 128) -> (256, 256, 3)
60 | x = conv(x, channels=3, kernel=3, pad=1, scope='G_logit')
61 | x = tf.identity(x, 'pre_tanh')
62 | x = tanh(x)
63 |
64 | return x
65 |
66 | @staticmethod
67 | def _conv_block(x, block_type, channel, scope='resblock_0'):
68 | if block_type == 'v1':
69 | x = resblock_v1(x, channel=channel, scope=scope)
70 | else:
71 | raise ValueError('Wrong block_type!')
72 | return x
73 |
74 | @staticmethod
75 | def _upsample(x, method: str = 'nearest'):
76 | if method == 'nearest':
77 | x = nearest_up(x)
78 | elif method == 'bilinear':
79 | x = bilinear_up(x)
80 | else:
81 | raise ValueError('Invalid upsampling method!')
82 | return x
83 |
--------------------------------------------------------------------------------
/imagedata.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | class ImageData:
5 | """Input image processing and augmentation."""
6 | def __init__(self, load_size, augment_type):
7 | self._load_size = load_size
8 | self._augment_type = augment_type
9 |
10 | def image_processing(self, filename):
11 | x = tf.read_file(filename)
12 | x_decode = tf.image.decode_jpeg(x, channels=3)
13 |
14 | if self._augment_type is None:
15 | img = tf.image.resize_images(x_decode, [self._load_size, self._load_size])
16 | elif self._augment_type == 'pad_crop':
17 | img = _augmentation_pad_crop(x_decode, self._load_size)
18 | elif self._augment_type == 'resize_crop':
19 | img = _augmentation_resize_crop(x_decode, self._load_size)
20 | else:
21 | raise ValueError('Invalid augment_type!')
22 |
23 | img = tf.cast(img, tf.float32) / 127.5 - 1
24 | return img
25 |
26 |
27 | def _augmentation_pad_crop(image, size_out):
28 | image = tf.image.resize(image, [size_out, size_out])
29 | image = tf.cast(image, tf.uint8)
30 | # The shape info will be lost after random jpeg quality.
31 | image = tf.image.random_jpeg_quality(image, min_jpeg_quality=50, max_jpeg_quality=100)
32 | image = tf.reshape(image, [size_out, size_out, 3])
33 |
34 | pad_size = round(size_out * 0.05)
35 | # White padding
36 | image = tf.pad(image, paddings=[[pad_size, pad_size], [pad_size, pad_size], [0, 0]], constant_values=255)
37 | image = tf.random_crop(image, [size_out, size_out, 3])
38 | image = _augmentation_general(image)
39 | return image
40 |
41 |
42 | def _augmentation_resize_crop(image, size_out):
43 | aug_rand = tf.random_uniform([])
44 | image = tf.cond(aug_rand < 0.5,
45 | lambda: _ugatit_resize_crop(image, size_out),
46 | lambda: tf.image.resize(image, [size_out, size_out]))
47 | image = tf.cast(image, tf.uint8)
48 | # The shape info will be lost after random jpeg quality.
49 | image = tf.image.random_jpeg_quality(image, min_jpeg_quality=50, max_jpeg_quality=100)
50 | image = tf.reshape(image, [size_out, size_out, 3])
51 | image = _augmentation_general(image)
52 | return image
53 |
54 |
55 | def _augmentation_general(image):
56 | # Operations that preserve the shape and are safe for most images.
57 | # These color changes should be done after padding to apply the changes on the paddings.
58 | image = tf.image.random_flip_left_right(image)
59 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
60 | image = tf.image.random_saturation(image, lower=0.8, upper=1.2)
61 | image = tf.image.random_hue(image, max_delta=0.02)
62 | image = tf.image.random_contrast(image, lower=0.8, upper=1.2)
63 | return image
64 |
65 |
66 | def _ugatit_resize_crop(image, size_out):
67 | augment_size = size_out
68 | if size_out == 256:
69 | augment_size += 30
70 | elif size_out == 512:
71 | augment_size += 60
72 | else:
73 | # Generalize the augmentation strategy in U-GAT-IT
74 | augment_size += round(augment_size * 0.1)
75 | image = tf.image.resize_images(image, [augment_size, augment_size])
76 | image = tf.random_crop(image, [size_out, size_out, 3])
77 | return image
78 |
--------------------------------------------------------------------------------
/images/SPatchGAN_D_20210317_3x.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NetEase-GameAI/SPatchGAN/64ac3547545dcdafee9cf7ccf1956410a2793ad5/images/SPatchGAN_D_20210317_3x.jpg
--------------------------------------------------------------------------------
/images/s2a_cmp_github_downsized.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NetEase-GameAI/SPatchGAN/64ac3547545dcdafee9cf7ccf1956410a2793ad5/images/s2a_cmp_github_downsized.jpg
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tensorflow as tf
3 | from configs import parse_args
4 | from utils import show_all_variables
5 | from gan.spatchgan import SPatchGAN
6 |
7 |
8 | def main():
9 | """General entry point for running GANs."""
10 | args = parse_args()
11 | os.makedirs(args.output_dir, exist_ok=True)
12 |
13 | gpu_options = tf.GPUOptions(allow_growth=True)
14 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)) as sess:
15 |
16 | if args.network == 'spatchgan':
17 | gan = SPatchGAN('SPatchGAN', sess, args)
18 | else:
19 | raise RuntimeError('Invalid network!')
20 |
21 | if args.phase == 'train':
22 | gan.build_model_train()
23 | show_all_variables()
24 | gan.train()
25 | print(" [*] Training finished!")
26 | elif args.phase == 'test':
27 | gan.build_model_test()
28 | show_all_variables()
29 | gan.test()
30 | print(" [*] Test finished!")
31 | elif args.phase == 'freeze_graph':
32 | gan.build_model_test()
33 | show_all_variables()
34 | gan.freeze_graph()
35 | print(" [*] Graph frozen!")
36 | else:
37 | raise RuntimeError('Invalid phase!')
38 |
39 |
40 | if __name__ == '__main__':
41 | main()
42 |
--------------------------------------------------------------------------------
/ops.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow.contrib as tf_contrib
3 |
4 | weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02)
5 | weight_regularizer = tf_contrib.layers.l2_regularizer(scale=0.0001)
6 |
7 |
8 | def conv(x, channels, kernel=1, stride=1, pad=0, pad_type: str = 'zero', use_bias=True,
9 | sn: str = None, scope: str = 'conv_0'):
10 | """Convolution layer."""
11 | with tf.variable_scope(scope):
12 | if pad > 0 :
13 | if (kernel - stride) % 2 == 0:
14 | pad_top = pad
15 | pad_bottom = pad
16 | pad_left = pad
17 | pad_right = pad
18 |
19 | else:
20 | # For kernel = 3, stride = 2, pad=1:
21 | # pad_top = pad_left = 1
22 | # pad_bottom = pad_right = 0
23 | # h_out = (h_in + 1 - 3) / 2 + 1 = h_in / 2
24 | pad_top = pad
25 | pad_bottom = kernel - stride - pad_top
26 | pad_left = pad
27 | pad_right = kernel - stride - pad_left
28 |
29 | if pad_type == 'zero':
30 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]])
31 | if pad_type == 'reflect':
32 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT')
33 |
34 | if sn is not None:
35 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init,
36 | regularizer=weight_regularizer)
37 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w, method=sn),
38 | strides=[1, stride, stride, 1], padding='VALID')
39 | if use_bias:
40 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
41 | x = tf.nn.bias_add(x, bias)
42 |
43 | else:
44 | x = tf.layers.conv2d(inputs=x, filters=channels,
45 | kernel_size=kernel, kernel_initializer=weight_init,
46 | kernel_regularizer=weight_regularizer,
47 | strides=stride, use_bias=use_bias)
48 | return x
49 |
50 |
51 | def fully_connected(x, units, use_bias=True, sn: str = None, scope='linear'):
52 | """Fully connected layer."""
53 | with tf.variable_scope(scope):
54 | x = tf.layers.flatten(x)
55 | shape = x.get_shape().as_list()
56 | channels = shape[-1]
57 |
58 | if sn is not None:
59 | w = tf.get_variable("kernel", [channels, units], tf.float32,
60 | initializer=weight_init, regularizer=weight_regularizer)
61 | if use_bias:
62 | bias = tf.get_variable("bias", [units],
63 | initializer=tf.constant_initializer(0.0))
64 |
65 | x = tf.matmul(x, spectral_norm(w, method=sn)) + bias
66 | else:
67 | x = tf.matmul(x, spectral_norm(w, method=sn))
68 |
69 | else:
70 | x = tf.layers.dense(x, units=units, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer, use_bias=use_bias)
71 |
72 | return x
73 |
74 |
75 | def instance_norm(x, scope='instance_norm'):
76 | """Instance normalization layer."""
77 | return tf_contrib.layers.instance_norm(x,
78 | epsilon=1e-05,
79 | center=True, scale=True,
80 | scope=scope)
81 |
82 |
83 | def layer_norm(x, scope='layer_norm'):
84 | """Layer normalization layer."""
85 | return tf_contrib.layers.layer_norm(x,
86 | center=True, scale=True,
87 | scope=scope)
88 |
89 |
90 | def spectral_norm(w, n_iters=1, method: str = 'fast'):
91 | """Spectral normalization layer."""
92 | w_shape = w.shape.as_list()
93 | w = tf.reshape(w, [-1, w_shape[-1]])
94 |
95 | u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False)
96 |
97 | u_hat = tf.stop_gradient(u) if method == 'full' else u
98 | v_hat = None
99 | for i in range(n_iters):
100 | """
101 | power iteration
102 | Usually iteration = 1 will be enough
103 | """
104 | v_ = tf.matmul(u_hat, tf.transpose(w))
105 | v_hat = tf.nn.l2_normalize(v_)
106 |
107 | u_ = tf.matmul(v_hat, w)
108 | u_hat = tf.nn.l2_normalize(u_)
109 |
110 | if method == 'fast':
111 | u_hat = tf.stop_gradient(u_hat)
112 | v_hat = tf.stop_gradient(v_hat)
113 | elif method == 'full':
114 | pass
115 | else:
116 | raise RuntimeError('Invalid sn method!')
117 |
118 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
119 |
120 | with tf.control_dependencies([u.assign(u_hat)]):
121 | w_norm = w / sigma
122 | w_norm = tf.reshape(w_norm, w_shape)
123 |
124 | return w_norm
125 |
126 |
127 | def lrelu(x, alpha=0.2):
128 | """Leaky ReLU."""
129 | return tf.nn.leaky_relu(x, alpha)
130 |
131 |
132 | def relu(x):
133 | """ReLU."""
134 | return tf.nn.relu(x)
135 |
136 |
137 | def tanh(x):
138 | """Tanh."""
139 | return tf.tanh(x)
140 |
141 |
142 | def global_avg_pooling(x):
143 | """Global average pooling for the NHWC data."""
144 | gap = tf.reduce_mean(x, axis=[1, 2])
145 | return gap
146 |
147 |
148 | def global_max_pooling(x):
149 | """Global max pooling for the NHWC data."""
150 | gmp = tf.reduce_max(x, axis=[1, 2])
151 | return gmp
152 |
153 |
154 | def nearest_up(x, scale_factor=2):
155 | """Nearest neighbor upsampling."""
156 | _, h, w, _ = x.get_shape().as_list()
157 | new_size = [h * scale_factor, w * scale_factor]
158 | return tf.image.resize_nearest_neighbor(x, size=new_size)
159 |
160 |
161 | def bilinear_up(x, scale_factor=2):
162 | """Bilinear upsampling."""
163 | _, h, w, _ = x.get_shape().as_list()
164 | new_size = [h * scale_factor, w * scale_factor]
165 | return tf.image.resize_images(x, size=new_size)
166 |
167 |
168 | def resblock_v1(x_init, channel, pad_type: str = 'zero', use_bias=True, is_res=True, scope='resblock_0'):
169 | """Residual block."""
170 | with tf.variable_scope(scope):
171 | with tf.variable_scope('res1'):
172 | x = conv(x_init, channel, kernel=3, pad=1, pad_type=pad_type, use_bias=use_bias)
173 | x = instance_norm(x)
174 | x = relu(x)
175 |
176 | with tf.variable_scope('res2'):
177 | x = conv(x, channel, kernel=3, pad=1, pad_type=pad_type, use_bias=use_bias)
178 | x = instance_norm(x)
179 |
180 | x_ret = x + x_init if is_res else x
181 | return x_ret
182 |
183 |
184 | def l1_loss(x, y):
185 | """Calculate the L1 loss."""
186 | loss = tf.reduce_mean(tf.abs(x - y))
187 | return loss
188 |
189 |
190 | def regularization_loss(scope_name: str):
191 | """Collect the regularization loss."""
192 | collection_regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
193 |
194 | loss = []
195 | for item in collection_regularization :
196 | if scope_name in item.name :
197 | loss.append(item)
198 |
199 | return tf.reduce_sum(loss)
200 |
201 |
202 | def adv_loss(x, loss_func : str, target : str):
203 | """Calculate the adversarial loss."""
204 | loss_list = []
205 | logits_list = x if isinstance(x, list) else [x]
206 | for i, logits in enumerate(logits_list):
207 | if loss_func == 'lsgan':
208 | if target == 'real':
209 | target_val = 1.0
210 | elif target == 'fake':
211 | target_val = 0.0
212 | else:
213 | raise ValueError('Invalid target {} for adv_loss'.format(target))
214 | loss = tf.squared_difference(logits, target_val)
215 | else:
216 | raise ValueError('Invalid loss_func {} for adv_loss'.format(loss_func))
217 | loss = tf.reduce_mean(loss) / len(logits_list)
218 | loss_list.append(loss)
219 |
220 | return sum(loss_list)
221 |
--------------------------------------------------------------------------------
/output/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file
4 | !.gitignore
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.18.1
2 | opencv-python==3.4.3.18
3 | tensorflow-gpu==1.14.0
4 | gast==0.2.2
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import tensorflow as tf
3 | import numpy as np
4 | from glob import glob
5 | from tensorflow.contrib import slim
6 |
7 |
8 | def show_all_variables():
9 | """Show all variables in the graph."""
10 | model_vars = tf.trainable_variables()
11 | slim.model_analyzer.analyze_vars(model_vars, print_info=True)
12 |
13 |
14 | def get_img_paths(input_dir: str, dataset_struct: str = 'plain'):
15 | """Get all images in an input directory."""
16 | exts = ['jpg', 'jpeg', 'png']
17 | imgs = []
18 | for ext in exts:
19 | if dataset_struct == 'plain':
20 | pattern = input_dir + '/*.{}'.format(ext)
21 | elif dataset_struct == 'tree':
22 | pattern = input_dir + '/*/*.{}'.format(ext)
23 | else:
24 | raise ValueError('Invalid dataset_struct!')
25 | imgs.extend(glob(pattern))
26 | return imgs
27 |
28 |
29 | def get_img_paths_auto(input_dir: str):
30 | """Auto detect the directory structure and get all images."""
31 | dataset = get_img_paths(input_dir)
32 | if len(dataset) == 0:
33 | dataset = get_img_paths(input_dir, dataset_struct='tree')
34 | return dataset
35 |
36 |
37 | def summary_by_keywords(keywords, node_type='tensor'):
38 | """Generate summary for the tf nodes whose names match the keywords."""
39 | summary_list = []
40 | if node_type == 'tensor':
41 | all_nodes = [tensor for op in tf.get_default_graph().get_operations() for tensor in op.values()]
42 | elif node_type == 'variable':
43 | all_nodes = tf.trainable_variables()
44 | else:
45 | raise RuntimeError('Invalid target!')
46 |
47 | keyword_list = keywords if isinstance(keywords, list) else [keywords]
48 |
49 | # Include a node if its name contains all keywords.
50 | nodes = [node for node in all_nodes if all(keyword in node.name for keyword in keyword_list)]
51 |
52 | for node in nodes:
53 | summary_list.append(tf.summary.scalar(node.name + "_min", tf.reduce_min(node)))
54 | summary_list.append(tf.summary.scalar(node.name + "_max", tf.reduce_max(node)))
55 | node_mean = tf.reduce_mean(node)
56 | summary_list.append(tf.summary.scalar(node.name + "_mean", node_mean))
57 | # Calculate the uncorrected standard deviation
58 | node_stddev = tf.sqrt(tf.reduce_mean(tf.square(node - node_mean)))
59 | summary_list.append(tf.summary.scalar(node.name + "_stddev", node_stddev))
60 |
61 | return summary_list
62 |
63 |
64 | def batch_resize(x, img_size):
65 | """Resize the NHWC Numpy images."""
66 | x_up = np.zeros((x.shape[0], img_size, img_size, 3))
67 | for i in range(x.shape[0]):
68 | x_up[i, :, :, :] = cv2.resize(x[i, :, :, :], dsize=(img_size, img_size))
69 | return x_up
70 |
71 |
72 | def save_images(images, size, image_path):
73 | """Save a grid of images."""
74 | return _imsave(_inverse_transform(images), size, image_path)
75 |
76 |
77 | def _inverse_transform(images):
78 | return ((images+1.) / 2) * 255.0
79 |
80 |
81 | def _imsave(images, size, path):
82 | images = _merge(images, size)
83 | images = cv2.cvtColor(images.astype('uint8'), cv2.COLOR_RGB2BGR)
84 |
85 | return cv2.imwrite(path, images)
86 |
87 |
88 | def _merge(images, size):
89 | h, w = images.shape[1], images.shape[2]
90 | img = np.zeros((h * size[0], w * size[1], 3))
91 | for idx, image in enumerate(images):
92 | i = idx % size[1]
93 | j = idx // size[1]
94 | img[h*j:h*(j+1), w*i:w*(i+1), :] = image
95 |
96 | return img
97 |
98 |
99 | def load_test_data(image_path, size=256):
100 | """Load test images with OpenCV."""
101 | img = cv2.imread(image_path, flags=cv2.IMREAD_COLOR)
102 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
103 | img = cv2.resize(img, dsize=(size, size))
104 | img = np.expand_dims(img, axis=0)
105 | img = img / 127.5 - 1
106 |
107 | return img
108 |
--------------------------------------------------------------------------------