├── .gitignore ├── LICENSE ├── README.md ├── main.py ├── ops.py ├── pix2pix.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Junho Kim (1993.01.12) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pix2pix-Tensorflow 2 | SImple Tensorflow implementations of " Image-to-Image Translation with Conditional Adversarial Networks" (CVPR 2017) 3 | 4 | ## Requirements 5 | * Tensorflow 1.4 6 | * Python 3.6 7 | 8 | ## Usage 9 | ```bash 10 | ├── dataset 11 |    └── YOUR_DATASET_NAME 12 |    ├── trainA 13 |           ├── 1.jpg (format doesn't matter) 14 | ├── 2.png 15 | └── ... 16 |    ├── trainB (target list) 17 | ├── 1_.jpg 18 | ├── 2_.png 19 | └── ... 20 |    ├── testA 21 |    ├── aaa.jpg 22 | ├── bbb.png 23 | └── ... 24 | 25 | ``` 26 | 27 | ```bash 28 | > python main.py --dataset cat --gray_to_RGB True 29 | ``` 30 | 31 | 32 | ## Author 33 | Junho Kim 34 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from pix2pix import pix2pix 2 | import argparse 3 | from utils import * 4 | 5 | """parsing and configuration""" 6 | def parse_args(): 7 | desc = "Tensorflow implementation of jh_GAN" 8 | parser = argparse.ArgumentParser(description=desc) 9 | parser.add_argument('--phase', type=str, default='train', help='train or test ?') 10 | parser.add_argument('--dataset', type=str, default='meet', help='dataset_name') 11 | 12 | parser.add_argument('--epoch', type=int, default=200, help='The number of epochs to run') 13 | parser.add_argument('--batch_size', type=int, default=1, help='The size of batch per gpu') 14 | parser.add_argument('--print_freq', type=int, default=100, help='The number of image_print_freq') 15 | 16 | parser.add_argument('--lr', type=float, default=0.0002, help='The learning rate') 17 | parser.add_argument('--L1_weight', type=float, default=10.0, help='The L1 lambda') 18 | 19 | 20 | parser.add_argument('--ch', type=int, default=64, help='base channel number per layer') 21 | parser.add_argument('--repeat', type=int, default=9, help='img size : 256 -> 9, 128 -> 6') 22 | 23 | parser.add_argument('--img_size', type=int, default=256, help='The size of image') 24 | parser.add_argument('--gray_to_RGB', type=bool, default=False, help='Gray -> RGB') 25 | 26 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 27 | help='Directory name to save the checkpoints') 28 | parser.add_argument('--result_dir', type=str, default='results', 29 | help='Directory name to save the generated images') 30 | parser.add_argument('--log_dir', type=str, default='logs', 31 | help='Directory name to save training logs') 32 | parser.add_argument('--sample_dir', type=str, default='samples', 33 | help='Directory name to save the samples on training') 34 | 35 | return check_args(parser.parse_args()) 36 | 37 | """checking arguments""" 38 | def check_args(args): 39 | # --checkpoint_dir 40 | check_folder(args.checkpoint_dir) 41 | 42 | # --result_dir 43 | check_folder(args.result_dir) 44 | 45 | # --result_dir 46 | check_folder(args.log_dir) 47 | 48 | # --sample_dir 49 | check_folder(args.sample_dir) 50 | 51 | # --epoch 52 | try: 53 | assert args.epoch >= 1 54 | except: 55 | print('number of epochs must be larger than or equal to one') 56 | 57 | # --batch_size 58 | try: 59 | assert args.batch_size >= 1 60 | except: 61 | print('batch size must be larger than or equal to one') 62 | return args 63 | 64 | """main""" 65 | def main(): 66 | # parse arguments 67 | args = parse_args() 68 | if args is None: 69 | exit() 70 | 71 | # open session 72 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 73 | gan = pix2pix(sess, args) 74 | 75 | # build graph 76 | gan.build_model() 77 | 78 | # show network architecture 79 | show_all_variables() 80 | 81 | if args.phase == 'train' : 82 | # launch the graph in a session 83 | gan.train() 84 | print(" [*] Training finished!") 85 | 86 | if args.phase == 'test' : 87 | gan.test() 88 | print(" [*] Test finished!") 89 | 90 | if __name__ == '__main__': 91 | main() -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib as tf_contrib 3 | 4 | 5 | weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02) 6 | """ 7 | pad = (k-1) // 2 8 | size = (I-k+1+2p) // s 9 | """ 10 | def conv(x, channels, kernel=4, stride=2, pad=1, scope='conv_0'): 11 | with tf.variable_scope(scope): 12 | x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]]) 13 | x = tf.layers.conv2d(inputs=x, filters=channels, kernel_size=kernel, kernel_initializer=weight_init, strides=stride) 14 | 15 | return x 16 | 17 | 18 | def deconv(x, channels, kernel=4, stride=2, scope='deconv_0'): 19 | with tf.variable_scope(scope): 20 | x = tf.layers.conv2d_transpose(inputs=x, filters=channels, kernel_size=kernel, kernel_initializer=weight_init, strides=stride, padding='SAME') 21 | 22 | return x 23 | 24 | 25 | def resblock(x_init, channels, scope='resblock_0'): 26 | with tf.variable_scope(scope): 27 | with tf.variable_scope('res1'): 28 | x = tf.pad(x_init, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT') 29 | x = tf.layers.conv2d(inputs=x, filters=channels, kernel_size=3, kernel_initializer=weight_init, strides=1) 30 | x = batch_norm(x) 31 | x = relu(x) 32 | 33 | with tf.variable_scope('res2'): 34 | x = tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT') 35 | x = tf.layers.conv2d(inputs=x, filters=channels, kernel_size=3, kernel_initializer=weight_init, strides=1) 36 | x = batch_norm(x) 37 | 38 | return x + x_init 39 | 40 | def flatten(x) : 41 | return tf.layers.flatten(x) 42 | 43 | def lrelu(x, alpha=0.01): 44 | # pytorch alpha is 0.01 45 | return tf.nn.leaky_relu(x, alpha) 46 | 47 | 48 | def relu(x): 49 | return tf.nn.relu(x) 50 | 51 | 52 | def sigmoid(x): 53 | return tf.sigmoid(x) 54 | 55 | 56 | def tanh(x): 57 | return tf.tanh(x) 58 | 59 | 60 | def batch_norm(x, is_training=True, scope='batch_norm'): 61 | return tf_contrib.layers.batch_norm(x, decay=0.9, epsilon=1e-05, center=True, scale=True, updates_collections=None, is_training=is_training, scope=scope) 62 | 63 | 64 | def instance_norm(x, scope='instance_norm'): 65 | return tf_contrib.layers.instance_norm(x, 66 | epsilon=1e-05, 67 | center=True, scale=True, 68 | scope=scope) 69 | 70 | def L1_loss(x, y): 71 | loss = tf.reduce_mean(tf.abs(x - y)) 72 | 73 | return loss 74 | 75 | 76 | 77 | def L2_loss(x, y): 78 | loss = tf.reduce_mean(tf.square(x - y)) 79 | 80 | return loss 81 | 82 | def discriminator_loss(real, fake): 83 | 84 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real), logits=real)) 85 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake), logits=fake)) 86 | 87 | loss = real_loss + fake_loss 88 | 89 | return loss 90 | 91 | 92 | def generator_loss(fake): 93 | 94 | loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake), logits=fake)) 95 | 96 | return loss 97 | 98 | 99 | 100 | 101 | """ 102 | def discriminator_loss(loss_func, real, fake): 103 | loss = None 104 | 105 | if loss_func == 'wgan-gp' : 106 | loss = tf.reduce_mean(fake) - tf.reduce_mean(real) 107 | 108 | if loss_func == 'lsgan' : 109 | real_loss = tf.reduce_mean(tf.squared_difference(real, 1.0)) 110 | fake_loss = tf.reduce_mean(tf.square(fake)) 111 | loss = (real_loss + fake_loss) * 0.5 112 | 113 | if loss_func == 'gan' : 114 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real), logits=real)) 115 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake), logits=fake)) 116 | loss = (real_loss + fake_loss) * 0.5 117 | 118 | return loss 119 | 120 | 121 | def generator_loss(loss_func, fake): 122 | loss = None 123 | 124 | if loss_func == 'wgan-gp' : 125 | loss = -tf.reduce_mean(fake) 126 | 127 | if loss_func == 'lsgan' : 128 | loss = tf.reduce_mean(tf.squared_difference(fake, 1.0)) 129 | 130 | if loss_func == 'gan' : 131 | loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake), logits=fake)) 132 | 133 | return loss 134 | 135 | """ -------------------------------------------------------------------------------- /pix2pix.py: -------------------------------------------------------------------------------- 1 | from ops import * 2 | from utils import * 3 | from glob import glob 4 | import time 5 | 6 | class pix2pix(object): 7 | def __init__(self, sess, args): 8 | self.model_name = 'pix2pix' 9 | self.sess = sess 10 | self.checkpoint_dir = args.checkpoint_dir 11 | self.result_dir = args.result_dir 12 | self.log_dir = args.log_dir 13 | self.sample_dir = args.sample_dir 14 | self.dataset_name = args.dataset 15 | 16 | self.epoch = args.epoch # 100000 17 | self.batch_size = args.batch_size 18 | self.print_freq = args.print_freq 19 | 20 | self.ch = args.ch 21 | self.repeat = args.repeat 22 | 23 | """ Weight """ 24 | self.L1_weight = args.L1_weight 25 | self.lr = args.lr 26 | 27 | self.img_size = args.img_size 28 | self.gray_to_RGB = args.gray_to_RGB 29 | 30 | if self.gray_to_RGB : 31 | self.input_ch = 1 32 | self.output_ch = 3 33 | else : 34 | self.input_ch = 3 35 | self.output_ch = 3 36 | 37 | self.trainA, self.trainB = prepare_data(dataset_name=self.dataset_name, size=self.img_size, gray_to_RGB=self.gray_to_RGB) 38 | self.num_batches = max(len(self.trainA), len(self.trainB)) // self.batch_size 39 | 40 | self.sample_dir = os.path.join(args.sample_dir, self.model_dir) 41 | check_folder(self.sample_dir) 42 | 43 | def generator(self, x, is_training=True, reuse=False, scope="generator"): 44 | channel = self.ch 45 | with tf.variable_scope(scope, reuse=reuse) : 46 | x = conv(x, channel, kernel=7, stride=1, pad=3, scope='conv_0') 47 | x = batch_norm(x, is_training, scope='conv_batch_0') 48 | x = relu(x) 49 | 50 | # Encoder 51 | for i in range(2) : 52 | x = conv(x, channel*2, kernel=3, stride=2, pad=1, scope='en_conv_'+str(i)) 53 | x = batch_norm(x, is_training, scope='en_batch_'+str(i)) 54 | x = relu(x) 55 | channel = channel * 2 56 | 57 | # Bottle-neck 58 | for i in range(self.repeat) : 59 | x = resblock(x, channel, scope='resblock_'+str(i)) 60 | 61 | # Decoder 62 | for i in range(2) : 63 | x = deconv(x, channel//2, kernel=4, stride=2, scope='deconv_'+str(i)) 64 | x = batch_norm(x, is_training, scope='de_batch_'+str(i)) 65 | x = relu(x) 66 | 67 | channel = channel // 2 68 | 69 | x = conv(x, channels=3, kernel=7, stride=1, pad=3, scope='last_conv') # NO BATCH NORM 70 | x = tanh(x) 71 | 72 | return x 73 | 74 | def discriminator(self, x, is_training=True, reuse=False, scope="discriminator"): 75 | channel = self.ch 76 | with tf.variable_scope(scope, reuse=reuse): 77 | x = conv(x, channel, kernel=4, stride=2, pad=1, scope='first_conv') # NO BATCH NORM 78 | x = lrelu(x, 0.2) 79 | 80 | for i in range(2) : 81 | x = conv(x, channel*2, kernel=4, stride=2, pad=1, scope='conv_'+str(i)) 82 | x = batch_norm(x, is_training, scope='batch_'+str(i)) 83 | x = lrelu(x, 0.2) 84 | channel = channel * 2 85 | 86 | x = conv(x, channel, kernel=3, stride=1, pad=1, scope='conv_2') 87 | x = batch_norm(x, is_training, scope='batch_2') 88 | x = lrelu(x, 0.2) 89 | 90 | x = conv(x, channels=1, kernel=3, stride=1, pad=1, scope='last_conv') 91 | 92 | return x 93 | 94 | 95 | def build_model(self): 96 | 97 | """ Graph Image""" 98 | self.real_A = tf.placeholder(tf.float32, [self.batch_size, self.img_size, self.img_size, self.input_ch], name='real_A') # gray 99 | 100 | self.real_B = tf.placeholder(tf.float32, [self.batch_size, self.img_size, self.img_size, self.output_ch], name='real_B') # rgb 101 | 102 | 103 | """ Loss Function """ 104 | D_real_logit = self.discriminator(self.real_B, reuse=False) 105 | 106 | self.fake_B = self.generator(self.real_A) 107 | D_fake_logit = self.discriminator(self.fake_B, reuse=True) 108 | 109 | self.d_loss = discriminator_loss(real=D_real_logit, fake=D_fake_logit) 110 | 111 | self.g_loss = generator_loss(fake=D_fake_logit) + self.L1_weight * L1_loss(self.real_B, self.fake_B) 112 | 113 | """ Training """ 114 | t_vars = tf.trainable_variables() 115 | G_vars = [var for var in t_vars if 'generator' in var.name] 116 | D_vars = [var for var in t_vars if 'discriminator' in var.name] 117 | 118 | self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.g_loss, var_list=G_vars) 119 | self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.d_loss, var_list=D_vars) 120 | 121 | """" Summary """ 122 | self.G_loss_summary = tf.summary.scalar("Generator_loss", self.g_loss) 123 | self.D_loss_summary = tf.summary.scalar("Discriminator_loss", self.d_loss) 124 | 125 | """ Test """ 126 | self.test_real_A = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.input_ch], name='test_real_A') 127 | self.sample = self.generator(self.test_real_A, is_training=False, reuse=True) 128 | 129 | 130 | def train(self): 131 | # initialize all variables 132 | tf.global_variables_initializer().run() 133 | 134 | # saver to save model 135 | self.saver = tf.train.Saver() 136 | 137 | # summary writer 138 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph) 139 | 140 | 141 | # restore check-point if it exits 142 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 143 | if could_load: 144 | start_epoch = (int)(checkpoint_counter / self.num_batches) 145 | start_batch_id = checkpoint_counter - start_epoch * self.num_batches 146 | counter = checkpoint_counter 147 | print(" [*] Load SUCCESS") 148 | else: 149 | start_epoch = 0 150 | start_batch_id = 0 151 | counter = 1 152 | print(" [!] Load failed...") 153 | 154 | # loop for epoch 155 | start_time = time.time() 156 | for epoch in range(start_epoch, self.epoch): 157 | self.trainA, self.trainB = shuffle(self.trainA, self.trainB) 158 | # get batch data 159 | for idx in range(start_batch_id, self.num_batches): 160 | batch_A_images = self.trainA[idx * self.batch_size : (idx + 1) * self.batch_size] 161 | batch_B_images = self.trainB[idx * self.batch_size : (idx + 1) * self.batch_size] 162 | 163 | train_feed_dict = { 164 | self.real_A : batch_A_images, 165 | self.real_B : batch_B_images, 166 | } 167 | 168 | # Update D 169 | _, d_loss, summary_str = self.sess.run([self.D_optim, self.d_loss, self.D_loss_summary], 170 | feed_dict=train_feed_dict) 171 | self.writer.add_summary(summary_str, counter) 172 | 173 | # Update G 174 | fake_B, _, g_loss, summary_str = self.sess.run([self.fake_B, self.G_optim, self.g_loss, self.G_loss_summary], 175 | feed_dict=train_feed_dict) 176 | self.writer.add_summary(summary_str, counter) 177 | 178 | # display training status 179 | counter += 1 180 | print("Epoch: [%2d] [%4d/%4d] time: %4.4f d_loss: %.8f, g_loss: %.8f" \ 181 | % (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss)) 182 | 183 | if np.mod(counter, self.print_freq) == 0: 184 | save_images(batch_A_images, [self.batch_size, 1], 185 | './{}/real_A_{:3d}_{:04d}.jpg'.format(self.sample_dir, epoch, idx)) 186 | save_images(batch_B_images, [self.batch_size, 1], 187 | './{}/real_B_{:03d}_{:04d}.jpg'.format(self.sample_dir, epoch, idx)) 188 | 189 | save_images(fake_B, [self.batch_size, 1], 190 | './{}/fake_B_{:03d}_{:04d}.jpg'.format(self.sample_dir, epoch, idx)) 191 | 192 | # After an epoch, start_batch_id is set to zero 193 | # non-zero value is only for the first epoch after loading pre-trained model 194 | start_batch_id = 0 195 | 196 | # save model 197 | # self.save(self.checkpoint_dir, counter) 198 | 199 | # save model for final step 200 | self.save(self.checkpoint_dir, counter) 201 | 202 | 203 | @property 204 | def model_dir(self): 205 | return "{}_{}".format(self.model_name, self.dataset_name) 206 | 207 | def save(self, checkpoint_dir, step): 208 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 209 | 210 | if not os.path.exists(checkpoint_dir): 211 | os.makedirs(checkpoint_dir) 212 | 213 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step) 214 | 215 | def load(self, checkpoint_dir): 216 | import re 217 | print(" [*] Reading checkpoints...") 218 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 219 | 220 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 221 | if ckpt and ckpt.model_checkpoint_path: 222 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 223 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 224 | counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0)) 225 | print(" [*] Success to read {}".format(ckpt_name)) 226 | return True, counter 227 | else: 228 | print(" [*] Failed to find a checkpoint") 229 | return False, 0 230 | 231 | def test(self): 232 | tf.global_variables_initializer().run() 233 | test_A_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testA')) 234 | 235 | self.saver = tf.train.Saver() 236 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 237 | 238 | if could_load : 239 | print(" [*] Load SUCCESS") 240 | else : 241 | print(" [!] Load failed...") 242 | 243 | # write html for visual comparison 244 | index_path = os.path.join(self.result_dir, 'index.html') 245 | index = open(index_path, 'w') 246 | index.write("") 247 | index.write("") 248 | 249 | for sample_file in test_A_files : # A -> B 250 | print('Processing A image: ' + sample_file) 251 | sample_image = np.asarray(load_test_data(sample_file, size=self.img_size, gray_to_RGB=self.gray_to_RGB)) 252 | image_path = os.path.join(self.result_dir,'{0}'.format(os.path.basename(sample_file))) 253 | 254 | fake_img = self.sess.run(self.sample, feed_dict = {self.test_real_A : sample_image}) 255 | 256 | save_images(fake_img, [1, 1], image_path) 257 | index.write("" % os.path.basename(image_path)) 258 | index.write("" % (sample_file if os.path.isabs(sample_file) else ( 259 | '..' + os.path.sep + sample_file), self.img_size, self.img_size)) 260 | index.write("" % (image_path if os.path.isabs(image_path) else ( 261 | '..' + os.path.sep + image_path), self.img_size, self.img_size)) 262 | index.write("") 263 | 264 | index.close() 265 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import slim 3 | from scipy import misc 4 | import os, random 5 | import numpy as np 6 | from glob import glob 7 | 8 | def prepare_data(dataset_name, size, gray_to_RGB=False): 9 | input_list = sorted(glob('./dataset/{}/*.*'.format(dataset_name + '/trainA'))) 10 | target_list = sorted(glob('./dataset/{}/*.*'.format(dataset_name + '/trainB'))) 11 | 12 | trainA = [] 13 | trainB = [] 14 | 15 | if gray_to_RGB : 16 | for image in input_list: 17 | trainA.append(np.expand_dims(misc.imresize(misc.imread(image, mode='L'), [size, size]), axis=-1)) 18 | 19 | for image in input_list: 20 | trainB.append(misc.imresize(misc.imread(image, mode='RGB'), [size, size])) 21 | 22 | # trainA = np.repeat(trainA, repeats=3, axis=-1) 23 | # trainA = np.array(trainA).astype(np.float32)[:, :, :, None] 24 | 25 | else : 26 | for image in input_list : 27 | trainA.append(misc.imresize(misc.imread(image, mode='RGB'), [size, size])) 28 | 29 | for image in target_list : 30 | trainB.append(misc.imresize(misc.imread(image, mode='RGB'), [size, size])) 31 | 32 | 33 | trainA = preprocessing(np.asarray(trainA)) 34 | trainB = preprocessing(np.asarray(trainB)) 35 | 36 | return trainA, trainB 37 | 38 | def shuffle(x, y) : 39 | seed = np.random.random_integers(low=0, high=1000) 40 | np.random.seed(seed) 41 | np.random.shuffle(x) 42 | 43 | np.random.seed(seed) 44 | np.random.shuffle(y) 45 | 46 | return x, y 47 | 48 | def load_test_data(image_path, size=256, gray_to_RGB=False): 49 | if gray_to_RGB : 50 | img = misc.imread(image_path, mode='L') 51 | img = misc.imresize(img, [size, size]) 52 | img = np.expand_dims(img, axis=-1) 53 | else : 54 | img = misc.imread(image_path, mode='RGB') 55 | img = misc.imresize(img, [size, size]) 56 | img = np.expand_dims(img, axis=0) 57 | img = preprocessing(img) 58 | 59 | return img 60 | 61 | 62 | def preprocessing(x): 63 | 64 | x = x/127.5 - 1 # -1 ~ 1 65 | return x 66 | 67 | def augmentation(image, augment_size): 68 | seed = random.randint(0, 2 ** 31 - 1) 69 | ori_image_shape = tf.shape(image) 70 | image = tf.image.random_flip_left_right(image, seed=seed) 71 | image = tf.image.resize_images(image, [augment_size, augment_size]) 72 | image = tf.random_crop(image, ori_image_shape, seed=seed) 73 | return image 74 | 75 | def save_images(images, size, image_path): 76 | return imsave(inverse_transform(images), size, image_path) 77 | 78 | def inverse_transform(images): 79 | return (images+1.) / 2 80 | 81 | def imsave(images, size, path): 82 | return misc.imsave(path, merge(images, size)) 83 | 84 | def merge(images, size): 85 | h, w = images.shape[1], images.shape[2] 86 | img = np.zeros((h * size[0], w * size[1], 3)) 87 | for idx, image in enumerate(images): 88 | i = idx % size[1] 89 | j = idx // size[1] 90 | img[h*j:h*(j+1), w*i:w*(i+1), :] = image 91 | 92 | return img 93 | 94 | def show_all_variables(): 95 | model_vars = tf.trainable_variables() 96 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 97 | 98 | def check_folder(log_dir): 99 | if not os.path.exists(log_dir): 100 | os.makedirs(log_dir) 101 | return log_dir 102 | --------------------------------------------------------------------------------
nameinputoutput
%s