├── examples
├── 1.png
├── 2.png
└── 3.png
├── train.sh
├── eval.sh
├── LICENSE
├── dataset
├── parse.py
└── build_dataset.py
├── README.md
├── modeling
├── loss.py
└── model.py
├── eval_model.py
└── train_model.py
/examples/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/z-x-yang/NS-Outpainting/HEAD/examples/1.png
--------------------------------------------------------------------------------
/examples/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/z-x-yang/NS-Outpainting/HEAD/examples/2.png
--------------------------------------------------------------------------------
/examples/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/z-x-yang/NS-Outpainting/HEAD/examples/3.png
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | python train_model.py --trainset-path /path/to/tf-record-trainset --testset-path /path/to/tf-record-testset
--------------------------------------------------------------------------------
/eval.sh:
--------------------------------------------------------------------------------
1 | python eval_model.py --trainset-path /path/to/tf-record-trainset --testset-path /path/to/tf-record-testset --checkpoint-path /path/to/checkpoint
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 z-x-yang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/dataset/parse.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def parse_trainset(example_proto):
5 |
6 | dics = {}
7 | dics['image'] = tf.FixedLenFeature(shape=[], dtype=tf.string)
8 |
9 | parsed_example = tf.parse_single_example(
10 | serialized=example_proto, features=dics)
11 | image = tf.decode_raw(parsed_example['image'], out_type=tf.uint8)
12 |
13 | image = tf.reshape(image, shape=[72 * 2, 216 * 2, 3])
14 |
15 | image = tf.random_crop(image, [64 * 2, 128 * 2, 3])
16 | image = tf.image.random_flip_left_right(image)
17 | image = tf.cast(image, tf.float32) / 255.
18 | image = 2. * image - 1.
19 |
20 | return image
21 |
22 |
23 | def parse_testset(example_proto):
24 |
25 | dics = {}
26 | dics['image'] = tf.FixedLenFeature(shape=[], dtype=tf.string)
27 |
28 | parsed_example = tf.parse_single_example(
29 | serialized=example_proto, features=dics)
30 | image = tf.decode_raw(parsed_example['image'], out_type=tf.uint8)
31 |
32 | image = tf.reshape(image, shape=[64 * 2, 128 * 2, 3])
33 |
34 | image = tf.cast(image, tf.float32) * (2. / 255) - 1.0
35 |
36 | return image
37 |
38 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Very Long Natural Scenery Image Prediction by Outpainting (NS-Outpainting)
2 | A neural architecture for scenery image outpaiting ([ICCV 2019](http://openaccess.thecvf.com/content_ICCV_2019/papers/Yang_Very_Long_Natural_Scenery_Image_Prediction_by_Outpainting_ICCV_2019_paper.pdf)), implemented in [TensorFlow](http://www.tensorflow.org).
3 |
4 | The architecture has an ability to generate a very long high-quality prediction from a small input image by outpaiting:
5 |
6 |
7 |
8 |
9 | ## Requirements and Preparation
10 |
11 | Please install `TensorFlow>=1.3.0`, `Python>=3.6`.
12 |
13 | For training and testing, we collect a new outpainting dataset, which has 6,000 images containing complex natural scenes. You can download the raw dataset from [here](https://drive.google.com/file/d/15rGKgeNHWqjs90An7wpZXJMz-zFaC1q0/view?usp=sharing) and split the training and testing set by yourself. Or, you can get our split from [here](https://drive.google.com/file/d/1LDRx0W6zo_eCZwN92pGgGZSCrqzB3KZ6/view?usp=sharing) (TFRecord format, 128 resolution, 5,000 images for training and 1,000 for testing).
14 |
15 | ## Usage
16 |
17 | For training and evaluation, you can use [train.sh](/train.sh) and [eval.sh](/eval.sh). Please remember to set the TFRecord dataset path inside them.
18 |
19 | Besides, you can get our **pretrain model** from [here](https://drive.google.com/file/d/1-DLSwNkB93MMKaYVO1rmPP9iJllXDJrg/view?usp=sharing), and run eval_model.py to evaluate it.
20 |
21 | After running eval_model.py, the evaluation process will store 4 types of images:
22 | 1) "ori_xxx.jpg", the groundtruth images of size 128x256;
23 | 2) "m0_xxx.jpg", the 1-step predictions of size 128x256 without any post-processing methods;
24 | 3) "m1_xxx.jpg", the 1-step predictions of size 128x256 with smoothly stitching;
25 | 4) "endless_xxx.jpg", the 4-step predictions of size 128x640.
26 |
27 | Notably, we measure Inception Score and Inception Distance between "ori_xxx.jpg" and "m0_xxx.jpg" in our paper.
28 |
29 | ## Citation
30 | ```
31 | @inproceedings{yang2019very,
32 | title={Very Long Natural Scenery Image Prediction by Outpainting},
33 | author={Yang, Zongxin and Dong, Jian and Liu, Ping and Yang, Yi and Yan, Shuicheng},
34 | booktitle={Proceedings of the IEEE International Conference on Computer Vision},
35 | pages={10561--10570},
36 | year={2019}
37 | }
38 | ```
39 |
--------------------------------------------------------------------------------
/dataset/build_dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | import os
3 | from glob import glob
4 | import numpy as np
5 | from PIL import Image
6 | import tensorflow as tf
7 | import argparse
8 |
9 | parser = argparse.ArgumentParser(description='Model training.')
10 | parser.add_argument('--dataset-path', type=str, default='./scenery/')
11 | parser.add_argument('--result-path', type=str, default='./')
12 |
13 | args = parser.parse_args()
14 | dataset_path = args.dataset_path
15 | result_path = args.result_path
16 |
17 |
18 | if not os.path.exists(result_path):
19 | os.makedirs(result_path)
20 |
21 | train_list = os.listdir(dataset_path)
22 | random.shuffle(train_list)
23 | trainset = list(map(lambda x: os.path.join(
24 | dataset_path, x), train_list))
25 |
26 | testset = trainset[0:1000]
27 | trainset = trainset[1000:]
28 |
29 |
30 | def build_trainset(image_list, name):
31 | len2 = len(image_list)
32 | print("len=", len2)
33 | writer = tf.python_io.TFRecordWriter(name)
34 | k = 0
35 | for i in range(len2):
36 |
37 | image = Image.open(image_list[i])
38 | image = image.resize((432, 144), Image.BILINEAR)
39 | image = image.convert('RGB')
40 |
41 | image_bytes = image.tobytes()
42 |
43 | features = {}
44 |
45 | features['image'] = tf.train.Feature(
46 | bytes_list=tf.train.BytesList(value=[image_bytes]))
47 |
48 | tf_features = tf.train.Features(feature=features)
49 |
50 | tf_example = tf.train.Example(features=tf_features)
51 |
52 | tf_serialized = tf_example.SerializeToString()
53 |
54 | writer.write(tf_serialized)
55 | k = k + 1
56 | print(k)
57 | writer.close()
58 |
59 |
60 | def build_testset(image_list, name):
61 | len2 = len(image_list)
62 | print("len=", len2)
63 | writer = tf.python_io.TFRecordWriter(name)
64 | for i in range(len2):
65 |
66 | image = Image.open(image_list[i])
67 | image = image.resize((256, 128), Image.BILINEAR)
68 | image = image.convert('RGB')
69 |
70 | image_flip = image.transpose(Image.FLIP_LEFT_RIGHT)
71 |
72 | image_bytes = image.tobytes()
73 |
74 | features = {}
75 |
76 | features['image'] = tf.train.Feature(
77 | bytes_list=tf.train.BytesList(value=[image_bytes]))
78 |
79 | tf_features = tf.train.Features(feature=features)
80 |
81 | tf_example = tf.train.Example(features=tf_features)
82 |
83 | tf_serialized = tf_example.SerializeToString()
84 |
85 | writer.write(tf_serialized)
86 |
87 | # flip image
88 | image = image_flip
89 |
90 | image_bytes = image.tobytes()
91 |
92 | features = {}
93 |
94 | features['image'] = tf.train.Feature(
95 | bytes_list=tf.train.BytesList(value=[image_bytes]))
96 |
97 | tf_features = tf.train.Features(feature=features)
98 |
99 | tf_example = tf.train.Example(features=tf_features)
100 |
101 | tf_serialized = tf_example.SerializeToString()
102 |
103 | writer.write(tf_serialized)
104 |
105 | writer.close()
106 |
107 |
108 | print('Build testset!')
109 | build_testset(testset, result_path + "/testset.tfr")
110 | print('Build trainset!')
111 | build_trainset(trainset, result_path + "/trainset.tfr")
112 |
113 | print('Done!')
114 |
--------------------------------------------------------------------------------
/modeling/loss.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import math
4 |
5 | class Loss():
6 | def __init__(self, cfg):
7 | self.cfg = cfg
8 |
9 | def masked_reconstruction_loss(self, gt, recon):
10 | loss_recon = tf.square(gt - recon)
11 | mask_values = np.ones((128, 128))
12 | for j in range(128):
13 | mask_values[:, j] = (1. + math.cos(math.pi * j / 127.0)) * 0.5
14 | mask_values = np.expand_dims(mask_values, 0)
15 | mask_values = np.expand_dims(mask_values, 3)
16 | mask1 = tf.constant(1, dtype=tf.float32, shape=[1, 128, 128, 1])
17 | mask2 = tf.constant(mask_values, dtype=tf.float32, shape=[1, 128, 128, 1])
18 | mask = tf.concat([mask1, mask2], axis=2)
19 | loss_recon = loss_recon * mask
20 | loss_recon = tf.reduce_mean(loss_recon)
21 | return loss_recon
22 |
23 | def adversarial_loss(self, dis_fun, real, fake, name):
24 | adversarial_pos = dis_fun(real, name=name)
25 | adversarial_neg = dis_fun(fake, reuse=tf.AUTO_REUSE, name=name)
26 |
27 | loss_adv_D = - tf.reduce_mean(adversarial_pos - adversarial_neg)
28 |
29 | differences = fake - real
30 | alpha = tf.random_uniform(shape=[self.cfg.batch_size_per_gpu, 1, 1, 1])
31 | interpolates = real + tf.multiply(alpha, differences)
32 | gradients = tf.gradients(dis_fun(
33 | interpolates, reuse=tf.AUTO_REUSE, name=name), [interpolates])[0]
34 | slopes = tf.sqrt(tf.reduce_sum(
35 | tf.square(gradients), [1, 2, 3]) + 1e-10)
36 | gradients_penalty = tf.reduce_mean((slopes - 1.) ** 2)
37 | loss_adv_D += self.cfg.lambda_gp * gradients_penalty
38 |
39 | loss_adv_G = -tf.reduce_mean(adversarial_neg)
40 |
41 | return loss_adv_D, loss_adv_G
42 |
43 | def global_adversarial_loss(self, dis_fun, real, fake):
44 | return self.adversarial_loss(dis_fun, real, fake, 'DIS')
45 |
46 | def local_adversarial_loss(self, dis_fun, real, fake):
47 | return self.adversarial_loss(dis_fun, real, fake, 'DIS2')
48 |
49 |
50 | def global_and_local_adv_loss(self, model, gt, recon):
51 |
52 | left_half_gt = tf.slice(gt, [0, 0, 0, 0], [self.cfg.batch_size_per_gpu, 128, 128, 3])
53 | right_half_gt = tf.slice(gt, [0, 0, 128, 0], [self.cfg.batch_size_per_gpu, 128, 128, 3])
54 | right_half_recon = tf.slice(recon, [0, 0, 128, 0], [self.cfg.batch_size_per_gpu, 128, 128, 3])
55 | real = gt
56 | fake = tf.concat([left_half_gt, right_half_recon], axis=2)
57 | global_D, global_G = self.global_adversarial_loss(model.build_adversarial_global, real, fake)
58 |
59 | real = right_half_gt
60 | fake = right_half_recon
61 | local_D, local_G = self.local_adversarial_loss(model.build_adversarial_local, real, fake)
62 |
63 | loss_adv_D = global_D + local_D
64 | loss_adv_G = self.cfg.beta * global_G + (1 - self.cfg.beta) * local_G
65 |
66 | return loss_adv_G, loss_adv_D
67 |
68 |
69 |
70 | def average_losses(self, loss):
71 | tf.add_to_collection('losses', loss)
72 |
73 | # Assemble all of the losses for the current tower only.
74 | losses = tf.get_collection('losses')
75 |
76 | # Calculate the total loss for the current tower.
77 | regularization_losses = tf.get_collection(
78 | tf.GraphKeys.REGULARIZATION_LOSSES)
79 | total_loss = tf.add_n(
80 | losses + regularization_losses, name='total_loss')
81 |
82 | # Compute the moving average of all individual losses and the total
83 | # loss.
84 | loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
85 | loss_averages_op = loss_averages.apply(losses + [total_loss])
86 |
87 | with tf.control_dependencies([loss_averages_op]):
88 | total_loss = tf.identity(total_loss)
89 | return total_loss
90 |
91 | def average_gradients(self, tower_grads):
92 | average_grads = []
93 | for grad_and_vars in zip(*tower_grads):
94 | # Note that each grad_and_vars looks like the following:
95 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
96 | grads = []
97 | # Average over the 'tower' dimension.
98 | g, _ = grad_and_vars[0]
99 |
100 | for g, _ in grad_and_vars:
101 | expanded_g = tf.expand_dims(g, 0)
102 | grads.append(expanded_g)
103 | grad = tf.concat(grads, axis=0)
104 | grad = tf.reduce_mean(grad, 0)
105 |
106 | # Keep in mind that the Variables are redundant because they are shared
107 | # across towers. So .. we will just return the first tower's pointer to
108 | # the Variable.
109 | v = grad_and_vars[0][1]
110 | grad_and_var = (grad, v)
111 | average_grads.append(grad_and_var)
112 | # clip
113 | if self.cfg.clip_gradient:
114 | gradients, variables = zip(*average_grads)
115 | gradients = [
116 | None if gradient is None else tf.clip_by_average_norm(gradient, self.cfg.clip_gradient_value)
117 | for gradient in gradients]
118 | average_grads = zip(gradients, variables)
119 | return average_grads
120 |
121 | def feed_all_gpu(self, inp_dict, gpu_num, payload_per_gpu, images, params):
122 | for i in range(gpu_num):
123 | gt = params[i]
124 | start_pos = i * payload_per_gpu
125 | stop_pos = (i + 1) * payload_per_gpu
126 | inp_dict[gt] = images[start_pos:stop_pos]
127 | return inp_dict
128 |
129 |
130 |
--------------------------------------------------------------------------------
/eval_model.py:
--------------------------------------------------------------------------------
1 | import random
2 | import os
3 | from glob import glob
4 | import numpy as np
5 | from PIL import Image
6 | import tensorflow as tf
7 | from tensorflow.python.training.moving_averages import assign_moving_average
8 | import tensorflow.contrib.layers as ly
9 | from modeling.model import Model
10 | from modeling.loss import Loss
11 | from dataset.parse import parse_trainset, parse_testset
12 | import argparse
13 | import math
14 |
15 | parser = argparse.ArgumentParser(description='Model testing.')
16 | # experiment
17 | parser.add_argument('--date', type=str, default='0817')
18 | parser.add_argument('--exp-index', type=int, default=2)
19 | parser.add_argument('--f', action='store_true', default=False)
20 |
21 | # gpu
22 | parser.add_argument('--start-gpu', type=int, default=0)
23 | parser.add_argument('--num-gpu', type=int, default=1)
24 |
25 | # dataset
26 | parser.add_argument('--trainset-path', type=str, default='./dataset/trainset.tfr')
27 | parser.add_argument('--testset-path', type=str, default='./dataset/testset.tfr')
28 | parser.add_argument('--trainset-length', type=int, default=5041)
29 | parser.add_argument('--testset-length', type=int, default=2000) # we flip every image in testset
30 |
31 | # training
32 | parser.add_argument('--base-lr', type=float, default=0.0001)
33 | parser.add_argument('--batch-size', type=int, default=20)
34 | parser.add_argument('--weight-decay', type=float, default=0.00002)
35 | parser.add_argument('--epoch', type=int, default=1500)
36 | parser.add_argument('--lr-decay-epoch', type=int, default=1000)
37 | parser.add_argument('--critic-steps', type=int, default=3)
38 | parser.add_argument('--warmup-steps', type=int, default=1000)
39 | parser.add_argument('--workers', type=int, default=2)
40 | parser.add_argument('--clip-gradient', action='store_true', default=False)
41 | parser.add_argument('--clip-gradient-value', type=float, default=0.1)
42 |
43 |
44 | # modeling
45 | parser.add_argument('--beta', type=float, default=0.9)
46 | parser.add_argument('--lambda-gp', type=float, default=10)
47 | parser.add_argument('--lambda-rec', type=float, default=0.998)
48 |
49 | # checkpoint
50 | parser.add_argument('--log-path', type=str, default='./logs/')
51 | parser.add_argument('--checkpoint-path', type=str, default=None)
52 | parser.add_argument('--resume-step', type=int, default=0)
53 |
54 |
55 | args = parser.parse_args()
56 |
57 |
58 | # prepare path
59 | base_path = args.log_path
60 | exp_date = args.date
61 | if exp_date is None:
62 | print('Exp date error!')
63 | import sys
64 | sys.exit()
65 | exp_name = exp_date + '/' + str(args.exp_index)
66 | print("Start Exp:", exp_name)
67 | output_path = base_path + exp_name + '/'
68 | model_path = output_path + 'models/'
69 | tensorboard_path = output_path + 'log/'
70 | result_path = output_path + 'results/'
71 |
72 | if not os.path.exists(model_path):
73 | os.makedirs(model_path)
74 | if not os.path.exists(tensorboard_path):
75 | os.makedirs(tensorboard_path)
76 | if not os.path.exists(result_path):
77 | os.makedirs(result_path)
78 | elif not args.f:
79 | if args.checkpoint_path is None:
80 | print('Exp exist!')
81 | import sys
82 | sys.exit()
83 | else:
84 | import shutil
85 | shutil.rmtree(model_path)
86 | os.makedirs(model_path)
87 | shutil.rmtree(tensorboard_path)
88 | os.makedirs(tensorboard_path)
89 |
90 | # prepare gpu
91 | num_gpu = args.num_gpu
92 | start_gpu = args.start_gpu
93 | gpu_id = str(start_gpu)
94 | for i in range(num_gpu - 1):
95 | gpu_id = gpu_id + ',' + str(start_gpu + i + 1)
96 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
97 | args.batch_size_per_gpu = int(args.batch_size / args.num_gpu)
98 |
99 |
100 |
101 |
102 | model = Model(args)
103 | loss = Loss(args)
104 |
105 | config = tf.ConfigProto(allow_soft_placement=True)
106 | config.gpu_options.allow_growth = True
107 | config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
108 |
109 | print("Start building model...")
110 | with tf.Session(config=config) as sess:
111 | with tf.device('/cpu:0'):
112 | learning_rate = tf.placeholder(tf.float32, [])
113 | lambda_rec = tf.placeholder(tf.float32, [])
114 |
115 | train_op_G = tf.train.AdamOptimizer(
116 | learning_rate=learning_rate, beta1=0.5, beta2=0.9)
117 | train_op_D = tf.train.AdamOptimizer(
118 | learning_rate=learning_rate, beta1=0.5, beta2=0.9)
119 |
120 |
121 | trainset = tf.data.TFRecordDataset(filenames=[args.trainset_path])
122 | trainset = trainset.shuffle(args.trainset_length)
123 | trainset = trainset.map(parse_trainset, num_parallel_calls=args.workers)
124 | trainset = trainset.batch(args.batch_size).repeat()
125 |
126 | train_iterator = trainset.make_one_shot_iterator()
127 | train_im = train_iterator.get_next()
128 |
129 | testset = tf.data.TFRecordDataset(filenames=[args.testset_path])
130 | testset = testset.map(parse_testset, num_parallel_calls=args.workers)
131 | testset = testset.batch(args.batch_size).repeat()
132 |
133 | test_iterator = testset.make_one_shot_iterator()
134 | test_im = test_iterator.get_next()
135 |
136 | print('build model on gpu tower')
137 | models = []
138 | params = []
139 | for gpu_id in range(num_gpu):
140 | with tf.device('/gpu:%d' % gpu_id):
141 | print('tower_%d' % gpu_id)
142 | with tf.name_scope('tower_%d' % gpu_id):
143 | with tf.variable_scope('cpu_variables', reuse=gpu_id > 0):
144 |
145 | groundtruth = tf.placeholder(
146 | tf.float32, [args.batch_size_per_gpu, 128, 256, 3], name='groundtruth')
147 | left_gt = tf.slice(groundtruth, [0, 0, 0, 0], [args.batch_size_per_gpu, 128, 128, 3])
148 |
149 |
150 | reconstruction_ori, reconstruction = model.build_reconstruction(left_gt)
151 | right_recon = tf.slice(reconstruction, [0, 0, 128, 0], [args.batch_size_per_gpu, 128, 128, 3])
152 |
153 | loss_rec = loss.masked_reconstruction_loss(groundtruth, reconstruction)
154 | loss_adv_G, loss_adv_D = loss.global_and_local_adv_loss(model, groundtruth, reconstruction)
155 |
156 | reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
157 | loss_G = loss_adv_G * (1 - lambda_rec) + loss_rec * lambda_rec + sum(reg_losses)
158 | loss_D = loss_adv_D
159 |
160 | var_G = list(filter(lambda x: x.name.startswith(
161 | 'cpu_variables/GEN'), tf.trainable_variables()))
162 | var_D = list(filter(lambda x: x.name.startswith(
163 | 'cpu_variables/DIS'), tf.trainable_variables()))
164 |
165 |
166 | grad_g = train_op_G.compute_gradients(
167 | loss_G, var_list=var_G)
168 | grad_d = train_op_D.compute_gradients(
169 | loss_D, var_list=var_D)
170 |
171 | models.append((reconstruction, right_recon))
172 | params.append(groundtruth)
173 |
174 | print('Done.')
175 |
176 | print('Start reducing towers on cpu...')
177 |
178 | reconstructions, right_recons = zip(*models)
179 | groundtruths = params
180 |
181 | with tf.device('/gpu:0'):
182 |
183 | reconstructions = tf.concat(reconstructions, axis=0)
184 | right_recons = tf.concat(right_recons, axis=0)
185 |
186 | print('Done.')
187 |
188 |
189 | iters = 0
190 | saver = tf.train.Saver(max_to_keep=5)
191 | if args.checkpoint_path is None:
192 | sess.run(tf.global_variables_initializer())
193 | else:
194 | print('Start loading checkpoint...')
195 | saver.restore(sess, args.checkpoint_path)
196 | iters = args.resume_step
197 | print('Done.')
198 |
199 |
200 |
201 |
202 | print('run eval...')
203 |
204 |
205 | stitch_mask1 = np.ones((args.batch_size, 128, 128, 3))
206 | for i in range(128):
207 | stitch_mask1[:, :, i, :] = 1. / 127. * (127. - i)
208 | stitch_mask2 = stitch_mask1[:, :, ::-1, :]
209 |
210 |
211 | ii = 0
212 |
213 | for _ in range(math.floor(args.testset_length / args.batch_size)):
214 | test_oris = sess.run([test_im])[0]
215 | origins1 = test_oris.copy()
216 |
217 | oris = None
218 | # oris
219 | print('oris ' + str(ii))
220 | for _ in range(4):
221 | inp_dict = {}
222 | inp_dict = loss.feed_all_gpu(inp_dict, args.num_gpu, args.batch_size_per_gpu, test_oris, params)
223 |
224 | if oris is None:
225 | reconstruction_vals, prediction_vals = sess.run(
226 | [reconstructions, right_recons],
227 | feed_dict=inp_dict)
228 |
229 | oris = reconstruction_vals
230 | pred1 = oris[:, :, :128, :]
231 | pred2 = oris[:, :, -128:, :]
232 | gt = origins1[:, :, :128, :]
233 | p1_m0 = np.concatenate((gt, pred2), axis=2)
234 | p1_m1 = np.concatenate((gt * stitch_mask1 + pred1 * stitch_mask2, pred2), axis=2)
235 | else:
236 | reconstruction_vals, prediction_vals = sess.run(
237 | [reconstruction, right_recons],
238 | feed_dict=inp_dict)
239 | A = oris[:, :, -128:, :]
240 | B = reconstruction_vals[:, :, :128, :]
241 | C = A * stitch_mask1 + B * stitch_mask2
242 | oris = np.concatenate((oris[:, :, :-128, :], C, prediction_vals), axis=2)
243 | test_oris = np.concatenate((prediction_vals, prediction_vals), axis=2)
244 | predictions1 = oris
245 |
246 | jj = ii
247 | for ori, m0, m1, endless in zip(origins1, p1_m0, p1_m1, predictions1):
248 | name = str(jj) + '.jpg'
249 | ori = (255. * (ori + 1) / 2.).astype(np.uint8)
250 | Image.fromarray(ori).save(os.path.join(
251 | result_path, 'ori_' + name))
252 |
253 | m0 = (255. * (m0 + 1) / 2.).astype(np.uint8)
254 | Image.fromarray(m0).save(os.path.join(
255 | result_path, 'm0_' + name))
256 |
257 | m1 = (255. * (m1 + 1) / 2.).astype(np.uint8)
258 | Image.fromarray(m1).save(os.path.join(
259 | result_path, 'm1_' + name))
260 |
261 | endless = (255. * (endless + 1) / 2.).astype(np.uint8)
262 | Image.fromarray(endless).save(os.path.join(
263 | result_path, 'endless_' + name))
264 | jj += 1
265 |
266 |
267 | ii += args.batch_size
268 |
--------------------------------------------------------------------------------
/train_model.py:
--------------------------------------------------------------------------------
1 | import random
2 | import os
3 | from glob import glob
4 | import numpy as np
5 | from PIL import Image
6 | import tensorflow as tf
7 | from tensorflow.python.training.moving_averages import assign_moving_average
8 | import tensorflow.contrib.layers as ly
9 | from modeling.model import Model
10 | from modeling.loss import Loss
11 | from dataset.parse import parse_trainset, parse_testset
12 | import argparse
13 |
14 | parser = argparse.ArgumentParser(description='Model training.')
15 | # experiment
16 | parser.add_argument('--date', type=str, default='0817')
17 | parser.add_argument('--exp-index', type=int, default=2)
18 | parser.add_argument('--f', action='store_true', default=False)
19 |
20 | # gpu
21 | parser.add_argument('--start-gpu', type=int, default=0)
22 | parser.add_argument('--num-gpu', type=int, default=2)
23 |
24 | # dataset
25 | parser.add_argument('--trainset-path', type=str, default='./dataset/trainset.tfr')
26 | parser.add_argument('--testset-path', type=str, default='./dataset/testset.tfr')
27 | parser.add_argument('--trainset-length', type=int, default=5041)
28 | parser.add_argument('--testset-length', type=int, default=2000) # we flip every image in testset
29 |
30 | # training
31 | parser.add_argument('--base-lr', type=float, default=0.0001)
32 | parser.add_argument('--batch-size', type=int, default=32)
33 | parser.add_argument('--weight-decay', type=float, default=0.00002)
34 | parser.add_argument('--epoch', type=int, default=1500)
35 | parser.add_argument('--lr-decay-epoch', type=int, default=1000)
36 | parser.add_argument('--critic-steps', type=int, default=3)
37 | parser.add_argument('--warmup-steps', type=int, default=1000)
38 | parser.add_argument('--workers', type=int, default=2)
39 | parser.add_argument('--clip-gradient', action='store_true', default=False)
40 | parser.add_argument('--clip-gradient-value', type=float, default=0.1)
41 |
42 |
43 | # modeling
44 | parser.add_argument('--beta', type=float, default=0.9)
45 | parser.add_argument('--lambda-gp', type=float, default=10)
46 | parser.add_argument('--lambda-rec', type=float, default=0.998)
47 |
48 | # checkpoint
49 | parser.add_argument('--log-path', type=str, default='./logs/')
50 | parser.add_argument('--checkpoint-path', type=str, default=None)
51 | parser.add_argument('--resume-step', type=int, default=0)
52 |
53 |
54 | args = parser.parse_args()
55 |
56 |
57 | # prepare path
58 | base_path = args.log_path
59 | exp_date = args.date
60 | if exp_date is None:
61 | print('Exp date error!')
62 | import sys
63 | sys.exit()
64 | exp_name = exp_date + '/' + str(args.exp_index)
65 | print("Start Exp:", exp_name)
66 | output_path = base_path + exp_name + '/'
67 | model_path = output_path + 'models/'
68 | tensorboard_path = output_path + 'log/'
69 | result_path = output_path + 'results/'
70 |
71 | if not os.path.exists(model_path):
72 | os.makedirs(model_path)
73 | if not os.path.exists(tensorboard_path):
74 | os.makedirs(tensorboard_path)
75 | if not os.path.exists(result_path):
76 | os.makedirs(result_path)
77 | elif not args.f:
78 | if args.checkpoint_path is None:
79 | print('Exp exist!')
80 | import sys
81 | sys.exit()
82 | else:
83 | import shutil
84 | shutil.rmtree(model_path)
85 | os.makedirs(model_path)
86 | shutil.rmtree(tensorboard_path)
87 | os.makedirs(tensorboard_path)
88 |
89 | # prepare gpu
90 | num_gpu = args.num_gpu
91 | start_gpu = args.start_gpu
92 | gpu_id = str(start_gpu)
93 | for i in range(num_gpu - 1):
94 | gpu_id = gpu_id + ',' + str(start_gpu + i + 1)
95 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
96 | args.batch_size_per_gpu = int(args.batch_size / args.num_gpu)
97 |
98 |
99 |
100 |
101 | model = Model(args)
102 | loss = Loss(args)
103 |
104 | config = tf.ConfigProto(allow_soft_placement=True)
105 | config.gpu_options.allow_growth = True
106 | config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
107 |
108 | print("Start building model...")
109 | with tf.Session(config=config) as sess:
110 | with tf.device('/cpu:0'):
111 | learning_rate = tf.placeholder(tf.float32, [])
112 | lambda_rec = tf.placeholder(tf.float32, [])
113 |
114 | train_op_G = tf.train.AdamOptimizer(
115 | learning_rate=learning_rate, beta1=0.5, beta2=0.9)
116 | train_op_D = tf.train.AdamOptimizer(
117 | learning_rate=learning_rate, beta1=0.5, beta2=0.9)
118 |
119 |
120 | trainset = tf.data.TFRecordDataset(filenames=[args.trainset_path])
121 | trainset = trainset.shuffle(args.trainset_length)
122 | trainset = trainset.map(parse_trainset, num_parallel_calls=args.workers)
123 | trainset = trainset.batch(args.batch_size).repeat()
124 |
125 | train_iterator = trainset.make_one_shot_iterator()
126 | train_im = train_iterator.get_next()
127 |
128 | testset = tf.data.TFRecordDataset(filenames=[args.testset_path])
129 | testset = testset.map(parse_testset, num_parallel_calls=args.workers)
130 | testset = testset.batch(args.batch_size).repeat()
131 |
132 | test_iterator = testset.make_one_shot_iterator()
133 | test_im = test_iterator.get_next()
134 |
135 | print('build model on gpu tower')
136 | models = []
137 | params = []
138 | for gpu_id in range(num_gpu):
139 | with tf.device('/gpu:%d' % gpu_id):
140 | print('tower_%d' % gpu_id)
141 | with tf.name_scope('tower_%d' % gpu_id):
142 | with tf.variable_scope('cpu_variables', reuse=gpu_id > 0):
143 |
144 | groundtruth = tf.placeholder(
145 | tf.float32, [args.batch_size_per_gpu, 128, 256, 3], name='groundtruth')
146 | left_gt = tf.slice(groundtruth, [0, 0, 0, 0], [args.batch_size_per_gpu, 128, 128, 3])
147 |
148 |
149 | reconstruction_ori, reconstruction = model.build_reconstruction(left_gt)
150 | right_recon = tf.slice(reconstruction, [0, 0, 128, 0], [args.batch_size_per_gpu, 128, 128, 3])
151 |
152 | loss_rec = loss.masked_reconstruction_loss(groundtruth, reconstruction)
153 | loss_adv_G, loss_adv_D = loss.global_and_local_adv_loss(model, groundtruth, reconstruction)
154 |
155 | reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
156 | loss_G = loss_adv_G * (1 - lambda_rec) + loss_rec * lambda_rec + sum(reg_losses)
157 | loss_D = loss_adv_D
158 |
159 | var_G = list(filter(lambda x: x.name.startswith(
160 | 'cpu_variables/GEN'), tf.trainable_variables()))
161 | var_D = list(filter(lambda x: x.name.startswith(
162 | 'cpu_variables/DIS'), tf.trainable_variables()))
163 |
164 |
165 | grad_g = train_op_G.compute_gradients(
166 | loss_G, var_list=var_G)
167 | grad_d = train_op_D.compute_gradients(
168 | loss_D, var_list=var_D)
169 |
170 | models.append((grad_g, grad_d, loss_G, loss_D, loss_adv_G, loss_rec, reconstruction))
171 | params.append(groundtruth)
172 |
173 | print('Done.')
174 |
175 | print('Start reducing towers on cpu...')
176 |
177 | grad_gs, grad_ds, loss_Gs, loss_Ds, loss_adv_Gs, loss_recs, reconstructions = zip(*models)
178 | groundtruths = params
179 |
180 | with tf.device('/gpu:0'):
181 | aver_loss_g = tf.reduce_mean(loss_Gs)
182 | aver_loss_d = tf.reduce_mean(loss_Ds)
183 | aver_loss_ag = tf.reduce_mean(loss_adv_Gs)
184 | aver_loss_rec = tf.reduce_mean(loss_recs)
185 |
186 | train_op_G = train_op_G.apply_gradients(
187 | loss.average_gradients(grad_gs))
188 | train_op_D = train_op_D.apply_gradients(
189 | loss.average_gradients(grad_ds))
190 |
191 | groundtruths = tf.concat(groundtruths, axis=0)
192 | reconstructions = tf.concat(reconstructions, axis=0)
193 |
194 | tf.summary.scalar('loss_g', aver_loss_g)
195 | tf.summary.scalar('loss_d', aver_loss_d)
196 | tf.summary.scalar('loss_ag', aver_loss_ag)
197 | tf.summary.scalar('loss_rec', aver_loss_rec)
198 | tf.summary.image('groundtruth', groundtruths, 2)
199 | tf.summary.image('reconstruction', reconstructions, 2)
200 |
201 | merged = tf.summary.merge_all()
202 | writer = tf.summary.FileWriter(tensorboard_path, sess.graph)
203 |
204 | print('Done.')
205 |
206 |
207 | iters = 0
208 | saver = tf.train.Saver(max_to_keep=5)
209 | if args.checkpoint_path is None:
210 | sess.run(tf.global_variables_initializer())
211 | else:
212 | print('Start loading checkpoint...')
213 | saver.restore(sess, args.checkpoint_path)
214 | iters = args.resume_step
215 | print('Done.')
216 |
217 |
218 |
219 |
220 | print('Start training...')
221 |
222 | for epoch in range(args.epoch):
223 |
224 | if epoch > args.lr_decay_epoch:
225 | learning_rate_val = args.base_lr / 10
226 | else:
227 | learning_rate_val = args.base_lr
228 |
229 | for start, end in zip(
230 | range(0, args.trainset_length, args.batch_size),
231 | range(args.batch_size, args.trainset_length, args.batch_size)):
232 |
233 | if iters == 0 and args.checkpoint_path is None:
234 | print('Start pretraining G!')
235 | for t in range(args.warmup_steps):
236 | if t % 20 == 0:
237 | print("Step:", t)
238 | images = sess.run([train_im])[0]
239 | if len(images) < args.batch_size:
240 | images = sess.run([train_im])[0]
241 |
242 | inp_dict = {}
243 | inp_dict = loss.feed_all_gpu(inp_dict, args.num_gpu, args.batch_size_per_gpu, images, params)
244 | inp_dict[learning_rate] = learning_rate_val
245 | inp_dict[lambda_rec] = 1.
246 |
247 | _ = sess.run(
248 | [train_op_G],
249 | feed_dict=inp_dict)
250 | print('Pre-train G Done!')
251 |
252 | if (iters < 25 and args.checkpoint_path is None) or iters % 500 == 0:
253 | n_cir = 30
254 | else:
255 | n_cir = args.critic_steps
256 |
257 | for t in range(n_cir):
258 | images = sess.run([train_im])[0]
259 | if len(images) < args.batch_size:
260 | images = sess.run([train_im])[0]
261 |
262 | inp_dict = {}
263 | inp_dict = loss.feed_all_gpu(inp_dict, args.num_gpu, args.batch_size_per_gpu, images, params)
264 | inp_dict[learning_rate] = learning_rate_val
265 | inp_dict[lambda_rec] = args.lambda_rec
266 |
267 | _ = sess.run(
268 | [train_op_D],
269 | feed_dict=inp_dict)
270 |
271 |
272 | if iters % 50 == 0:
273 |
274 | _, g_val, ag_val, rs, d_val = sess.run(
275 | [train_op_G, aver_loss_g, aver_loss_ag, merged, aver_loss_d],
276 | feed_dict=inp_dict)
277 | writer.add_summary(rs, iters)
278 |
279 | else:
280 |
281 | _, g_val, ag_val, d_val = sess.run(
282 | [train_op_G, aver_loss_g, aver_loss_ag, aver_loss_d],
283 | feed_dict=inp_dict)
284 | if iters % 20 == 0:
285 | print("Iter:", iters, 'loss_g:', g_val, 'loss_d:', d_val, 'loss_adv_g:', ag_val)
286 |
287 | iters += 1
288 |
289 | saver.save(sess, model_path, global_step=iters)
290 |
291 | # testing
292 | if epoch > 0:
293 | ii = 0
294 | g_vals = 0
295 | d_vals = 0
296 | ag_vals = 0
297 | n_batchs = 0
298 | for _ in range(int(args.testset_length / args.batch_size)):
299 | test_oris = sess.run([test_im])[0]
300 | if len(test_oris) < args.batch_size:
301 | test_oris = sess.run([test_im])[0]
302 |
303 | inp_dict = {}
304 | inp_dict = loss.feed_all_gpu(inp_dict, args.num_gpu, args.batch_size_per_gpu, test_oris, params)
305 | inp_dict[learning_rate] = learning_rate_val
306 | inp_dict[lambda_rec] = args.lambda_rec
307 |
308 | reconstruction_vals, g_val, d_val, ag_val = sess.run(
309 | [reconstruction, aver_loss_g, aver_loss_d, aver_loss_ag],
310 | feed_dict=inp_dict)
311 |
312 | g_vals += g_val
313 | d_vals += d_val
314 | ag_vals += ag_val
315 | n_batchs += 1
316 |
317 | # Save test result every 100 epochs
318 | if epoch % 100 == 0:
319 |
320 | for rec_val, test_ori in zip(reconstruction_vals, test_oris):
321 | rec_hid = (255. * (rec_val + 1) /
322 | 2.).astype(np.uint8)
323 | test_ori = (255. * (test_ori + 1) /
324 | 2.).astype(np.uint8)
325 | Image.fromarray(rec_hid).save(os.path.join(
326 | result_path, 'img_' + str(ii) + '.' + str(int(iters / 100)) + '.jpg'))
327 | if epoch == 0:
328 | Image.fromarray(test_ori).save(
329 | os.path.join(result_path, 'img_' + str(ii) + '.' + str(int(iters / 100)) + '.ori.jpg'))
330 | ii += 1
331 | g_vals /= n_batchs
332 | d_vals /= n_batchs
333 | ag_vals /= n_batchs
334 |
335 | summary = tf.Summary()
336 | summary.value.add(tag='eval/g',
337 | simple_value=g_vals)
338 | summary.value.add(tag='eval/d',
339 | simple_value=d_vals)
340 | summary.value.add(tag='eval/ag',
341 | simple_value=ag_vals)
342 | writer.add_summary(summary, iters)
343 |
344 | print("=========================================================================")
345 | print('loss_g:', g_val, 'loss_d:', d_val, 'loss_adv_g:', ag_val)
346 | print("=========================================================================")
347 |
348 | if np.isnan(reconstruction_vals.min()) or np.isnan(reconstruction_vals.max()):
349 | print("NaN detected!!")
350 |
--------------------------------------------------------------------------------
/modeling/model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow.contrib.layers as ly
3 |
4 |
5 | class Model():
6 | def __init__(self, cfg):
7 | self.cfg = cfg
8 |
9 | def new_atrous_conv_layer(self, bottom, filter_shape, rate, name=None):
10 | with tf.variable_scope(name):
11 | regularizer = tf.contrib.layers.l2_regularizer(self.cfg.weight_decay)
12 | initializer = tf.contrib.layers.xavier_initializer()
13 | W = tf.get_variable(
14 | "W",
15 | shape=filter_shape,
16 | regularizer=regularizer,
17 | initializer=initializer)
18 |
19 | x = tf.nn.atrous_conv2d(
20 | bottom, W, rate, padding='SAME')
21 | return x
22 |
23 | def identity_block(self, X_input, kernel_size, filters, stage, block, is_relu=False):
24 |
25 | if is_relu:
26 | activation_fn=tf.nn.relu
27 |
28 | else:
29 | activation_fn=self.leaky_relu
30 |
31 | normalizer_fn = ly.instance_norm
32 |
33 |
34 | # defining name basis
35 | conv_name_base = 'res' + str(stage) + block + '_branch'
36 |
37 | with tf.variable_scope("id_block_stage" + str(stage) + block):
38 | filter1, filter2, filter3 = filters
39 | X_shortcut = X_input
40 | regularizer = tf.contrib.layers.l2_regularizer(self.cfg.weight_decay)
41 | initializer = tf.contrib.layers.xavier_initializer()
42 |
43 | # First component of main path
44 | x = tf.layers.conv2d(X_input, filter1,
45 | kernel_size=(1, 1), strides=(1, 1), name=conv_name_base + '2a', kernel_regularizer=regularizer, kernel_initializer=initializer, use_bias=False)
46 | x = normalizer_fn(x)
47 | x = activation_fn(x)
48 |
49 | # Second component of main path
50 | x = tf.layers.conv2d(x, filter2, (kernel_size, kernel_size),
51 | padding='same', name=conv_name_base + '2b', kernel_regularizer=regularizer, kernel_initializer=initializer, use_bias=False)
52 | x = normalizer_fn(x)
53 | x = activation_fn(x)
54 |
55 | # Third component of main path
56 | x = tf.layers.conv2d(x, filter3, kernel_size=(
57 | 1, 1), name=conv_name_base + '2c', kernel_regularizer=regularizer, kernel_initializer=initializer, use_bias=False)
58 | x = normalizer_fn(x)
59 |
60 | # Final step: Add shortcut value to main path, and pass it through
61 | x = tf.add(x, X_shortcut)
62 | x = activation_fn(x)
63 |
64 | return x
65 |
66 | def convolutional_block(self, X_input, kernel_size, filters, stage, block, stride=2, is_relu=False):
67 |
68 | if is_relu:
69 | activation_fn=tf.nn.relu
70 |
71 | else:
72 | activation_fn=self.leaky_relu
73 |
74 | normalizer_fn = ly.instance_norm
75 |
76 | # defining name basis
77 | conv_name_base = 'res' + str(stage) + block + '_branch'
78 |
79 | with tf.variable_scope("conv_block_stage" + str(stage) + block):
80 |
81 | regularizer = tf.contrib.layers.l2_regularizer(self.cfg.weight_decay)
82 | initializer = tf.contrib.layers.xavier_initializer()
83 | # initializer = tf.variance_scaling_initializer(scale=1.0,mode='fan_in')
84 |
85 | # Retrieve Filters
86 | filter1, filter2, filter3 = filters
87 |
88 | # Save the input value
89 | X_shortcut = X_input
90 |
91 | # First component of main path
92 | x = tf.layers.conv2d(X_input, filter1,
93 | kernel_size=(1, 1),
94 | strides=(1, 1),
95 | name=conv_name_base + '2a', kernel_regularizer=regularizer, kernel_initializer=initializer, use_bias=False)
96 | x = normalizer_fn(x)
97 | x = activation_fn(x)
98 |
99 | # Second component of main path
100 | x = tf.layers.conv2d(x, filter2, (kernel_size, kernel_size), strides=(stride, stride), name=conv_name_base +
101 | '2b', padding='same', kernel_regularizer=regularizer, kernel_initializer=initializer, use_bias=False)
102 | x = normalizer_fn(x)
103 | x = activation_fn(x)
104 |
105 | # Third component of main path
106 | x = tf.layers.conv2d(x, filter3, (1, 1), name=conv_name_base + '2c',
107 | kernel_regularizer=regularizer, kernel_initializer=initializer, use_bias=False)
108 | x = normalizer_fn(x)
109 |
110 |
111 | # SHORTCUT PATH
112 | X_shortcut = tf.layers.conv2d(X_shortcut, filter3, (1, 1),
113 | strides=(stride, stride), name=conv_name_base + '1', kernel_regularizer=regularizer, kernel_initializer=initializer, use_bias=False)
114 | X_shortcut = normalizer_fn(X_shortcut)
115 |
116 | # Final step: Add shortcut value to main path, and pass it through
117 | # a RELU activation
118 | x = tf.add(X_shortcut, x)
119 | x = activation_fn(x)
120 |
121 | return x
122 |
123 | def leaky_relu(self, x, name=None, leak=0.2):
124 | f1 = 0.5 * (1 + leak)
125 | f2 = 0.5 * (1 - leak)
126 | return f1 * x + f2 * abs(x)
127 |
128 | def in_lrelu(self, x, name=None):
129 | x = tf.contrib.layers.instance_norm(x)
130 | x = self.leaky_relu(x)
131 | return x
132 |
133 | def in_relu(self, x, name=None):
134 | x = tf.contrib.layers.instance_norm(x)
135 | x = tf.nn.relu(x)
136 | return x
137 |
138 | def rct(self, x):
139 | regularizer = tf.contrib.layers.l2_regularizer(self.cfg.weight_decay)
140 | output_size = x.get_shape().as_list()[3]
141 | size = 512
142 | layer_num = 2
143 | activation_fn = tf.tanh
144 | x = ly.conv2d(x, size, 1, stride=1, activation_fn=None,
145 | normalizer_fn=None, padding='SAME', weights_regularizer=regularizer, biases_initializer=None)
146 | x = self.in_lrelu(x)
147 | x = tf.transpose(x, [0, 2, 1, 3])
148 | x = tf.reshape(x, [-1, 4, 4 * size])
149 | x = tf.transpose(x, [1, 0, 2])
150 | # encoder_inputs = x
151 | x = tf.reshape(x, [-1, 4 * size])
152 | x_split = tf.split(x, 4, 0)
153 |
154 | ys = []
155 | with tf.variable_scope('LSTM'):
156 | with tf.variable_scope('encoder'):
157 | lstm_cell = tf.contrib.rnn.LSTMCell(
158 | 4 * size, activation=activation_fn)
159 | lstm_cell = tf.contrib.rnn.MultiRNNCell(
160 | [lstm_cell] * layer_num, state_is_tuple=True)
161 |
162 | init_state = lstm_cell.zero_state(self.cfg.batch_size_per_gpu, dtype=tf.float32)
163 | now, _state = lstm_cell(x_split[0], init_state)
164 | now, _state = lstm_cell(x_split[1], _state)
165 | now, _state = lstm_cell(x_split[2], _state)
166 | now, _state = lstm_cell(x_split[3], _state)
167 |
168 | with tf.variable_scope('decoder'):
169 | lstm_cell = tf.contrib.rnn.BasicLSTMCell(
170 | 4 * size, activation=activation_fn)
171 | lstm_cell2 = tf.contrib.rnn.MultiRNNCell(
172 | [lstm_cell] * layer_num, state_is_tuple=True)
173 | #predict
174 | now, _state = lstm_cell2(x_split[3], _state)
175 | ys.append(tf.reshape(now, [-1, 4, 1, size]))
176 | now, _state = lstm_cell2(now, _state)
177 | ys.append(tf.reshape(now, [-1, 4, 1, size]))
178 | now, _state = lstm_cell2(now, _state)
179 | ys.append(tf.reshape(now, [-1, 4, 1, size]))
180 | now, _state = lstm_cell2(now, _state)
181 | ys.append(tf.reshape(now, [-1, 4, 1, size]))
182 |
183 |
184 | y = tf.concat(ys, axis=2)
185 |
186 | y = ly.conv2d(y, output_size, 1, stride=1, activation_fn=None,
187 | normalizer_fn=None, padding='SAME', weights_regularizer=regularizer, biases_initializer=None)
188 | y = self.in_lrelu(y)
189 | return y
190 |
191 |
192 |
193 | def shc(self, x, shortcut, channels):
194 | regularizer = tf.contrib.layers.l2_regularizer(self.cfg.weight_decay)
195 | x = ly.conv2d(x, channels / 2, 1, stride=1, activation_fn=tf.nn.relu,
196 | normalizer_fn=tf.contrib.layers.instance_norm, padding='SAME', weights_regularizer=regularizer)
197 | x = ly.conv2d(x, channels / 2, 3, stride=1, activation_fn=tf.nn.relu,
198 | normalizer_fn=tf.contrib.layers.instance_norm, padding='SAME', weights_regularizer=regularizer)
199 | x = ly.conv2d(x, channels, 1, stride=1, activation_fn=None,
200 | normalizer_fn=tf.contrib.layers.instance_norm, padding='SAME', weights_regularizer=regularizer)
201 | return tf.add(shortcut, x)
202 |
203 |
204 | def grb(self, x, filters, rate, name):
205 | activation_fn = tf.nn.relu
206 | normalizer_fn = ly.instance_norm
207 | shortcut = x
208 | x1 = self.new_atrous_conv_layer(x, [3, 1, filters, filters], rate, name+'_a1')
209 | x1 = normalizer_fn(x1)
210 | x1 = activation_fn(x1)
211 | x1 = self.new_atrous_conv_layer(x1, [1, 7, filters, filters], rate, name+'_a2')
212 | x1 = normalizer_fn(x1)
213 |
214 | x2 = self.new_atrous_conv_layer(x, [1, 7, filters, filters], rate, name+'_b1')
215 | x2 = normalizer_fn(x2)
216 | x2 = activation_fn(x2)
217 | x2 = self.new_atrous_conv_layer(x2, [3, 1, filters, filters], rate, name+'_b2')
218 | x2 = normalizer_fn(x2)
219 |
220 | x = tf.add(shortcut, x1)
221 | x = tf.add(x, x2)
222 | x = activation_fn(x)
223 | return x
224 |
225 | def build_reconstruction(self, images, reuse=None):
226 |
227 | with tf.variable_scope('GEN', reuse=reuse):
228 | x = images
229 | normalizer_fn = ly.instance_norm
230 | regularizer = tf.contrib.layers.l2_regularizer(self.cfg.weight_decay)
231 | initializer = tf.contrib.layers.xavier_initializer()
232 | # stage 1
233 |
234 | x = tf.layers.conv2d(x, filters=64, kernel_size=(4, 4), strides=(
235 | 2, 2), name='conv0', kernel_regularizer=regularizer, padding='same', kernel_initializer=initializer, use_bias=False)
236 | x = self.in_lrelu(x)
237 | short_cut0 = x
238 | x = tf.layers.conv2d(x, filters=128, kernel_size=(4, 4), strides=(
239 | 2, 2), name='conv1', padding='same', kernel_regularizer=regularizer, kernel_initializer=initializer, use_bias=False)
240 | x = self.in_lrelu(x)
241 | short_cut1 = x
242 |
243 | # stage 2
244 | x = self.convolutional_block(x, kernel_size=3, filters=[
245 | 64, 64, 256], stage=2, block='a', stride=2)
246 | x = self.identity_block(
247 | x, 3, [64, 64, 256], stage=2, block='b')
248 | x = self.identity_block(
249 | x, 3, [64, 64, 256], stage=2, block='c')
250 | short_cut2 = x
251 |
252 | # stage 3
253 | x = self.convolutional_block(x, kernel_size=3, filters=[128, 128, 512],
254 | stage=3, block='a', stride=2)
255 | x = self.identity_block(
256 | x, 3, [128, 128, 512], stage=3, block='b')
257 | x = self.identity_block(
258 | x, 3, [128, 128, 512], stage=3, block='c')
259 | x = self.identity_block(
260 | x, 3, [128, 128, 512], stage=3, block='d',)
261 | short_cut3 = x
262 |
263 | # stage 4
264 | x = self.convolutional_block(x, kernel_size=3, filters=[
265 | 256, 256, 1024], stage=4, block='a', stride=2)
266 | x = self.identity_block(
267 | x, 3, [256, 256, 1024], stage=4, block='b')
268 | x = self.identity_block(
269 | x, 3, [256, 256, 1024], stage=4, block='c')
270 | x = self.identity_block(
271 | x, 3, [256, 256, 1024], stage=4, block='d')
272 | x = self.identity_block(
273 | x, 3, [256, 256, 1024], stage=4, block='e')
274 | short_cut4 = x
275 |
276 | # rct transfer
277 | train = self.rct(x)
278 |
279 |
280 | # stage -4
281 | train = tf.concat([short_cut4, train], axis=2)
282 |
283 | train = self.grb(train, 1024, 1, 't4')
284 | train = self.identity_block(
285 | train, 3, [256, 256, 1024], stage=-4, block='b', is_relu=True)
286 | train = self.identity_block(
287 | train, 3, [256, 256, 1024], stage=-4, block='c', is_relu=True)
288 |
289 |
290 | train = ly.conv2d_transpose(train, 512, 4, stride=2,
291 | activation_fn=None, normalizer_fn=normalizer_fn, padding='SAME', weights_initializer=initializer, weights_regularizer=regularizer, biases_initializer=None)
292 | sc, kp = tf.split(train, 2, axis=2)
293 | sc = tf.nn.relu(sc)
294 | merge = tf.concat([short_cut3, sc], axis=3)
295 | merge = self.shc(merge, short_cut3, 512)
296 | merge = self.in_relu(merge)
297 | train = tf.concat(
298 | [merge, kp], axis=2)
299 |
300 |
301 | # stage -3
302 | train = self.grb(train, 512, 2, 't3')
303 | train = self.identity_block(
304 | train, 3, [128, 128, 512], stage=-3, block='b', is_relu=True)
305 | train = self.identity_block(
306 | train, 3, [128, 128, 512], stage=-3, block='c', is_relu=True)
307 | train = self.identity_block(
308 | train, 3, [128, 128, 512], stage=-3, block='d', is_relu=True)
309 |
310 |
311 |
312 | train = ly.conv2d_transpose(train, 256, 4, stride=2,
313 | activation_fn=None, normalizer_fn=normalizer_fn, padding='SAME', weights_initializer=initializer, weights_regularizer=regularizer, biases_initializer=None)
314 | sc, kp = tf.split(train, 2, axis=2)
315 | sc = tf.nn.relu(sc)
316 | merge = tf.concat([short_cut2, sc], axis=3)
317 | merge = self.shc(merge, short_cut2, 256)
318 | merge = self.in_relu(merge)
319 | train = tf.concat(
320 | [merge, kp], axis=2)
321 |
322 | # stage -2
323 | train = self.grb(train, 256, 4, 't2')
324 | train = self.identity_block(
325 | train, 3, [64, 64, 256], stage=-2, block='b', is_relu=True)
326 | train = self.identity_block(
327 | train, 3, [64, 64, 256], stage=-2, block='c', is_relu=True)
328 | train = self.identity_block(
329 | train, 3, [64, 64, 256], stage=-2, block='d', is_relu=True)
330 | train = self.identity_block(
331 | train, 3, [64, 64, 256], stage=-2, block='e', is_relu=True)
332 |
333 | train = ly.conv2d_transpose(train, 128, 4, stride=2,
334 | activation_fn=None, normalizer_fn=normalizer_fn, padding='SAME', weights_initializer=initializer, weights_regularizer=regularizer, biases_initializer=None)
335 | sc, kp = tf.split(train, 2, axis=2)
336 | sc = tf.nn.relu(sc)
337 | merge = tf.concat([short_cut1, sc], axis=3)
338 | merge = self.shc(merge, short_cut1, 128)
339 | merge = self.in_relu(merge)
340 | train = tf.concat(
341 | [merge, kp], axis=2)
342 |
343 |
344 | # stage -1
345 |
346 | train = ly.conv2d_transpose(train, 64, 4, stride=2,
347 | activation_fn=None, normalizer_fn=normalizer_fn, padding='SAME', weights_initializer=initializer, weights_regularizer=regularizer, biases_initializer=None)
348 | sc, kp = tf.split(train, 2, axis=2)
349 | sc = tf.nn.relu(sc)
350 | merge = tf.concat([short_cut0, sc], axis=3)
351 | merge = self.shc(merge, short_cut0, 64)
352 | merge = self.in_relu(merge)
353 | train = tf.concat(
354 | [merge, kp], axis=2)
355 |
356 | # stage -0
357 | recon = ly.conv2d_transpose(train, 3, 4, stride=2,
358 | activation_fn=None, padding='SAME', weights_initializer=initializer, weights_regularizer=regularizer, biases_initializer=None)
359 |
360 | return recon, tf.nn.tanh(recon)
361 |
362 | def build_adversarial_global(self, img, reuse=None, name=None):
363 | bs = img.get_shape().as_list()[0]
364 | with tf.variable_scope(name, reuse=reuse):
365 |
366 | def lrelu(x, leak=0.2, name="lrelu"):
367 | with tf.variable_scope(name):
368 | f1 = 0.5 * (1 + leak)
369 | f2 = 0.5 * (1 - leak)
370 | return f1 * x + f2 * abs(x)
371 |
372 | size = 128
373 | normalizer_fn = ly.instance_norm
374 | activation_fn = lrelu
375 |
376 | img = ly.conv2d(img, num_outputs=size / 2, kernel_size=4,
377 | stride=2, activation_fn=activation_fn)
378 | img = ly.conv2d(img, num_outputs=size, kernel_size=4,
379 | stride=2, activation_fn=activation_fn, normalizer_fn=normalizer_fn)
380 | img = ly.conv2d(img, num_outputs=size * 2, kernel_size=4,
381 | stride=2, activation_fn=activation_fn, normalizer_fn=normalizer_fn)
382 | img = ly.conv2d(img, num_outputs=size * 4, kernel_size=4,
383 | stride=2, activation_fn=activation_fn, normalizer_fn=normalizer_fn)
384 | img = ly.conv2d(img, num_outputs=size * 4, kernel_size=4,
385 | stride=2, activation_fn=activation_fn, normalizer_fn=normalizer_fn)
386 |
387 | logit = ly.fully_connected(tf.reshape(
388 | img, [bs, -1]), 1, activation_fn=None)
389 |
390 | return logit
391 |
392 | def build_adversarial_local(self, img, reuse=None, name=None):
393 | bs = img.get_shape().as_list()[0]
394 | with tf.variable_scope(name, reuse=reuse):
395 |
396 | def lrelu(x, leak=0.2, name="lrelu"):
397 | with tf.variable_scope(name):
398 | f1 = 0.5 * (1 + leak)
399 | f2 = 0.5 * (1 - leak)
400 | return f1 * x + f2 * abs(x)
401 |
402 | size = 128
403 | normalizer_fn = ly.instance_norm
404 | activation_fn = lrelu
405 |
406 | img = ly.conv2d(img, num_outputs=size / 2, kernel_size=4,
407 | stride=2, activation_fn=activation_fn)
408 | img = ly.conv2d(img, num_outputs=size, kernel_size=4,
409 | stride=2, activation_fn=activation_fn, normalizer_fn=normalizer_fn)
410 | img = ly.conv2d(img, num_outputs=size * 2, kernel_size=4,
411 | stride=2, activation_fn=activation_fn, normalizer_fn=normalizer_fn)
412 | img = ly.conv2d(img, num_outputs=size * 2, kernel_size=4,
413 | stride=2, activation_fn=activation_fn, normalizer_fn=normalizer_fn)
414 |
415 | logit = ly.fully_connected(tf.reshape(
416 | img, [bs, -1]), 1, activation_fn=None)
417 |
418 | return logit
419 |
420 |
421 |
--------------------------------------------------------------------------------