95 |
96 | # Code
97 | ## Painter
98 | To evaluate the generalization ability of our inpainting models, we carry out object removal experiments in user scenarios. We develop a interactive image removal and completion tool with Opencv. You may download the checkpoint of the inpainting model pretrained on Places2 training and validation data from **[here](https://pan.baidu.com/s/1SBbfR94KWG5UMm_FClmdMQ)** with pass code: **uiqn**.
99 |
100 | Or [google drive](https://drive.google.com/drive/folders/1ReSArrra8NOQv8dlU2QK0DE0P5qoalCT?usp=sharing)
101 |
102 | Run the paint.py in command line (We implement our model using tensorflow 1.15.2, python 3.7):
103 | > python painter.py --checkpoint checkpoint/places2 --save_path imgs
104 |
105 | Do object removal experiments, it will work like:
106 |
107 |

108 |

109 |
110 |
111 |
112 | ## Citation
113 | ```html
114 | @inproceedings{jie2020inpainting,
115 | title={Learning to Incorporate Structure Knowledge for Image Inpainting},
116 | author={Jie Yang, Zhiquan Qi, Yong Shi},
117 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
118 | volume={34},
119 | number={7},
120 | pages={12605-12612},
121 | year={2020}
122 | }
123 | ```
124 | ## License
125 | CC 4.0 Attribution-NonCommercial International. The software is for educaitonal and academic research purpose only.
126 |
--------------------------------------------------------------------------------
/src/val_inpaint_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import logging
4 | import tensorflow as tf
5 | import numpy as np
6 | import time
7 | import pandas as pd
8 | import re
9 |
10 | from math import ceil
11 | from scipy.misc import imsave
12 | from inpaint_model import InpaintModel
13 | from config import Config, select_gpu
14 | from utils_fn import show_all_variables, load_test_data, load_test_mask, create_test_mask, dataset_len, load_test_img_edge
15 |
16 | from frechet_inception_distance import calculate_fid_given_paths
17 | from metrics import uqi_vif
18 |
19 | # For reproducible result
20 | np.random.seed(0)
21 | tf.set_random_seed(0)
22 |
23 | # with tf.device('/cpu:0'):
24 | """
25 | Testing
26 | """
27 | # Load config file for run an inpainting model
28 | args = Config('inpaint_config.yml')
29 |
30 | # GPU config
31 | # os.environ["CUDA_VISIBLE_DEVICES"] = str(args.GPU_ID)
32 | os.environ["CUDA_VISIBLE_DEVICES"] = select_gpu()
33 | config_gpu = tf.ConfigProto()
34 | config_gpu.gpu_options.allow_growth = True # allow memory grow
35 |
36 | # log setting
37 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
38 | logger = logging.getLogger("YOUNG")
39 | logger.setLevel(level=logging.INFO)
40 |
41 | """ Input Data (images and masks) """
42 | # images
43 | if args.CUSTOM_DATASET:
44 | images, image_iterator = load_test_img_edge(args)
45 | else:
46 | images = tf.placeholder(tf.float32, [args.BATCH_SIZE, args.IMG_SHAPES[0], args.IMG_SHAPES[1], args.IMG_SHAPES[2]],
47 | name='real_images')
48 | # test masks
49 | if args.MASK_MODE == 'irregular':
50 | masks, mask_iterator = load_test_mask(args)
51 | else:
52 | masks = tf.placeholder(tf.float32, [args.TEST_NUM, args.IMG_SHAPES[0], args.IMG_SHAPES[1], 1],
53 | name='test_regular_masks')
54 |
55 | """ Build Testing Inpaint Model"""
56 | # Testing model
57 | model = InpaintModel(args)
58 | logger.info("Build Testing Inpaint Model")
59 | model.build_test_model(images, masks, args)
60 |
61 | """ Testing Logic"""
62 | with tf.Session(config=config_gpu) as sess:
63 |
64 | # Saver to restore model: to restore variables
65 | # TODO: we can choose variables to store and steps to keep (max_to_keep)
66 | saver = tf.train.Saver()
67 |
68 | # Model dir
69 | # If restore a specific model
70 | args.MODEL_DIR = args.MODEL_RESTORE
71 |
72 | # Result dirs
73 | # (1) result/model_dir/inpainted_images
74 | # (2) result/model_dir/masked_images
75 | # (3) result/model_dir/sample_images
76 | result_dir = os.path.join(args.RESULT_DIR, args.MODEL_DIR)
77 | if not os.path.exists(result_dir):
78 | os.makedirs(result_dir)
79 |
80 | # (1) result/model_dir/inpainted_images
81 | inpainted_dir = os.path.join(result_dir, 'inpainted_images')
82 | if not os.path.exists(inpainted_dir):
83 | os.makedirs(inpainted_dir)
84 | # (2) result/model_dir/maked_images
85 | masked_dir = os.path.join(result_dir, 'masked_images')
86 | if not os.path.exists(masked_dir):
87 | os.makedirs(masked_dir)
88 | # (3) result/model_dir/sample_images
89 | sample_dir = os.path.join(result_dir, 'sample_images')
90 | if not os.path.exists(sample_dir):
91 | os.makedirs(sample_dir)
92 | # (4) result/model_dir/masks
93 | mask_dir = os.path.join(result_dir, 'masks')
94 | if not os.path.exists(mask_dir):
95 | os.makedirs(mask_dir)
96 | # (5) result/model_dir/inpainted_smpales
97 | inpainted_sample_dir = os.path.join(result_dir, 'inpainted_samples')
98 | if not os.path.exists(inpainted_sample_dir):
99 | os.makedirs(inpainted_sample_dir)
100 |
101 | # Model Checkpoint dir
102 | checkpoint_dir = os.path.join(args.CHECKPOINT_DIR, args.MODEL_DIR)
103 |
104 | # Testing data info
105 | with open(args.DATA_FLIST[args.DATASET][1]) as f:
106 | fnames = f.read().splitlines()
107 |
108 | data_len = len(fnames)
109 | max_test_step = ceil(data_len / args.TEST_NUM) # TEST_NUM can be 1 or a batch like 8
110 | max_test_step = min(max_test_step, ceil(args.MAX_TEST_NUM / args.TEST_NUM)) # max test number of images
111 |
112 | # Training data info
113 | max_step = dataset_len(args) // args.BATCH_SIZE # max step for each epoch
114 | last_step = int(args.EPOCH * max_step) # total steps
115 | # Parameters
116 | imgh = args.IMG_SHAPES[0]
117 | imgw = args.IMG_SHAPES[1]
118 |
119 | # Try to restore model
120 | # Initialize all the variables
121 | tf.global_variables_initializer().run()
122 | # Show network architecture
123 | show_all_variables()
124 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) # checkpoint
125 | if ckpt and ckpt.model_checkpoint_path:
126 | # print ckpt name with dir
127 | logger.info("Latest ckpt: {}".format(ckpt.model_checkpoint_path))
128 | logger.info("All ckpt: {}".format(ckpt.all_model_checkpoint_paths))
129 | # ckpt base name
130 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
131 | # restore
132 | # saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name)) # restore
133 | vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
134 | assign_ops = []
135 | for var in vars_list:
136 | vname = var.name
137 | from_name = vname
138 | try:
139 | var_value = tf.contrib.framework.load_variable(os.path.join(checkpoint_dir, ckpt_name), from_name)
140 | assign_ops.append(tf.assign(var, var_value))
141 | except Exception:
142 | continue
143 | sess.run(assign_ops)
144 | print('Model loaded.')
145 |
146 | counter = int(next(re.finditer("\d+", ckpt_name)).group(0))
147 | logger.info(" [*] Success to read {}".format(ckpt_name))
148 | else:
149 | logger.info(" [*] Failed to find a checkpoint")
150 | # Existing Training info
151 | current_epoch = counter // max_step
152 | current_step = counter % max_step
153 | logger.info('Evaluating epoch {}, step {}.'.format(current_epoch, current_step))
154 |
155 | # Testing start
156 | # For saving evaluation results
157 | if not os.path.exists(os.path.join(result_dir, 'evaluation.csv')):
158 | with open(os.path.join(result_dir, 'evaluation.csv'), mode='a') as f:
159 | f.write("epoch, step, l1, pnsr, ssim, fid, uqi, vif\n")
160 | mask_size = []
161 | l1_list = []
162 | psnr_list = []
163 | ssim_list = []
164 |
165 | count = 1
166 | sess.run(image_iterator.initializer)
167 | if args.MASK_MODE == 'irregular':
168 | sess.run(mask_iterator.initializer)
169 | for step in range(1, max_test_step+1):
170 | time_start = time.time()
171 |
172 | try:
173 | if args.MASK_MODE == 'irregular':
174 | raw_x, raw_x_incomplete, raw_x_complete, mask, l1, psnr, ssim = sess.run([model.raw_x, model.raw_x_incomplete,
175 | model.raw_x_complete, model.mask,
176 | model.l1, model.psnr, model.ssim])
177 | else:
178 | mask = create_test_mask(imgw, imgh, imgw // 2, imgh // 2, args)
179 | raw_x, raw_x_incomplete, raw_x_complete, mask, l1, psnr, ssim = sess.run([model.raw_x, model.raw_x_incomplete,
180 | model.raw_x_complete, model.mask,
181 | model.l1, model.psnr, model.ssim],
182 | feed_dict={masks: mask})
183 | except tf.errors.OutOfRangeError:
184 | break
185 |
186 | # setting hole pixel value = 255
187 | ones_x = np.ones_like(raw_x_incomplete)
188 | raw_x_incomplete = raw_x_incomplete + ones_x*mask*255
189 |
190 | for i in range(args.TEST_NUM):
191 | # save result
192 | imsave(os.path.join(sample_dir, args.DATASET+"{}.png".format(count)), raw_x[i])
193 | imsave(os.path.join(inpainted_dir, args.DATASET+"{}.png".format(count)), raw_x_complete[i])
194 | imsave(os.path.join(masked_dir, args.DATASET+"{}.png".format(count)), raw_x_incomplete[i])
195 | imsave(os.path.join(mask_dir, args.DATASET+"{}.png".format(count)), mask[i, :, :, 0]) # mask is grey image
196 |
197 | # mask size
198 | mask_size.append(mask[i].sum())
199 | l1_list.append(l1[i])
200 | psnr_list.append(psnr[i])
201 | ssim_list.append(ssim[i])
202 |
203 | if step == 1:
204 | imsave(os.path.join(inpainted_sample_dir, args.DATASET + "{}_{}.png".format(count, current_epoch)), raw_x_complete[i])
205 |
206 | count += 1
207 |
208 | time_cost = time.time() - time_start
209 | time_remaining = (max_test_step - step) * time_cost
210 | logger.info(
211 | 'step {}/{}, image {}/{}, cost {:.2f}s, remaining {:.2f}s.'.format(step, max_test_step, count, data_len, time_cost,
212 | time_remaining))
213 |
214 | # Final evaluation
215 | # df_evaluation = pd.DataFrame(data=np.array([l1_list, psnr_list, ssim_list, mask_size]).T,
216 | # columns=["l1", "psnr", "ssim", "mask"])
217 | # df_evaluation.to_csv(os.path.join(result_dir, 'evaluation.csv'), index=False)
218 | logger.info("Saving Finished.")
219 |
220 | logger.info("Evaluating Results..")
221 | # Evaluation result
222 | # l1, psnr, ssim, fid
223 |
224 | # fid score
225 | logger.info("FID score")
226 | fid_value = calculate_fid_given_paths([sample_dir, inpainted_dir], 'inception')
227 | print("FID: ", fid_value)
228 |
229 | # print(df_evaluation.mean(axis=0))
230 | # UQI and VIF
231 | uqi, vif = uqi_vif(sample_dir, inpainted_dir)
232 | # uqi, vif = 0., 0.
233 | # df_evaluation = pd.concat(df_evaluation, pd.DataFrame(data={"epoch": current_epoch, "step": current_step,
234 | # "l1": np.array(l1_list).mean(),
235 | # "psnr": np.array(psnr_list).mean(),
236 | # "ssim": np.array(ssim_list).mean(),
237 | # "fid": fid_value}), axis=0)
238 | # df_evaluation.to_csv(os.path.join(result_dir, 'evaluation.csv'), index=False)
239 | with open(os.path.join(result_dir, 'evaluation.csv'), mode='a') as f:
240 | f.write("{}, {}, {:.4f}, {:.4f}, {:.4f}, {:.4f}, {:.4f}, {:.4f}\n".format(current_epoch, current_step,
241 | np.array(l1_list).mean(),
242 | np.array(psnr_list).mean(),
243 | np.array(ssim_list).mean(),
244 | fid_value,
245 | uqi,
246 | vif))
247 |
248 | logger.info("Evaluation Finished.")
--------------------------------------------------------------------------------
/src/frechet_inception_distance.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # Copyright 2017 Martin Heusel
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # Adapted from the original implementation by Martin Heusel.
18 | # Source https://github.com/bioinf-jku/TTUR/blob/master/fid.py
19 |
20 | ''' Calculates the Frechet Inception Distance (FID) to evalulate GANs.
21 |
22 | The FID metric calculates the distance between two distributions of images.
23 | Typically, we have summary statistics (mean & covariance matrix) of one
24 | of these distributions, while the 2nd distribution is given by a GAN.
25 |
26 | When run as a stand-alone program, it compares the distribution of
27 | images that are stored as PNG/JPEG at a specified location with a
28 | distribution given by summary statistics (in pickle format).
29 |
30 | The FID is calculated by assuming that X_1 and X_2 are the activations of
31 | the pool_3 layer of the inception net for generated samples and real world
32 | samples respectivly.
33 |
34 | See --help to see further details.
35 | '''
36 |
37 | from __future__ import absolute_import, division, print_function
38 | import numpy as np
39 | import scipy as sp
40 | import os
41 | import gzip, pickle
42 | import tensorflow as tf
43 | from scipy.misc import imread
44 | import pathlib
45 | import urllib
46 |
47 |
48 | class InvalidFIDException(Exception):
49 | pass
50 |
51 |
52 | def create_inception_graph(pth):
53 | """Creates a graph from saved GraphDef file."""
54 | # Creates graph from saved graph_def.pb.
55 | with tf.gfile.FastGFile( pth, 'rb') as f:
56 | graph_def = tf.GraphDef()
57 | graph_def.ParseFromString( f.read())
58 | _ = tf.import_graph_def( graph_def, name='FID_Inception_Net')
59 | #-------------------------------------------------------------------------------
60 |
61 |
62 | # code for handling inception net derived from
63 | # https://github.com/openai/improved-gan/blob/master/inception_score/model.py
64 | def _get_inception_layer(sess):
65 | """Prepares inception net for batched usage and returns pool_3 layer. """
66 | layername = 'FID_Inception_Net/pool_3:0'
67 | pool3 = sess.graph.get_tensor_by_name(layername)
68 | ops = pool3.graph.get_operations()
69 | for op_idx, op in enumerate(ops):
70 | for o in op.outputs:
71 | shape = o.get_shape()
72 | if shape._dims is not None:
73 | shape = [s.value for s in shape]
74 | new_shape = []
75 | for j, s in enumerate(shape):
76 | if s == 1 and j == 0:
77 | new_shape.append(None)
78 | else:
79 | new_shape.append(s)
80 | try:
81 | o._shape = tf.TensorShape(new_shape)
82 | except ValueError:
83 | o._shape_val = tf.TensorShape(new_shape) # EDIT: added for compatibility with tensorflow 1.6.0
84 | return pool3
85 | #-------------------------------------------------------------------------------
86 |
87 |
88 | def get_activations(images, sess, batch_size=50, verbose=False):
89 | """Calculates the activations of the pool_3 layer for all images.
90 |
91 | Params:
92 | -- images : Numpy array of dimension (n_images, hi, wi, 3). The values
93 | must lie between 0 and 256.
94 | -- sess : current session
95 | -- batch_size : the images numpy array is split into batches with batch size
96 | batch_size. A reasonable batch size depends on the disposable hardware.
97 | -- verbose : If set to True and parameter out_step is given, the number of calculated
98 | batches is reported.
99 | Returns:
100 | -- A numpy array of dimension (num images, 2048) that contains the
101 | activations of the given tensor when feeding inception with the query tensor.
102 | """
103 | inception_layer = _get_inception_layer(sess)
104 | d0 = images.shape[0]
105 | if batch_size > d0:
106 | print("warning: batch size is bigger than the data size. setting batch size to data size")
107 | batch_size = d0
108 | n_batches = d0//batch_size
109 | n_used_imgs = n_batches*batch_size
110 | pred_arr = np.empty((n_used_imgs,2048))
111 | for i in range(n_batches):
112 | if verbose:
113 | print("\rPropagating batch %d/%d" % (i+1, n_batches), end="", flush=True)
114 | start = i*batch_size
115 | end = start + batch_size
116 | batch = images[start:end]
117 | pred = sess.run(inception_layer, {'FID_Inception_Net/ExpandDims:0': batch})
118 | pred_arr[start:end] = pred.reshape(batch_size,-1)
119 | if verbose:
120 | print(" done")
121 | return pred_arr
122 | #-------------------------------------------------------------------------------
123 |
124 |
125 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2):
126 | """Numpy implementation of the Frechet Distance.
127 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
128 | and X_2 ~ N(mu_2, C_2) is
129 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
130 |
131 | Params:
132 | -- mu1 : Numpy array containing the activations of the pool_3 layer of the
133 | inception net ( like returned by the function 'get_predictions')
134 | -- mu2 : The sample mean over activations of the pool_3 layer, precalcualted
135 | on an representive data set.
136 | -- sigma2: The covariance matrix over activations of the pool_3 layer,
137 | precalcualted on an representive data set.
138 |
139 | Returns:
140 | -- dist : The Frechet Distance.
141 |
142 | Raises:
143 | -- InvalidFIDException if nan occures.
144 | """
145 | m = np.square(mu1 - mu2).sum()
146 | #s = sp.linalg.sqrtm(np.dot(sigma1, sigma2)) # EDIT: commented out
147 | s, _ = sp.linalg.sqrtm(np.dot(sigma1, sigma2), disp=False) # EDIT: added
148 | dist = m + np.trace(sigma1+sigma2 - 2*s)
149 | #if np.isnan(dist): # EDIT: commented out
150 | # raise InvalidFIDException("nan occured in distance calculation.") # EDIT: commented out
151 | #return dist # EDIT: commented out
152 | return np.real(dist) # EDIT: added
153 | #-------------------------------------------------------------------------------
154 |
155 |
156 | def calculate_activation_statistics(images, sess, batch_size=50, verbose=False):
157 | """Calculation of the statistics used by the FID.
158 | Params:
159 | -- images : Numpy array of dimension (n_images, hi, wi, 3). The values
160 | must lie between 0 and 255.
161 | -- sess : current session
162 | -- batch_size : the images numpy array is split into batches with batch size
163 | batch_size. A reasonable batch size depends on the available hardware.
164 | -- verbose : If set to True and parameter out_step is given, the number of calculated
165 | batches is reported.
166 | Returns:
167 | -- mu : The mean over samples of the activations of the pool_3 layer of
168 | the incption model.
169 | -- sigma : The covariance matrix of the activations of the pool_3 layer of
170 | the incption model.
171 | """
172 | act = get_activations(images, sess, batch_size, verbose)
173 | mu = np.mean(act, axis=0)
174 | sigma = np.cov(act, rowvar=False)
175 | return mu, sigma
176 | #-------------------------------------------------------------------------------
177 |
178 |
179 | #-------------------------------------------------------------------------------
180 | # The following functions aren't needed for calculating the FID
181 | # they're just here to make this module work as a stand-alone script
182 | # for calculating FID scores
183 | #-------------------------------------------------------------------------------
184 | def check_or_download_inception(inception_path):
185 | ''' Checks if the path to the inception file is valid, or downloads
186 | the file if it is not present. '''
187 | INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
188 | if inception_path is None:
189 | inception_path = '/tmp'
190 | inception_path = pathlib.Path(inception_path)
191 | model_file = inception_path / 'classify_image_graph_def.pb'
192 | if not model_file.exists():
193 | print("Downloading Inception model")
194 | from urllib import request
195 | import tarfile
196 | fn, _ = request.urlretrieve(INCEPTION_URL)
197 | with tarfile.open(fn, mode='r') as f:
198 | f.extract('classify_image_graph_def.pb', str(model_file.parent))
199 | return str(model_file)
200 |
201 |
202 | def _handle_path(path, sess):
203 | if path.endswith('.npz'):
204 | f = np.load(path)
205 | m, s = f['mu'][:], f['sigma'][:]
206 | f.close()
207 | else:
208 | path = pathlib.Path(path)
209 | files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
210 | x = np.array([imread(str(fn)).astype(np.float32) for fn in files])
211 | m, s = calculate_activation_statistics(x, sess)
212 | return m, s
213 |
214 |
215 | def calculate_fid_given_paths(paths, inception_path):
216 | ''' Calculates the FID of two paths. '''
217 | inception_path = check_or_download_inception(inception_path)
218 |
219 | for p in paths:
220 | if not os.path.exists(p):
221 | raise RuntimeError("Invalid path: %s" % p)
222 |
223 | os.environ["CUDA_VISIBLE_DEVICES"] = '1'
224 |
225 | create_inception_graph(str(inception_path))
226 | with tf.Session() as sess:
227 | sess.run(tf.global_variables_initializer())
228 | m1, s1 = _handle_path(paths[0], sess)
229 | m2, s2 = _handle_path(paths[1], sess)
230 | fid_value = calculate_frechet_distance(m1, s1, m2, s2)
231 | return fid_value
232 |
233 |
234 | if __name__ == "__main__":
235 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
236 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
237 | parser.add_argument("path", type=str, nargs=2,
238 | help='Path to the generated images or to .npz statistic files')
239 | parser.add_argument("-i", "--inception", type=str, default=None,
240 | help='Path to Inception model (will be downloaded if not provided)')
241 | parser.add_argument("--gpu", default="", type=str,
242 | help='GPU to use (leave blank for CPU only)')
243 | args = parser.parse_args()
244 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
245 | fid_value = calculate_fid_given_paths(args.path, args.inception)
246 | print("FID: ", fid_value)
247 |
248 | #----------------------------------------------------------------------------
249 | # EDIT: added
250 |
251 | class API:
252 | def __init__(self, num_images, image_shape, image_dtype, minibatch_size):
253 | import config
254 | self.network_dir = os.path.join(config.result_dir, '_inception_fid')
255 | self.network_file = check_or_download_inception(self.network_dir)
256 | self.sess = tf.get_default_session()
257 | create_inception_graph(self.network_file)
258 |
259 | def get_metric_names(self):
260 | return ['FID']
261 |
262 | def get_metric_formatting(self):
263 | return ['%-10.4f']
264 |
265 | def begin(self, mode):
266 | assert mode in ['warmup', 'reals', 'fakes']
267 | self.activations = []
268 |
269 | def feed(self, mode, minibatch):
270 | act = get_activations(minibatch.transpose(0,2,3,1), self.sess, batch_size=minibatch.shape[0])
271 | self.activations.append(act)
272 |
273 | def end(self, mode):
274 | act = np.concatenate(self.activations)
275 | mu = np.mean(act, axis=0)
276 | sigma = np.cov(act, rowvar=False)
277 | if mode in ['warmup', 'reals']:
278 | self.mu_real = mu
279 | self.sigma_real = sigma
280 | fid = calculate_frechet_distance(mu, sigma, self.mu_real, self.sigma_real)
281 | return [fid]
282 |
283 | #----------------------------------------------------------------------------
284 |
--------------------------------------------------------------------------------
/painter/painter.py:
--------------------------------------------------------------------------------
1 | from tkinter import *
2 | from PIL import Image, ImageTk, ImageDraw
3 | import tkinter.filedialog as tkFileDialog
4 | import numpy as np
5 | import cv2
6 | import os
7 | import subprocess
8 | import argparse
9 | import tensorflow as tf
10 | from config import Config
11 | from skimage import feature
12 | from skimage.color import rgb2gray
13 |
14 | from inpaint_model import InpaintModel
15 |
16 | # os.environ['CUDA_VISIBLE_DEVICES'] = str(np.argmax([int(x.split()[2]) for x in subprocess.Popen(
17 | # "nvidia-smi -q -d Memory | grep -A4 GPU | grep Free", shell=True, stdout=subprocess.PIPE).stdout.readlines()]))
18 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
19 |
20 | class Paint(object):
21 | MARKER_COLOR = 'white'
22 |
23 | def __init__(self, config):
24 | self.config = config
25 | print("******************************",self.config.CHECKPOINT)
26 |
27 | self.root = Tk()
28 |
29 | self.rect_button = Button(self.root, text='rectangle', command=self.use_rect, width=12, height=3)
30 | self.rect_button.grid(row=0, column=2)
31 |
32 | self.poly_button = Button(self.root, text='stroke', command=self.use_poly, width=12, height=3)
33 | self.poly_button.grid(row=1, column=2)
34 |
35 | self.revoke_button = Button(self.root, text='revoke', command=self.revoke, width=12, height=3)
36 | self.revoke_button.grid(row=2, column=2)
37 |
38 | self.clear_button = Button(self.root, text='clear', command=self.clear, width=12, height=3)
39 | self.clear_button.grid(row=3, column=2)
40 |
41 | self.c = Canvas(self.root, bg='white', width=config.IMG_SHAPES[1]+8, height=config.IMG_SHAPES[0])
42 | self.c.grid(row=0, column=0, rowspan=8)
43 |
44 | self.out = Canvas(self.root, bg='white', width=config.IMG_SHAPES[1]+8, height=config.IMG_SHAPES[0])
45 | self.out.grid(row=0, column=1, rowspan=8)
46 |
47 | self.save_button = Button(self.root, text="save", command=self.save, width=12, height=3)
48 | self.save_button.grid(row=6, column=2)
49 |
50 | self.load_button = Button(self.root, text='load', command=self.load, width=12, height=3)
51 | self.load_button.grid(row=5, column=2)
52 |
53 | self.fill_button = Button(self.root, text='fill', command=self.fill, width=12, height=3)
54 | self.fill_button.grid(row=7, column=2)
55 | self.filename = None
56 |
57 | self.setup()
58 | self.root.mainloop()
59 |
60 | def setup(self):
61 | self.old_x = None
62 | self.old_y = None
63 | self.start_x = None
64 | self.start_y = None
65 | self.end_x = None
66 | self.end_y = None
67 | self.eraser_on = False
68 | self.active_button = self.rect_button
69 | self.isPainting = False
70 | self.c.bind('', self.paint)
71 | self.c.bind('', self.reset)
72 | self.c.bind('', self.beginPaint)
73 | self.c.bind('', self.icon2pen)
74 | self.c.bind('', self.icon2mice)
75 | self.mode = 'rect'
76 | self.rect_buf = None
77 | self.line_buf = None
78 | assert self.mode in ['rect', 'poly']
79 | self.paint_color = self.MARKER_COLOR
80 | self.mask_candidate = []
81 | self.rect_candidate = []
82 | self.im_h = None
83 | self.im_w = None
84 | self.mask = None
85 | self.result = None
86 | self.blank = None
87 | self.line_width = 8
88 |
89 | # painter model
90 | self.model = InpaintModel(self.config)
91 | self.reuse = False
92 | sess_config = tf.ConfigProto()
93 | sess_config.gpu_options.allow_growth = False
94 | self.sess = tf.Session(config=sess_config)
95 |
96 | self.input_image_tf = tf.placeholder(dtype=tf.float32,
97 | shape=[1, self.config.IMG_SHAPES[0], self.config.IMG_SHAPES[1], 3])
98 | self.input_mask_tf = tf.placeholder(dtype=tf.float32,
99 | shape=[1, self.config.IMG_SHAPES[0], self.config.IMG_SHAPES[1], 1])
100 | self.input_edge_tf = tf.placeholder(dtype=tf.float32,
101 | shape=[1, self.config.IMG_SHAPES[0], self.config.IMG_SHAPES[1], 1])
102 | output = self.model.evaluate(self.input_image_tf, self.input_edge_tf, self.input_mask_tf,
103 | args=self.config, reuse=self.reuse)
104 | # output = (output + 1) * 127.5
105 | output = tf.minimum(tf.maximum(output[:, :, :, ::-1], 0), 255)
106 | # self.output = tf.cast(output, tf.uint8)
107 | self.output = output
108 |
109 | # load pretrained model
110 | vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
111 | assign_ops = list(map(lambda x: tf.assign(x, tf.contrib.framework.load_variable(self.config.CHECKPOINT, x.name)),
112 | vars_list))
113 | self.sess.run(assign_ops)
114 | print('Model loaded.')
115 |
116 | def checkResp(self):
117 | assert len(self.mask_candidate) == len(self.rect_candidate)
118 |
119 | def load(self):
120 | self.filename = tkFileDialog.askopenfilename(initialdir='./imgs',
121 | title="Select file",
122 | filetypes=(("png files", "*.png"), ("jpg files", "*.jpg"),
123 | ("all files", "*.*")))
124 | self.filename_ = self.filename.split('/')[-1][:-4]
125 | self.filepath = '/'.join(self.filename.split('/')[:-1])
126 | print(self.filename_, self.filepath)
127 | try:
128 | photo = Image.open(self.filename)
129 | self.image = cv2.imread(self.filename)
130 | except:
131 | print('do not load image')
132 | else:
133 | self.im_w, self.im_h = photo.size
134 | self.mask = np.zeros((self.im_h, self.im_w, 1)).astype(np.uint8)
135 | print(photo.size)
136 | self.displayPhoto = photo
137 | self.displayPhoto = self.displayPhoto.resize((self.im_w, self.im_h))
138 | self.draw = ImageDraw.Draw(self.displayPhoto)
139 | self.photo_tk = ImageTk.PhotoImage(image=self.displayPhoto)
140 | self.c.create_image(0, 0, image=self.photo_tk, anchor=NW)
141 | self.rect_candidate.clear()
142 | self.mask_candidate.clear()
143 | if self.blank is None:
144 | if not os.path.exists('imgs/blank.png'):
145 | self.blank = Image.new(mode='L', size=(1000,1000), color=1)
146 | else:
147 | self.blank = Image.open('imgs/blank.png')
148 | self.blank = self.blank.resize((self.im_w, self.im_h))
149 | self.blank_tk = ImageTk.PhotoImage(image=self.blank)
150 | self.out.create_image(0, 0, image=self.blank_tk, anchor=NW)
151 |
152 | def save(self):
153 | img = np.array(self.displayPhoto)
154 | cv2.imwrite(os.path.join(self.filepath, 'tmp.png'), img)
155 |
156 | if self.mode == 'rect':
157 | self.mask[:,:,:] = 0
158 | for rect in self.mask_candidate:
159 | self.mask[rect[1]:rect[3], rect[0]:rect[2], :] = 1
160 |
161 | self.save_filename = tkFileDialog.asksaveasfilename(initialdir=self.config.SAVEPATH,
162 | title="Select file",
163 | filetypes=(("png files", "*.png"), ("jpg files", "*.jpg"),
164 | ("all files", "*.*")))
165 | self.save_filename_ = self.save_filename.split('/')[-1][:-4]
166 | self.save_filepath = '/'.join(self.save_filename.split('/')[:-1])
167 |
168 | cv2.imwrite(os.path.join(self.save_filepath, self.save_filename_ + '_mask.png'), self.mask * 255)
169 | cv2.imwrite(os.path.join(self.save_filepath, self.save_filename_ + '_result.png'), self.result[0][:, :, ::-1])
170 | cv2.imwrite(os.path.join(self.save_filepath, self.save_filename_ + '_masked.png'),
171 | self.result[0][:, :, ::-1] * (1 - self.mask) + self.mask * 255)
172 |
173 | def fill(self):
174 | if self.mode == 'rect':
175 | for rect in self.mask_candidate:
176 | self.mask[rect[1]:rect[3], rect[0]:rect[2], :] = 1
177 | image = np.expand_dims(self.image, 0)
178 | mask = np.expand_dims(self.mask, 0)
179 |
180 | img_gray = rgb2gray(self.image)
181 | edge = feature.canny(img_gray, sigma=1.5).astype(np.float32)
182 | edge = np.reshape(edge,(1, 256, 256, 1))
183 | # print(image.shape)
184 | # print(mask.shape)
185 | # print(edge.shape)
186 |
187 | self.result = self.sess.run(self.output, feed_dict={self.input_image_tf: image * 1.0,
188 | self.input_mask_tf: mask * 1.0,
189 | self.input_edge_tf: edge * 1.0})
190 | cv2.imwrite('./imgs/tmp.png', self.result[0][:, :, ::-1]) # self.output has batch size = 1, so self.result[0]
191 |
192 | photo = Image.open('./imgs/tmp.png')
193 | self.displayPhotoResult = photo
194 | self.displayPhotoResult = self.displayPhotoResult.resize((self.im_w, self.im_h))
195 | self.photo_tk_result = ImageTk.PhotoImage(image=self.displayPhotoResult)
196 | self.out.create_image(0, 0, image=self.photo_tk_result, anchor=NW)
197 | return
198 |
199 | def use_rect(self):
200 | self.activate_button(self.rect_button)
201 | self.mode = 'rect'
202 |
203 | def use_poly(self):
204 | self.activate_button(self.poly_button)
205 | self.mode = 'poly'
206 |
207 | def revoke(self):
208 | if len(self.rect_candidate) > 0:
209 | self.c.delete(self.rect_candidate[-1])
210 | self.rect_candidate.remove(self.rect_candidate[-1])
211 | self.mask_candidate.remove(self.mask_candidate[-1])
212 | self.checkResp()
213 |
214 | def clear(self):
215 | self.mask = np.zeros((self.im_h, self.im_w, 1)).astype(np.uint8)
216 | if self.mode == 'poly':
217 | photo = Image.open(self.filename)
218 | self.image = cv2.imread(self.filename)
219 | self.displayPhoto = photo
220 | self.displayPhoto = self.displayPhoto.resize((self.im_w, self.im_h))
221 | self.draw = ImageDraw.Draw(self.displayPhoto)
222 | self.photo_tk = ImageTk.PhotoImage(image=self.displayPhoto)
223 | self.c.create_image(0, 0, image=self.photo_tk, anchor=NW)
224 | else:
225 | if self.rect_candidate is None or len(self.rect_candidate) == 0:
226 | return
227 | for item in self.rect_candidate:
228 | self.c.delete(item)
229 | self.rect_candidate.clear()
230 | self.mask_candidate.clear()
231 | self.checkResp()
232 |
233 | #TODO: reset canvas
234 | #TODO: undo and redo
235 | #TODO: draw triangle, rectangle, oval, text
236 |
237 | def activate_button(self, some_button, eraser_mode=False):
238 | self.active_button.config(relief=RAISED)
239 | some_button.config(relief=SUNKEN)
240 | self.active_button = some_button
241 | self.eraser_on = eraser_mode
242 |
243 | def beginPaint(self, event):
244 | self.start_x = event.x
245 | self.start_y = event.y
246 | self.isPainting = True
247 |
248 | def paint(self, event):
249 | if self.start_x and self.start_y and self.mode == 'rect':
250 | self.end_x = max(min(event.x, self.im_w), 0)
251 | self.end_y = max(min(event.y, self.im_h), 0)
252 | rect = self.c.create_rectangle(self.start_x, self.start_y, self.end_x, self.end_y, fill=self.paint_color)
253 | if self.rect_buf is not None:
254 | self.c.delete(self.rect_buf)
255 | self.rect_buf = rect
256 | elif self.old_x and self.old_y and self.mode == 'poly':
257 | line = self.c.create_line(self.old_x, self.old_y, event.x, event.y,
258 | width=self.line_width, fill=self.paint_color, capstyle=ROUND,
259 | smooth=True, splinesteps=36)
260 | cv2.line(self.mask, (self.old_x, self.old_y), (event.x, event.y), (1), self.line_width)
261 | self.old_x = event.x
262 | self.old_y = event.y
263 |
264 | def reset(self, event):
265 | self.old_x, self.old_y = None, None
266 | if self.mode == 'rect':
267 | self.isPainting = False
268 | rect = self.c.create_rectangle(self.start_x, self.start_y, self.end_x, self.end_y,
269 | fill=self.paint_color)
270 | if self.rect_buf is not None:
271 | self.c.delete(self.rect_buf)
272 | self.rect_buf = None
273 | self.rect_candidate.append(rect)
274 |
275 | x1, y1, x2, y2 = min(self.start_x, self.end_x), min(self.start_y, self.end_y),\
276 | max(self.start_x, self.end_x), max(self.start_y, self.end_y)
277 | # up left corner, low right corner
278 | self.mask_candidate.append((x1, y1, x2, y2))
279 | print(self.mask_candidate[-1])
280 |
281 | def icon2pen(self, event):
282 | return
283 |
284 | def icon2mice(self, event):
285 | return
286 |
287 |
288 | if __name__ == '__main__':
289 | config = Config('inpaint_config.yml')
290 | config.mode = 'silent'
291 | parser = argparse.ArgumentParser()
292 | parser.add_argument('--checkpoint', type=str, help='path to the model checkpoint')
293 | parser.add_argument('--save_path', type=str, help='path to the model checkpoint')
294 | args = parser.parse_args()
295 | config.CHECKPOINT = args.checkpoint
296 | config.SAVEPATH = args.save_path
297 | # print("@############################################", config.CHECKPIONT)
298 | ge = Paint(config)
299 |
--------------------------------------------------------------------------------
/src/train_inpaint_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import logging
4 | import tensorflow as tf
5 | import numpy as np
6 | import time
7 | import re
8 |
9 | from inpaint_model import InpaintModel
10 | from config import Config, select_gpu
11 | from utils_fn import (show_all_variables, load_mask, create_mask,
12 | save_images, load_validation_data, load_validation_mask, create_validation_mask,
13 | dataset_len, load_img_scale_edge, load_val_img_scale_edge)
14 |
15 | # Reproducible result
16 | np.random.seed(0)
17 | tf.set_random_seed(0)
18 |
19 |
20 | def multi_gpu_setting(model, args):
21 | gpu_num = args.NUM_GPUS
22 | batch_size = args.BATCH_SIZE
23 |
24 | with tf.device("/cpu:0"):
25 | """ Input Data (images and masks) """
26 | # images and edges
27 | if args.CUSTOM_DATASET:
28 | images_edges = load_img_scale_edge(args)
29 | else:
30 | images_edges = tf.placeholder(tf.float32,
31 | [args.BATCH_SIZE * gpu_num, args.IMG_SHAPES[0], args.IMG_SHAPES[1], args.IMG_SHAPES[2]],
32 | name='real_images')
33 | images_, edges_, edges_128_, edges_64_ = images_edges # a tuple
34 | images_ = tf.reshape(images_,
35 | [args.BATCH_SIZE * gpu_num, args.IMG_SHAPES[0], args.IMG_SHAPES[1], args.IMG_SHAPES[2]])
36 | edges_ = tf.reshape(edges_, [-1, args.IMG_SHAPES[0], args.IMG_SHAPES[1], 1])
37 | edges_128_ = tf.reshape(edges_128_, [-1, 128, 128, 1])
38 | edges_64_ = tf.reshape(edges_64_, [-1, 64, 64, 1])
39 |
40 | # masks
41 | if args.MASK_MODE == 'irregular':
42 | masks = load_mask(args)
43 | else:
44 | masks = tf.placeholder(tf.float32, [1, args.IMG_SHAPES[0], args.IMG_SHAPES[1], 1],
45 | name='regular_masks')
46 | _masks = tf.reshape(masks, [-1, args.IMG_SHAPES[0], args.IMG_SHAPES[1], 1])
47 |
48 | # opt
49 | g_optimizer = tf.train.AdamOptimizer(learning_rate=args.G_LR, beta1=0., beta2=0.9)
50 | d_optimizer = tf.train.AdamOptimizer(learning_rate=args.D_LR, beta1=0., beta2=0.9)
51 |
52 | # update grad
53 | tower_g_grads = []
54 | tower_d_grads = []
55 |
56 | with tf.variable_scope(tf.get_variable_scope()):
57 | for i in range(gpu_num): # GPU IDs
58 | with tf.device("/gpu:%d" % i):
59 | with tf.name_scope("tower_%d" % i):
60 | _images = images_[i * batch_size: (i + 1) * batch_size]
61 | _edges = edges_[i * batch_size: (i + 1) * batch_size]
62 | _edges_128 = edges_128_[i * batch_size: (i + 1) * batch_size]
63 | _edges_64 = edges_64_[i * batch_size: (i + 1) * batch_size]
64 | print(_images.shape)
65 | print(_masks.shape)
66 | print(_edges.shape)
67 | print(_edges_64)
68 | model.build_graph_with_losses(_images, _masks, _edges, _edges_128, _edges_64, args, reuse=tf.AUTO_REUSE)
69 | tf.get_variable_scope().reuse_variables()
70 | # scale 256
71 | _g256_grads = g_optimizer.compute_gradients(model.g_loss, var_list=model.total_g_vars)
72 | _d256_grads = d_optimizer.compute_gradients(model.d_loss, var_list=model.total_d_vars)
73 | tower_g_grads.append(_g256_grads)
74 | with open("tower_{}_g.txt".format(i), 'w') as f:
75 | for g in tower_g_grads[0]:
76 | f.write("g:"+str(g)+'\n')
77 | tower_d_grads.append(_d256_grads)
78 | with open("tower_{}_d.txt".format(i), 'w') as f:
79 | for g in tower_g_grads[0]:
80 | f.write("d:"+str(g)+'\n')
81 |
82 |
83 | # average grads
84 | g_grads = average_gradients(tower_g_grads)
85 | d_grads = average_gradients(tower_d_grads)
86 |
87 | # train op
88 | g_train_op = g_optimizer.apply_gradients(g_grads)
89 | d_train_op = d_optimizer.apply_gradients(d_grads)
90 |
91 | # summary model in the last gpu device
92 | all_sum_256 = model.all_sum # only keep the final summary
93 |
94 | # return train ops and inputs
95 | return g_train_op, d_train_op, images_edges, masks, all_sum_256
96 |
97 |
98 | def average_gradients(tower_grads):
99 | average_grads = []
100 | for grad_and_vars in zip(*tower_grads):
101 | grads = []
102 | for g, _ in grad_and_vars:
103 | expend_g = tf.expand_dims(g, 0)
104 | grads.append(expend_g)
105 | grad = tf.concat(grads, 0)
106 | grad = tf.reduce_mean(grad, 0)
107 | v = grad_and_vars[0][1]
108 | grad_and_var = (grad, v)
109 | average_grads.append(grad_and_var)
110 | return average_grads
111 |
112 |
113 | def single_gpu_setting(model, args):
114 | gpu_num = args.NUM_GPUS
115 | assert(gpu_num == 1)
116 |
117 | """ Input Data (images and masks) """
118 | # images and edges
119 | if args.CUSTOM_DATASET:
120 | images_edges = load_img_scale_edge(args)
121 | else:
122 | images_edges = tf.placeholder(tf.float32,
123 | [args.BATCH_SIZE, args.IMG_SHAPES[0], args.IMG_SHAPES[1], args.IMG_SHAPES[2]],
124 | name='real_images')
125 | images_, edges_, edges_128_, edges_64_ = images_edges # a tuple
126 | images = tf.reshape(images_,
127 | [args.BATCH_SIZE * gpu_num, args.IMG_SHAPES[0], args.IMG_SHAPES[1], args.IMG_SHAPES[2]])
128 | edges = tf.reshape(edges_, [-1, args.IMG_SHAPES[0], args.IMG_SHAPES[1], 1])
129 | edges_128 = tf.reshape(edges_128_, [-1, 128, 128, 1])
130 | edges_64 = tf.reshape(edges_64_, [-1, 64, 64, 1])
131 |
132 | # masks
133 | if args.MASK_MODE == 'irregular':
134 | masks = load_mask(args)
135 | else:
136 | masks = tf.placeholder(tf.float32, [1, args.IMG_SHAPES[0], args.IMG_SHAPES[1], 1],
137 | name='regular_masks')
138 | masks = tf.reshape(masks, [-1, args.IMG_SHAPES[0], args.IMG_SHAPES[1], 1])
139 |
140 | # build model with losses
141 | model.build_graph_with_losses(images, masks, edges, edges_128, edges_64, args, reuse=False)
142 |
143 | # train op
144 | g_train_op = tf.train.AdamOptimizer(learning_rate=args.G_LR, beta1=0., beta2=0.9).minimize(
145 | model.g_loss, var_list=model.total_g_vars)
146 | d_train_op = tf.train.AdamOptimizer(learning_rate=args.D_LR, beta1=0., beta2=0.9).minimize(
147 | model.d_loss, var_list=model.total_d_vars)
148 |
149 | # summary
150 | all_sum_256 = model.all_sum
151 |
152 | # return train ops and inputs
153 | return g_train_op, d_train_op, images_edges, masks, all_sum_256
154 |
155 |
156 | def main():
157 | """
158 | Training
159 | """
160 | # Load config file for run an inpainting model
161 | args = Config('inpaint_config.yml')
162 |
163 | # GPU config
164 | gpu_ids = args.GPU_ID
165 | os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(gpu) for gpu in gpu_ids]) # default "2"
166 | # os.environ["CUDA_VISIBLE_DEVICES"] = select_gpu()
167 | config_gpu = tf.ConfigProto()
168 | config_gpu.gpu_options.allow_growth = True # allow memory grow
169 | config_gpu.allow_soft_placement = True
170 | # config_gpu.log_device_placement = True
171 |
172 | # log setting
173 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
174 | logger = logging.getLogger("YOUNG")
175 | logger.setLevel(level=logging.INFO)
176 |
177 | """ Build Inpaint Model with Loss and Optimizer"""
178 | # Model and training setting
179 | model = InpaintModel(args)
180 | if args.NUM_GPUS > 1 or len(args.GPU_ID) > 1: # multi-gpu
181 | logger.info("Build Inpaint Model with Loss and Optimizer in Multi-GPU setting.")
182 | g_train256_op, d_train256_op, images_edges, masks, all_sum_256 = multi_gpu_setting(model, args)
183 | else: # cpu or single gpu
184 | logger.info("Build Inpaint Model with Loss and Optimizer in Single-GPU or CPU setting.")
185 | g_train256_op, d_train256_op, images_edges, masks, all_sum_256 = single_gpu_setting(model, args)
186 |
187 | # If validation?
188 | if args.VAL:
189 | logger.info("Build Validation Model.")
190 | with tf.device('/cpu:0'):
191 | # images
192 | images_edges_val, img_iterator_val = load_val_img_scale_edge(args)
193 | # masks
194 | if args.MASK_MODE == 'irregular':
195 | masks_val, mask_iterator_val = load_validation_mask(args)
196 | else:
197 | masks_val = tf.placeholder(tf.float32, [args.VAL_NUM, args.IMG_SHAPES[0], args.IMG_SHAPES[1], 1],
198 | name='val_regular_masks')
199 | model.build_validation_model(images_edges_val, masks_val, args)
200 |
201 | """ Train Logic"""
202 | with tf.Session(config=config_gpu) as sess:
203 |
204 | # Model dir
205 | # If restore a specific model
206 | if args.MODEL_RESTORE == '':
207 | args.MODEL_DIR = '-'.join(time.asctime().split()) + "_GPU" + '-'.join([str(gpu) for gpu in gpu_ids]) + \
208 | "_" + args.DATASET + "_" + args.GAN_TYPE + \
209 | '_' + str(args.GAN_LOSS_TYPE) + str(args.PATCH_GAN_ALPHA) + \
210 | "_" + "L1" + str(args.L1_FORE_ALPHA) + "_" + str(args.L1_BACK_ALPHA) + \
211 | "_" + "C" + str(args.CONTENT_FORE_ALPHA) + "_" + "S" + str(args.STYLE_FORE_ALPHA) +\
212 | "_" + "T" + str(args.TV_ALPHA) + "_" + args.PADDING + '_Deep_MT' +\
213 | "_" + str(args.ALPHA)
214 | else:
215 | args.MODEL_DIR = args.MODEL_RESTORE
216 |
217 | # Checkpoint dir
218 | checkpoint_dir = os.path.join(args.CHECKPOINT_DIR, args.MODEL_DIR)
219 | if not os.path.exists(checkpoint_dir):
220 | os.makedirs(checkpoint_dir)
221 |
222 | # Sample dir
223 | sample_dir = os.path.join(args.SAMPLE_DIR, args.MODEL_DIR)
224 | if not os.path.exists(sample_dir):
225 | os.makedirs(sample_dir)
226 |
227 | # Summary writer
228 | writer = tf.summary.FileWriter(args.LOG_DIR + '/' + args.MODEL_DIR, sess.graph)
229 |
230 | # Saver to save model: to save variables
231 | # TODO: we can choose variables to store and steps to keep (max_to_keep)
232 | saver = tf.train.Saver()
233 |
234 | # Initialize all the variables
235 | tf.global_variables_initializer().run()
236 | # Show network architecture
237 | show_all_variables()
238 |
239 | # Try to restore model
240 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) # get checkpoint and restore training
241 | if ckpt and ckpt.model_checkpoint_path:
242 | # print ckpt name with dir
243 | logger.info("Latest ckpt: {}".format(ckpt.model_checkpoint_path))
244 | logger.info("All ckpt: {}".format(ckpt.all_model_checkpoint_paths))
245 | # ckpt base name
246 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
247 | # restore
248 | # saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name)) # restore
249 | vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
250 |
251 | assign_ops = []
252 | for var in vars_list:
253 | vname = var.name
254 | from_name = vname
255 | try:
256 | var_value = tf.contrib.framework.load_variable(os.path.join(checkpoint_dir, ckpt_name), from_name)
257 | assign_ops.append(tf.assign(var, var_value))
258 | except Exception:
259 | continue
260 | sess.run(assign_ops)
261 | print('Model loaded.')
262 |
263 | counter = int(next(re.finditer("\d+", ckpt_name)).group(0))
264 | logger.info(" [*] Success to read {}".format(ckpt_name))
265 | else:
266 | counter = 0
267 | logger.info(" [*] Failed to find a checkpoint")
268 |
269 | # Parameters
270 | imgh = args.IMG_SHAPES[0]
271 | imgw = args.IMG_SHAPES[1]
272 |
273 | max_step = dataset_len(args) // (args.BATCH_SIZE * args.NUM_GPUS) # max step for each epoch
274 | last_step = int(args.EPOCH * max_step) # total steps
275 | max_iter = last_step * args.BATCH_SIZE * args.NUM_GPUS # max iteration when batch size is 1
276 |
277 | # continue to train
278 | if counter < last_step:
279 | current_epoch = counter // max_step
280 | current_step = counter % max_step + 1 # TODO: may not right here?
281 | logger.info("Start Training...")
282 | logger.info(
283 | "Total Epoch {}, Iteration per Epoch {}, Max Iteration {}, Max Iteration (batch_size=1) {}.".format(
284 | args.EPOCH, max_step, last_step, max_iter))
285 | logger.info("Epoch Start {} at step {}".format(current_epoch, current_step))
286 |
287 | # not continue to train
288 | else:
289 | current_step = 0
290 | current_epoch = args.EPOCH
291 |
292 | count = 1 + counter
293 | for epoch in range(current_epoch, args.EPOCH):
294 | logger.info("Epoch {}:".format(epoch))
295 | time_start = time.time()
296 | time_s = time_start
297 | for step in range(current_step, max_step+1):
298 |
299 | # save
300 | if count % args.SAVE_FREQ == 0 or count == last_step:
301 | saver.save(sess, os.path.join(checkpoint_dir, model.model_name + '.model'), global_step=count,write_meta_graph=False)
302 |
303 | if args.MASK_MODE == 'irregular':
304 | # logs
305 | if count % args.LOG_FREQ == 0 or count == last_step:
306 | all_sum = sess.run(model.all_sum)
307 | writer.add_summary(all_sum, count)
308 | # train step
309 | sess.run([d_train256_op, g_train256_op])
310 | else:
311 | mask = create_mask(imgw, imgh, imgw // 2, imgh // 2, delta=0) # random block with hole size (imgw // 2, imgh // 2)
312 | # logs
313 | if count % args.LOG_FREQ == 0 or count == last_step:
314 | all_sum = sess.run(model.all_sum, feed_dict={masks: mask})
315 | writer.add_summary(all_sum, count)
316 | # train step
317 | sess.run([d_train256_op, g_train256_op], feed_dict={masks: mask})
318 |
319 | # validation
320 | if args.VAL:
321 | if count % args.VAL_FREQ == 0 or count == last_step:
322 | sess.run(img_iterator_val.initializer)
323 |
324 | if args.MASK_MODE == 'irregular':
325 | sess.run(mask_iterator_val.initializer)
326 | try:
327 | val_all_sum = sess.run(model.val_all_sum_256)
328 |
329 | writer.add_summary(val_all_sum, count)
330 | except tf.errors.OutOfRangeError:
331 | break
332 | else:
333 | try:
334 | if args.STATIC_VIEW:
335 | mask = create_validation_mask(imgw, imgh, imgw // 2, imgh // 2, args, imgw // 4, imgh // 4)
336 | else:
337 | mask = create_validation_mask(imgw, imgh, imgw // 2, imgh // 2, args, delta=0)
338 | val_all_sum = sess.run(model.val_all_sum_256, feed_dict={masks_val: mask})
339 |
340 | writer.add_summary(val_all_sum, count)
341 | except tf.errors.OutOfRangeError:
342 | break
343 |
344 | # logger info
345 | if count % args.PRINT_FREQ == 0 or count == last_step:
346 | time_cost = (time.time() - time_start) / args.PRINT_FREQ
347 | time_remaining = (last_step - count) * time_cost / 3600.
348 | logger.info('epoch {}/{}, step {}/{}, cost {:.2f}s, remaining {:.2f}h.'.format(epoch, args.EPOCH, step, max_step, time_cost,time_remaining))
349 | time_start = time.time()
350 |
351 | current_step = 0
352 | count += 1
353 |
354 | logger.info('epoch {}/{}, cost {:.2f}min.'.format(epoch, args.EPOCH, (time.time() - time_s)/60))
355 |
356 | logger.info("Finish.")
357 |
358 |
359 |
360 | if __name__ == "__main__":
361 |
362 | main()
363 |
--------------------------------------------------------------------------------
/src/ops.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.contrib.framework.python.ops import add_arg_scope
3 |
4 |
5 | weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02)
6 | weight_regularizer = None
7 |
8 | @add_arg_scope
9 | def gen_conv(x, cnum, ksize, stride=1, rate=1, name='conv', IN=True, reuse=False,
10 | padding='SAME', activation=tf.nn.elu, use_bias=True, training=True, sn=False):
11 | """Define conv for generator.
12 |
13 | Args:
14 | x: Input.
15 | cnum: Channel number.
16 | ksize: Kernel size.
17 | Stride: Convolution stride.
18 | Rate: Rate for or dilated conv.
19 | name: Name of layers.
20 | padding: Default to SYMMETRIC.
21 | activation: Activation function after convolution.
22 | training: If current graph is for training or inference, used for bn.
23 |
24 | Returns:
25 | tf.Tensor: output
26 |
27 | """
28 | assert padding in ['SYMMETRIC', 'SAME', 'REFLECT']
29 | if padding == 'SYMMETRIC' or padding == 'REFLECT':
30 | """
31 | Padding layer.
32 | Dilated kernel size: k_r = ksize + (rate - 1)*(ksize - 1)
33 | Padding size: o = i + 2p - k_r and o = i, so p = rate * (ksize - 1) / 2 (when i and o has the same image shape)
34 | """
35 | p = int(rate*(ksize-1)/2)
36 | x = tf.pad(x, [[0,0], [p, p], [p, p], [0,0]], mode=padding)
37 | padding = 'VALID'
38 |
39 | # if spectrum normalization
40 | if sn:
41 | with tf.variable_scope(name, reuse=reuse):
42 | w = tf.get_variable("kernel", shape=[ksize, ksize, x.get_shape()[-1], cnum], initializer=weight_init,
43 | regularizer=weight_regularizer)
44 |
45 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w),
46 | strides=[1, stride, stride, 1], padding=padding, dilations=[1, rate, rate, 1])
47 | if use_bias:
48 | bias = tf.get_variable("bias", [cnum], initializer=tf.constant_initializer(0.0))
49 | x = tf.nn.bias_add(x, bias)
50 | else:
51 | x = tf.layers.conv2d(inputs=x, filters=cnum, activation=None,
52 | kernel_size=ksize, strides=stride,
53 | dilation_rate=rate, padding=padding,
54 | kernel_initializer=None,
55 | kernel_regularizer=weight_regularizer,
56 | use_bias=use_bias)
57 | if IN:
58 | x = tf.contrib.layers.instance_norm(x) # if instance norm? before non-linear activation!!!
59 | if activation is not None:
60 | x = activation(x)
61 | return x
62 |
63 | @add_arg_scope
64 | def gen_deconv(x, cnum, ksize=4, stride=2, rate=1, method='deconv',IN=True,
65 | activation=tf.nn.relu, name='upsample', padding='SAME', sn=False, training=True, reuse=False):
66 | """Define deconv for generator.
67 | The deconv is defined to be a x2 resize_nearest_neighbor operation with
68 | additional gen_conv operation.
69 |
70 | Args:
71 | x: Input.
72 | cnum: Channel number.
73 | name: Name of layers.
74 | training: If current graph is for training or inference, used for bn.
75 |
76 | Returns:
77 | tf.Tensor: output
78 |
79 | """
80 | with tf.variable_scope(name, reuse=reuse):
81 | if method == 'nearest':
82 | x = resize(x, func=tf.image.resize_nearest_neighbor) # tf.image.resize_bilinear ?
83 | x = gen_conv(
84 | x, cnum, 3, 1, name=name+'_conv', padding=padding,
85 | training=training, IN=IN)
86 | elif method == 'bilinear':
87 | x = resize(x, func=tf.image.resize_bilinear)
88 | x = gen_conv(
89 | x, cnum, 3, 1, name=name + '_conv', padding=padding,
90 | training=training, IN=IN)
91 | elif method == 'bicubic':
92 | x = resize(x, func=tf.image.resize_bicubic)
93 | x = gen_conv(
94 | x, cnum, 3, 1, name=name + '_conv', padding=padding,
95 | training=training, IN=IN) # default instance normalization, see function gen_conv()
96 | else:
97 | # assert padding in ['SYMMETRIC', 'SAME', 'REFLECT']
98 | # if padding == 'SYMMETRIC' or padding == 'REFLECT':
99 | # p = int(rate * (ksize - 1) / 2)
100 | # p = 0
101 | # x = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]], mode=padding)
102 | padding = 'SAME'
103 | x = tf.layers.conv2d_transpose(x, cnum, kernel_size=ksize, strides=stride,
104 | activation=None, padding=padding)
105 | if IN:
106 | x = tf.contrib.layers.instance_norm(x) # if instance norm?
107 | if activation is not None:
108 | x = activation(x)
109 | return x
110 |
111 | def resize(x, scale=2, to_shape=None, align_corners=True, dynamic=False,
112 | func=tf.image.resize_bilinear, name='resize'):
113 | if dynamic:
114 | xs = tf.cast(tf.shape(x), tf.float32)
115 | new_xs = [tf.cast(xs[1]*scale, tf.int32),
116 | tf.cast(xs[2]*scale, tf.int32)]
117 | else:
118 | xs = x.get_shape().as_list()
119 | new_xs = [int(xs[1]*scale), int(xs[2]*scale)]
120 | with tf.variable_scope(name):
121 | if to_shape is None:
122 | x = func(x, new_xs, align_corners=align_corners)
123 | else:
124 | x = func(x, [to_shape[0], to_shape[1]],
125 | align_corners=align_corners)
126 | return x
127 |
128 | # yj
129 | @add_arg_scope
130 | def resnet_blocks(x, cnum, ksize, stride, rate, block_num, name, IN=True,
131 | padding='REFLECT', activation=tf.nn.elu, training=True):
132 | for block in range(block_num):
133 | # x = resnet_block12(x, cnum, ksize, stride, rate, name+"_"+str(block), padding, activation, training=training)
134 | x = resnet_block21(x, cnum, ksize, stride, rate, name + "_" + str(block), padding=padding,
135 | activation=activation, training=training)
136 | return x
137 |
138 | # yj
139 | def resnet_block21(x, cnum, ksize, stride, rate, name, IN=True,
140 | padding='SAME', activation=tf.nn.relu, training=True):
141 | xin = x
142 | assert padding in ['SYMMETRIC', 'SAME', 'REFLECT']
143 | if padding == 'SYMMETRIC' or padding == 'REFLECT':
144 | p = int(rate*(ksize-1)/2)
145 | x = tf.pad(x, [[0,0], [p, p], [p, p], [0,0]], mode=padding)
146 | padding1 = 'VALID'
147 | else:
148 | padding1 = padding
149 | x = tf.layers.conv2d(
150 | x, cnum, ksize, stride, dilation_rate=rate,
151 | activation=None, padding=padding1, name=name+"0")
152 | if IN:
153 | x = tf.contrib.layers.instance_norm(x) # if instance norm?
154 | if activation is not None:
155 | x = activation(x)
156 |
157 | rate = 1
158 | if padding == 'SYMMETRIC' or padding == 'REFLECT':
159 | p = int(rate*(ksize-1)/2)
160 | x = tf.pad(x, [[0,0], [p, p], [p, p], [0,0]], mode=padding)
161 | padding2 = 'VALID'
162 | else:
163 | padding2 = padding
164 | x = tf.layers.conv2d(
165 | x, cnum, ksize, stride, dilation_rate=rate,
166 | activation=None, padding=padding2, name=name+"1")
167 | if IN:
168 | x = tf.contrib.layers.instance_norm(x) # if instance norm?
169 | return xin + x
170 |
171 | # yj
172 | def resnet_block12(x, cnum, ksize, stride, rate, name, IN=True,
173 | padding='REFLECT', activation=tf.nn.elu, training=True):
174 | xin = x
175 | rate = 1
176 | assert padding in ['SYMMETRIC', 'SAME', 'REFLECT']
177 | if padding == 'SYMMETRIC' or padding == 'REFLECT':
178 | p = int(rate*(ksize-1)/2)
179 | x = tf.pad(x, [[0,0], [p, p], [p, p], [0,0]], mode=padding)
180 | padding1 = 'VALID'
181 | else:
182 | padding1 = padding
183 | x = tf.layers.conv2d(
184 | x, cnum, ksize, stride, dilation_rate=rate,
185 | activation=None, padding=padding1, name=name+"0")
186 | if IN:
187 | x = tf.contrib.layers.instance_norm(x) # if instance norm?
188 | if activation is not None:
189 | x = activation(x)
190 |
191 | rate = 2
192 | if padding == 'SYMMETRIC' or padding == 'REFLECT':
193 | p = int(rate*(ksize-1)/2)
194 | x = tf.pad(x, [[0,0], [p, p], [p, p], [0,0]], mode=padding)
195 | padding2 = 'VALID'
196 | else:
197 | padding2 = padding
198 | x = tf.layers.conv2d(
199 | x, cnum, ksize, stride, dilation_rate=rate,
200 | activation=None, padding=padding2, name=name+"1")
201 | if IN:
202 | x = tf.contrib.layers.instance_norm(x) # if instance norm?
203 |
204 | return xin + x
205 |
206 |
207 | def torgb(x, cnum, ksize, stride, rate, name, activation=tf.nn.tanh, padding="SAME"):
208 | x = tf.layers.conv2d(
209 | x, cnum, ksize, stride, dilation_rate=rate,
210 | activation=activation, padding=padding, name=name)
211 | # x = tf.clip_by_value(x, -1., 1.)
212 | return x
213 |
214 |
215 | def dis_conv(x, cnum, ksize=5, stride=2, rate=1, activation=tf.nn.leaky_relu, name='conv',
216 | padding='SAME', use_bias=True, sn=True, training=True, reuse=False):
217 | """Define conv for discriminator.
218 | Activation is set to leaky_relu.
219 |
220 | Args:
221 | x: Input.
222 | cnum: Channel number.
223 | ksize: Kernel size.
224 | stride: Convolution stride.
225 | name: Name of layers.
226 | training: If current graph is for training or inference, used for bn.
227 |
228 | Returns:
229 | tf.Tensor: output
230 |
231 | """
232 | # if spectrum normalization
233 | if sn:
234 | with tf.variable_scope(name, reuse=reuse):
235 | w = tf.get_variable("kernel", shape=[ksize, ksize, x.get_shape()[-1], cnum], initializer=weight_init,
236 | regularizer=weight_regularizer)
237 |
238 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w),
239 | strides=[1, stride, stride, 1], padding=padding, dilations=[1, rate, rate, 1])
240 | if use_bias:
241 | bias = tf.get_variable("bias", [cnum], initializer=tf.constant_initializer(0.0))
242 | x = tf.nn.bias_add(x, bias)
243 | if activation is not None:
244 | x = activation(x)
245 | else:
246 | x = tf.layers.conv2d(inputs=x, filters=cnum, activation=activation,
247 | kernel_size=ksize, strides=stride,
248 | dilation_rate=rate, padding=padding,
249 | kernel_initializer=None,
250 | kernel_regularizer=None,
251 | use_bias=use_bias,
252 | reuse=reuse)
253 | return x
254 |
255 | def flatten(x, name='flatten'):
256 | """Flatten wrapper.
257 | """
258 | with tf.variable_scope(name):
259 | return tf.contrib.layers.flatten(x)
260 |
261 | def out_complete(out, x_incomplete, mask, res):
262 | mask = tf.image.resize_images(mask, (res, res))
263 | x_incomplete = tf.image.resize_images(x_incomplete, (res, res))
264 | x_complete = out * mask + x_incomplete * (1. - mask)
265 | return x_complete
266 |
267 |
268 | # linear embedding
269 | @add_arg_scope
270 | def conv(x, channels, kernel=3, stride=1, pad=0, pad_type='REFLECT', use_bias=True, sn=False, scope='conv_0',
271 | reuse=False, training=False, padding=None):
272 | with tf.variable_scope(scope, reuse=reuse):
273 | if pad_type == 'zero' :
274 | x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]])
275 | if pad_type == 'reflect' :
276 | x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]], mode='REFLECT')
277 |
278 | if sn :
279 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init,
280 | regularizer=weight_regularizer)
281 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w),
282 | strides=[1, stride, stride, 1], padding='VALID')
283 | if use_bias :
284 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
285 | x = tf.nn.bias_add(x, bias)
286 |
287 | else :
288 | x = tf.layers.conv2d(inputs=x, filters=channels,
289 | kernel_size=kernel, kernel_initializer=weight_init,
290 | kernel_regularizer=weight_regularizer,
291 | strides=stride, use_bias=use_bias, reuse=reuse)
292 | return x
293 |
294 | def spectral_norm(w, iteration=1):
295 | w_shape = w.shape.as_list()
296 | w = tf.reshape(w, [-1, w_shape[-1]])
297 |
298 | u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False)
299 |
300 | u_hat = u
301 | v_hat = None
302 | for i in range(iteration):
303 | """
304 | power iteration
305 | Usually iteration = 1 will be enough
306 | """
307 | v_ = tf.matmul(u_hat, tf.transpose(w))
308 | v_hat = l2_norm(v_)
309 |
310 | u_ = tf.matmul(v_hat, w)
311 | u_hat = l2_norm(u_)
312 |
313 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
314 | w_norm = w / sigma
315 |
316 | with tf.control_dependencies([u.assign(u_hat)]):
317 | w_norm = tf.reshape(w_norm, w_shape)
318 |
319 | return w_norm
320 |
321 | def l2_norm(v, eps=1e-12):
322 | return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps)
323 |
324 | def hw_flatten(x) :
325 | return tf.reshape(x, shape=[x.shape[0], -1, x.shape[-1]])
326 |
327 | def max_pooling(x, pool_size=2):
328 | x = tf.layers.max_pooling2d(x, pool_size=pool_size, strides=pool_size, padding='SAME')
329 | return x
330 |
331 |
332 | def avg_pooling(x, pool_size=2):
333 | x = tf.layers.average_pooling2d(x, pool_size=pool_size, strides=pool_size, padding='SAME')
334 | return x
335 |
336 |
337 |
338 | """##### attention #####"""
339 | def attention(x, channels, neighbors=1, use_bias=True, sn=False, down_scale = 2, pool_scale=2,
340 | name='attention_pooling', training=True, padding='REFLECT', reuse=False):
341 | if neighbors > 1:
342 | x = attention_with_neighbors(x, channels, down_scale=down_scale, pool_scale=pool_scale, name=name)
343 | else:
344 | x = attention_with_pooling(x, channels, down_scale=down_scale, pool_scale=pool_scale, name=name)
345 | return x
346 |
347 | @add_arg_scope
348 | def attention_with_pooling(x, channels, ksize=4, use_bias=True, sn=False, down_scale = 2, pool_scale=2,
349 | name='attention_pooling', training=True, padding='REFLECT', reuse=False):
350 | with tf.variable_scope(name, reuse=reuse):
351 | x_origin = x
352 |
353 | # down sampling
354 | if down_scale > 1:
355 | x = gen_conv(x, channels, ksize, stride=down_scale, activation=tf.nn.relu, name='attention_down_sample',reuse=reuse)
356 |
357 | # attention
358 | f = conv(x, channels // 16, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='f_conv', reuse=reuse) # [bs, h, w, c']
359 | f = max_pooling(f, pool_scale)
360 | # f = avg_pooling(f)
361 |
362 | g = conv(x, channels // 16, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='g_conv',reuse=reuse) # [bs, h, w, c']
363 |
364 | h = conv(x, channels // 16, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='h_conv',reuse=reuse) # [bs, h, w, c]
365 | h = max_pooling(h, pool_scale)
366 | # h = avg_pooling(h) [4,65536,4096]
367 |
368 | # N = h * w
369 | s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N]
370 |
371 | beta = tf.nn.softmax(s) # attention map
372 |
373 | o = tf.matmul(beta, hw_flatten(h)) # [bs, N, C]
374 | gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))
375 |
376 | o = tf.reshape(o, shape=[x.shape[0], x.shape[1], x.shape[2], channels // 16]) # [bs, h, w, C]
377 | # o = conv(o, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='attn_conv_up') # from bottleneck
378 |
379 | # up sampling
380 | if down_scale > 1:
381 | o = gen_deconv(o, channels, ksize, method='deconv', stride=down_scale, activation=tf.nn.relu, name='attention_down_upsample',reuse=reuse)
382 |
383 | x = gamma * o + x_origin
384 |
385 | return x
386 |
387 | # attention consider neighbors
388 | @add_arg_scope
389 | def attention_with_neighbors(x, channels, ksize=3, use_bias=True, sn=False, stride=2,
390 | down_scale = 2, pool_scale=2, name='attention_pooling',
391 | training=True, padding='REFLECT', reuse=False):
392 | with tf.variable_scope(name, reuse=reuse):
393 | x1 = x
394 |
395 | # downsample input feature maps if needed due to limited GPU memory
396 | # down sampling
397 | if down_scale > 1:
398 | x1 = gen_conv(x1, channels, ksize, stride=down_scale, activation=tf.nn.relu, name='attention_down_sample',
399 | reuse=reuse)
400 | # get shapes
401 | int_x1s = x1.get_shape().as_list()
402 | # extract patches from high-level feature maps for matching and attending
403 | x1_groups = tf.split(x1, int_x1s[0], axis=0)
404 | w = tf.extract_image_patches(
405 | x1, [1, ksize, ksize, 1], [1, stride, stride, 1], [1, 1, 1, 1], padding='SAME')
406 | w = tf.reshape(w, [int_x1s[0], -1, ksize, ksize, int_x1s[3]])
407 | w = tf.transpose(w, [0, 2, 3, 4, 1]) # transpose to [b, ksize, ksize, c, hw/4] # need transpose?? -- 480
408 | w_groups = tf.split(w, int_x1s[0], axis=0)
409 |
410 | # matching and attending hole and non-hole patches
411 | y = []
412 | scale = 10.
413 | # high level patches: w_groups, low level patches: raw_w_groups, x2_groups: high level feature map
414 | for xi, wi in zip(x1_groups, w_groups):
415 | # matching on high-level feature maps
416 | wi = wi[0]
417 | wi_normed = wi / tf.maximum(tf.sqrt(tf.reduce_sum(tf.square(wi), axis=[0, 1, 2])), 1e-4)
418 | yi = tf.nn.conv2d(xi, wi_normed, strides=[1, 1, 1, 1], padding="SAME")
419 | yi = tf.reshape(yi, [1, int_x1s[1], int_x1s[2], (int_x1s[1] // stride) * (int_x1s[2] // stride)])
420 | yi = tf.nn.softmax(yi * scale, 3)
421 | # non local mean
422 | wi_center = tf.transpose(wi, [0, 1, 3, 2])
423 | yi = tf.nn.conv2d(yi, wi_center, strides=[1, 1, 1, 1], padding="SAME") / 4.
424 |
425 | # filter: [height, width, output_channels, in_channels]
426 | y.append(yi)
427 | y = tf.concat(y, axis=0)
428 | y.set_shape(int_x1s)
429 | # up sampling
430 | if down_scale > 1:
431 | y = gen_deconv(y, channels, ksize, method='deconv', stride=down_scale, activation=tf.nn.relu,
432 | name='attention_down_upsample', reuse=reuse)
433 |
434 | gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))
435 | x = gamma * y + x
436 | x = tf.layers.conv2d(x, channels, 3, 1, dilation_rate=1, activation=tf.nn.relu, padding='SAME')
437 | return x
438 |
439 | def normalize(x) :
440 | return x/127.5 - 1
441 |
442 | def imsave(images, size, path):
443 | return scipy.misc.imsave(path, merge(images, size))
444 |
445 | def inverse_transform(images):
446 | return (images+1.)*127.5
--------------------------------------------------------------------------------
/painter/ops.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.contrib.framework.python.ops import add_arg_scope
3 |
4 | weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02)
5 | weight_regularizer = None
6 |
7 | @add_arg_scope
8 | def gen_conv(x, cnum, ksize, stride=1, rate=1, name='conv', IN=True, reuse=False,
9 | padding='SAME', activation=tf.nn.elu, use_bias=True, training=True, sn=False):
10 | """Define conv for generator.
11 |
12 | Args:
13 | x: Input.
14 | cnum: Channel number.
15 | ksize: Kernel size.
16 | Stride: Convolution stride.
17 | Rate: Rate for or dilated conv.
18 | name: Name of layers.
19 | padding: Default to SYMMETRIC.
20 | activation: Activation function after convolution.
21 | training: If current graph is for training or inference, used for bn.
22 |
23 | Returns:
24 | tf.Tensor: output
25 |
26 | """
27 | assert padding in ['SYMMETRIC', 'SAME', 'REFLECT']
28 | if padding == 'SYMMETRIC' or padding == 'REFLECT':
29 | """
30 | Padding layer.
31 | Dilated kernel size: k_r = ksize + (rate - 1)*(ksize - 1)
32 | Padding size: o = i + 2p - k_r and o = i, so p = rate * (ksize - 1) / 2 (when i and o has the same image shape)
33 | """
34 | p = int(rate*(ksize-1)/2)
35 | x = tf.pad(x, [[0,0], [p, p], [p, p], [0,0]], mode=padding)
36 | padding = 'VALID'
37 |
38 | # if spectrum normalization
39 | if sn:
40 | with tf.variable_scope(name, reuse=reuse):
41 | w = tf.get_variable("kernel", shape=[ksize, ksize, x.get_shape()[-1], cnum], initializer=weight_init,
42 | regularizer=weight_regularizer)
43 |
44 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w),
45 | strides=[1, stride, stride, 1], padding=padding, dilations=[1, rate, rate, 1])
46 | if use_bias:
47 | bias = tf.get_variable("bias", [cnum], initializer=tf.constant_initializer(0.0))
48 | x = tf.nn.bias_add(x, bias)
49 | else:
50 | x = tf.layers.conv2d(inputs=x, filters=cnum, activation=None,
51 | kernel_size=ksize, strides=stride,
52 | dilation_rate=rate, padding=padding,
53 | kernel_initializer=None,
54 | kernel_regularizer=weight_regularizer,
55 | use_bias=use_bias)
56 | if IN:
57 | x = tf.contrib.layers.instance_norm(x) # if instance norm? before non-linear activation!!!
58 | if activation is not None:
59 | x = activation(x)
60 | return x
61 |
62 | @add_arg_scope
63 | def gen_deconv(x, cnum, ksize=4, stride=2, rate=1, method='deconv',IN=True,
64 | activation=tf.nn.relu, name='upsample', padding='SAME', sn=False, training=True, reuse=False):
65 | """Define deconv for generator.
66 | The deconv is defined to be a x2 resize_nearest_neighbor operation with
67 | additional gen_conv operation.
68 |
69 | Args:
70 | x: Input.
71 | cnum: Channel number.
72 | name: Name of layers.
73 | training: If current graph is for training or inference, used for bn.
74 |
75 | Returns:
76 | tf.Tensor: output
77 |
78 | """
79 | with tf.variable_scope(name, reuse=reuse):
80 | if method == 'nearest':
81 | x = resize(x, func=tf.image.resize_nearest_neighbor) # tf.image.resize_bilinear ?
82 | x = gen_conv(
83 | x, cnum, 3, 1, name=name+'_conv', padding=padding,
84 | training=training, IN=IN)
85 | elif method == 'bilinear':
86 | x = resize(x, func=tf.image.resize_bilinear)
87 | x = gen_conv(
88 | x, cnum, 3, 1, name=name + '_conv', padding=padding,
89 | training=training, IN=IN)
90 | elif method == 'bicubic':
91 | x = resize(x, func=tf.image.resize_bicubic)
92 | x = gen_conv(
93 | x, cnum, 3, 1, name=name + '_conv', padding=padding,
94 | training=training, IN=IN) # default instance normalization, see function gen_conv()
95 | else:
96 | # assert padding in ['SYMMETRIC', 'SAME', 'REFLECT']
97 | # if padding == 'SYMMETRIC' or padding == 'REFLECT':
98 | # p = int(rate * (ksize - 1) / 2)
99 | # p = 0
100 | # x = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]], mode=padding)
101 | padding = 'SAME'
102 | x = tf.layers.conv2d_transpose(x, cnum, kernel_size=ksize, strides=stride,
103 | activation=None, padding=padding)
104 | if IN:
105 | x = tf.contrib.layers.instance_norm(x) # if instance norm?
106 | if activation is not None:
107 | x = activation(x)
108 | return x
109 |
110 | def resize(x, scale=2, to_shape=None, align_corners=True, dynamic=False,
111 | func=tf.image.resize_bilinear, name='resize'):
112 | if dynamic:
113 | xs = tf.cast(tf.shape(x), tf.float32)
114 | new_xs = [tf.cast(xs[1]*scale, tf.int32),
115 | tf.cast(xs[2]*scale, tf.int32)]
116 | else:
117 | xs = x.get_shape().as_list()
118 | new_xs = [int(xs[1]*scale), int(xs[2]*scale)]
119 | with tf.variable_scope(name):
120 | if to_shape is None:
121 | x = func(x, new_xs, align_corners=align_corners)
122 | else:
123 | x = func(x, [to_shape[0], to_shape[1]],
124 | align_corners=align_corners)
125 | return x
126 |
127 | # yj
128 | @add_arg_scope
129 | def resnet_blocks(x, cnum, ksize, stride, rate, block_num, name, IN=True,
130 | padding='REFLECT', activation=tf.nn.elu, training=True):
131 | for block in range(block_num):
132 | # x = resnet_block12(x, cnum, ksize, stride, rate, name+"_"+str(block), padding, activation, training=training)
133 | x = resnet_block21(x, cnum, ksize, stride, rate, name + "_" + str(block), padding=padding,
134 | activation=activation, training=training)
135 | return x
136 |
137 | # yj
138 | def resnet_block21(x, cnum, ksize, stride, rate, name, IN=True,
139 | padding='SAME', activation=tf.nn.relu, training=True):
140 | xin = x
141 | assert padding in ['SYMMETRIC', 'SAME', 'REFLECT']
142 | if padding == 'SYMMETRIC' or padding == 'REFLECT':
143 | p = int(rate*(ksize-1)/2)
144 | x = tf.pad(x, [[0,0], [p, p], [p, p], [0,0]], mode=padding)
145 | padding1 = 'VALID'
146 | else:
147 | padding1 = padding
148 | x = tf.layers.conv2d(
149 | x, cnum, ksize, stride, dilation_rate=rate,
150 | activation=None, padding=padding1, name=name+"0")
151 | if IN:
152 | x = tf.contrib.layers.instance_norm(x) # if instance norm?
153 | if activation is not None:
154 | x = activation(x)
155 |
156 | rate = 1
157 | if padding == 'SYMMETRIC' or padding == 'REFLECT':
158 | p = int(rate*(ksize-1)/2)
159 | x = tf.pad(x, [[0,0], [p, p], [p, p], [0,0]], mode=padding)
160 | padding2 = 'VALID'
161 | else:
162 | padding2 = padding
163 | x = tf.layers.conv2d(
164 | x, cnum, ksize, stride, dilation_rate=rate,
165 | activation=None, padding=padding2, name=name+"1")
166 | if IN:
167 | x = tf.contrib.layers.instance_norm(x) # if instance norm?
168 | return xin + x
169 |
170 | # yj
171 | def resnet_block12(x, cnum, ksize, stride, rate, name, IN=True,
172 | padding='REFLECT', activation=tf.nn.elu, training=True):
173 | xin = x
174 | rate = 1
175 | assert padding in ['SYMMETRIC', 'SAME', 'REFLECT']
176 | if padding == 'SYMMETRIC' or padding == 'REFLECT':
177 | p = int(rate*(ksize-1)/2)
178 | x = tf.pad(x, [[0,0], [p, p], [p, p], [0,0]], mode=padding)
179 | padding1 = 'VALID'
180 | else:
181 | padding1 = padding
182 | x = tf.layers.conv2d(
183 | x, cnum, ksize, stride, dilation_rate=rate,
184 | activation=None, padding=padding1, name=name+"0")
185 | if IN:
186 | x = tf.contrib.layers.instance_norm(x) # if instance norm?
187 | if activation is not None:
188 | x = activation(x)
189 |
190 | rate = 2
191 | if padding == 'SYMMETRIC' or padding == 'REFLECT':
192 | p = int(rate*(ksize-1)/2)
193 | x = tf.pad(x, [[0,0], [p, p], [p, p], [0,0]], mode=padding)
194 | padding2 = 'VALID'
195 | else:
196 | padding2 = padding
197 | x = tf.layers.conv2d(
198 | x, cnum, ksize, stride, dilation_rate=rate,
199 | activation=None, padding=padding2, name=name+"1")
200 | if IN:
201 | x = tf.contrib.layers.instance_norm(x) # if instance norm?
202 |
203 | return xin + x
204 |
205 |
206 | # TODO:torgb, only with conv 1x1 and bias are enough? 线性输出 vs 使用tanh激活函数
207 | def torgb(x, cnum, ksize, stride, rate, name, activation=tf.nn.tanh, padding="SAME"):
208 | x = tf.layers.conv2d(
209 | x, cnum, ksize, stride, dilation_rate=rate,
210 | activation=activation, padding=padding, name=name)
211 | # x = tf.clip_by_value(x, -1., 1.)
212 | return x
213 |
214 |
215 | def dis_conv(x, cnum, ksize=5, stride=2, rate=1, activation=tf.nn.leaky_relu, name='conv',
216 | padding='SAME', use_bias=True, sn=True, training=True, reuse=False):
217 | """Define conv for discriminator.
218 | Activation is set to leaky_relu.
219 |
220 | Args:
221 | x: Input.
222 | cnum: Channel number.
223 | ksize: Kernel size.
224 | stride: Convolution stride.
225 | name: Name of layers.
226 | training: If current graph is for training or inference, used for bn.
227 |
228 | Returns:
229 | tf.Tensor: output
230 |
231 | """
232 | # if spectrum normalization
233 | if sn:
234 | with tf.variable_scope(name, reuse=reuse):
235 | w = tf.get_variable("kernel", shape=[ksize, ksize, x.get_shape()[-1], cnum], initializer=weight_init,
236 | regularizer=weight_regularizer)
237 |
238 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w),
239 | strides=[1, stride, stride, 1], padding=padding, dilations=[1, rate, rate, 1])
240 | if use_bias:
241 | bias = tf.get_variable("bias", [cnum], initializer=tf.constant_initializer(0.0))
242 | x = tf.nn.bias_add(x, bias)
243 | if activation is not None:
244 | x = activation(x)
245 | else:
246 | x = tf.layers.conv2d(inputs=x, filters=cnum, activation=activation,
247 | kernel_size=ksize, strides=stride,
248 | dilation_rate=rate, padding=padding,
249 | kernel_initializer=None,
250 | kernel_regularizer=None,
251 | use_bias=use_bias,
252 | reuse=reuse)
253 | return x
254 |
255 | def flatten(x, name='flatten'):
256 | """Flatten wrapper.
257 | """
258 | with tf.variable_scope(name):
259 | return tf.contrib.layers.flatten(x)
260 |
261 | def out_complete(out, x_incomplete, mask, res):
262 | mask = tf.image.resize_images(mask, (res, res))
263 | x_incomplete = tf.image.resize_images(x_incomplete, (res, res))
264 | x_complete = out * mask + x_incomplete * (1. - mask)
265 | return x_complete
266 |
267 |
268 | # linear embedding
269 | @add_arg_scope
270 | def conv(x, channels, kernel=3, stride=1, pad=0, pad_type='REFLECT', use_bias=True, sn=False, scope='conv_0', reuse=False, training=False, padding=None):
271 | with tf.variable_scope(scope, reuse=reuse):
272 | if pad_type == 'zero' :
273 | x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]])
274 | if pad_type == 'reflect' :
275 | x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]], mode='REFLECT')
276 |
277 | if sn :
278 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init,
279 | regularizer=weight_regularizer)
280 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w),
281 | strides=[1, stride, stride, 1], padding='VALID')
282 | if use_bias :
283 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
284 | x = tf.nn.bias_add(x, bias)
285 |
286 | else :
287 | x = tf.layers.conv2d(inputs=x, filters=channels,
288 | kernel_size=kernel, kernel_initializer=weight_init,
289 | kernel_regularizer=weight_regularizer,
290 | strides=stride, use_bias=use_bias, reuse=reuse)
291 | return x
292 |
293 | def spectral_norm(w, iteration=1):
294 | w_shape = w.shape.as_list()
295 | w = tf.reshape(w, [-1, w_shape[-1]])
296 |
297 | u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False)
298 |
299 | u_hat = u
300 | v_hat = None
301 | for i in range(iteration):
302 | """
303 | power iteration
304 | Usually iteration = 1 will be enough
305 | """
306 | v_ = tf.matmul(u_hat, tf.transpose(w))
307 | v_hat = l2_norm(v_)
308 |
309 | u_ = tf.matmul(v_hat, w)
310 | u_hat = l2_norm(u_)
311 |
312 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
313 | w_norm = w / sigma
314 |
315 | with tf.control_dependencies([u.assign(u_hat)]):
316 | w_norm = tf.reshape(w_norm, w_shape)
317 |
318 | return w_norm
319 |
320 | def l2_norm(v, eps=1e-12):
321 | return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps)
322 |
323 | def hw_flatten(x) :
324 | return tf.reshape(x, shape=[x.shape[0], -1, x.shape[-1]])
325 |
326 | def max_pooling(x, pool_size=2):
327 | x = tf.layers.max_pooling2d(x, pool_size=pool_size, strides=pool_size, padding='SAME')
328 | return x
329 |
330 |
331 | def avg_pooling(x, pool_size=2):
332 | x = tf.layers.average_pooling2d(x, pool_size=pool_size, strides=pool_size, padding='SAME')
333 | return x
334 |
335 | # ATN layer
336 | import tensorflow as tf
337 | from tensorflow.contrib.framework.python.ops import add_arg_scope
338 |
339 | @add_arg_scope
340 | def AtnConv(x1, x2, mask=None, ksize=3, stride=1, rate=2,
341 | softmax_scale=10., training=True, rescale=False):
342 | r""" Attention transfer networks implementation in tensorflow
343 |
344 | Attention transfer networks is introduced in publication:
345 | Learning Pyramid-Context Encoder Networks for High-Quality Image Inpainting, Zeng et al.
346 | https://arxiv.org/pdf/1904.07475.pdf
347 | https://github.com/researchmm/PEN-Net-for-Inpainting
348 | inspired by:
349 | Generative Image Inpainting with Contextual Attention, Yu et al.
350 | https://github.com/JiahuiYu/generative_inpainting/blob/master/inpaint_ops.py
351 | https://arxiv.org/abs/1801.07892
352 | Args:
353 | x1: low-level feature map with larger size [b, h, w, c].
354 | x2: high-level feature map with smaller size [b, h/2, w/2, c].
355 | mask: Input mask, 1 for missing regions 0 for known regions.
356 | ksize: Kernel size for attention transfer networks.
357 | stride: Stride for extracting patches from feature map.
358 | rate: Dilation for matching.
359 | softmax_scale: Scaled softmax for attention.
360 | training: Indicating if current graph is training or inference.
361 | rescale: Indicating if input feature maps need to be downsample
362 | Returns:
363 | tf.Tensor: reconstructed feature map
364 | """
365 | # downsample input feature maps if needed due to limited GPU memory
366 | if rescale:
367 | x1 = resize(x1, scale=1. / 2, func=tf.image.resize_nearest_neighbor)
368 | x2 = resize(x2, scale=1. / 2, func=tf.image.resize_nearest_neighbor)
369 | # get shapes
370 | raw_x1s = tf.shape(x1)
371 | int_x1s = x1.get_shape().as_list()
372 | int_x2s = x2.get_shape().as_list()
373 |
374 | # extract patches from low-level feature maps for reconstruction
375 | kernel = 2 * rate
376 | raw_w = tf.extract_image_patches(
377 | x1, [1, kernel, kernel, 1], [1, rate * stride, rate * stride, 1], [1, 1, 1, 1], padding='SAME')
378 | raw_w = tf.reshape(raw_w, [int_x1s[0], -1, kernel, kernel, int_x1s[3]])
379 | raw_w = tf.transpose(raw_w, [0, 2, 3, 4, 1]) # transpose to [b, kernel, kernel, c, hw]
380 | raw_w_groups = tf.split(raw_w, int_x1s[0], axis=0)
381 |
382 | # extract patches from high-level feature maps for matching and attending
383 | x2_groups = tf.split(x2, int_x2s[0], axis=0)
384 | w = tf.extract_image_patches(
385 | x2, [1, ksize, ksize, 1], [1, stride, stride, 1], [1, 1, 1, 1], padding='SAME')
386 | w = tf.reshape(w, [int_x2s[0], -1, ksize, ksize, int_x2s[3]])
387 | w = tf.transpose(w, [0, 2, 3, 4, 1]) # transpose to [b, ksize, ksize, c, hw/4] # need transpose?? -- 480
388 | w_groups = tf.split(w, int_x2s[0], axis=0)
389 |
390 | # resize and extract patches from masks
391 | mask = resize(mask, to_shape=int_x2s[1:3], func=tf.image.resize_nearest_neighbor)
392 | m = tf.extract_image_patches(
393 | mask, [1, ksize, ksize, 1], [1, stride, stride, 1], [1, 1, 1, 1], padding='SAME')
394 | m = tf.reshape(m, [1, -1, ksize, ksize, 1])
395 | m = tf.transpose(m, [0, 2, 3, 4, 1]) # transpose to [1, ksize, ksize, 1, hw/4]
396 | m = m[0]
397 | mm = tf.cast(tf.equal(tf.reduce_mean(m, axis=[0, 1, 2], keep_dims=True), 0.), tf.float32)
398 |
399 | # matching and attending hole and non-hole patches
400 | y = []
401 | scale = softmax_scale
402 | # high level patches: w_groups, low level patches: raw_w_groups, x2_groups: high level feature map
403 | for xi, wi, raw_wi in zip(x2_groups, w_groups, raw_w_groups):
404 | # matching on high-level feature maps
405 | wi = wi[0]
406 | wi_normed = wi / tf.maximum(tf.sqrt(tf.reduce_sum(tf.square(wi), axis=[0, 1, 2])), 1e-4)
407 | yi = tf.nn.conv2d(xi, wi_normed, strides=[1, 1, 1, 1], padding="SAME")
408 | yi = tf.reshape(yi, [1, int_x2s[1], int_x2s[2], (int_x2s[1] // stride) * (int_x2s[2] // stride)])
409 | # apply softmax to obtain attention score
410 | yi *= mm # mask
411 | yi = tf.nn.softmax(yi * scale, 3)
412 | yi *= mm # mask yi: score maps, score maps for non-hole regions are zeros through masks
413 | # transfer non-hole features into holes according to the atttention score
414 | wi_center = raw_wi[0]
415 | yi = tf.nn.conv2d_transpose(yi, wi_center, tf.concat([[1], raw_x1s[1:]], axis=0),
416 | strides=[1, rate * stride, rate * stride, 1]) / 4. # filter: [height, width, output_channels, in_channels]
417 | y.append(yi)
418 | y = tf.concat(y, axis=0)
419 | y.set_shape(int_x1s)
420 | # refine filled feature map after matching and attending
421 | y1 = tf.layers.conv2d(y, int_x1s[-1] // 4, 3, 1, dilation_rate=1, activation=tf.nn.relu, padding='SAME')
422 | y2 = tf.layers.conv2d(y, int_x1s[-1] // 4, 3, 1, dilation_rate=2, activation=tf.nn.relu, padding='SAME')
423 | y3 = tf.layers.conv2d(y, int_x1s[-1] // 4, 3, 1, dilation_rate=4, activation=tf.nn.relu, padding='SAME')
424 | y4 = tf.layers.conv2d(y, int_x1s[-1] // 4, 3, 1, dilation_rate=8, activation=tf.nn.relu, padding='SAME')
425 | y = tf.concat([y1, y2, y3, y4], axis=3)
426 | if rescale:
427 | y = resize(y, scale=2., func=tf.image.resize_nearest_neighbor)
428 | return y
429 |
430 |
431 | """##### our-attention #####"""
432 | def attention(x, channels, neighbors=1, use_bias=True, sn=False, down_scale = 2, pool_scale=2,
433 | name='attention_pooling', training=True, padding='REFLECT', reuse=False):
434 | if neighbors > 1:
435 | x = attention_with_neighbors(x, channels, down_scale=down_scale, pool_scale=pool_scale, name=name)
436 | else:
437 | x = attention_with_pooling(x, channels, down_scale=down_scale, pool_scale=pool_scale, name=name)
438 | return x
439 |
440 | @add_arg_scope
441 | def attention_with_pooling(x, channels, ksize=4, use_bias=True, sn=False, down_scale = 2, pool_scale=2, name='attention_pooling', training=True, padding='REFLECT', reuse=False):
442 | with tf.variable_scope(name, reuse=reuse):
443 | x_origin = x
444 |
445 | # down sampling
446 | if down_scale > 1:
447 | x = gen_conv(x, channels, ksize, stride=down_scale, activation=tf.nn.relu, name='attention_down_sample',reuse=reuse)
448 |
449 | # attention
450 | f = conv(x, channels // 16, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='f_conv', reuse=reuse) # [bs, h, w, c']
451 | f = max_pooling(f, pool_scale)
452 | # f = avg_pooling(f)
453 |
454 | g = conv(x, channels // 16, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='g_conv',reuse=reuse) # [bs, h, w, c']
455 |
456 | h = conv(x, channels // 16, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='h_conv',reuse=reuse) # [bs, h, w, c]
457 | h = max_pooling(h, pool_scale)
458 | # h = avg_pooling(h) [4,65536,4096]
459 |
460 | # N = h * w
461 | s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N]
462 |
463 | beta = tf.nn.softmax(s) # attention map
464 |
465 | o = tf.matmul(beta, hw_flatten(h)) # [bs, N, C]
466 | gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))
467 |
468 | o = tf.reshape(o, shape=[x.shape[0], x.shape[1], x.shape[2], channels // 16]) # [bs, h, w, C]
469 | # o = conv(o, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='attn_conv_up') # from bottleneck
470 |
471 | # up sampling
472 | if down_scale > 1:
473 | o = gen_deconv(o, channels, ksize, method='deconv', stride=down_scale, activation=tf.nn.relu, name='attention_down_upsample',reuse=reuse)
474 |
475 | x = gamma * o + x_origin
476 |
477 | return x
478 |
479 | # attention consider neighbors
480 | @add_arg_scope
481 | def attention_with_neighbors(x, channels, ksize=3, use_bias=True, sn=False, stride=2,
482 | down_scale = 2, pool_scale=2, name='attention_pooling',
483 | training=True, padding='REFLECT', reuse=False):
484 | with tf.variable_scope(name, reuse=reuse):
485 | x1 = x
486 |
487 | # downsample input feature maps if needed due to limited GPU memory
488 | # down sampling
489 | if down_scale > 1:
490 | x1 = gen_conv(x1, channels, ksize, stride=down_scale, activation=tf.nn.relu, name='attention_down_sample',
491 | reuse=reuse)
492 | # get shapes
493 | int_x1s = x1.get_shape().as_list()
494 | # extract patches from high-level feature maps for matching and attending
495 | x1_groups = tf.split(x1, int_x1s[0], axis=0)
496 | w = tf.extract_image_patches(
497 | x1, [1, ksize, ksize, 1], [1, stride, stride, 1], [1, 1, 1, 1], padding='SAME')
498 | w = tf.reshape(w, [int_x1s[0], -1, ksize, ksize, int_x1s[3]])
499 | w = tf.transpose(w, [0, 2, 3, 4, 1]) # transpose to [b, ksize, ksize, c, hw/4] # need transpose?? -- 480
500 | w_groups = tf.split(w, int_x1s[0], axis=0)
501 |
502 | # matching and attending hole and non-hole patches
503 | y = []
504 | scale = 10.
505 | # high level patches: w_groups, low level patches: raw_w_groups, x2_groups: high level feature map
506 | for xi, wi in zip(x1_groups, w_groups):
507 | # matching on high-level feature maps
508 | wi = wi[0]
509 | wi_normed = wi / tf.maximum(tf.sqrt(tf.reduce_sum(tf.square(wi), axis=[0, 1, 2])), 1e-4)
510 | yi = tf.nn.conv2d(xi, wi_normed, strides=[1, 1, 1, 1], padding="SAME")
511 | yi = tf.reshape(yi, [1, int_x1s[1], int_x1s[2], (int_x1s[1] // stride) * (int_x1s[2] // stride)])
512 | yi = tf.nn.softmax(yi * scale, 3)
513 | # non local mean
514 | wi_center = tf.transpose(wi, [0, 1, 3, 2])
515 | yi = tf.nn.conv2d(yi, wi_center, strides=[1, 1, 1, 1], padding="SAME") / 4.
516 |
517 | # filter: [height, width, output_channels, in_channels]
518 | y.append(yi)
519 | y = tf.concat(y, axis=0)
520 | y.set_shape(int_x1s)
521 | # up sampling
522 | if down_scale > 1:
523 | y = gen_deconv(y, channels, ksize, method='deconv', stride=down_scale, activation=tf.nn.relu,
524 | name='attention_down_upsample', reuse=reuse)
525 |
526 | gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))
527 | x = gamma * y + x
528 | x = tf.layers.conv2d(x, channels, 3, 1, dilation_rate=1, activation=tf.nn.relu, padding='SAME')
529 | return x
530 |
531 | def normalize(x) :
532 | return x/127.5 - 1
533 |
534 | def imsave(images, size, path):
535 | return scipy.misc.imsave(path, merge(images, size))
536 |
537 | def inverse_transform(images):
538 | return (images+1.)*127.5
--------------------------------------------------------------------------------
/src/loss.py:
--------------------------------------------------------------------------------
1 | import vgg_network
2 | from logging import exception
3 | import tensorflow as tf
4 | import numpy as np
5 | from easydict import EasyDict as edict
6 |
7 | from sys import stdout
8 | from functools import reduce
9 | from vgg_network import VGG
10 |
11 |
12 | # loss config
13 | config = edict()
14 | config.W = edict()
15 |
16 | # TODO: content
17 | # weights
18 | config.W.Content = 1.
19 |
20 | config.Content = edict()
21 | config.Content.feat_layers = {'relu1_1': 0.2, 'relu2_1': 0.2,'relu3_1': 0.2,'relu4_1': 0.2,'relu5_1': 0.2}
22 |
23 | # TODO: style
24 | config.W.Style = 1.
25 | config.Style = edict()
26 | config.Style.feat_layers = {'relu1_1': 0.2, 'relu2_1': 0.2,'relu3_1': 0.2,'relu4_1': 0.2,'relu5_1': 0.2}
27 |
28 |
29 | class LossCalculator:
30 |
31 | def __init__(self, vgg_dir, real_image):
32 | self.vgg_model = VGG(vgg_dir)
33 | self.vgg_real = self.vgg_model.net(real_image)
34 |
35 | def content_loss(self, content_fake, layers=None):
36 | # compute content loss
37 | vgg_fake = self.vgg_model.net(content_fake) # dict: net[name] = current_layer
38 | if config.W.Content > 0:
39 | if layers is not None:
40 | config.Content.feat_layers = layers
41 | content_loss_list = [w * self._content_loss_helper(self.vgg_real[layer], vgg_fake[layer])
42 | for layer, w in config.Content.feat_layers.items()]
43 | content_loss = tf.reduce_sum(content_loss_list)
44 | else:
45 | zero_tensor = tf.constant(0.0, dtype=tf.float32)
46 | content_loss = zero_tensor
47 | return content_loss
48 |
49 | def style_loss(self, style_fake, layers=None):
50 | vgg_fake = self.vgg_model.net(style_fake) # dict: net[name] = current_layer
51 | # image = tf.placeholder('float32', shape=style.shape)
52 | # style_net = self.vgg.net(image)
53 |
54 | if config.W.Style > 0:
55 | if layers is not None:
56 | config.Style.feat_layers = layers
57 | style_loss_list = [w * self._style_loss_helper(self.vgg_real[layer], vgg_fake[layer])
58 | for layer, w in config.Style.feat_layers.items()]
59 | style_loss = tf.reduce_sum(style_loss_list)
60 | else:
61 | zero_tensor = tf.constant(0.0, dtype=tf.float32)
62 | style_loss = zero_tensor
63 | return style_loss
64 |
65 | # def _calculate_input_gram_matrix_for(self, layer):
66 | # image_feature = self.network[layer]
67 | # _, height, width, number = map(lambda i: i.value, image_feature.get_shape())
68 | # size = height * width * number
69 | # image_feature = tf.reshape(image_feature, (-1, number))
70 | # return tf.matmul(tf.transpose(image_feature), image_feature) / size
71 |
72 |
73 | def _content_loss_helper(self, vgg_A, vgg_B):
74 | N, fH, fW, fC = vgg_A.shape.as_list()
75 | feature_size = N * fH * fW *fC
76 | content_loss = 2 * tf.nn.l2_loss(vgg_A - vgg_B) / feature_size
77 | return content_loss
78 |
79 | def _style_loss_helper(self, vgg_A, vgg_B):
80 | N, fH, fW, fC = vgg_A.shape.as_list()
81 | feature_size = N * fH * fW *fC
82 | gram_A = self._compute_gram(vgg_A)
83 | gram_B = self._compute_gram(vgg_B)
84 | style_loss = 2 * tf.nn.l2_loss(gram_A - gram_B) / feature_size
85 | return style_loss
86 |
87 | def _compute_gram(self, feature):
88 | # https://github.com/fullfanta/real_time_style_transfer/blob/master/train.py
89 | shape = tf.shape(feature)
90 | psi = tf.reshape(feature, [shape[0], shape[1] * shape[2], shape[3]])
91 | # psi_t = tf.transpose(psi, perm=[0, 2, 1])
92 | gram = tf.matmul(psi, psi, transpose_a=True)
93 | gram = tf.div(gram, tf.cast(shape[1] * shape[2] * shape[3], tf.float32))
94 | return gram
95 |
96 | def tv_loss(self, image):
97 | # total variation denoising
98 | tv_y_size = _tensor_size(image[:,1:,:,:])
99 | tv_x_size = _tensor_size(image[:,:,1:,:])
100 | shape = image.shape.as_list()
101 | tv_loss = 2 * (
102 | (tf.nn.l2_loss(image[:,1:,:,:] - image[:,:shape[1]-1,:,:]) /
103 | tv_y_size) +
104 | (tf.nn.l2_loss(image[:,:,1:,:] - image[:,:,:shape[2]-1,:]) /
105 | tv_x_size))
106 |
107 | return tv_loss
108 |
109 | # TODO: l1_loss(x, x_complete_256)
110 | def l1_loss(self, image, predict, mask, type='foreground'):
111 | error = tf.abs(predict - image)
112 | if type == 'foreground':
113 | loss = tf.reduce_sum(mask * error) / tf.reduce_sum(mask) # * tf.reduce_sum(1. - mask) for balance?
114 | elif type == 'background':
115 | loss = tf.reduce_sum((1. - mask) * error) / tf.reduce_sum(1. - mask)
116 | else:
117 | loss = tf.reduce_sum(mask * tf.abs(predict - image)) / tf.reduce_sum(mask)
118 | return loss
119 |
120 | # TODO:
121 | def adversarial_loss(self):
122 | pass
123 |
124 | def _tensor_size(tensor):
125 | from operator import mul
126 | return reduce(mul, (d.value for d in tensor.get_shape()), 1)
127 |
128 | def gan_wgan_loss(pos, neg, name='gan_wgan_loss'):
129 | """
130 | wgan loss function for GANs.
131 |
132 | - Wasserstein GAN: https://arxiv.org/abs/1701.07875
133 | """
134 | with tf.variable_scope(name):
135 | d_loss = tf.reduce_mean(neg-pos)
136 | g_loss = -tf.reduce_mean(neg)
137 | # scalar_summary('d_loss', d_loss)
138 | # scalar_summary('g_loss', g_loss)
139 | # scalar_summary('pos_value_avg', tf.reduce_mean(pos))
140 | # scalar_summary('neg_value_avg', tf.reduce_mean(neg))
141 | return g_loss, d_loss
142 |
143 | def patch_gan_loss(pos, neg, name='patch_gan_loss', loss_type='gan'):
144 | """
145 | patch gan loss
146 | """
147 | with tf.variable_scope(name):
148 | if loss_type =='gan':
149 | g_loss = tf.reduce_mean(
150 | tf.nn.sigmoid_cross_entropy_with_logits(logits=neg, labels=tf.ones_like(neg))) # 生成器loss
151 |
152 | d_loss_fake = tf.reduce_mean(
153 | tf.nn.sigmoid_cross_entropy_with_logits(logits=neg, labels=tf.zeros_like(neg)))
154 | d_loss_real = tf.reduce_mean(
155 | tf.nn.sigmoid_cross_entropy_with_logits(logits=pos, labels=tf.ones_like(pos)))
156 | d_loss = d_loss_fake + d_loss_real # 判别器loss
157 |
158 | if loss_type == 'hinge':
159 | d_loss_real = tf.reduce_mean(tf.nn.relu(1.0 - pos))
160 | d_loss_fake = tf.reduce_mean(tf.nn.relu(1.0 + neg))
161 | d_loss = d_loss_real + d_loss_fake
162 |
163 | g_loss = -tf.reduce_mean(neg)
164 |
165 | return g_loss, d_loss, d_loss_real, d_loss_fake
166 |
167 | def random_interpolates(x, y, alpha=None):
168 | """
169 | x: first dimension as batch_size
170 | y: first dimension as batch_size
171 | alpha: [BATCH_SIZE, 1]
172 | """
173 | shape = x.get_shape().as_list()
174 | x = tf.reshape(x, [shape[0], -1])
175 | y = tf.reshape(y, [shape[0], -1])
176 | if alpha is None:
177 | alpha = tf.random_uniform(shape=[shape[0], 1])
178 | interpolates = x + alpha*(y - x)
179 | return tf.reshape(interpolates, shape)
180 |
181 |
182 | def gradients_penalty(x, y, mask=None, norm=1.):
183 | """Improved Training of Wasserstein GANs
184 |
185 | - https://arxiv.org/abs/1704.00028
186 | """
187 | gradients = tf.gradients(y, x)[0]
188 | if mask is None:
189 | mask = tf.ones_like(gradients)
190 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients) * mask, axis=[1, 2, 3]))
191 | return tf.reduce_mean(tf.square(slopes - norm))
192 |
193 | from tensorflow.python.ops import array_ops
194 | def focal_loss(prediction_tensor, target_tensor, weights=None, alpha=0.25, gamma=2):
195 | r"""Compute focal loss for predictions.
196 | Multi-labels Focal loss formula:
197 | FL = -alpha * (z-p)^gamma * log(p) -(1-alpha) * p^gamma * log(1-p)
198 | ,which alpha = 0.25, gamma = 2, p = sigmoid(x), z = target_tensor, z = 1.
199 | ref: https://github.com/ailias/Focal-Loss-implement-on-Tensorflow
200 | if z == 1, J = -a * (1 – p) * log(p)
201 | if z != 1, J = -(1 – a) * p * log(1 –p)
202 | Args:
203 | prediction_tensor: A float tensor of shape [batch_size, num_anchors,
204 | num_classes] representing the predicted logits for each class
205 | target_tensor: A float tensor of shape [batch_size, num_anchors,
206 | num_classes] representing one-hot encoded classification targets
207 | weights: A float tensor of shape [batch_size, num_anchors]
208 | alpha: A scalar tensor for focal loss alpha hyper-parameter
209 | gamma: A scalar tensor for focal loss gamma hyper-parameter
210 | Returns:
211 | loss: A (scalar) tensor representing the value of the loss function
212 | """
213 | sigmoid_p = tf.nn.sigmoid(prediction_tensor)
214 | zeros = array_ops.zeros_like(sigmoid_p, dtype=sigmoid_p.dtype)
215 |
216 | # For poitive prediction, only need consider front part loss, back part is 0;
217 | # target_tensor > zeros <=> z=1, so poitive coefficient = z - p.
218 | pos_p_sub = array_ops.where(target_tensor > zeros, target_tensor - sigmoid_p, zeros)
219 |
220 | # For negative prediction, only need consider back part loss, front part is 0;
221 | # target_tensor > zeros <=> z=1, so negative coefficient = 0.
222 | neg_p_sub = array_ops.where(target_tensor > zeros, zeros, sigmoid_p)
223 | per_entry_cross_ent = - alpha * (pos_p_sub ** gamma) * tf.log(tf.clip_by_value(sigmoid_p, 1e-8, 1.0)) \
224 | - (1 - alpha) * (neg_p_sub ** gamma) * tf.log(tf.clip_by_value(1.0 - sigmoid_p, 1e-8, 1.0))
225 | return tf.reduce_mean(per_entry_cross_ent)
226 |
227 | def sigmoid_cross_entropy_balanced_fore(logits, label, mask, name='cross_entropy_loss'):
228 | """
229 | Implements Equation [2] in https://arxiv.org/pdf/1504.06375.pdf
230 | Compute edge pixels for each training sample and set as pos_weights to
231 | tf.nn.weighted_cross_entropy_with_logits
232 | """
233 | y = tf.cast(label, tf.float32)
234 |
235 | count_neg = tf.reduce_sum(mask * (1. - y))
236 | count_pos = tf.reduce_sum(mask * y)
237 |
238 | # Equation [2]
239 | beta = count_neg / (count_neg + count_pos)
240 |
241 | # Equation [2] divide by 1 - beta
242 | pos_weight = beta / (1 - beta)
243 |
244 | cost = tf.nn.weighted_cross_entropy_with_logits(logits=logits, targets=y, pos_weight=pos_weight)
245 |
246 | # Multiply by 1 - beta
247 | # cost = tf.reduce_mean(cost * (1 - beta))
248 | # N, H, W, C = logits.get_shape().as_list()
249 | size = count_neg + count_neg
250 | cost = tf.reduce_sum(cost * (1 - beta)) / size
251 |
252 | # check if image has no edge pixels return 0 else return complete error function
253 | return tf.where(tf.equal(count_pos, 0.0), 0.0, cost, name=name)
254 |
255 | def sigmoid_cross_entropy_balanced_back(logits, label, name='cross_entropy_loss'):
256 | """
257 | Implements Equation [2] in https://arxiv.org/pdf/1504.06375.pdf
258 | Compute edge pixels for each training sample and set as pos_weights to
259 | tf.nn.weighted_cross_entropy_with_logits
260 | """
261 | y = tf.cast(label, tf.float32)
262 |
263 | count_neg = tf.reduce_sum(1. - y)
264 | count_pos = tf.reduce_sum(y)
265 |
266 | # Equation [2]
267 | beta = count_neg / (count_neg + count_pos)
268 |
269 | # Equation [2] divide by 1 - beta
270 | pos_weight = beta / (1 - beta)
271 |
272 | cost = tf.nn.weighted_cross_entropy_with_logits(logits=logits, targets=y, pos_weight=pos_weight)
273 |
274 | # Multiply by 1 - beta
275 | cost = tf.reduce_mean(cost * (1 - beta))
276 |
277 | # check if image has no edge pixels return 0 else return complete error function
278 | return tf.where(tf.equal(count_pos, 0.0), 0.0, cost, name=name)
279 |
280 |
281 | """
282 | id-mrf
283 | """
284 | from enum import Enum
285 |
286 | class Distance(Enum):
287 | L2 = 0
288 | DotProduct = 1
289 |
290 | class CSFlow:
291 | def __init__(self, sigma=float(0.1), b=float(1.0)):
292 | self.b = b
293 | self.sigma = sigma
294 |
295 | def __calculate_CS(self, scaled_distances, axis_for_normalization=3):
296 | self.scaled_distances = scaled_distances
297 | self.cs_weights_before_normalization = tf.exp((self.b - scaled_distances) / self.sigma, name='weights_before_normalization')
298 | self.cs_NHWC = CSFlow.sum_normalize(self.cs_weights_before_normalization, axis_for_normalization)
299 |
300 | def reversed_direction_CS(self):
301 | cs_flow_opposite = CSFlow(self.sigma, self.b)
302 | cs_flow_opposite.raw_distances = self.raw_distances
303 | work_axis = [1, 2]
304 | relative_dist = cs_flow_opposite.calc_relative_distances(axis=work_axis)
305 | cs_flow_opposite.__calculate_CS(relative_dist, work_axis)
306 | return cs_flow_opposite
307 |
308 | # --
309 | @staticmethod
310 | def create_using_L2(I_features, T_features, sigma=float(0.1), b=float(1.0)):
311 | cs_flow = CSFlow(sigma, b)
312 | with tf.name_scope('CS'):
313 | sT = T_features.shape.as_list()
314 | sI = I_features.shape.as_list()
315 |
316 | Ivecs = tf.reshape(I_features, (sI[0], -1, sI[3]))
317 | Tvecs = tf.reshape(T_features, (sI[0], -1, sT[3]))
318 | r_Ts = tf.reduce_sum(Tvecs * Tvecs, 2)
319 | r_Is = tf.reduce_sum(Ivecs * Ivecs, 2)
320 | raw_distances_list = []
321 |
322 | N, _, _, _ = T_features.shape.as_list()
323 | for i in range(N):
324 | Ivec, Tvec, r_T, r_I = Ivecs[i], Tvecs[i], r_Ts[i], r_Is[i]
325 | A = tf.matmul(Tvec,tf.transpose(Ivec))
326 | cs_flow.A = A
327 | # A = tf.matmul(Tvec, tf.transpose(Ivec))
328 | r_T = tf.reshape(r_T, [-1, 1]) # turn to column vector
329 | dist = r_T - 2 * A + r_I
330 | cs_shape = sI[:3] + [dist.shape[0].value]
331 | cs_shape[0] = 1
332 | dist = tf.reshape(tf.transpose(dist), cs_shape)
333 | # protecting against numerical problems, dist should be positive
334 | dist = tf.maximum(float(0.0), dist)
335 | # dist = tf.sqrt(dist)
336 | raw_distances_list += [dist]
337 |
338 | cs_flow.raw_distances = tf.convert_to_tensor([tf.squeeze(raw_dist, axis=0) for raw_dist in raw_distances_list])
339 |
340 | relative_dist = cs_flow.calc_relative_distances()
341 | cs_flow.__calculate_CS(relative_dist)
342 | return cs_flow
343 |
344 | #--
345 | @staticmethod
346 | def create_using_dotP(I_features, T_features, sigma=float(1.0), b=float(1.0), args=None):
347 | cs_flow = CSFlow(sigma, b)
348 | with tf.name_scope('CS'):
349 | # prepare feature before calculating cosine distance
350 | T_features, I_features = cs_flow.center_by_T(T_features, I_features)
351 | with tf.name_scope('TFeatures'):
352 | T_features = CSFlow.l2_normalize_channelwise(T_features)
353 | with tf.name_scope('IFeatures'):
354 | I_features = CSFlow.l2_normalize_channelwise(I_features)
355 | # work seperatly for each example in dim 1
356 | cosine_dist_l = []
357 | N, _, _, _ = T_features.shape.as_list()
358 | for i in range(N):
359 | T_features_i = tf.expand_dims(T_features[i, :, :, :], 0)
360 | I_features_i = tf.expand_dims(I_features[i, :, :, :], 0)
361 | patches_i = cs_flow.patch_decomposition(T_features_i, args)
362 | # every patch in patches_i as a kernel to conv I_features, obtain dis between each patch in patches_i
363 | # and I_features. (GPU is OK?)
364 | cosine_dist_i = tf.nn.conv2d(I_features_i, patches_i, strides=[1, 1, 1, 1],
365 | padding='VALID', use_cudnn_on_gpu=True, name='cosine_dist')
366 | cosine_dist_l.append(cosine_dist_i)
367 |
368 | cs_flow.cosine_dist = tf.concat(cosine_dist_l, axis = 0)
369 |
370 | cosine_dist_zero_to_one = -(cs_flow.cosine_dist - 1) / 2
371 | cs_flow.raw_distances = cosine_dist_zero_to_one
372 |
373 | relative_dist = cs_flow.calc_relative_distances()
374 | cs_flow.__calculate_CS(relative_dist)
375 | return cs_flow
376 |
377 | def calc_relative_distances(self, axis=3):
378 | epsilon = 1e-5
379 | div = tf.reduce_min(self.raw_distances, axis=axis, keep_dims=True)
380 | # div = tf.reduce_mean(self.raw_distances, axis=axis, keep_dims=True)
381 | relative_dist = self.raw_distances / (div + epsilon)
382 | return relative_dist
383 |
384 | def weighted_average_dist(self, axis=3):
385 | if not hasattr(self, 'raw_distances'):
386 | raise exception('raw_distances property does not exists. cant calculate weighted average l2')
387 |
388 | multiply = self.raw_distances * self.cs_NHWC
389 | return tf.reduce_sum(multiply, axis=axis, name='weightedDistPerPatch')
390 |
391 | # --
392 | @staticmethod
393 | def create(I_features, T_features, distance : Distance, nnsigma=float(1.0), b=float(1.0), args=None):
394 | if distance.value == Distance.DotProduct.value:
395 | cs_flow = CSFlow.create_using_dotP(I_features, T_features, nnsigma, b, args)
396 | elif distance.value == Distance.L2.value:
397 | cs_flow = CSFlow.create_using_L2(I_features, T_features, nnsigma, b)
398 | else:
399 | raise "not supported distance " + distance.__str__()
400 | return cs_flow
401 |
402 | @staticmethod
403 | def sum_normalize(cs, axis=3):
404 | reduce_sum = tf.reduce_sum(cs, axis, keep_dims=True, name='sum')
405 | return tf.divide(cs, reduce_sum, name='sumNormalized')
406 |
407 | def center_by_T(self, T_features, I_features):
408 | # assuming both input are of the same size
409 |
410 | # calculate stas over [batch, height, width], expecting 1x1xDepth tensor
411 | axes = [0, 1, 2]
412 | self.meanT, self.varT = tf.nn.moments(
413 | T_features, axes, name='TFeatures/moments')
414 | # we do not divide by std since its causing the histogram
415 | # for the final cs to be very thin, so the NN weights
416 | # are not distinctive, giving similar values for all patches.
417 | # stdT = tf.sqrt(varT, "stdT")
418 | # correct places with std zero
419 | # stdT[tf.less(stdT, tf.constant(0.001))] = tf.constant(1)
420 | with tf.name_scope('TFeatures/centering'):
421 | self.T_features_centered = T_features - self.meanT
422 | with tf.name_scope('IFeatures/centering'):
423 | self.I_features_centered = I_features - self.meanT
424 |
425 | return self.T_features_centered, self.I_features_centered
426 |
427 | @staticmethod
428 | def l2_normalize_channelwise(features):
429 | norms = tf.norm(features, ord='euclidean', axis=3, name='norm')
430 | # expanding the norms tensor to support broadcast division
431 | norms_expanded = tf.expand_dims(norms, 3)
432 | features = tf.divide(features, norms_expanded, name='normalized')
433 | return features
434 |
435 | def patch_decomposition(self, T_features, args=None):
436 | # patch decomposition
437 | if args is None:
438 | patch_size = 1
439 | stride_size = 1
440 | else:
441 | patch_size = args.PATCH_SIZE
442 | stride_size = args.STRIDE_SIZE
443 | patches_as_depth_vectors = tf.extract_image_patches(
444 | images=T_features, ksizes=[1, patch_size, patch_size, 1],
445 | strides=[1, stride_size, stride_size, 1], rates=[1, 1, 1, 1], padding='VALID',
446 | name='patches_as_depth_vectors')
447 |
448 | out_channels = int(patches_as_depth_vectors.shape[3].value / patch_size / patch_size)
449 | self.patches_NHWC = tf.reshape(
450 | patches_as_depth_vectors,
451 | shape=[-1, patch_size, patch_size, out_channels],
452 | name='patches_PHWC') # patches_as_depth_vectors.shape[3].value / patch_size / patch_size; because here path_size=1,so it's right
453 |
454 | self.patches_HWCN = tf.transpose(
455 | self.patches_NHWC,
456 | perm=[1, 2, 3, 0],
457 | name='patches_HWCP') # tf.conv2 ready format (every patch as a kernel)
458 |
459 | return self.patches_HWCN
460 |
461 |
462 | def mrf_loss(T_features, I_features, distance=Distance.DotProduct, nnsigma=float(1.0), args=None):
463 | T_features = tf.convert_to_tensor(T_features, dtype=tf.float32)
464 | I_features = tf.convert_to_tensor(I_features, dtype=tf.float32)
465 |
466 | with tf.name_scope('cx'):
467 | cs_flow = CSFlow.create(I_features, T_features, distance, nnsigma)
468 | # sum_normalize:
469 | height_width_axis = [1, 2]
470 | # To:
471 | cs = cs_flow.cs_NHWC
472 | k_max_NC = tf.reduce_max(cs, axis=height_width_axis)
473 | CS = tf.reduce_mean(k_max_NC, axis=[1])
474 | CS_as_loss = 1 - CS
475 | CS_loss = -tf.log(1 - CS_as_loss)
476 | CS_loss = tf.reduce_mean(CS_loss)
477 | return CS_loss
478 |
479 |
480 | def random_sampling(tensor_in, n, indices=None):
481 | N, H, W, C = tf.convert_to_tensor(tensor_in).shape.as_list()
482 | S = H * W
483 | tensor_NSC = tf.reshape(tensor_in, [N, S, C])
484 | all_indices = list(range(S))
485 | shuffled_indices = tf.random_shuffle(all_indices)
486 | indices = tf.gather(shuffled_indices, list(range(n)), axis=0) if indices is None else indices
487 | res = tf.gather(tensor_NSC, indices, axis=1)
488 | return res, indices
489 |
490 |
491 | def random_pooling(feats, output_1d_size=100):
492 | is_input_tensor = type(feats) is tf.Tensor
493 |
494 | if is_input_tensor:
495 | feats = [feats]
496 |
497 | # convert all inputs to tensors
498 | feats = [tf.convert_to_tensor(feats_i) for feats_i in feats]
499 |
500 | N, H, W, C = feats[0].shape.as_list()
501 | feats_sampled_0, indices = random_sampling(feats[0], output_1d_size ** 2)
502 | res = [feats_sampled_0]
503 | for i in range(1, len(feats)):
504 | feats_sampled_i, _ = random_sampling(feats[i], -1, indices)
505 | res.append(feats_sampled_i)
506 |
507 | res = [tf.reshape(feats_sampled_i, [N, output_1d_size, output_1d_size, C]) for feats_sampled_i in res]
508 | if is_input_tensor:
509 | return res[0]
510 | return res
511 |
512 |
513 | def crop_quarters(feature_tensor):
514 | N, fH, fW, fC = feature_tensor.shape.as_list()
515 | quarters_list = []
516 | quarter_size = [N, round(fH / 2), round(fW / 2), fC]
517 | quarters_list.append(tf.slice(feature_tensor, [0, 0, 0, 0], quarter_size))
518 | quarters_list.append(tf.slice(feature_tensor, [0, round(fH / 2), 0, 0], quarter_size))
519 | quarters_list.append(tf.slice(feature_tensor, [0, 0, round(fW / 2), 0], quarter_size))
520 | quarters_list.append(tf.slice(feature_tensor, [0, round(fH / 2), round(fW / 2), 0], quarter_size))
521 | feature_tensor = tf.concat(quarters_list, axis=0)
522 | return feature_tensor
523 |
524 |
525 | def id_mrf_reg_feat(feat_A, feat_B, config, args):
526 | if config.crop_quarters is True:
527 | feat_A = crop_quarters(feat_A)
528 | feat_B = crop_quarters(feat_B)
529 |
530 | N, fH, fW, fC = feat_A.shape.as_list()
531 | if fH * fW <= config.max_sampling_1d_size ** 2:
532 | print(' #### Skipping pooling ....')
533 | else:
534 | print(' #### pooling %d**2 out of %dx%d' % (config.max_sampling_1d_size, fH, fW))
535 | feat_A, feat_B = random_pooling([feat_A, feat_B], output_1d_size=config.max_sampling_1d_size)
536 |
537 | return mrf_loss(feat_A, feat_B, distance=config.Dist, nnsigma=config.nn_stretch_sigma, args=args)
538 |
539 |
540 | from easydict import EasyDict as edict
541 | # scale of im_src and im_dst: [-1, 1]
542 | def grad_matching_loss(im_src, im_dst, config):
543 |
544 | match_config = edict()
545 | match_config.crop_quarters = False
546 | match_config.max_sampling_1d_size = 65
547 | match_config.Dist = Distance.DotProduct
548 | match_config.nn_stretch_sigma = 0.5 # 0.1
549 |
550 | match_loss = id_mrf_reg_feat(im_src, im_dst, match_config, config)
551 |
552 | match_loss = tf.reduce_sum(match_loss)
553 |
554 | return match_loss
555 |
556 |
557 | """
558 | Salient Edge
559 | """
560 | import cv2
561 | def gaussian_kernel_2d_opencv(kernel_size = 3,sigma = 0):
562 | """
563 | ref: https://blog.csdn.net/qq_16013649/article/details/78784791
564 | ref: tensorflow
565 | (1) https://stackoverflow.com/questions/52012657/how-to-make-a-2d-gaussian-filter-in-tensorflow
566 | (2) https://github.com/tensorflow/tensorflow/issues/2826
567 | """
568 | kx = cv2.getGaussianKernel(kernel_size,sigma)
569 | ky = cv2.getGaussianKernel(kernel_size,sigma)
570 | return np.multiply(kx,np.transpose(ky))
571 |
572 | def priority_loss_mask(mask, ksize=5, sigma=1, iteration=2):
573 | gaussian_kernel = gaussian_kernel_2d_opencv(kernel_size=ksize, sigma=sigma)
574 | gaussian_kernel = np.reshape(gaussian_kernel, (ksize, ksize, 1, 1))
575 | mask_priority = tf.convert_to_tensor(mask, dtype=tf.float32)
576 | for i in range(iteration):
577 | mask_priority = tf.nn.conv2d(mask_priority, gaussian_kernel, strides=[1,1,1,1], padding='SAME')
578 |
579 | return mask_priority
580 |
581 |
582 | # structure loss
583 | from skimage import feature
584 | from skimage.color import rgb2gray
585 |
586 | """
587 | Structure loss
588 | """
589 | import cv2
590 |
591 | def canny_edge(images, sigma=1.5):
592 | """
593 | Extract edges in tensorflow.
594 | example:
595 | input = tf.placeholder(dtype=tf.float32, shape=[None, 900, 900, 3])
596 | output = tf.py_func(canny_edge, [input], tf.float32, stateful=False)
597 |
598 | :param images:
599 | :param sigma:
600 | :return:
601 | """
602 | edges = []
603 | for i in range(len(images)):
604 | grey_img = rgb2gray(images[i])
605 | edge = feature.canny(grey_img, sigma=sigma)
606 | edges.append(np.expand_dims(edge, axis=0))
607 | edges = np.concatenate(edges, axis=0)
608 | return np.expand_dims(edges, axis=3).astype(np.float32)
609 |
610 |
611 | def pyramid_structure_loss(image, predicts, edge_alpha, grad_alpha):
612 | _, H, W, _ = image.get_shape().as_list()
613 | loss = 0.
614 | for predict in predicts:
615 | _, h, w, _ = predict.get_shape().as_list()
616 | if h != H:
617 | gt_img = tf.image.resize_nearest_neighbor(image, size=(h, w))
618 | # gt_mask = tf.image.resize_nearest_neighbor(mask, size=(h, w))
619 |
620 | # grad
621 | gt_grad = tf.image.sobel_edges(gt_img)
622 | gt_grad = tf.reshape(gt_grad, [-1, h, w, 6]) # 6 channel
623 | grad_error = tf.abs(predict - gt_grad)
624 |
625 | # edge
626 | gt_edge = tf.py_func(canny_edge, [gt_img], tf.float32, stateful=False)
627 | edge_priority = priority_loss_mask(gt_edge, ksize=5, sigma=1, iteration=2)
628 | else:
629 | gt_img = image
630 | # gt_mask = mask
631 |
632 | # grad
633 | gt_grad = tf.image.sobel_edges(gt_img)
634 | gt_grad = tf.reshape(gt_grad, [-1, H, W, 6]) # 6 channel
635 | grad_error = tf.abs(predict - gt_grad)
636 |
637 | # edge
638 | gt_edge = tf.py_func(canny_edge, [gt_img], tf.float32, stateful=False)
639 | edge_priority = priority_loss_mask(gt_edge, ksize=5, sigma=1, iteration=2)
640 |
641 | grad_loss = tf.reduce_mean(grad_alpha * grad_error)
642 | edge_weight = edge_alpha * edge_priority
643 | # print("edge_weight", edge_weight.shape)
644 | # print("grad_error", grad_error.shape)
645 | edge_loss = tf.reduce_sum(edge_weight * grad_error) / tf.reduce_sum(edge_weight) / 6. # 6 channel
646 |
647 | loss = loss + grad_loss + edge_loss
648 |
649 | return loss
--------------------------------------------------------------------------------
/src/utils_fn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import numpy as np
4 | import scipy
5 | from scipy.misc import imread
6 | from scipy import ndimage
7 | from scipy.misc import imresize
8 |
9 | import skimage
10 | from skimage import feature
11 | from skimage.color import rgb2gray
12 |
13 | import tensorflow as tf
14 | import tensorflow.contrib.slim as slim
15 | from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch
16 |
17 | import cv2
18 | # free form mask (generated by algorithm)
19 | def np_free_form_mask(maxVertex, maxLength, maxBrushWidth, maxAngle, h, w):
20 | mask = np.zeros((h, w, 1), np.float32)
21 | numVertex = np.random.randint(maxVertex + 1)
22 | startY = np.random.randint(h)
23 | startX = np.random.randint(w)
24 | brushWidth = 0
25 | for i in range(numVertex):
26 | angle = np.random.randint(maxAngle + 1)
27 | angle = angle / 360.0 * 2 * np.pi
28 | if i % 2 == 0:
29 | angle = 2 * np.pi - angle
30 | length = np.random.randint(maxLength + 1)
31 | brushWidth = np.random.randint(10, maxBrushWidth + 1) // 2 * 2
32 | nextY = startY + length * np.cos(angle)
33 | nextX = startX + length * np.sin(angle)
34 |
35 | nextY = np.maximum(np.minimum(nextY, h - 1), 0).astype(np.int)
36 | nextX = np.maximum(np.minimum(nextX, w - 1), 0).astype(np.int)
37 |
38 | cv2.line(mask, (startY, startX), (nextY, nextX), 1, brushWidth)
39 | cv2.circle(mask, (startY, startX), brushWidth // 2, 2)
40 |
41 | startY, startX = nextY, nextX
42 | cv2.circle(mask, (startY, startX), brushWidth // 2, 2)
43 | return mask
44 |
45 |
46 | def free_form_mask_tf(parts, maxVertex=16, maxLength=60, maxBrushWidth=14, maxAngle=360, im_size=(256, 256), name='fmask'):
47 | """
48 | Free form mask
49 | rf: NIPS multi-column conv
50 | """
51 | # mask = np.zeros((im_size[0], im_size[1], 1), dtype=np.float32)
52 | with tf.variable_scope(name):
53 | mask = tf.Variable(tf.zeros([1, im_size[0], im_size[1], 1]), name='free_mask')
54 | maxVertex = tf.constant(maxVertex, dtype=tf.int32)
55 | maxLength = tf.constant(maxLength, dtype=tf.int32)
56 | maxBrushWidth = tf.constant(maxBrushWidth, dtype=tf.int32)
57 | maxAngle = tf.constant(maxAngle, dtype=tf.int32)
58 | h = tf.constant(im_size[0], dtype=tf.int32)
59 | w = tf.constant(im_size[1], dtype=tf.int32)
60 | for i in range(parts):
61 | p = tf.py_func(np_free_form_mask, [maxVertex, maxLength, maxBrushWidth, maxAngle, h, w], tf.float32)
62 | p = tf.reshape(p, [1, im_size[0], im_size[1], 1])
63 | mask = mask + p
64 | mask = tf.minimum(mask, 1.0)
65 | return mask
66 |
67 | def free_form_mask(parts, maxVertex=16, maxLength=60, maxBrushWidth=14, maxAngle=360, im_size=(256, 256)):
68 | h, w = im_size[0], im_size[1]
69 | mask = np.zeros((h, w, 1), dtype=np.float32)
70 | for i in range(parts):
71 | p = np_free_form_mask(maxVertex, maxLength, maxBrushWidth, maxAngle, h, w)
72 | p = np.reshape(p, [1, h, w, 1])
73 | mask = mask + p
74 | mask = np.minimum(mask, 1.0)
75 | return mask
76 |
77 | class ImageData:
78 |
79 | def __init__(self, args=None):
80 | """
81 | image size
82 | """
83 | self.img_size = args.IMG_SHAPES[0]
84 | self.channels = args.IMG_SHAPES[2]
85 | self.sigma = args.SIGMA
86 | # self.level = args.DOWN_LEVEL
87 | self.mode = 'rect'
88 |
89 | # TODO: different images with different preprocessing method
90 | def image_processing(self, filename):
91 | """
92 | """
93 | x = tf.read_file(filename,mode='RGB') # read filename
94 | img = tf.image.decode_jpeg(x, channels=self.channels) # read image and decode it. tf.image.decode_image
95 | img = tf.image.resize_images(img, [self.img_size, self.img_size])
96 | img = tf.cast(img, tf.float32) / 127.5 - 1 # scale to [-1, 1]
97 | return img
98 |
99 | def image_processing2(self, filename):
100 | img = imread(filename,mode='RGB')
101 | imgh, imgw = img.shape[0:2]
102 | if imgh != imgw:
103 | # center crop
104 | side = np.minimum(imgh, imgw)
105 | j = (imgh - side) // 2
106 | i = (imgw - side) // 2
107 | img = img[j:j + side, i:i + side, ...]
108 |
109 | img = scipy.misc.imresize(img, [self.img_size, self.img_size])
110 | img = scipy.misc.imresize(img, [self.img_size, self.img_size])
111 | img = img.astype(np.float32) / 127.5 - 1 # scale to [-1, 1]
112 | return img
113 |
114 | def image_edge_processing(self, filename):
115 | img = imread(filename,mode='RGB')
116 | imgh, imgw = img.shape[0:2]
117 | if imgh != imgw:
118 | # center crop
119 | side = np.minimum(imgh, imgw)
120 | j = (imgh - side) // 2
121 | i = (imgw - side) // 2
122 | img = img[j:j + side, i:i + side, ...]
123 |
124 | img = scipy.misc.imresize(img, [self.img_size, self.img_size])
125 | img = scipy.misc.imresize(img, [self.img_size, self.img_size])
126 |
127 | # edge
128 | img_gray = rgb2gray(img) # with the channel dimension removed
129 | edge = feature.canny(img_gray, sigma=self.sigma).astype(np.float32)
130 |
131 | img = img.astype(np.float32) / 127.5 - 1 # scale to [-1, 1]
132 | return img, edge
133 |
134 | def image_edge_scale_processing(self, filename):
135 | img = imread(filename,mode='RGB')
136 | imgh, imgw = img.shape[0:2]
137 | if imgh != imgw:
138 | # center crop
139 | side = np.minimum(imgh, imgw)
140 | j = (imgh - side) // 2
141 | i = (imgw - side) // 2
142 | img = img[j:j + side, i:i + side, ...]
143 |
144 | img = scipy.misc.imresize(img, [self.img_size, self.img_size])
145 | img = scipy.misc.imresize(img, [self.img_size, self.img_size])
146 |
147 | # edge
148 | img_gray = rgb2gray(img) # with the channel dimension removed
149 | edge_256 = feature.canny(img_gray, sigma=self.sigma).astype(np.float32)
150 | img_gray = rgb2gray(imresize(img, [128, 128], interp='nearest'))
151 | edge_128 = feature.canny(img_gray, sigma=self.sigma).astype(np.float32)
152 | img_gray = rgb2gray(imresize(img, [64, 64], interp='nearest'))
153 | edge_64 = feature.canny(img_gray, sigma=self.sigma).astype(np.float32)
154 | # img_gray = rgb2gray(imresize(img, [32, 32]), interp='nearest')
155 | # edge_32 = feature.canny(img_gray, sigma=self.sigma).astype(np.float32)
156 |
157 | img = img.astype(np.float32) / 127.5 - 1 # scale to [-1, 1]
158 | return img, edge_256, edge_128, edge_64
159 |
160 | def mask_processing(self, filename):
161 | x = tf.read_file(filename) # read mask filename
162 | mask = tf.image.decode_png(x, channels=1) # read image and decode it. tf.image.decode_image
163 | mask = tf.image.resize_images(mask, [self.img_size, self.img_size])
164 | return mask
165 |
166 | def mask_processing2(self, filename):
167 | """
168 | For training
169 | """
170 | mask = imread(filename)
171 |
172 | # mask: hole = 1, data augmentation
173 | # mask = (mask > 0).astype(np.float32)
174 | # print(mask.max())
175 | # print(mask.min())
176 | mask[mask <= 127] = 0
177 | mask[mask > 127] = 1
178 |
179 | # print(mask.max())
180 | # print(mask.min())
181 | # resize
182 | #mask = scipy.misc.imresize(mask, (self.img_size, self.img_size))
183 |
184 | # random dilation (25%), we augmentation the mask in external way
185 | if np.random.randint(0, 4) == 0:
186 | mask = ndimage.binary_dilation(mask, iterations=np.random.randint(1,6)).astype(np.float32)
187 | mask = mask[np.newaxis, :, :, np.newaxis]
188 |
189 | # 5% prob generate fixed mask
190 | if np.random.randint(0, 20) == 0:
191 | mask = create_mask(256, 256, 256 // 2, 256 // 2, delta=0)
192 |
193 | # 10% prob generate free-form mask (ref: 2018NIPS-multi-column)
194 | if np.random.randint(0, 10) == 0:
195 | mask = free_form_mask(parts=8, im_size=(self.img_size, self.img_size),
196 | maxBrushWidth=20, maxLength=80, maxVertex=16)
197 | return mask.astype(np.float32)
198 |
199 | def mask_processing3(self, filename):
200 | """
201 | For validation and test
202 | """
203 | mask = imread(filename)
204 | # mask = skimage.io.imread(filename)
205 |
206 | # mask: hole = 1
207 | # mask = (mask > 0).astype(np.float32)
208 | mask[mask <= 127] = 0
209 | mask[mask > 127] = 1
210 |
211 | # resize
212 | # mask = scipy.misc.imresize(mask, (self.img_size, self.img_size))
213 |
214 | mask = mask[np.newaxis, :, :, np.newaxis]
215 |
216 | return mask.astype(np.float32)
217 |
218 |
219 | def load_data(args):
220 | """
221 | Load image data
222 | """
223 | # training data: 0, as file list
224 | # image files
225 | with open(args.DATA_FLIST[args.DATASET][0]) as f:
226 | fnames = f.read().splitlines()
227 |
228 | # TODO: create input dataset (images and masks)
229 | inputs = tf.data.Dataset.from_tensor_slices(fnames) # a tf dataset object (op)
230 | if args.NUM_GPUS == 1:
231 | device = '/gpu:0' # to which gpu. prefetch_to_device(device, batch_size)
232 | else:
233 | device = '/cpu:0'
234 | dataset_num = len(fnames)
235 | # TODO: dataset with preprocessing (images and masks)
236 | Image_Data_Class = ImageData(args=args)
237 |
238 | # inputs = inputs.apply(shuffle_and_repeat(dataset_num)).apply(
239 | # map_and_batch(Image_Data_Class.image_processing, args.BATCH_SIZE, num_parallel_batches=16,
240 | # drop_remainder=True)).apply(prefetch_to_device(gpu_device, args.BATCH_SIZE))
241 | inputs = inputs.apply(shuffle_and_repeat(dataset_num)).map(lambda filename: tf.py_func(Image_Data_Class.image_processing2, [filename], [tf.float32]), num_parallel_calls=3)
242 | inputs = inputs.batch(args.BATCH_SIZE*args.NUM_GPUS, drop_remainder=True).apply(prefetch_to_device(device, args.BATCH_SIZE))
243 | inputs_iterator = inputs.make_one_shot_iterator() # iterator, 一次访问新的数据集的一个元素(batch)
244 |
245 | images = inputs_iterator.get_next() # an iteration get a batch of data
246 |
247 | return images
248 |
249 | def load_mask(args):
250 | # mask files
251 | with open(args.TRAIN_MASK_FLIST) as f:
252 | fnames = f.read().splitlines()
253 |
254 | # TODO: create input dataset (masks)
255 | inputs = tf.data.Dataset.from_tensor_slices(fnames) # a tf dataset object (op)
256 |
257 | if args.NUM_GPUS == 1:
258 | device = '/gpu:0' # to which gpu. prefetch_to_device(device, batch_size)
259 | else:
260 | device = '/cpu:0'
261 |
262 | dataset_num = len(fnames)
263 | # TODO: dataset with preprocessing (masks)
264 | Image_Data_Class = ImageData(args=args)
265 |
266 | # inputs = inputs.apply(shuffle_and_repeat(dataset_num)).apply(
267 | # map_and_batch(Image_Data_Class.image_processing, args.BATCH_SIZE, num_parallel_batches=16,
268 | # drop_remainder=True)).apply(prefetch_to_device(gpu_device, args.BATCH_SIZE))
269 | inputs = inputs.apply(shuffle_and_repeat(dataset_num)).map(lambda filename: tf.py_func(
270 | Image_Data_Class.mask_processing2, [filename], [tf.float32]), num_parallel_calls=3)
271 | inputs = inputs.batch(1,drop_remainder=True).apply(prefetch_to_device(device, 1))
272 | # inputs = inputs.apply(prefetch_to_device(device))
273 | inputs_iterator = inputs.make_one_shot_iterator() # iterator
274 |
275 | masks = inputs_iterator.get_next() # an iteration get a batch of data
276 |
277 | return masks
278 |
279 | def create_mask(width, height, mask_width, mask_height, x=None, y=None, delta=0):
280 | """
281 | create_mask(imgw, imgh, imgw // 2, imgh, 0 if random.random() < 0.5 else imgw // 2, 0)
282 | delta: margin between mask and image boundary
283 | """
284 | mask = np.zeros((height, width))
285 | mask_x = x if x is not None else np.random.randint(delta, width - mask_width - delta)
286 | mask_y = y if y is not None else np.random.randint(delta, height - mask_height - delta)
287 | mask[mask_y:mask_y + mask_height, mask_x:mask_x + mask_width] = 1
288 | mask = mask[np.newaxis, :, :, np.newaxis]
289 | return mask
290 |
291 | def load_validation_data(args):
292 | """
293 | Load image data
294 | """
295 | # validation data: 1, as file list
296 | # image files
297 | with open(args.DATA_FLIST[args.DATASET][1]) as f:
298 | fnames = f.read().splitlines()
299 |
300 | # TODO: create input dataset (images)
301 | inputs = tf.data.Dataset.from_tensor_slices(fnames) # a tf dataset object (op)
302 |
303 | gpu_device = '/gpu:0' # to which gpu. prefetch_to_device(device, batch_size)
304 | # gpu_device = '/gpu:{}'.format(args.GPU_ID)
305 |
306 | dataset_num = len(fnames)
307 | # TODO: dataset with preprocessing (images)
308 | Image_Data_Class = ImageData(args)
309 | inputs = inputs.map(lambda filename: tf.py_func(Image_Data_Class.image_processing2, [filename], [tf.float32]), num_parallel_calls=3)
310 | inputs = inputs.batch(args.VAL_NUM,drop_remainder=True).apply(prefetch_to_device(gpu_device,1))
311 | inputs_iterator = inputs.make_initializable_iterator() # iterator, need to be initialized
312 |
313 | images = inputs_iterator.get_next() # an iteration get a batch of data
314 |
315 | return images, inputs_iterator
316 |
317 | def load_validation_mask(args):
318 | # mask files
319 | with open(args.VAL_MASK_FLIST) as f:
320 | fnames = f.read().splitlines()
321 |
322 | # TODO: create input dataset (masks)
323 | inputs = tf.data.Dataset.from_tensor_slices(fnames) # a tf dataset object (op)
324 |
325 | gpu_device = '/gpu:0' # to which gpu. prefetch_to_device(device, batch_size)
326 | # gpu_device = '/gpu:{}'.format(args.GPU_ID)
327 |
328 | dataset_num = len(fnames)
329 | # TODO: dataset with preprocessing (masks)
330 | Image_Data_Class = ImageData(args=args)
331 |
332 | inputs = inputs.map(lambda filename: tf.py_func(Image_Data_Class.mask_processing3, [filename], [tf.float32]), num_parallel_calls=3)
333 | inputs = inputs.batch(args.VAL_NUM,drop_remainder=True).apply(prefetch_to_device(gpu_device, 1))
334 | # inputs = inputs.apply(prefetch_to_device(gpu_device))
335 | inputs_iterator = inputs.make_initializable_iterator()
336 |
337 | masks = inputs_iterator.get_next() # an iteration get a batch of data
338 |
339 | return masks, inputs_iterator
340 |
341 | def create_validation_mask(width, height, mask_width, mask_height, args, x=None, y=None, delta=0):
342 | """
343 | create_mask(imgw, imgh, imgw // 2, imgh, 0 if random.random() < 0.5 else imgw // 2, 0)
344 | """
345 | masks = np.zeros((args.VAL_NUM, height, width))
346 | for i in range(args.VAL_NUM):
347 | mask_x = x if x is not None else np.random.randint(delta, width - mask_width - delta)
348 | mask_y = y if y is not None else np.random.randint(delta, height - mask_height - delta)
349 | masks[i,mask_y:mask_y + mask_height, mask_x:mask_x + mask_width] = 1
350 | masks = masks[:, :, :, np.newaxis]
351 | return masks
352 |
353 | def load_test_data(args):
354 | """
355 | Load image data
356 | """
357 | # test data: 2, as file list
358 | # image files
359 | with open(args.DATA_FLIST[args.DATASET][1]) as f:
360 | fnames = f.read().splitlines()
361 |
362 | # TODO: create input dataset (images)
363 | inputs = tf.data.Dataset.from_tensor_slices(fnames) # a tf dataset object (op)
364 |
365 | gpu_device = '/gpu:0' # to which gpu. prefetch_to_device(device, batch_size)
366 | # gpu_device = '/gpu:{}'.format(args.GPU_ID)
367 |
368 | dataset_num = len(fnames)
369 | # TODO: dataset with preprocessing (images)
370 | Image_Data_Class = ImageData(args=args)
371 | inputs = inputs.map(lambda filename: tf.py_func(Image_Data_Class.image_processing2, [filename], [tf.float32]), num_parallel_calls=3)
372 | inputs = inputs.batch(args.TEST_NUM,drop_remainder=True).apply(prefetch_to_device(gpu_device))
373 | inputs_iterator = inputs.make_initializable_iterator() # iterator, need to be initialized
374 |
375 | images = inputs_iterator.get_next() # an iteration get a batch of data
376 |
377 | return images, inputs_iterator
378 |
379 | def load_test_mask(args):
380 | # mask files
381 | with open(args.TEST_MASK_FLIST) as f:
382 | fnames = f.read().splitlines()
383 |
384 | # TODO: create input dataset (masks)
385 | inputs = tf.data.Dataset.from_tensor_slices(fnames) # a tf dataset object (op)
386 |
387 | gpu_device = '/gpu:0' # to which gpu. prefetch_to_device(device, batch_size)
388 | # gpu_device = '/gpu:{}'.format(args.GPU_ID)
389 |
390 | dataset_num = len(fnames)
391 | # TODO: dataset with preprocessing (masks)
392 | Image_Data_Class = ImageData(args=args)
393 |
394 | inputs = inputs.map(lambda filename: tf.py_func(Image_Data_Class.mask_processing3, [filename], [tf.float32]), num_parallel_calls=3)
395 | inputs = inputs.batch(args.TEST_NUM,drop_remainder=True).apply(prefetch_to_device(gpu_device, 1))
396 | # inputs = inputs.apply(prefetch_to_device(gpu_device))
397 | inputs_iterator = inputs.make_initializable_iterator() # iterator
398 |
399 | masks = inputs_iterator.get_next() # an iteration get a batch of data
400 |
401 | return masks, inputs_iterator
402 |
403 | def create_test_mask(width, height, mask_width, mask_height, args, x=None, y=None, delta=0):
404 | """
405 | create_mask(imgw, imgh, imgw // 2, imgh, 0 if random.random() < 0.5 else imgw // 2, 0)
406 | """
407 | masks = np.zeros((args.TEST_NUM, height, width))
408 | for i in range(args.TEST_NUM):
409 | mask_x = x if x is not None else np.random.randint(delta, width - mask_width - delta)
410 | mask_y = y if y is not None else np.random.randint(delta, height - mask_height - delta)
411 | masks[i,mask_y:mask_y + mask_height, mask_x:mask_x + mask_width] = 1
412 | masks = masks[:, :, :, np.newaxis]
413 | return masks
414 |
415 | def dataset_len(args):
416 | with open(args.DATA_FLIST[args.DATASET][0]) as f:
417 | fnames = f.read().splitlines()
418 | return len(fnames)
419 |
420 | def show_all_variables():
421 | """
422 | Show all the variables of an tf model.
423 | """
424 | model_vars = tf.trainable_variables()
425 | slim.model_analyzer.analyze_vars(model_vars, print_info=True)
426 |
427 | def normalize(x) :
428 | return x/127.5 - 1
429 |
430 | def save_images(images, size, image_path):
431 | return imsave(inverse_transform(images), size, image_path)
432 |
433 | def merge(images, size):
434 | h, w = images.shape[1], images.shape[2]
435 | if (images.shape[3] in (3,4)):
436 | c = images.shape[3]
437 | img = np.zeros((h * size[0], w * size[1], c))
438 | for idx, image in enumerate(images):
439 | i = idx % size[1]
440 | j = idx // size[1]
441 | img[j * h:j * h + h, i * w:i * w + w, :] = image
442 | return img
443 | elif images.shape[3]==1:
444 | img = np.zeros((h * size[0], w * size[1]))
445 | for idx, image in enumerate(images):
446 | i = idx % size[1]
447 | j = idx // size[1]
448 | img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0]
449 | return img
450 | else:
451 | raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4')
452 |
453 | def imsave(images, size, path):
454 | return scipy.misc.imsave(path, merge(images, size))
455 |
456 | def inverse_transform(images):
457 | return (images+1.)*127.5
458 |
459 | def load_img_edge(args):
460 | """
461 | Load image data
462 | """
463 | # training data: 0, as file list
464 | # image files
465 | with open(args.DATA_FLIST[args.DATASET][0]) as f:
466 | fnames = f.read().splitlines()
467 |
468 | # TODO: create input dataset (images and masks)
469 | inputs = tf.data.Dataset.from_tensor_slices(fnames) # a tf dataset object (op)
470 | if args.NUM_GPUS == 1:
471 | device = '/gpu:0' # to which gpu. prefetch_to_device(device, batch_size)
472 | # gpu_device = '/gpu:{}'.format(args.GPU_ID)
473 | else:
474 | device = '/cpu:0'
475 | dataset_num = len(fnames)
476 | # TODO: dataset with preprocessing (images and masks)
477 | Image_Data_Class = ImageData(args=args)
478 | # inputs = inputs.apply(shuffle_and_repeat(dataset_num)).apply(
479 | # map_and_batch(Image_Data_Class.image_processing, args.BATCH_SIZE, num_parallel_batches=16,
480 | # drop_remainder=True)).apply(prefetch_to_device(gpu_device, args.BATCH_SIZE))
481 | inputs = inputs.apply(shuffle_and_repeat(dataset_num)).map(lambda filename: tf.py_func(
482 | Image_Data_Class.image_edge_processing, [filename], [tf.float32, tf.float32]), num_parallel_calls=3)
483 | inputs = inputs.batch(args.BATCH_SIZE*args.NUM_GPUS, drop_remainder=True).apply(prefetch_to_device(device, args.BATCH_SIZE))
484 | inputs_iterator = inputs.make_one_shot_iterator() # iterator
485 |
486 | images_edges = inputs_iterator.get_next() # an iteration get a batch of data
487 |
488 | return images_edges
489 |
490 | def load_val_img_edge(args):
491 |
492 | """
493 | Load image data
494 | """
495 | # validation data: 1, as file list
496 | # image files
497 | with open(args.DATA_FLIST[args.DATASET][1]) as f:
498 | fnames = f.read().splitlines()
499 |
500 | # TODO: create input dataset (images)
501 | inputs = tf.data.Dataset.from_tensor_slices(fnames) # a tf dataset object (op)
502 |
503 | gpu_device = '/gpu:0' # to which gpu. prefetch_to_device(device, batch_size)
504 | # gpu_device = '/gpu:{}'.format(args.GPU_ID)
505 |
506 | dataset_num = len(fnames)
507 | # TODO: dataset with preprocessing (images)
508 | Image_Data_Class = ImageData(args)
509 | inputs = inputs.map(lambda filename: tf.py_func(Image_Data_Class.image_edge_processing, [filename], [tf.float32, tf.float32]),
510 | num_parallel_calls=3)
511 | inputs = inputs.batch(args.VAL_NUM, drop_remainder=True).apply(prefetch_to_device(gpu_device))
512 | inputs_iterator = inputs.make_initializable_iterator() # iterator, need to be initialized
513 |
514 | images_edges = inputs_iterator.get_next() # an iteration get a batch of data
515 |
516 | return images_edges, inputs_iterator
517 |
518 | def load_test_img_edge(args):
519 |
520 | """
521 | Load image data
522 | """
523 | # validation data: 1, as file list
524 | # image files
525 | with open(args.DATA_FLIST[args.DATASET][1]) as f:
526 | fnames = f.read().splitlines()
527 |
528 | # TODO: create input dataset (images)
529 | inputs = tf.data.Dataset.from_tensor_slices(fnames) # a tf dataset object (op)
530 |
531 | gpu_device = '/gpu:0' # to which gpu. prefetch_to_device(device, batch_size)
532 | # gpu_device = '/gpu:{}'.format(args.GPU_ID)
533 |
534 | dataset_num = len(fnames)
535 | # TODO: dataset with preprocessing (images)
536 | Image_Data_Class = ImageData(args)
537 | inputs = inputs.map(lambda filename: tf.py_func(Image_Data_Class.image_edge_processing, [filename], [tf.float32, tf.float32]),
538 | num_parallel_calls=3)
539 | inputs = inputs.batch(args.TEST_NUM, drop_remainder=True).apply(prefetch_to_device(gpu_device))
540 | inputs_iterator = inputs.make_initializable_iterator() # iterator, need to be initialized
541 |
542 | images_edges = inputs_iterator.get_next() # an iteration get a batch of data
543 |
544 | return images_edges, inputs_iterator
545 |
546 | def load_img_scale_edge(args):
547 | """
548 | Load image data
549 | """
550 | # training data: 0, as file list
551 | # image files
552 | with open(args.DATA_FLIST[args.DATASET][0]) as f:
553 | fnames = f.read().splitlines()
554 |
555 | # TODO: create input dataset (images and masks)
556 | inputs = tf.data.Dataset.from_tensor_slices(fnames) # a tf dataset object (op)
557 | if args.NUM_GPUS == 1:
558 | device = '/gpu:0' # to which gpu. prefetch_to_device(device, batch_size)
559 | # gpu_device = '/gpu:{}'.format(args.GPU_ID)
560 | else:
561 | device = '/cpu:0'
562 | dataset_num = len(fnames)
563 | # TODO: dataset with preprocessing (images and masks)
564 | Image_Data_Class = ImageData(args=args)
565 |
566 | # inputs = inputs.apply(shuffle_and_repeat(dataset_num)).apply(
567 | # map_and_batch(Image_Data_Class.image_processing, args.BATCH_SIZE, num_parallel_batches=16,
568 | # drop_remainder=True)).apply(prefetch_to_device(gpu_device, args.BATCH_SIZE))
569 | inputs = inputs.apply(shuffle_and_repeat(dataset_num)).map(lambda filename: tf.py_func(
570 | Image_Data_Class.image_edge_scale_processing, [filename], [tf.float32, tf.float32,tf.float32, tf.float32]), num_parallel_calls=3)
571 | inputs = inputs.batch(args.BATCH_SIZE*args.NUM_GPUS, drop_remainder=True).apply(prefetch_to_device(device, args.BATCH_SIZE*args.NUM_GPUS))
572 | inputs_iterator = inputs.make_one_shot_iterator() # iterator
573 |
574 | images_edges = inputs_iterator.get_next() # an iteration get a batch of data
575 |
576 | return images_edges
577 |
578 | def load_val_img_scale_edge(args):
579 | """
580 | Load image data
581 | """
582 | # training data: 0, as file list
583 | # image files
584 | with open(args.DATA_FLIST[args.DATASET][1]) as f:
585 | fnames = f.read().splitlines()
586 |
587 | # TODO: create input dataset (images and masks)
588 | inputs = tf.data.Dataset.from_tensor_slices(fnames) # a tf dataset object (op)
589 |
590 | gpu_device = '/gpu:0' # to which gpu. prefetch_to_device(device, batch_size)
591 | # gpu_device = '/gpu:{}'.format(args.GPU_ID)
592 |
593 | dataset_num = len(fnames)
594 | # TODO: dataset with preprocessing (images and masks)
595 | Image_Data_Class = ImageData(args=args)
596 |
597 | # inputs = inputs.apply(shuffle_and_repeat(dataset_num)).apply(
598 | # map_and_batch(Image_Data_Class.image_processing, args.BATCH_SIZE, num_parallel_batches=16,
599 | # drop_remainder=True)).apply(prefetch_to_device(gpu_device, args.BATCH_SIZE))
600 | inputs = inputs.apply(shuffle_and_repeat(dataset_num)).map(lambda filename: tf.py_func(
601 | Image_Data_Class.image_edge_scale_processing, [filename], [tf.float32, tf.float32,tf.float32, tf.float32]), num_parallel_calls=3)
602 | inputs = inputs.batch(args.VAL_NUM,drop_remainder=True).apply(prefetch_to_device(gpu_device, args.BATCH_SIZE))
603 | inputs_iterator = inputs.make_initializable_iterator() # iterator
604 |
605 | images_edges = inputs_iterator.get_next() # an iteration get a batch of data
606 |
607 | return images_edges, inputs_iterator
608 |
609 |
610 |
611 | # random rect mask
612 | def random_bbox(config):
613 | """Generate a random tlhw with configuration.
614 |
615 | Args:
616 | config: Config should have configuration including IMG_SHAPES,
617 | VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.
618 |
619 | Returns:
620 | tuple: (top, left, height, width)
621 |
622 | """
623 | img_shape = config.img_shapes
624 | img_height = img_shape[0]
625 | img_width = img_shape[1]
626 | if config.random_mask is True:
627 | maxt = img_height - config.margins[0] - config.mask_shapes[0]
628 | maxl = img_width - config.margins[1] - config.mask_shapes[1]
629 | t = tf.random_uniform(
630 | [], minval=config.margins[0], maxval=maxt, dtype=tf.int32)
631 | l = tf.random_uniform(
632 | [], minval=config.margins[1], maxval=maxl, dtype=tf.int32)
633 | else:
634 | t = config.mask_shapes[0]//2
635 | l = config.mask_shapes[1]//2
636 | h = tf.constant(config.mask_shapes[0])
637 | w = tf.constant(config.mask_shapes[1])
638 | return (t, l, h, w)
639 |
640 |
641 | def bbox2mask(bbox, config, name='mask'):
642 | """Generate mask tensor from bbox.
643 |
644 | Args:
645 | bbox: configuration tuple, (top, left, height, width)
646 | config: Config should have configuration including IMG_SHAPES,
647 | MAX_DELTA_HEIGHT, MAX_DELTA_WIDTH.
648 |
649 | Returns:
650 | tf.Tensor: output with shape [1, H, W, 1]
651 |
652 | """
653 | def npmask(bbox, height, width, delta_h, delta_w):
654 | mask = np.zeros((1, height, width, 1), np.float32)
655 | h = np.random.randint(delta_h//2+1)
656 | w = np.random.randint(delta_w//2+1)
657 | mask[:, bbox[0]+h:bbox[0]+bbox[2]-h,
658 | bbox[1]+w:bbox[1]+bbox[3]-w, :] = 1.
659 | return mask
660 | with tf.variable_scope(name), tf.device('/cpu:0'):
661 | img_shape = config.img_shapes
662 | height = img_shape[0]
663 | width = img_shape[1]
664 | mask = tf.py_func(
665 | npmask,
666 | [bbox, height, width,
667 | config.max_delta_shapes[0], config.max_delta_shapes[1]],
668 | tf.float32, stateful=False)
669 | mask.set_shape([1] + [height, width] + [1])
670 | return mask
671 |
672 | """
673 | How to use
674 | # generate mask, 1 represents masked point
675 | if config.mask_type == 'rect':
676 | bbox = random_bbox(config)
677 | mask = bbox2mask(bbox, config, name='mask_c')
678 | else:
679 | mask = free_form_mask_tf(parts=8, im_size=(config.img_shapes[0], config.img_shapes[1]),
680 | maxBrushWidth=20, maxLength=80, maxVertex=16)
681 | """
--------------------------------------------------------------------------------