├── examples ├── 1.png ├── 2.png └── 3.png ├── train.sh ├── eval.sh ├── LICENSE ├── dataset ├── parse.py └── build_dataset.py ├── README.md ├── modeling ├── loss.py └── model.py ├── eval_model.py └── train_model.py /examples/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/NS-Outpainting/HEAD/examples/1.png -------------------------------------------------------------------------------- /examples/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/NS-Outpainting/HEAD/examples/2.png -------------------------------------------------------------------------------- /examples/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-x-yang/NS-Outpainting/HEAD/examples/3.png -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python train_model.py --trainset-path /path/to/tf-record-trainset --testset-path /path/to/tf-record-testset -------------------------------------------------------------------------------- /eval.sh: -------------------------------------------------------------------------------- 1 | python eval_model.py --trainset-path /path/to/tf-record-trainset --testset-path /path/to/tf-record-testset --checkpoint-path /path/to/checkpoint 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 z-x-yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /dataset/parse.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def parse_trainset(example_proto): 5 | 6 | dics = {} 7 | dics['image'] = tf.FixedLenFeature(shape=[], dtype=tf.string) 8 | 9 | parsed_example = tf.parse_single_example( 10 | serialized=example_proto, features=dics) 11 | image = tf.decode_raw(parsed_example['image'], out_type=tf.uint8) 12 | 13 | image = tf.reshape(image, shape=[72 * 2, 216 * 2, 3]) 14 | 15 | image = tf.random_crop(image, [64 * 2, 128 * 2, 3]) 16 | image = tf.image.random_flip_left_right(image) 17 | image = tf.cast(image, tf.float32) / 255. 18 | image = 2. * image - 1. 19 | 20 | return image 21 | 22 | 23 | def parse_testset(example_proto): 24 | 25 | dics = {} 26 | dics['image'] = tf.FixedLenFeature(shape=[], dtype=tf.string) 27 | 28 | parsed_example = tf.parse_single_example( 29 | serialized=example_proto, features=dics) 30 | image = tf.decode_raw(parsed_example['image'], out_type=tf.uint8) 31 | 32 | image = tf.reshape(image, shape=[64 * 2, 128 * 2, 3]) 33 | 34 | image = tf.cast(image, tf.float32) * (2. / 255) - 1.0 35 | 36 | return image 37 | 38 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Very Long Natural Scenery Image Prediction by Outpainting (NS-Outpainting) 2 | A neural architecture for scenery image outpaiting ([ICCV 2019](http://openaccess.thecvf.com/content_ICCV_2019/papers/Yang_Very_Long_Natural_Scenery_Image_Prediction_by_Outpainting_ICCV_2019_paper.pdf)), implemented in [TensorFlow](http://www.tensorflow.org). 3 | 4 | The architecture has an ability to generate a very long high-quality prediction from a small input image by outpaiting: 5 | 6 | 7 | 8 | 9 | ## Requirements and Preparation 10 | 11 | Please install `TensorFlow>=1.3.0`, `Python>=3.6`. 12 | 13 | For training and testing, we collect a new outpainting dataset, which has 6,000 images containing complex natural scenes. You can download the raw dataset from [here](https://drive.google.com/file/d/15rGKgeNHWqjs90An7wpZXJMz-zFaC1q0/view?usp=sharing) and split the training and testing set by yourself. Or, you can get our split from [here](https://drive.google.com/file/d/1LDRx0W6zo_eCZwN92pGgGZSCrqzB3KZ6/view?usp=sharing) (TFRecord format, 128 resolution, 5,000 images for training and 1,000 for testing). 14 | 15 | ## Usage 16 | 17 | For training and evaluation, you can use [train.sh](/train.sh) and [eval.sh](/eval.sh). Please remember to set the TFRecord dataset path inside them. 18 | 19 | Besides, you can get our **pretrain model** from [here](https://drive.google.com/file/d/1-DLSwNkB93MMKaYVO1rmPP9iJllXDJrg/view?usp=sharing), and run eval_model.py to evaluate it. 20 | 21 | After running eval_model.py, the evaluation process will store 4 types of images: 22 | 1) "ori_xxx.jpg", the groundtruth images of size 128x256; 23 | 2) "m0_xxx.jpg", the 1-step predictions of size 128x256 without any post-processing methods; 24 | 3) "m1_xxx.jpg", the 1-step predictions of size 128x256 with smoothly stitching; 25 | 4) "endless_xxx.jpg", the 4-step predictions of size 128x640. 26 | 27 | Notably, we measure Inception Score and Inception Distance between "ori_xxx.jpg" and "m0_xxx.jpg" in our paper. 28 | 29 | ## Citation 30 | ``` 31 | @inproceedings{yang2019very, 32 | title={Very Long Natural Scenery Image Prediction by Outpainting}, 33 | author={Yang, Zongxin and Dong, Jian and Liu, Ping and Yang, Yi and Yan, Shuicheng}, 34 | booktitle={Proceedings of the IEEE International Conference on Computer Vision}, 35 | pages={10561--10570}, 36 | year={2019} 37 | } 38 | ``` 39 | -------------------------------------------------------------------------------- /dataset/build_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | from glob import glob 4 | import numpy as np 5 | from PIL import Image 6 | import tensorflow as tf 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser(description='Model training.') 10 | parser.add_argument('--dataset-path', type=str, default='./scenery/') 11 | parser.add_argument('--result-path', type=str, default='./') 12 | 13 | args = parser.parse_args() 14 | dataset_path = args.dataset_path 15 | result_path = args.result_path 16 | 17 | 18 | if not os.path.exists(result_path): 19 | os.makedirs(result_path) 20 | 21 | train_list = os.listdir(dataset_path) 22 | random.shuffle(train_list) 23 | trainset = list(map(lambda x: os.path.join( 24 | dataset_path, x), train_list)) 25 | 26 | testset = trainset[0:1000] 27 | trainset = trainset[1000:] 28 | 29 | 30 | def build_trainset(image_list, name): 31 | len2 = len(image_list) 32 | print("len=", len2) 33 | writer = tf.python_io.TFRecordWriter(name) 34 | k = 0 35 | for i in range(len2): 36 | 37 | image = Image.open(image_list[i]) 38 | image = image.resize((432, 144), Image.BILINEAR) 39 | image = image.convert('RGB') 40 | 41 | image_bytes = image.tobytes() 42 | 43 | features = {} 44 | 45 | features['image'] = tf.train.Feature( 46 | bytes_list=tf.train.BytesList(value=[image_bytes])) 47 | 48 | tf_features = tf.train.Features(feature=features) 49 | 50 | tf_example = tf.train.Example(features=tf_features) 51 | 52 | tf_serialized = tf_example.SerializeToString() 53 | 54 | writer.write(tf_serialized) 55 | k = k + 1 56 | print(k) 57 | writer.close() 58 | 59 | 60 | def build_testset(image_list, name): 61 | len2 = len(image_list) 62 | print("len=", len2) 63 | writer = tf.python_io.TFRecordWriter(name) 64 | for i in range(len2): 65 | 66 | image = Image.open(image_list[i]) 67 | image = image.resize((256, 128), Image.BILINEAR) 68 | image = image.convert('RGB') 69 | 70 | image_flip = image.transpose(Image.FLIP_LEFT_RIGHT) 71 | 72 | image_bytes = image.tobytes() 73 | 74 | features = {} 75 | 76 | features['image'] = tf.train.Feature( 77 | bytes_list=tf.train.BytesList(value=[image_bytes])) 78 | 79 | tf_features = tf.train.Features(feature=features) 80 | 81 | tf_example = tf.train.Example(features=tf_features) 82 | 83 | tf_serialized = tf_example.SerializeToString() 84 | 85 | writer.write(tf_serialized) 86 | 87 | # flip image 88 | image = image_flip 89 | 90 | image_bytes = image.tobytes() 91 | 92 | features = {} 93 | 94 | features['image'] = tf.train.Feature( 95 | bytes_list=tf.train.BytesList(value=[image_bytes])) 96 | 97 | tf_features = tf.train.Features(feature=features) 98 | 99 | tf_example = tf.train.Example(features=tf_features) 100 | 101 | tf_serialized = tf_example.SerializeToString() 102 | 103 | writer.write(tf_serialized) 104 | 105 | writer.close() 106 | 107 | 108 | print('Build testset!') 109 | build_testset(testset, result_path + "/testset.tfr") 110 | print('Build trainset!') 111 | build_trainset(trainset, result_path + "/trainset.tfr") 112 | 113 | print('Done!') 114 | -------------------------------------------------------------------------------- /modeling/loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import math 4 | 5 | class Loss(): 6 | def __init__(self, cfg): 7 | self.cfg = cfg 8 | 9 | def masked_reconstruction_loss(self, gt, recon): 10 | loss_recon = tf.square(gt - recon) 11 | mask_values = np.ones((128, 128)) 12 | for j in range(128): 13 | mask_values[:, j] = (1. + math.cos(math.pi * j / 127.0)) * 0.5 14 | mask_values = np.expand_dims(mask_values, 0) 15 | mask_values = np.expand_dims(mask_values, 3) 16 | mask1 = tf.constant(1, dtype=tf.float32, shape=[1, 128, 128, 1]) 17 | mask2 = tf.constant(mask_values, dtype=tf.float32, shape=[1, 128, 128, 1]) 18 | mask = tf.concat([mask1, mask2], axis=2) 19 | loss_recon = loss_recon * mask 20 | loss_recon = tf.reduce_mean(loss_recon) 21 | return loss_recon 22 | 23 | def adversarial_loss(self, dis_fun, real, fake, name): 24 | adversarial_pos = dis_fun(real, name=name) 25 | adversarial_neg = dis_fun(fake, reuse=tf.AUTO_REUSE, name=name) 26 | 27 | loss_adv_D = - tf.reduce_mean(adversarial_pos - adversarial_neg) 28 | 29 | differences = fake - real 30 | alpha = tf.random_uniform(shape=[self.cfg.batch_size_per_gpu, 1, 1, 1]) 31 | interpolates = real + tf.multiply(alpha, differences) 32 | gradients = tf.gradients(dis_fun( 33 | interpolates, reuse=tf.AUTO_REUSE, name=name), [interpolates])[0] 34 | slopes = tf.sqrt(tf.reduce_sum( 35 | tf.square(gradients), [1, 2, 3]) + 1e-10) 36 | gradients_penalty = tf.reduce_mean((slopes - 1.) ** 2) 37 | loss_adv_D += self.cfg.lambda_gp * gradients_penalty 38 | 39 | loss_adv_G = -tf.reduce_mean(adversarial_neg) 40 | 41 | return loss_adv_D, loss_adv_G 42 | 43 | def global_adversarial_loss(self, dis_fun, real, fake): 44 | return self.adversarial_loss(dis_fun, real, fake, 'DIS') 45 | 46 | def local_adversarial_loss(self, dis_fun, real, fake): 47 | return self.adversarial_loss(dis_fun, real, fake, 'DIS2') 48 | 49 | 50 | def global_and_local_adv_loss(self, model, gt, recon): 51 | 52 | left_half_gt = tf.slice(gt, [0, 0, 0, 0], [self.cfg.batch_size_per_gpu, 128, 128, 3]) 53 | right_half_gt = tf.slice(gt, [0, 0, 128, 0], [self.cfg.batch_size_per_gpu, 128, 128, 3]) 54 | right_half_recon = tf.slice(recon, [0, 0, 128, 0], [self.cfg.batch_size_per_gpu, 128, 128, 3]) 55 | real = gt 56 | fake = tf.concat([left_half_gt, right_half_recon], axis=2) 57 | global_D, global_G = self.global_adversarial_loss(model.build_adversarial_global, real, fake) 58 | 59 | real = right_half_gt 60 | fake = right_half_recon 61 | local_D, local_G = self.local_adversarial_loss(model.build_adversarial_local, real, fake) 62 | 63 | loss_adv_D = global_D + local_D 64 | loss_adv_G = self.cfg.beta * global_G + (1 - self.cfg.beta) * local_G 65 | 66 | return loss_adv_G, loss_adv_D 67 | 68 | 69 | 70 | def average_losses(self, loss): 71 | tf.add_to_collection('losses', loss) 72 | 73 | # Assemble all of the losses for the current tower only. 74 | losses = tf.get_collection('losses') 75 | 76 | # Calculate the total loss for the current tower. 77 | regularization_losses = tf.get_collection( 78 | tf.GraphKeys.REGULARIZATION_LOSSES) 79 | total_loss = tf.add_n( 80 | losses + regularization_losses, name='total_loss') 81 | 82 | # Compute the moving average of all individual losses and the total 83 | # loss. 84 | loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg') 85 | loss_averages_op = loss_averages.apply(losses + [total_loss]) 86 | 87 | with tf.control_dependencies([loss_averages_op]): 88 | total_loss = tf.identity(total_loss) 89 | return total_loss 90 | 91 | def average_gradients(self, tower_grads): 92 | average_grads = [] 93 | for grad_and_vars in zip(*tower_grads): 94 | # Note that each grad_and_vars looks like the following: 95 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 96 | grads = [] 97 | # Average over the 'tower' dimension. 98 | g, _ = grad_and_vars[0] 99 | 100 | for g, _ in grad_and_vars: 101 | expanded_g = tf.expand_dims(g, 0) 102 | grads.append(expanded_g) 103 | grad = tf.concat(grads, axis=0) 104 | grad = tf.reduce_mean(grad, 0) 105 | 106 | # Keep in mind that the Variables are redundant because they are shared 107 | # across towers. So .. we will just return the first tower's pointer to 108 | # the Variable. 109 | v = grad_and_vars[0][1] 110 | grad_and_var = (grad, v) 111 | average_grads.append(grad_and_var) 112 | # clip 113 | if self.cfg.clip_gradient: 114 | gradients, variables = zip(*average_grads) 115 | gradients = [ 116 | None if gradient is None else tf.clip_by_average_norm(gradient, self.cfg.clip_gradient_value) 117 | for gradient in gradients] 118 | average_grads = zip(gradients, variables) 119 | return average_grads 120 | 121 | def feed_all_gpu(self, inp_dict, gpu_num, payload_per_gpu, images, params): 122 | for i in range(gpu_num): 123 | gt = params[i] 124 | start_pos = i * payload_per_gpu 125 | stop_pos = (i + 1) * payload_per_gpu 126 | inp_dict[gt] = images[start_pos:stop_pos] 127 | return inp_dict 128 | 129 | 130 | -------------------------------------------------------------------------------- /eval_model.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | from glob import glob 4 | import numpy as np 5 | from PIL import Image 6 | import tensorflow as tf 7 | from tensorflow.python.training.moving_averages import assign_moving_average 8 | import tensorflow.contrib.layers as ly 9 | from modeling.model import Model 10 | from modeling.loss import Loss 11 | from dataset.parse import parse_trainset, parse_testset 12 | import argparse 13 | import math 14 | 15 | parser = argparse.ArgumentParser(description='Model testing.') 16 | # experiment 17 | parser.add_argument('--date', type=str, default='0817') 18 | parser.add_argument('--exp-index', type=int, default=2) 19 | parser.add_argument('--f', action='store_true', default=False) 20 | 21 | # gpu 22 | parser.add_argument('--start-gpu', type=int, default=0) 23 | parser.add_argument('--num-gpu', type=int, default=1) 24 | 25 | # dataset 26 | parser.add_argument('--trainset-path', type=str, default='./dataset/trainset.tfr') 27 | parser.add_argument('--testset-path', type=str, default='./dataset/testset.tfr') 28 | parser.add_argument('--trainset-length', type=int, default=5041) 29 | parser.add_argument('--testset-length', type=int, default=2000) # we flip every image in testset 30 | 31 | # training 32 | parser.add_argument('--base-lr', type=float, default=0.0001) 33 | parser.add_argument('--batch-size', type=int, default=20) 34 | parser.add_argument('--weight-decay', type=float, default=0.00002) 35 | parser.add_argument('--epoch', type=int, default=1500) 36 | parser.add_argument('--lr-decay-epoch', type=int, default=1000) 37 | parser.add_argument('--critic-steps', type=int, default=3) 38 | parser.add_argument('--warmup-steps', type=int, default=1000) 39 | parser.add_argument('--workers', type=int, default=2) 40 | parser.add_argument('--clip-gradient', action='store_true', default=False) 41 | parser.add_argument('--clip-gradient-value', type=float, default=0.1) 42 | 43 | 44 | # modeling 45 | parser.add_argument('--beta', type=float, default=0.9) 46 | parser.add_argument('--lambda-gp', type=float, default=10) 47 | parser.add_argument('--lambda-rec', type=float, default=0.998) 48 | 49 | # checkpoint 50 | parser.add_argument('--log-path', type=str, default='./logs/') 51 | parser.add_argument('--checkpoint-path', type=str, default=None) 52 | parser.add_argument('--resume-step', type=int, default=0) 53 | 54 | 55 | args = parser.parse_args() 56 | 57 | 58 | # prepare path 59 | base_path = args.log_path 60 | exp_date = args.date 61 | if exp_date is None: 62 | print('Exp date error!') 63 | import sys 64 | sys.exit() 65 | exp_name = exp_date + '/' + str(args.exp_index) 66 | print("Start Exp:", exp_name) 67 | output_path = base_path + exp_name + '/' 68 | model_path = output_path + 'models/' 69 | tensorboard_path = output_path + 'log/' 70 | result_path = output_path + 'results/' 71 | 72 | if not os.path.exists(model_path): 73 | os.makedirs(model_path) 74 | if not os.path.exists(tensorboard_path): 75 | os.makedirs(tensorboard_path) 76 | if not os.path.exists(result_path): 77 | os.makedirs(result_path) 78 | elif not args.f: 79 | if args.checkpoint_path is None: 80 | print('Exp exist!') 81 | import sys 82 | sys.exit() 83 | else: 84 | import shutil 85 | shutil.rmtree(model_path) 86 | os.makedirs(model_path) 87 | shutil.rmtree(tensorboard_path) 88 | os.makedirs(tensorboard_path) 89 | 90 | # prepare gpu 91 | num_gpu = args.num_gpu 92 | start_gpu = args.start_gpu 93 | gpu_id = str(start_gpu) 94 | for i in range(num_gpu - 1): 95 | gpu_id = gpu_id + ',' + str(start_gpu + i + 1) 96 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id) 97 | args.batch_size_per_gpu = int(args.batch_size / args.num_gpu) 98 | 99 | 100 | 101 | 102 | model = Model(args) 103 | loss = Loss(args) 104 | 105 | config = tf.ConfigProto(allow_soft_placement=True) 106 | config.gpu_options.allow_growth = True 107 | config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 108 | 109 | print("Start building model...") 110 | with tf.Session(config=config) as sess: 111 | with tf.device('/cpu:0'): 112 | learning_rate = tf.placeholder(tf.float32, []) 113 | lambda_rec = tf.placeholder(tf.float32, []) 114 | 115 | train_op_G = tf.train.AdamOptimizer( 116 | learning_rate=learning_rate, beta1=0.5, beta2=0.9) 117 | train_op_D = tf.train.AdamOptimizer( 118 | learning_rate=learning_rate, beta1=0.5, beta2=0.9) 119 | 120 | 121 | trainset = tf.data.TFRecordDataset(filenames=[args.trainset_path]) 122 | trainset = trainset.shuffle(args.trainset_length) 123 | trainset = trainset.map(parse_trainset, num_parallel_calls=args.workers) 124 | trainset = trainset.batch(args.batch_size).repeat() 125 | 126 | train_iterator = trainset.make_one_shot_iterator() 127 | train_im = train_iterator.get_next() 128 | 129 | testset = tf.data.TFRecordDataset(filenames=[args.testset_path]) 130 | testset = testset.map(parse_testset, num_parallel_calls=args.workers) 131 | testset = testset.batch(args.batch_size).repeat() 132 | 133 | test_iterator = testset.make_one_shot_iterator() 134 | test_im = test_iterator.get_next() 135 | 136 | print('build model on gpu tower') 137 | models = [] 138 | params = [] 139 | for gpu_id in range(num_gpu): 140 | with tf.device('/gpu:%d' % gpu_id): 141 | print('tower_%d' % gpu_id) 142 | with tf.name_scope('tower_%d' % gpu_id): 143 | with tf.variable_scope('cpu_variables', reuse=gpu_id > 0): 144 | 145 | groundtruth = tf.placeholder( 146 | tf.float32, [args.batch_size_per_gpu, 128, 256, 3], name='groundtruth') 147 | left_gt = tf.slice(groundtruth, [0, 0, 0, 0], [args.batch_size_per_gpu, 128, 128, 3]) 148 | 149 | 150 | reconstruction_ori, reconstruction = model.build_reconstruction(left_gt) 151 | right_recon = tf.slice(reconstruction, [0, 0, 128, 0], [args.batch_size_per_gpu, 128, 128, 3]) 152 | 153 | loss_rec = loss.masked_reconstruction_loss(groundtruth, reconstruction) 154 | loss_adv_G, loss_adv_D = loss.global_and_local_adv_loss(model, groundtruth, reconstruction) 155 | 156 | reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) 157 | loss_G = loss_adv_G * (1 - lambda_rec) + loss_rec * lambda_rec + sum(reg_losses) 158 | loss_D = loss_adv_D 159 | 160 | var_G = list(filter(lambda x: x.name.startswith( 161 | 'cpu_variables/GEN'), tf.trainable_variables())) 162 | var_D = list(filter(lambda x: x.name.startswith( 163 | 'cpu_variables/DIS'), tf.trainable_variables())) 164 | 165 | 166 | grad_g = train_op_G.compute_gradients( 167 | loss_G, var_list=var_G) 168 | grad_d = train_op_D.compute_gradients( 169 | loss_D, var_list=var_D) 170 | 171 | models.append((reconstruction, right_recon)) 172 | params.append(groundtruth) 173 | 174 | print('Done.') 175 | 176 | print('Start reducing towers on cpu...') 177 | 178 | reconstructions, right_recons = zip(*models) 179 | groundtruths = params 180 | 181 | with tf.device('/gpu:0'): 182 | 183 | reconstructions = tf.concat(reconstructions, axis=0) 184 | right_recons = tf.concat(right_recons, axis=0) 185 | 186 | print('Done.') 187 | 188 | 189 | iters = 0 190 | saver = tf.train.Saver(max_to_keep=5) 191 | if args.checkpoint_path is None: 192 | sess.run(tf.global_variables_initializer()) 193 | else: 194 | print('Start loading checkpoint...') 195 | saver.restore(sess, args.checkpoint_path) 196 | iters = args.resume_step 197 | print('Done.') 198 | 199 | 200 | 201 | 202 | print('run eval...') 203 | 204 | 205 | stitch_mask1 = np.ones((args.batch_size, 128, 128, 3)) 206 | for i in range(128): 207 | stitch_mask1[:, :, i, :] = 1. / 127. * (127. - i) 208 | stitch_mask2 = stitch_mask1[:, :, ::-1, :] 209 | 210 | 211 | ii = 0 212 | 213 | for _ in range(math.floor(args.testset_length / args.batch_size)): 214 | test_oris = sess.run([test_im])[0] 215 | origins1 = test_oris.copy() 216 | 217 | oris = None 218 | # oris 219 | print('oris ' + str(ii)) 220 | for _ in range(4): 221 | inp_dict = {} 222 | inp_dict = loss.feed_all_gpu(inp_dict, args.num_gpu, args.batch_size_per_gpu, test_oris, params) 223 | 224 | if oris is None: 225 | reconstruction_vals, prediction_vals = sess.run( 226 | [reconstructions, right_recons], 227 | feed_dict=inp_dict) 228 | 229 | oris = reconstruction_vals 230 | pred1 = oris[:, :, :128, :] 231 | pred2 = oris[:, :, -128:, :] 232 | gt = origins1[:, :, :128, :] 233 | p1_m0 = np.concatenate((gt, pred2), axis=2) 234 | p1_m1 = np.concatenate((gt * stitch_mask1 + pred1 * stitch_mask2, pred2), axis=2) 235 | else: 236 | reconstruction_vals, prediction_vals = sess.run( 237 | [reconstruction, right_recons], 238 | feed_dict=inp_dict) 239 | A = oris[:, :, -128:, :] 240 | B = reconstruction_vals[:, :, :128, :] 241 | C = A * stitch_mask1 + B * stitch_mask2 242 | oris = np.concatenate((oris[:, :, :-128, :], C, prediction_vals), axis=2) 243 | test_oris = np.concatenate((prediction_vals, prediction_vals), axis=2) 244 | predictions1 = oris 245 | 246 | jj = ii 247 | for ori, m0, m1, endless in zip(origins1, p1_m0, p1_m1, predictions1): 248 | name = str(jj) + '.jpg' 249 | ori = (255. * (ori + 1) / 2.).astype(np.uint8) 250 | Image.fromarray(ori).save(os.path.join( 251 | result_path, 'ori_' + name)) 252 | 253 | m0 = (255. * (m0 + 1) / 2.).astype(np.uint8) 254 | Image.fromarray(m0).save(os.path.join( 255 | result_path, 'm0_' + name)) 256 | 257 | m1 = (255. * (m1 + 1) / 2.).astype(np.uint8) 258 | Image.fromarray(m1).save(os.path.join( 259 | result_path, 'm1_' + name)) 260 | 261 | endless = (255. * (endless + 1) / 2.).astype(np.uint8) 262 | Image.fromarray(endless).save(os.path.join( 263 | result_path, 'endless_' + name)) 264 | jj += 1 265 | 266 | 267 | ii += args.batch_size 268 | -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | from glob import glob 4 | import numpy as np 5 | from PIL import Image 6 | import tensorflow as tf 7 | from tensorflow.python.training.moving_averages import assign_moving_average 8 | import tensorflow.contrib.layers as ly 9 | from modeling.model import Model 10 | from modeling.loss import Loss 11 | from dataset.parse import parse_trainset, parse_testset 12 | import argparse 13 | 14 | parser = argparse.ArgumentParser(description='Model training.') 15 | # experiment 16 | parser.add_argument('--date', type=str, default='0817') 17 | parser.add_argument('--exp-index', type=int, default=2) 18 | parser.add_argument('--f', action='store_true', default=False) 19 | 20 | # gpu 21 | parser.add_argument('--start-gpu', type=int, default=0) 22 | parser.add_argument('--num-gpu', type=int, default=2) 23 | 24 | # dataset 25 | parser.add_argument('--trainset-path', type=str, default='./dataset/trainset.tfr') 26 | parser.add_argument('--testset-path', type=str, default='./dataset/testset.tfr') 27 | parser.add_argument('--trainset-length', type=int, default=5041) 28 | parser.add_argument('--testset-length', type=int, default=2000) # we flip every image in testset 29 | 30 | # training 31 | parser.add_argument('--base-lr', type=float, default=0.0001) 32 | parser.add_argument('--batch-size', type=int, default=32) 33 | parser.add_argument('--weight-decay', type=float, default=0.00002) 34 | parser.add_argument('--epoch', type=int, default=1500) 35 | parser.add_argument('--lr-decay-epoch', type=int, default=1000) 36 | parser.add_argument('--critic-steps', type=int, default=3) 37 | parser.add_argument('--warmup-steps', type=int, default=1000) 38 | parser.add_argument('--workers', type=int, default=2) 39 | parser.add_argument('--clip-gradient', action='store_true', default=False) 40 | parser.add_argument('--clip-gradient-value', type=float, default=0.1) 41 | 42 | 43 | # modeling 44 | parser.add_argument('--beta', type=float, default=0.9) 45 | parser.add_argument('--lambda-gp', type=float, default=10) 46 | parser.add_argument('--lambda-rec', type=float, default=0.998) 47 | 48 | # checkpoint 49 | parser.add_argument('--log-path', type=str, default='./logs/') 50 | parser.add_argument('--checkpoint-path', type=str, default=None) 51 | parser.add_argument('--resume-step', type=int, default=0) 52 | 53 | 54 | args = parser.parse_args() 55 | 56 | 57 | # prepare path 58 | base_path = args.log_path 59 | exp_date = args.date 60 | if exp_date is None: 61 | print('Exp date error!') 62 | import sys 63 | sys.exit() 64 | exp_name = exp_date + '/' + str(args.exp_index) 65 | print("Start Exp:", exp_name) 66 | output_path = base_path + exp_name + '/' 67 | model_path = output_path + 'models/' 68 | tensorboard_path = output_path + 'log/' 69 | result_path = output_path + 'results/' 70 | 71 | if not os.path.exists(model_path): 72 | os.makedirs(model_path) 73 | if not os.path.exists(tensorboard_path): 74 | os.makedirs(tensorboard_path) 75 | if not os.path.exists(result_path): 76 | os.makedirs(result_path) 77 | elif not args.f: 78 | if args.checkpoint_path is None: 79 | print('Exp exist!') 80 | import sys 81 | sys.exit() 82 | else: 83 | import shutil 84 | shutil.rmtree(model_path) 85 | os.makedirs(model_path) 86 | shutil.rmtree(tensorboard_path) 87 | os.makedirs(tensorboard_path) 88 | 89 | # prepare gpu 90 | num_gpu = args.num_gpu 91 | start_gpu = args.start_gpu 92 | gpu_id = str(start_gpu) 93 | for i in range(num_gpu - 1): 94 | gpu_id = gpu_id + ',' + str(start_gpu + i + 1) 95 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id) 96 | args.batch_size_per_gpu = int(args.batch_size / args.num_gpu) 97 | 98 | 99 | 100 | 101 | model = Model(args) 102 | loss = Loss(args) 103 | 104 | config = tf.ConfigProto(allow_soft_placement=True) 105 | config.gpu_options.allow_growth = True 106 | config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 107 | 108 | print("Start building model...") 109 | with tf.Session(config=config) as sess: 110 | with tf.device('/cpu:0'): 111 | learning_rate = tf.placeholder(tf.float32, []) 112 | lambda_rec = tf.placeholder(tf.float32, []) 113 | 114 | train_op_G = tf.train.AdamOptimizer( 115 | learning_rate=learning_rate, beta1=0.5, beta2=0.9) 116 | train_op_D = tf.train.AdamOptimizer( 117 | learning_rate=learning_rate, beta1=0.5, beta2=0.9) 118 | 119 | 120 | trainset = tf.data.TFRecordDataset(filenames=[args.trainset_path]) 121 | trainset = trainset.shuffle(args.trainset_length) 122 | trainset = trainset.map(parse_trainset, num_parallel_calls=args.workers) 123 | trainset = trainset.batch(args.batch_size).repeat() 124 | 125 | train_iterator = trainset.make_one_shot_iterator() 126 | train_im = train_iterator.get_next() 127 | 128 | testset = tf.data.TFRecordDataset(filenames=[args.testset_path]) 129 | testset = testset.map(parse_testset, num_parallel_calls=args.workers) 130 | testset = testset.batch(args.batch_size).repeat() 131 | 132 | test_iterator = testset.make_one_shot_iterator() 133 | test_im = test_iterator.get_next() 134 | 135 | print('build model on gpu tower') 136 | models = [] 137 | params = [] 138 | for gpu_id in range(num_gpu): 139 | with tf.device('/gpu:%d' % gpu_id): 140 | print('tower_%d' % gpu_id) 141 | with tf.name_scope('tower_%d' % gpu_id): 142 | with tf.variable_scope('cpu_variables', reuse=gpu_id > 0): 143 | 144 | groundtruth = tf.placeholder( 145 | tf.float32, [args.batch_size_per_gpu, 128, 256, 3], name='groundtruth') 146 | left_gt = tf.slice(groundtruth, [0, 0, 0, 0], [args.batch_size_per_gpu, 128, 128, 3]) 147 | 148 | 149 | reconstruction_ori, reconstruction = model.build_reconstruction(left_gt) 150 | right_recon = tf.slice(reconstruction, [0, 0, 128, 0], [args.batch_size_per_gpu, 128, 128, 3]) 151 | 152 | loss_rec = loss.masked_reconstruction_loss(groundtruth, reconstruction) 153 | loss_adv_G, loss_adv_D = loss.global_and_local_adv_loss(model, groundtruth, reconstruction) 154 | 155 | reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) 156 | loss_G = loss_adv_G * (1 - lambda_rec) + loss_rec * lambda_rec + sum(reg_losses) 157 | loss_D = loss_adv_D 158 | 159 | var_G = list(filter(lambda x: x.name.startswith( 160 | 'cpu_variables/GEN'), tf.trainable_variables())) 161 | var_D = list(filter(lambda x: x.name.startswith( 162 | 'cpu_variables/DIS'), tf.trainable_variables())) 163 | 164 | 165 | grad_g = train_op_G.compute_gradients( 166 | loss_G, var_list=var_G) 167 | grad_d = train_op_D.compute_gradients( 168 | loss_D, var_list=var_D) 169 | 170 | models.append((grad_g, grad_d, loss_G, loss_D, loss_adv_G, loss_rec, reconstruction)) 171 | params.append(groundtruth) 172 | 173 | print('Done.') 174 | 175 | print('Start reducing towers on cpu...') 176 | 177 | grad_gs, grad_ds, loss_Gs, loss_Ds, loss_adv_Gs, loss_recs, reconstructions = zip(*models) 178 | groundtruths = params 179 | 180 | with tf.device('/gpu:0'): 181 | aver_loss_g = tf.reduce_mean(loss_Gs) 182 | aver_loss_d = tf.reduce_mean(loss_Ds) 183 | aver_loss_ag = tf.reduce_mean(loss_adv_Gs) 184 | aver_loss_rec = tf.reduce_mean(loss_recs) 185 | 186 | train_op_G = train_op_G.apply_gradients( 187 | loss.average_gradients(grad_gs)) 188 | train_op_D = train_op_D.apply_gradients( 189 | loss.average_gradients(grad_ds)) 190 | 191 | groundtruths = tf.concat(groundtruths, axis=0) 192 | reconstructions = tf.concat(reconstructions, axis=0) 193 | 194 | tf.summary.scalar('loss_g', aver_loss_g) 195 | tf.summary.scalar('loss_d', aver_loss_d) 196 | tf.summary.scalar('loss_ag', aver_loss_ag) 197 | tf.summary.scalar('loss_rec', aver_loss_rec) 198 | tf.summary.image('groundtruth', groundtruths, 2) 199 | tf.summary.image('reconstruction', reconstructions, 2) 200 | 201 | merged = tf.summary.merge_all() 202 | writer = tf.summary.FileWriter(tensorboard_path, sess.graph) 203 | 204 | print('Done.') 205 | 206 | 207 | iters = 0 208 | saver = tf.train.Saver(max_to_keep=5) 209 | if args.checkpoint_path is None: 210 | sess.run(tf.global_variables_initializer()) 211 | else: 212 | print('Start loading checkpoint...') 213 | saver.restore(sess, args.checkpoint_path) 214 | iters = args.resume_step 215 | print('Done.') 216 | 217 | 218 | 219 | 220 | print('Start training...') 221 | 222 | for epoch in range(args.epoch): 223 | 224 | if epoch > args.lr_decay_epoch: 225 | learning_rate_val = args.base_lr / 10 226 | else: 227 | learning_rate_val = args.base_lr 228 | 229 | for start, end in zip( 230 | range(0, args.trainset_length, args.batch_size), 231 | range(args.batch_size, args.trainset_length, args.batch_size)): 232 | 233 | if iters == 0 and args.checkpoint_path is None: 234 | print('Start pretraining G!') 235 | for t in range(args.warmup_steps): 236 | if t % 20 == 0: 237 | print("Step:", t) 238 | images = sess.run([train_im])[0] 239 | if len(images) < args.batch_size: 240 | images = sess.run([train_im])[0] 241 | 242 | inp_dict = {} 243 | inp_dict = loss.feed_all_gpu(inp_dict, args.num_gpu, args.batch_size_per_gpu, images, params) 244 | inp_dict[learning_rate] = learning_rate_val 245 | inp_dict[lambda_rec] = 1. 246 | 247 | _ = sess.run( 248 | [train_op_G], 249 | feed_dict=inp_dict) 250 | print('Pre-train G Done!') 251 | 252 | if (iters < 25 and args.checkpoint_path is None) or iters % 500 == 0: 253 | n_cir = 30 254 | else: 255 | n_cir = args.critic_steps 256 | 257 | for t in range(n_cir): 258 | images = sess.run([train_im])[0] 259 | if len(images) < args.batch_size: 260 | images = sess.run([train_im])[0] 261 | 262 | inp_dict = {} 263 | inp_dict = loss.feed_all_gpu(inp_dict, args.num_gpu, args.batch_size_per_gpu, images, params) 264 | inp_dict[learning_rate] = learning_rate_val 265 | inp_dict[lambda_rec] = args.lambda_rec 266 | 267 | _ = sess.run( 268 | [train_op_D], 269 | feed_dict=inp_dict) 270 | 271 | 272 | if iters % 50 == 0: 273 | 274 | _, g_val, ag_val, rs, d_val = sess.run( 275 | [train_op_G, aver_loss_g, aver_loss_ag, merged, aver_loss_d], 276 | feed_dict=inp_dict) 277 | writer.add_summary(rs, iters) 278 | 279 | else: 280 | 281 | _, g_val, ag_val, d_val = sess.run( 282 | [train_op_G, aver_loss_g, aver_loss_ag, aver_loss_d], 283 | feed_dict=inp_dict) 284 | if iters % 20 == 0: 285 | print("Iter:", iters, 'loss_g:', g_val, 'loss_d:', d_val, 'loss_adv_g:', ag_val) 286 | 287 | iters += 1 288 | 289 | saver.save(sess, model_path, global_step=iters) 290 | 291 | # testing 292 | if epoch > 0: 293 | ii = 0 294 | g_vals = 0 295 | d_vals = 0 296 | ag_vals = 0 297 | n_batchs = 0 298 | for _ in range(int(args.testset_length / args.batch_size)): 299 | test_oris = sess.run([test_im])[0] 300 | if len(test_oris) < args.batch_size: 301 | test_oris = sess.run([test_im])[0] 302 | 303 | inp_dict = {} 304 | inp_dict = loss.feed_all_gpu(inp_dict, args.num_gpu, args.batch_size_per_gpu, test_oris, params) 305 | inp_dict[learning_rate] = learning_rate_val 306 | inp_dict[lambda_rec] = args.lambda_rec 307 | 308 | reconstruction_vals, g_val, d_val, ag_val = sess.run( 309 | [reconstruction, aver_loss_g, aver_loss_d, aver_loss_ag], 310 | feed_dict=inp_dict) 311 | 312 | g_vals += g_val 313 | d_vals += d_val 314 | ag_vals += ag_val 315 | n_batchs += 1 316 | 317 | # Save test result every 100 epochs 318 | if epoch % 100 == 0: 319 | 320 | for rec_val, test_ori in zip(reconstruction_vals, test_oris): 321 | rec_hid = (255. * (rec_val + 1) / 322 | 2.).astype(np.uint8) 323 | test_ori = (255. * (test_ori + 1) / 324 | 2.).astype(np.uint8) 325 | Image.fromarray(rec_hid).save(os.path.join( 326 | result_path, 'img_' + str(ii) + '.' + str(int(iters / 100)) + '.jpg')) 327 | if epoch == 0: 328 | Image.fromarray(test_ori).save( 329 | os.path.join(result_path, 'img_' + str(ii) + '.' + str(int(iters / 100)) + '.ori.jpg')) 330 | ii += 1 331 | g_vals /= n_batchs 332 | d_vals /= n_batchs 333 | ag_vals /= n_batchs 334 | 335 | summary = tf.Summary() 336 | summary.value.add(tag='eval/g', 337 | simple_value=g_vals) 338 | summary.value.add(tag='eval/d', 339 | simple_value=d_vals) 340 | summary.value.add(tag='eval/ag', 341 | simple_value=ag_vals) 342 | writer.add_summary(summary, iters) 343 | 344 | print("=========================================================================") 345 | print('loss_g:', g_val, 'loss_d:', d_val, 'loss_adv_g:', ag_val) 346 | print("=========================================================================") 347 | 348 | if np.isnan(reconstruction_vals.min()) or np.isnan(reconstruction_vals.max()): 349 | print("NaN detected!!") 350 | -------------------------------------------------------------------------------- /modeling/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.layers as ly 3 | 4 | 5 | class Model(): 6 | def __init__(self, cfg): 7 | self.cfg = cfg 8 | 9 | def new_atrous_conv_layer(self, bottom, filter_shape, rate, name=None): 10 | with tf.variable_scope(name): 11 | regularizer = tf.contrib.layers.l2_regularizer(self.cfg.weight_decay) 12 | initializer = tf.contrib.layers.xavier_initializer() 13 | W = tf.get_variable( 14 | "W", 15 | shape=filter_shape, 16 | regularizer=regularizer, 17 | initializer=initializer) 18 | 19 | x = tf.nn.atrous_conv2d( 20 | bottom, W, rate, padding='SAME') 21 | return x 22 | 23 | def identity_block(self, X_input, kernel_size, filters, stage, block, is_relu=False): 24 | 25 | if is_relu: 26 | activation_fn=tf.nn.relu 27 | 28 | else: 29 | activation_fn=self.leaky_relu 30 | 31 | normalizer_fn = ly.instance_norm 32 | 33 | 34 | # defining name basis 35 | conv_name_base = 'res' + str(stage) + block + '_branch' 36 | 37 | with tf.variable_scope("id_block_stage" + str(stage) + block): 38 | filter1, filter2, filter3 = filters 39 | X_shortcut = X_input 40 | regularizer = tf.contrib.layers.l2_regularizer(self.cfg.weight_decay) 41 | initializer = tf.contrib.layers.xavier_initializer() 42 | 43 | # First component of main path 44 | x = tf.layers.conv2d(X_input, filter1, 45 | kernel_size=(1, 1), strides=(1, 1), name=conv_name_base + '2a', kernel_regularizer=regularizer, kernel_initializer=initializer, use_bias=False) 46 | x = normalizer_fn(x) 47 | x = activation_fn(x) 48 | 49 | # Second component of main path 50 | x = tf.layers.conv2d(x, filter2, (kernel_size, kernel_size), 51 | padding='same', name=conv_name_base + '2b', kernel_regularizer=regularizer, kernel_initializer=initializer, use_bias=False) 52 | x = normalizer_fn(x) 53 | x = activation_fn(x) 54 | 55 | # Third component of main path 56 | x = tf.layers.conv2d(x, filter3, kernel_size=( 57 | 1, 1), name=conv_name_base + '2c', kernel_regularizer=regularizer, kernel_initializer=initializer, use_bias=False) 58 | x = normalizer_fn(x) 59 | 60 | # Final step: Add shortcut value to main path, and pass it through 61 | x = tf.add(x, X_shortcut) 62 | x = activation_fn(x) 63 | 64 | return x 65 | 66 | def convolutional_block(self, X_input, kernel_size, filters, stage, block, stride=2, is_relu=False): 67 | 68 | if is_relu: 69 | activation_fn=tf.nn.relu 70 | 71 | else: 72 | activation_fn=self.leaky_relu 73 | 74 | normalizer_fn = ly.instance_norm 75 | 76 | # defining name basis 77 | conv_name_base = 'res' + str(stage) + block + '_branch' 78 | 79 | with tf.variable_scope("conv_block_stage" + str(stage) + block): 80 | 81 | regularizer = tf.contrib.layers.l2_regularizer(self.cfg.weight_decay) 82 | initializer = tf.contrib.layers.xavier_initializer() 83 | # initializer = tf.variance_scaling_initializer(scale=1.0,mode='fan_in') 84 | 85 | # Retrieve Filters 86 | filter1, filter2, filter3 = filters 87 | 88 | # Save the input value 89 | X_shortcut = X_input 90 | 91 | # First component of main path 92 | x = tf.layers.conv2d(X_input, filter1, 93 | kernel_size=(1, 1), 94 | strides=(1, 1), 95 | name=conv_name_base + '2a', kernel_regularizer=regularizer, kernel_initializer=initializer, use_bias=False) 96 | x = normalizer_fn(x) 97 | x = activation_fn(x) 98 | 99 | # Second component of main path 100 | x = tf.layers.conv2d(x, filter2, (kernel_size, kernel_size), strides=(stride, stride), name=conv_name_base + 101 | '2b', padding='same', kernel_regularizer=regularizer, kernel_initializer=initializer, use_bias=False) 102 | x = normalizer_fn(x) 103 | x = activation_fn(x) 104 | 105 | # Third component of main path 106 | x = tf.layers.conv2d(x, filter3, (1, 1), name=conv_name_base + '2c', 107 | kernel_regularizer=regularizer, kernel_initializer=initializer, use_bias=False) 108 | x = normalizer_fn(x) 109 | 110 | 111 | # SHORTCUT PATH 112 | X_shortcut = tf.layers.conv2d(X_shortcut, filter3, (1, 1), 113 | strides=(stride, stride), name=conv_name_base + '1', kernel_regularizer=regularizer, kernel_initializer=initializer, use_bias=False) 114 | X_shortcut = normalizer_fn(X_shortcut) 115 | 116 | # Final step: Add shortcut value to main path, and pass it through 117 | # a RELU activation 118 | x = tf.add(X_shortcut, x) 119 | x = activation_fn(x) 120 | 121 | return x 122 | 123 | def leaky_relu(self, x, name=None, leak=0.2): 124 | f1 = 0.5 * (1 + leak) 125 | f2 = 0.5 * (1 - leak) 126 | return f1 * x + f2 * abs(x) 127 | 128 | def in_lrelu(self, x, name=None): 129 | x = tf.contrib.layers.instance_norm(x) 130 | x = self.leaky_relu(x) 131 | return x 132 | 133 | def in_relu(self, x, name=None): 134 | x = tf.contrib.layers.instance_norm(x) 135 | x = tf.nn.relu(x) 136 | return x 137 | 138 | def rct(self, x): 139 | regularizer = tf.contrib.layers.l2_regularizer(self.cfg.weight_decay) 140 | output_size = x.get_shape().as_list()[3] 141 | size = 512 142 | layer_num = 2 143 | activation_fn = tf.tanh 144 | x = ly.conv2d(x, size, 1, stride=1, activation_fn=None, 145 | normalizer_fn=None, padding='SAME', weights_regularizer=regularizer, biases_initializer=None) 146 | x = self.in_lrelu(x) 147 | x = tf.transpose(x, [0, 2, 1, 3]) 148 | x = tf.reshape(x, [-1, 4, 4 * size]) 149 | x = tf.transpose(x, [1, 0, 2]) 150 | # encoder_inputs = x 151 | x = tf.reshape(x, [-1, 4 * size]) 152 | x_split = tf.split(x, 4, 0) 153 | 154 | ys = [] 155 | with tf.variable_scope('LSTM'): 156 | with tf.variable_scope('encoder'): 157 | lstm_cell = tf.contrib.rnn.LSTMCell( 158 | 4 * size, activation=activation_fn) 159 | lstm_cell = tf.contrib.rnn.MultiRNNCell( 160 | [lstm_cell] * layer_num, state_is_tuple=True) 161 | 162 | init_state = lstm_cell.zero_state(self.cfg.batch_size_per_gpu, dtype=tf.float32) 163 | now, _state = lstm_cell(x_split[0], init_state) 164 | now, _state = lstm_cell(x_split[1], _state) 165 | now, _state = lstm_cell(x_split[2], _state) 166 | now, _state = lstm_cell(x_split[3], _state) 167 | 168 | with tf.variable_scope('decoder'): 169 | lstm_cell = tf.contrib.rnn.BasicLSTMCell( 170 | 4 * size, activation=activation_fn) 171 | lstm_cell2 = tf.contrib.rnn.MultiRNNCell( 172 | [lstm_cell] * layer_num, state_is_tuple=True) 173 | #predict 174 | now, _state = lstm_cell2(x_split[3], _state) 175 | ys.append(tf.reshape(now, [-1, 4, 1, size])) 176 | now, _state = lstm_cell2(now, _state) 177 | ys.append(tf.reshape(now, [-1, 4, 1, size])) 178 | now, _state = lstm_cell2(now, _state) 179 | ys.append(tf.reshape(now, [-1, 4, 1, size])) 180 | now, _state = lstm_cell2(now, _state) 181 | ys.append(tf.reshape(now, [-1, 4, 1, size])) 182 | 183 | 184 | y = tf.concat(ys, axis=2) 185 | 186 | y = ly.conv2d(y, output_size, 1, stride=1, activation_fn=None, 187 | normalizer_fn=None, padding='SAME', weights_regularizer=regularizer, biases_initializer=None) 188 | y = self.in_lrelu(y) 189 | return y 190 | 191 | 192 | 193 | def shc(self, x, shortcut, channels): 194 | regularizer = tf.contrib.layers.l2_regularizer(self.cfg.weight_decay) 195 | x = ly.conv2d(x, channels / 2, 1, stride=1, activation_fn=tf.nn.relu, 196 | normalizer_fn=tf.contrib.layers.instance_norm, padding='SAME', weights_regularizer=regularizer) 197 | x = ly.conv2d(x, channels / 2, 3, stride=1, activation_fn=tf.nn.relu, 198 | normalizer_fn=tf.contrib.layers.instance_norm, padding='SAME', weights_regularizer=regularizer) 199 | x = ly.conv2d(x, channels, 1, stride=1, activation_fn=None, 200 | normalizer_fn=tf.contrib.layers.instance_norm, padding='SAME', weights_regularizer=regularizer) 201 | return tf.add(shortcut, x) 202 | 203 | 204 | def grb(self, x, filters, rate, name): 205 | activation_fn = tf.nn.relu 206 | normalizer_fn = ly.instance_norm 207 | shortcut = x 208 | x1 = self.new_atrous_conv_layer(x, [3, 1, filters, filters], rate, name+'_a1') 209 | x1 = normalizer_fn(x1) 210 | x1 = activation_fn(x1) 211 | x1 = self.new_atrous_conv_layer(x1, [1, 7, filters, filters], rate, name+'_a2') 212 | x1 = normalizer_fn(x1) 213 | 214 | x2 = self.new_atrous_conv_layer(x, [1, 7, filters, filters], rate, name+'_b1') 215 | x2 = normalizer_fn(x2) 216 | x2 = activation_fn(x2) 217 | x2 = self.new_atrous_conv_layer(x2, [3, 1, filters, filters], rate, name+'_b2') 218 | x2 = normalizer_fn(x2) 219 | 220 | x = tf.add(shortcut, x1) 221 | x = tf.add(x, x2) 222 | x = activation_fn(x) 223 | return x 224 | 225 | def build_reconstruction(self, images, reuse=None): 226 | 227 | with tf.variable_scope('GEN', reuse=reuse): 228 | x = images 229 | normalizer_fn = ly.instance_norm 230 | regularizer = tf.contrib.layers.l2_regularizer(self.cfg.weight_decay) 231 | initializer = tf.contrib.layers.xavier_initializer() 232 | # stage 1 233 | 234 | x = tf.layers.conv2d(x, filters=64, kernel_size=(4, 4), strides=( 235 | 2, 2), name='conv0', kernel_regularizer=regularizer, padding='same', kernel_initializer=initializer, use_bias=False) 236 | x = self.in_lrelu(x) 237 | short_cut0 = x 238 | x = tf.layers.conv2d(x, filters=128, kernel_size=(4, 4), strides=( 239 | 2, 2), name='conv1', padding='same', kernel_regularizer=regularizer, kernel_initializer=initializer, use_bias=False) 240 | x = self.in_lrelu(x) 241 | short_cut1 = x 242 | 243 | # stage 2 244 | x = self.convolutional_block(x, kernel_size=3, filters=[ 245 | 64, 64, 256], stage=2, block='a', stride=2) 246 | x = self.identity_block( 247 | x, 3, [64, 64, 256], stage=2, block='b') 248 | x = self.identity_block( 249 | x, 3, [64, 64, 256], stage=2, block='c') 250 | short_cut2 = x 251 | 252 | # stage 3 253 | x = self.convolutional_block(x, kernel_size=3, filters=[128, 128, 512], 254 | stage=3, block='a', stride=2) 255 | x = self.identity_block( 256 | x, 3, [128, 128, 512], stage=3, block='b') 257 | x = self.identity_block( 258 | x, 3, [128, 128, 512], stage=3, block='c') 259 | x = self.identity_block( 260 | x, 3, [128, 128, 512], stage=3, block='d',) 261 | short_cut3 = x 262 | 263 | # stage 4 264 | x = self.convolutional_block(x, kernel_size=3, filters=[ 265 | 256, 256, 1024], stage=4, block='a', stride=2) 266 | x = self.identity_block( 267 | x, 3, [256, 256, 1024], stage=4, block='b') 268 | x = self.identity_block( 269 | x, 3, [256, 256, 1024], stage=4, block='c') 270 | x = self.identity_block( 271 | x, 3, [256, 256, 1024], stage=4, block='d') 272 | x = self.identity_block( 273 | x, 3, [256, 256, 1024], stage=4, block='e') 274 | short_cut4 = x 275 | 276 | # rct transfer 277 | train = self.rct(x) 278 | 279 | 280 | # stage -4 281 | train = tf.concat([short_cut4, train], axis=2) 282 | 283 | train = self.grb(train, 1024, 1, 't4') 284 | train = self.identity_block( 285 | train, 3, [256, 256, 1024], stage=-4, block='b', is_relu=True) 286 | train = self.identity_block( 287 | train, 3, [256, 256, 1024], stage=-4, block='c', is_relu=True) 288 | 289 | 290 | train = ly.conv2d_transpose(train, 512, 4, stride=2, 291 | activation_fn=None, normalizer_fn=normalizer_fn, padding='SAME', weights_initializer=initializer, weights_regularizer=regularizer, biases_initializer=None) 292 | sc, kp = tf.split(train, 2, axis=2) 293 | sc = tf.nn.relu(sc) 294 | merge = tf.concat([short_cut3, sc], axis=3) 295 | merge = self.shc(merge, short_cut3, 512) 296 | merge = self.in_relu(merge) 297 | train = tf.concat( 298 | [merge, kp], axis=2) 299 | 300 | 301 | # stage -3 302 | train = self.grb(train, 512, 2, 't3') 303 | train = self.identity_block( 304 | train, 3, [128, 128, 512], stage=-3, block='b', is_relu=True) 305 | train = self.identity_block( 306 | train, 3, [128, 128, 512], stage=-3, block='c', is_relu=True) 307 | train = self.identity_block( 308 | train, 3, [128, 128, 512], stage=-3, block='d', is_relu=True) 309 | 310 | 311 | 312 | train = ly.conv2d_transpose(train, 256, 4, stride=2, 313 | activation_fn=None, normalizer_fn=normalizer_fn, padding='SAME', weights_initializer=initializer, weights_regularizer=regularizer, biases_initializer=None) 314 | sc, kp = tf.split(train, 2, axis=2) 315 | sc = tf.nn.relu(sc) 316 | merge = tf.concat([short_cut2, sc], axis=3) 317 | merge = self.shc(merge, short_cut2, 256) 318 | merge = self.in_relu(merge) 319 | train = tf.concat( 320 | [merge, kp], axis=2) 321 | 322 | # stage -2 323 | train = self.grb(train, 256, 4, 't2') 324 | train = self.identity_block( 325 | train, 3, [64, 64, 256], stage=-2, block='b', is_relu=True) 326 | train = self.identity_block( 327 | train, 3, [64, 64, 256], stage=-2, block='c', is_relu=True) 328 | train = self.identity_block( 329 | train, 3, [64, 64, 256], stage=-2, block='d', is_relu=True) 330 | train = self.identity_block( 331 | train, 3, [64, 64, 256], stage=-2, block='e', is_relu=True) 332 | 333 | train = ly.conv2d_transpose(train, 128, 4, stride=2, 334 | activation_fn=None, normalizer_fn=normalizer_fn, padding='SAME', weights_initializer=initializer, weights_regularizer=regularizer, biases_initializer=None) 335 | sc, kp = tf.split(train, 2, axis=2) 336 | sc = tf.nn.relu(sc) 337 | merge = tf.concat([short_cut1, sc], axis=3) 338 | merge = self.shc(merge, short_cut1, 128) 339 | merge = self.in_relu(merge) 340 | train = tf.concat( 341 | [merge, kp], axis=2) 342 | 343 | 344 | # stage -1 345 | 346 | train = ly.conv2d_transpose(train, 64, 4, stride=2, 347 | activation_fn=None, normalizer_fn=normalizer_fn, padding='SAME', weights_initializer=initializer, weights_regularizer=regularizer, biases_initializer=None) 348 | sc, kp = tf.split(train, 2, axis=2) 349 | sc = tf.nn.relu(sc) 350 | merge = tf.concat([short_cut0, sc], axis=3) 351 | merge = self.shc(merge, short_cut0, 64) 352 | merge = self.in_relu(merge) 353 | train = tf.concat( 354 | [merge, kp], axis=2) 355 | 356 | # stage -0 357 | recon = ly.conv2d_transpose(train, 3, 4, stride=2, 358 | activation_fn=None, padding='SAME', weights_initializer=initializer, weights_regularizer=regularizer, biases_initializer=None) 359 | 360 | return recon, tf.nn.tanh(recon) 361 | 362 | def build_adversarial_global(self, img, reuse=None, name=None): 363 | bs = img.get_shape().as_list()[0] 364 | with tf.variable_scope(name, reuse=reuse): 365 | 366 | def lrelu(x, leak=0.2, name="lrelu"): 367 | with tf.variable_scope(name): 368 | f1 = 0.5 * (1 + leak) 369 | f2 = 0.5 * (1 - leak) 370 | return f1 * x + f2 * abs(x) 371 | 372 | size = 128 373 | normalizer_fn = ly.instance_norm 374 | activation_fn = lrelu 375 | 376 | img = ly.conv2d(img, num_outputs=size / 2, kernel_size=4, 377 | stride=2, activation_fn=activation_fn) 378 | img = ly.conv2d(img, num_outputs=size, kernel_size=4, 379 | stride=2, activation_fn=activation_fn, normalizer_fn=normalizer_fn) 380 | img = ly.conv2d(img, num_outputs=size * 2, kernel_size=4, 381 | stride=2, activation_fn=activation_fn, normalizer_fn=normalizer_fn) 382 | img = ly.conv2d(img, num_outputs=size * 4, kernel_size=4, 383 | stride=2, activation_fn=activation_fn, normalizer_fn=normalizer_fn) 384 | img = ly.conv2d(img, num_outputs=size * 4, kernel_size=4, 385 | stride=2, activation_fn=activation_fn, normalizer_fn=normalizer_fn) 386 | 387 | logit = ly.fully_connected(tf.reshape( 388 | img, [bs, -1]), 1, activation_fn=None) 389 | 390 | return logit 391 | 392 | def build_adversarial_local(self, img, reuse=None, name=None): 393 | bs = img.get_shape().as_list()[0] 394 | with tf.variable_scope(name, reuse=reuse): 395 | 396 | def lrelu(x, leak=0.2, name="lrelu"): 397 | with tf.variable_scope(name): 398 | f1 = 0.5 * (1 + leak) 399 | f2 = 0.5 * (1 - leak) 400 | return f1 * x + f2 * abs(x) 401 | 402 | size = 128 403 | normalizer_fn = ly.instance_norm 404 | activation_fn = lrelu 405 | 406 | img = ly.conv2d(img, num_outputs=size / 2, kernel_size=4, 407 | stride=2, activation_fn=activation_fn) 408 | img = ly.conv2d(img, num_outputs=size, kernel_size=4, 409 | stride=2, activation_fn=activation_fn, normalizer_fn=normalizer_fn) 410 | img = ly.conv2d(img, num_outputs=size * 2, kernel_size=4, 411 | stride=2, activation_fn=activation_fn, normalizer_fn=normalizer_fn) 412 | img = ly.conv2d(img, num_outputs=size * 2, kernel_size=4, 413 | stride=2, activation_fn=activation_fn, normalizer_fn=normalizer_fn) 414 | 415 | logit = ly.fully_connected(tf.reshape( 416 | img, [bs, -1]), 1, activation_fn=None) 417 | 418 | return logit 419 | 420 | 421 | --------------------------------------------------------------------------------