├── painter ├── README.md ├── imgs │ ├── README.md │ ├── 256x256 │ │ ├── README.md │ │ ├── inter_scene1.png │ │ ├── inter_scene2.png │ │ └── inter_scene3.png │ ├── inter_scene1_mask.png │ ├── inter_scene1_masked.png │ └── inter_scene1_result.png ├── checkpoint │ └── checkpoint.txt ├── inpaint_config.yml ├── structure_loss.py ├── config.py ├── inpaint_model.py ├── painter.py └── ops.py ├── src ├── utils │ ├── __init__.py │ ├── FetchManager.py │ └── helper.py ├── data │ └── places2 │ │ └── README.md ├── inception │ └── classify_image_graph_def.pb ├── checkpoint │ ├── places2 │ │ └── checkpoint │ └── celeba │ │ └── checkpoint ├── auto_validation.sh ├── flist │ └── flist.py ├── evaluation │ ├── evaluation.py │ └── metrics.py ├── README.md ├── metrics.py ├── vgg_network.py ├── inpaint_config.yml ├── config.py ├── val_inpaint_model.py ├── frechet_inception_distance.py ├── train_inpaint_model.py ├── ops.py ├── loss.py └── utils_fn.py ├── project-images ├── sobel.jpg ├── sobel.pdf ├── painter.jpg ├── removal.jpg ├── ablation.jpg ├── attention.jpg ├── painter-a.jpg ├── painter-b.jpg ├── architecture.jpg ├── architecture.pdf ├── attention-ok.pdf ├── quality-compare-celeba.jpg └── quality-compare-place.jpg ├── structure_loss.py └── README.md /painter/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /painter/imgs/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /painter/imgs/256x256/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/data/places2/README.md: -------------------------------------------------------------------------------- 1 | Create xx.filst 2 | -------------------------------------------------------------------------------- /src/inception/classify_image_graph_def.pb: -------------------------------------------------------------------------------- 1 | need download inception network weights 2 | -------------------------------------------------------------------------------- /painter/checkpoint/checkpoint.txt: -------------------------------------------------------------------------------- 1 | site: https://pan.baidu.com/s/1SBbfR94KWG5UMm_FClmdMQ 2 | code: uiqn 3 | -------------------------------------------------------------------------------- /project-images/sobel.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungGod/sturcture-inpainting/HEAD/project-images/sobel.jpg -------------------------------------------------------------------------------- /project-images/sobel.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungGod/sturcture-inpainting/HEAD/project-images/sobel.pdf -------------------------------------------------------------------------------- /project-images/painter.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungGod/sturcture-inpainting/HEAD/project-images/painter.jpg -------------------------------------------------------------------------------- /project-images/removal.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungGod/sturcture-inpainting/HEAD/project-images/removal.jpg -------------------------------------------------------------------------------- /painter/inpaint_config.yml: -------------------------------------------------------------------------------- 1 | # parameters 2 | PADDING: 'REFLECT' # 'REFLECT' 'SAME' 3 | IMG_SHAPES: [256, 256, 3] 4 | -------------------------------------------------------------------------------- /project-images/ablation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungGod/sturcture-inpainting/HEAD/project-images/ablation.jpg -------------------------------------------------------------------------------- /project-images/attention.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungGod/sturcture-inpainting/HEAD/project-images/attention.jpg -------------------------------------------------------------------------------- /project-images/painter-a.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungGod/sturcture-inpainting/HEAD/project-images/painter-a.jpg -------------------------------------------------------------------------------- /project-images/painter-b.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungGod/sturcture-inpainting/HEAD/project-images/painter-b.jpg -------------------------------------------------------------------------------- /project-images/architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungGod/sturcture-inpainting/HEAD/project-images/architecture.jpg -------------------------------------------------------------------------------- /project-images/architecture.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungGod/sturcture-inpainting/HEAD/project-images/architecture.pdf -------------------------------------------------------------------------------- /project-images/attention-ok.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungGod/sturcture-inpainting/HEAD/project-images/attention-ok.pdf -------------------------------------------------------------------------------- /painter/imgs/inter_scene1_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungGod/sturcture-inpainting/HEAD/painter/imgs/inter_scene1_mask.png -------------------------------------------------------------------------------- /src/checkpoint/places2/checkpoint: -------------------------------------------------------------------------------- 1 | google: https://drive.google.com/drive/folders/1ReSArrra8NOQv8dlU2QK0DE0P5qoalCT?usp=sharing 2 | -------------------------------------------------------------------------------- /painter/imgs/inter_scene1_masked.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungGod/sturcture-inpainting/HEAD/painter/imgs/inter_scene1_masked.png -------------------------------------------------------------------------------- /painter/imgs/inter_scene1_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungGod/sturcture-inpainting/HEAD/painter/imgs/inter_scene1_result.png -------------------------------------------------------------------------------- /src/auto_validation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | while [ 1 = 1 ] 3 | do 4 | nohup python val_inpaint_model.py 5 | sleep 7200; 6 | done 7 | -------------------------------------------------------------------------------- /painter/imgs/256x256/inter_scene1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungGod/sturcture-inpainting/HEAD/painter/imgs/256x256/inter_scene1.png -------------------------------------------------------------------------------- /painter/imgs/256x256/inter_scene2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungGod/sturcture-inpainting/HEAD/painter/imgs/256x256/inter_scene2.png -------------------------------------------------------------------------------- /painter/imgs/256x256/inter_scene3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungGod/sturcture-inpainting/HEAD/painter/imgs/256x256/inter_scene3.png -------------------------------------------------------------------------------- /project-images/quality-compare-celeba.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungGod/sturcture-inpainting/HEAD/project-images/quality-compare-celeba.jpg -------------------------------------------------------------------------------- /project-images/quality-compare-place.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungGod/sturcture-inpainting/HEAD/project-images/quality-compare-place.jpg -------------------------------------------------------------------------------- /src/checkpoint/celeba/checkpoint: -------------------------------------------------------------------------------- 1 | site: https://pan.baidu.com/s/1SBbfR94KWG5UMm_FClmdMQ 2 | code: uiqn 3 | 4 | or google: https://drive.google.com/drive/folders/1ReSArrra8NOQv8dlU2QK0DE0P5qoalCT?usp=sharing 5 | -------------------------------------------------------------------------------- /src/utils/FetchManager.py: -------------------------------------------------------------------------------- 1 | class FetchManager: 2 | def __init__(self, sess, fetches): 3 | self.fetches = fetches 4 | self.sess = sess 5 | 6 | def fetch(self, feed_dictionary, additional_fetches=[]): 7 | fetches = self.fetches + additional_fetches 8 | evaluation = self.sess.run(fetches, feed_dictionary) 9 | return {k:v for k,v in zip(fetches, evaluation)} -------------------------------------------------------------------------------- /src/flist/flist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--path', type=str, default=train_path, help='path to the dataset') 7 | parser.add_argument('--output', type=str, default=output_path, help='path to the file list') 8 | args = parser.parse_args() 9 | 10 | ext = {'.jpg', '.png'} 11 | 12 | # print("root path:", os.listdir(args.path)) 13 | 14 | images = [] 15 | for root, dirs, files in os.walk(args.path): 16 | print('loading ' + root) 17 | for file in files: 18 | if os.path.splitext(file)[1] in ext: 19 | images.append(os.path.join(root, file)) 20 | 21 | # images = sorted(images) 22 | np.random.shuffle(images) 23 | np.savetxt(args.output, images, fmt='%s') 24 | -------------------------------------------------------------------------------- /src/evaluation/evaluation.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from metrics import evaluate 3 | import os 4 | from glob import glob 5 | 6 | # Test model dirs 7 | 8 | 9 | # Evaluation Data dirs: test images and masks 10 | result_dir = './result' 11 | model_names = os.listdir(result_dir) 12 | data_dir = os.path.join(result_dir, model_names[0]) 13 | 14 | 15 | # """##### Evaluation #####""" 16 | print('Start Evaluation...') 17 | # For saving evaluation results 18 | with open('evaluation.csv', mode='a') as f: 19 | f.write("model, l1/mae, pnsr, ssim, fid, uqi, vif\n") 20 | 21 | 22 | # Our model 23 | print("Our Models:") 24 | for model_name in model_names: 25 | data_dir = os.path.join(result_dir, model_name) 26 | path_true = data_dir+'/sample_images' 27 | path_pred = data_dir+'/inpainted_images' 28 | 29 | l1, psnr, ssim, fid, uqi, vif = evaluate(path_true, path_pred) 30 | print("l1/mae:{:.4f}, psnr:{:.4f}, ssim:{:.4f}, fid:{:.4f}, uqi:{:.4f}, vif:{:.4f}".format(l1, psnr, ssim, fid, uqi, vif)) 31 | with open('evaluation.csv', mode='a') as f: 32 | f.write("{}, {:.4f}, {:.4f}, {:.4f}, {:.4f}, {:.4f}, {:.4f}\n".format(model_name.split('-')[-1], l1, psnr, ssim, fid, uqi, vif)) 33 | 34 | with open('evaluation.csv', mode='a') as f: 35 | f.write("\n") 36 | -------------------------------------------------------------------------------- /src/utils/helper.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------- 2 | # code credits: https://github.com/CQFIO/PhotographicImageSynthesis 3 | # --------------------------------------------------- 4 | import numpy as np 5 | import scipy 6 | from config import * 7 | import tensorflow as tf 8 | 9 | 10 | def read_image(file_name, resize=True, fliplr=False): 11 | image = scipy.misc.imread(file_name) 12 | if resize: 13 | image = scipy.misc.imresize(image, size=config.TRAIN.resize, interp='bilinear', mode=None) 14 | if fliplr: 15 | image = np.fliplr(image) 16 | image = np.float32(image) 17 | return np.expand_dims(image, axis=0) 18 | 19 | 20 | def save_image(output, file_name): 21 | output = np.minimum(np.maximum(output, 0.0), 255.0) 22 | scipy.misc.toimage(output.squeeze(axis=0), cmin=0, cmax=255).save(file_name) 23 | return 24 | 25 | 26 | def write_loss_in_txt(loss_list, epoch): 27 | target = open(config.TRAIN.out_dir + "/%04d/score.txt" % epoch, 'w') 28 | target.write("%f" % np.mean(loss_list[np.where(loss_list)])) 29 | target.close() 30 | 31 | 32 | def random_crop_together(im1, im2, size): 33 | images = tf.concat([im1, im2], axis=0) 34 | images_croped = tf.random_crop(images, size=size) 35 | im1, im2 = tf.split(images_croped, 2, axis=0) 36 | return im1, im2 37 | 38 | -------------------------------------------------------------------------------- /src/README.md: -------------------------------------------------------------------------------- 1 | # Source code 2 | Here is the early version of source code for developing our inpainting model. 3 | 4 | Some of the files are too large to upload. We strongly recommend to download the src from [google drive](https://drive.google.com/file/d/1mh5t17vaR1GcL44iMyavqRDMbTmo665Z/view?usp=sharing) 5 | or [baidu](https://pan.baidu.com/s/1eXc2elmsY2t__mJRKI_l2g) with password: w5r4. 6 | 7 | 8 | 9 | # Enviroment 10 | tensorflow 1.12 and other dependencies (install if needed) 11 | 12 | # Config file 13 | inpaint_config.yml 14 | 15 | # Data 16 | Prepare *xxx.flist* in **/data** using **flist/flist.py** 17 | 18 | # Mask 19 | 1. download from https://drive.google.com/file/d/140bV9FlOnnBbG4L4OiiqMmAbOQ09bQH7/view?usp=sharing 20 | 2. prepare *xxx.flist* in **/data**; 21 | 3. partially generate from scratch, defined in *utils_fn.py* 22 | 23 | # Train in terminal 24 | **IMPORTANT:** config *inpaint_config.yml* correctly 25 | ``` 26 | python train_inpaint_model.py 27 | ``` 28 | 29 | # Validation or test in terminal 30 | 1. set TEST_NUM in *inpaint_config.yml* 31 | 2. set MODEL_RESTORE in *inpaint_config.yml* 32 | ``` 33 | python val_inpaint_model.py 34 | ``` 35 | 36 | # Tensorboard 37 | ``` 38 | python -m tensorard.main --logdir=./logs 39 | ``` 40 | 41 | # Evaluation final metrics 42 | After run ```python val_inpaint_model.py```, then run 43 | ``` 44 | python evaluation/evaluation.py 45 | ``` 46 | -------------------------------------------------------------------------------- /src/metrics.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def psnr(x, pred_x, max_val=255): 4 | """ 5 | PSNR 6 | """ 7 | val = tf.reduce_mean(tf.image.psnr(x, pred_x, max_val=max_val)) 8 | return val 9 | 10 | def ssmi(x, pred_x, max_val=255): 11 | """ 12 | SSMI 13 | """ 14 | val = tf.reduce_mean(tf.image.ssim(x, pred_x, max_val=max_val)) 15 | return val 16 | 17 | def mm_ssmi(x, pred_x, max_val=255): 18 | """ 19 | MM-SSMI 20 | """ 21 | val = tf.reduce_mean(tf.image.ssim_multiscale(x, pred_x, max_val=max_val)) 22 | return val 23 | 24 | def avg_l1(x, pred_x): 25 | val = tf.reduce_mean(tf.abs(x - pred_x)) 26 | return val 27 | 28 | def tv_loss(pred_x): 29 | N, H, W, C = pred_x.shape.as_list() 30 | size = H*W*C 31 | val = tf.reduce_mean(tf.image.total_variation(pred_x)) / size 32 | return val 33 | 34 | import numpy as np 35 | from glob import glob 36 | from ntpath import basename 37 | from scipy.misc import imread 38 | from skimage.color import rgb2gray 39 | from sewar.full_ref import uqi 40 | from sewar.full_ref import vifp 41 | 42 | def uqi_vif(path_true, path_pred): 43 | 44 | UQI = [] 45 | VIF = [] 46 | names = [] 47 | index = 1 48 | 49 | files = list(glob(path_true + '/*.jpg')) + list(glob(path_true + '/*.png')) 50 | for fn in sorted(files): 51 | name = basename(str(fn)) 52 | names.append(name) 53 | 54 | img_gt = (imread(str(fn)) / 255.0).astype(np.float32) 55 | img_pred = (imread(path_pred + '/' + basename(str(fn))) / 255.0).astype(np.float32) 56 | 57 | img_gt = rgb2gray(img_gt) 58 | img_pred = rgb2gray(img_pred) 59 | 60 | UQI.append(uqi(img_gt, img_pred)) 61 | VIF.append(vifp(img_gt, img_pred)) 62 | if np.mod(index, 100) == 0: 63 | print( 64 | str(index) + ' images processed', 65 | "UQI: %.4f" % round(np.mean(UQI), 4), 66 | "VIF: %.4f" % round(np.mean(VIF), 4), 67 | ) 68 | index += 1 69 | 70 | UQI = np.mean(UQI) 71 | VIF = np.mean(VIF) 72 | 73 | return UQI, VIF -------------------------------------------------------------------------------- /src/vgg_network.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import scipy.io 4 | 5 | 6 | class VGG: 7 | LAYERS = ( 8 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 9 | 10 | 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 11 | 12 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 13 | 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 14 | 15 | 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 16 | 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 17 | 18 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 19 | 'relu5_3', 'conv5_4', 'relu5_4' 20 | ) 21 | 22 | def __init__(self, data_path): 23 | self.data_path = data_path 24 | self.data = scipy.io.loadmat(data_path) 25 | mean = self.data['normalization'][0][0][0] 26 | self.mean_pixel = np.mean(mean, axis=(0, 1)).astype('float32') 27 | self.weights = self.data['layers'][0] 28 | 29 | 30 | 31 | def preprocess(self, image): 32 | return image - self.mean_pixel 33 | 34 | 35 | def unprocess(self, image): 36 | return image + self.mean_pixel 37 | 38 | def net(self, input_image): 39 | net = {} 40 | current_layer = input_image 41 | for i, name in enumerate(self.LAYERS): 42 | if _is_convolutional_layer(name): 43 | kernels, bias = self.weights[i][0][0][0][0] 44 | # matconvnet: weights are [width, height, in_channels, out_channels] 45 | # tensorflow: weights are [height, width, in_channels, out_channels] 46 | kernels = np.transpose(kernels, (1, 0, 2, 3)) 47 | bias = bias.reshape(-1) 48 | current_layer = _conv_layer_from(current_layer, kernels, bias) 49 | elif _is_relu_layer(name): 50 | current_layer = tf.nn.relu(current_layer) 51 | elif _is_pooling_layer(name): 52 | current_layer = _pooling_layer_from(current_layer) 53 | net[name] = current_layer 54 | 55 | assert len(net) == len(self.LAYERS) 56 | return net 57 | 58 | 59 | def _is_convolutional_layer(name): 60 | return name[:4] == 'conv' 61 | 62 | def _is_relu_layer(name): 63 | return name[:4] == 'relu' 64 | 65 | def _is_pooling_layer(name): 66 | return name[:4] == 'pool' 67 | 68 | def _conv_layer_from(input, weights, bias): 69 | conv = tf.nn.conv2d(input, tf.constant(weights), strides=(1, 1, 1, 1), 70 | padding='SAME') 71 | return tf.nn.bias_add(conv, bias) 72 | 73 | def _pooling_layer_from(input): 74 | return tf.nn.max_pool(input, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1), 75 | padding='SAME') 76 | 77 | -------------------------------------------------------------------------------- /src/evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | import matplotlib.pyplot as plt 5 | 6 | from glob import glob 7 | from ntpath import basename 8 | from scipy.misc import imread 9 | from skimage.measure import compare_ssim 10 | from skimage.measure import compare_psnr 11 | from skimage.color import rgb2gray 12 | from frechet_inception_distance import calculate_fid_given_paths 13 | from sewar.full_ref import uqi 14 | from sewar.full_ref import vifp 15 | 16 | def compare_mae(img_true, img_test): 17 | img_true = img_true.astype(np.float32) 18 | img_test = img_test.astype(np.float32) 19 | return np.sum(np.abs(img_true - img_test)) / np.sum(img_true + img_test) 20 | 21 | def evaluate(path_true, path_pred, inception_dir = '../inception'): 22 | 23 | psnr = [] 24 | ssim = [] 25 | mae = [] 26 | UQI = [] 27 | VIF = [] 28 | names = [] 29 | index = 1 30 | 31 | files = list(glob(path_true + '/*.jpg')) + list(glob(path_true + '/*.png')) 32 | for fn in sorted(files): 33 | name = basename(str(fn)) 34 | names.append(name) 35 | 36 | img_gt = (imread(str(fn)) / 255.0).astype(np.float32) 37 | img_pred = (imread(path_pred + '/' + basename(str(fn))) / 255.0).astype(np.float32) 38 | 39 | img_gt = rgb2gray(img_gt) 40 | img_pred = rgb2gray(img_pred) 41 | 42 | # plt.subplot('121') 43 | # plt.imshow(img_gt) 44 | # plt.title('Groud truth') 45 | # plt.subplot('122') 46 | # plt.imshow(img_pred) 47 | # plt.title('Output') 48 | # plt.show() 49 | 50 | psnr.append(compare_psnr(img_gt, img_pred, data_range=1)) 51 | ssim.append(compare_ssim(img_gt, img_pred, data_range=1, win_size=51)) 52 | mae.append(compare_mae(img_gt, img_pred)) 53 | UQI.append(uqi(img_gt, img_pred)) 54 | VIF.append(vifp(img_gt, img_pred)) 55 | if np.mod(index, 100) == 0: 56 | print( 57 | str(index) + ' images processed', 58 | "PSNR: %.4f" % round(np.mean(psnr), 4), 59 | "SSIM: %.4f" % round(np.mean(ssim), 4), 60 | "MAE: %.4f" % round(np.mean(mae), 4), 61 | "UQI: %.4f" % round(np.mean(UQI), 4), 62 | "VIF: %.4f" % round(np.mean(VIF), 4) 63 | ) 64 | index += 1 65 | 66 | psnr = np.mean(psnr) 67 | mae = np.mean(mae) 68 | ssim = np.mean(ssim) 69 | UQI = np.mean(UQI) 70 | VIF = np.mean(VIF) 71 | # print( 72 | # "PSNR: %.4f" % round(psnr, 4), 73 | # "PSNR Variance: %.4f" % round(np.var(psnr), 4), 74 | # "SSIM: %.4f" % round(ssim, 4), 75 | # "SSIM Variance: %.4f" % round(np.var(ssim), 4), 76 | # "MAE: %.4f" % round(mae, 4), 77 | # "MAE Variance: %.4f" % round(np.var(mae), 4) 78 | # ) 79 | 80 | fid_value = calculate_fid_given_paths([path_true, path_pred], '../inception') # inception dir = '../inception' 81 | return mae, psnr, ssim, fid_value, UQI, VIF -------------------------------------------------------------------------------- /structure_loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from skimage import feature 4 | from skimage.color import rgb2gray 5 | 6 | """ 7 | Structure loss 8 | """ 9 | import cv2 10 | 11 | 12 | def gaussian_kernel_2d_opencv(kernel_size=3,sigma=0): 13 | """ 14 | ref: https://blog.csdn.net/qq_16013649/article/details/78784791 15 | ref: tensorflow 16 | (1) https://stackoverflow.com/questions/52012657/how-to-make-a-2d-gaussian-filter-in-tensorflow 17 | (2) https://github.com/tensorflow/tensorflow/issues/2826 18 | """ 19 | kx = cv2.getGaussianKernel(kernel_size,sigma) 20 | ky = cv2.getGaussianKernel(kernel_size,sigma) 21 | return np.multiply(kx,np.transpose(ky)) 22 | 23 | 24 | def canny_edge(images, sigma=1.5): 25 | """ 26 | Extract edges in tensorflow. 27 | example: 28 | input = tf.placeholder(dtype=tf.float32, shape=[None, 900, 900, 3]) 29 | output = tf.py_func(canny_edge, [input], tf.float32, stateful=False) 30 | 31 | :param images: 32 | :param sigma: 33 | :return: 34 | """ 35 | edges = [] 36 | for i in range(len(images)): 37 | grey_img = rgb2gray(images[i]) 38 | edge = feature.canny(grey_img, sigma=sigma) 39 | edges.append(np.expand_dims(edge, axis=0)) 40 | edges = np.concatenate(edges, axis=0) 41 | return np.expand_dims(edges, axis=3).astype(np.float32) 42 | 43 | 44 | def priority_loss_mask(mask, ksize=5, sigma=1, iteration=2): 45 | gaussian_kernel = gaussian_kernel_2d_opencv(kernel_size=ksize, sigma=sigma) 46 | gaussian_kernel = np.reshape(gaussian_kernel, (ksize, ksize, 1, 1)) 47 | mask_priority = tf.convert_to_tensor(mask, dtype=tf.float32) 48 | for i in range(iteration): 49 | mask_priority = tf.nn.conv2d(mask_priority, gaussian_kernel, strides=[1,1,1,1], padding='SAME') 50 | 51 | return mask_priority 52 | 53 | 54 | def pyramid_structure_loss(image, predicts, edge_alpha, grad_alpha): 55 | _, H, W, _ = image.get_shape().as_list() 56 | loss = 0. 57 | for predict in predicts: 58 | _, h, w, _ = predict.get_shape().as_list() 59 | if h != H: 60 | gt_img = tf.image.resize_nearest_neighbor(image, size=(h, w)) 61 | # gt_mask = tf.image.resize_nearest_neighbor(mask, size=(h, w)) 62 | 63 | # grad 64 | gt_grad = tf.image.sobel_edges(gt_img) 65 | gt_grad = tf.reshape(gt_grad, [-1, h, w, 6]) # 6 channel 66 | grad_error = tf.abs(predict - gt_grad) 67 | 68 | # edge 69 | gt_edge = tf.py_func(canny_edge, [gt_img], tf.float32, stateful=False) 70 | edge_priority = priority_loss_mask(gt_edge, ksize=5, sigma=1, iteration=2) 71 | else: 72 | gt_img = image 73 | # gt_mask = mask 74 | 75 | # grad 76 | gt_grad = tf.image.sobel_edges(gt_img) 77 | gt_grad = tf.reshape(gt_grad, [-1, H, W, 6]) # 6 channel 78 | grad_error = tf.abs(predict - gt_grad) 79 | 80 | # edge 81 | gt_edge = tf.py_func(canny_edge, [gt_img], tf.float32, stateful=False) 82 | edge_priority = priority_loss_mask(gt_edge, ksize=5, sigma=1, iteration=2) 83 | 84 | grad_loss = tf.reduce_mean(grad_alpha * grad_error) 85 | edge_weight = edge_alpha * edge_priority 86 | # print("edge_weight", edge_weight.shape) 87 | # print("grad_error", grad_error.shape) 88 | edge_loss = tf.reduce_sum(edge_weight * grad_error) / tf.reduce_sum(edge_weight) / 6. # 6 channel 89 | 90 | loss = loss + grad_loss + edge_loss 91 | 92 | return loss -------------------------------------------------------------------------------- /painter/structure_loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from skimage import feature 4 | from skimage.color import rgb2gray 5 | 6 | """ 7 | Structure loss 8 | """ 9 | import cv2 10 | 11 | 12 | def gaussian_kernel_2d_opencv(kernel_size=3,sigma=0): 13 | """ 14 | ref: https://blog.csdn.net/qq_16013649/article/details/78784791 15 | ref: tensorflow 16 | (1) https://stackoverflow.com/questions/52012657/how-to-make-a-2d-gaussian-filter-in-tensorflow 17 | (2) https://github.com/tensorflow/tensorflow/issues/2826 18 | """ 19 | kx = cv2.getGaussianKernel(kernel_size,sigma) 20 | ky = cv2.getGaussianKernel(kernel_size,sigma) 21 | return np.multiply(kx,np.transpose(ky)) 22 | 23 | 24 | def canny_edge(images, sigma=1.5): 25 | """ 26 | Extract edges in tensorflow. 27 | example: 28 | input = tf.placeholder(dtype=tf.float32, shape=[None, 900, 900, 3]) 29 | output = tf.py_func(canny_edge, [input], tf.float32, stateful=False) 30 | 31 | :param images: 32 | :param sigma: 33 | :return: 34 | """ 35 | edges = [] 36 | for i in range(len(images)): 37 | grey_img = rgb2gray(images[i]) 38 | edge = feature.canny(grey_img, sigma=sigma) 39 | edges.append(np.expand_dims(edge, axis=0)) 40 | edges = np.concatenate(edges, axis=0) 41 | return np.expand_dims(edges, axis=3).astype(np.float32) 42 | 43 | 44 | def priority_loss_mask(mask, ksize=5, sigma=1, iteration=2): 45 | gaussian_kernel = gaussian_kernel_2d_opencv(kernel_size=ksize, sigma=sigma) 46 | gaussian_kernel = np.reshape(gaussian_kernel, (ksize, ksize, 1, 1)) 47 | mask_priority = tf.convert_to_tensor(mask, dtype=tf.float32) 48 | for i in range(iteration): 49 | mask_priority = tf.nn.conv2d(mask_priority, gaussian_kernel, strides=[1,1,1,1], padding='SAME') 50 | 51 | return mask_priority 52 | 53 | 54 | def pyramid_structure_loss(image, predicts, edge_alpha, grad_alpha): 55 | _, H, W, _ = image.get_shape().as_list() 56 | loss = 0. 57 | for predict in predicts: 58 | _, h, w, _ = predict.get_shape().as_list() 59 | if h != H: 60 | gt_img = tf.image.resize_nearest_neighbor(image, size=(h, w)) 61 | # gt_mask = tf.image.resize_nearest_neighbor(mask, size=(h, w)) 62 | 63 | # grad 64 | gt_grad = tf.image.sobel_edges(gt_img) 65 | gt_grad = tf.reshape(gt_grad, [-1, h, w, 6]) # 6 channel 66 | grad_error = tf.abs(predict - gt_grad) 67 | 68 | # edge 69 | gt_edge = tf.py_func(canny_edge, [gt_img], tf.float32, stateful=False) 70 | edge_priority = priority_loss_mask(gt_edge, ksize=5, sigma=1, iteration=2) 71 | else: 72 | gt_img = image 73 | # gt_mask = mask 74 | 75 | # grad 76 | gt_grad = tf.image.sobel_edges(gt_img) 77 | gt_grad = tf.reshape(gt_grad, [-1, H, W, 6]) # 6 channel 78 | grad_error = tf.abs(predict - gt_grad) 79 | 80 | # edge 81 | gt_edge = tf.py_func(canny_edge, [gt_img], tf.float32, stateful=False) 82 | edge_priority = priority_loss_mask(gt_edge, ksize=5, sigma=1, iteration=2) 83 | 84 | grad_loss = tf.reduce_mean(grad_alpha * grad_error) 85 | edge_weight = edge_alpha * edge_priority 86 | # print("edge_weight", edge_weight.shape) 87 | # print("grad_error", grad_error.shape) 88 | edge_loss = tf.reduce_sum(edge_weight * grad_error) / tf.reduce_sum(edge_weight) / 6. # 6 channel 89 | 90 | loss = loss + grad_loss + edge_loss 91 | 92 | return loss -------------------------------------------------------------------------------- /painter/config.py: -------------------------------------------------------------------------------- 1 | """config utilities for yml file""" 2 | import logging 3 | import os 4 | import yaml 5 | 6 | logger = logging.getLogger() 7 | 8 | class LoaderMeta(type): 9 | """Constructor for supporting `!include`.""" 10 | 11 | def __new__(mcs, __name__, __bases__, __dict__): 12 | """Add include constructer to class.""" 13 | # register the include constructor on the class 14 | cls = super().__new__(mcs, __name__, __bases__, __dict__) 15 | cls.add_constructor('!include', cls.construct_include) 16 | return cls 17 | 18 | 19 | class Loader(yaml.Loader, metaclass=LoaderMeta): 20 | """YAML Loader with `!include` constructor.""" 21 | 22 | def __init__(self, stream): 23 | try: 24 | self._root = os.path.split(stream.name)[0] 25 | except AttributeError: 26 | self._root = os.path.curdir 27 | super().__init__(stream) 28 | 29 | def construct_include(self, node): 30 | """Include file referenced at node.""" 31 | filename = os.path.abspath( 32 | os.path.join(self._root, self.construct_scalar(node))) 33 | extension = os.path.splitext(filename)[1].lstrip('.') 34 | with open(filename, 'r') as f: 35 | if extension in ('yaml', 'yml'): 36 | return yaml.load(f, Loader) 37 | else: 38 | return ''.join(f.readlines()) 39 | 40 | 41 | class DictAsMember(dict): 42 | """Dict as member trick.""" 43 | 44 | def __getattr__(self, name): 45 | value = self[name] 46 | if isinstance(value, dict): 47 | value = DictAsMember(value) 48 | return value 49 | 50 | class Config(dict): 51 | """Config with yaml file. 52 | 53 | This class is used to config model hyper-parameters, global constants, and 54 | other settings with yaml file. All settings in yaml file will be 55 | automatically logged into file. 56 | 57 | Args: 58 | filename(str): File name. 59 | 60 | Examples: 61 | 62 | yaml file ``model.yml``:: 63 | 64 | NAME: 'neuralgym' 65 | ALPHA: 1.0 66 | DATASET: '/mnt/data/imagenet' 67 | """ 68 | 69 | def __init__(self, filename=None): 70 | assert os.path.exists(filename), "ERROR: Config File doesn't exist." 71 | try: 72 | with open(filename, 'r') as f: 73 | self._cfg_dict = yaml.load(f, Loader) 74 | # parent of IOError, OSError *and* WindowsError where available 75 | except EnvironmentError: 76 | logger.error('Please check the file with name of "%s"', filename) 77 | logger.info(' APP CONFIG '.center(80, '-')) 78 | self.show() 79 | logger.info(''.center(80, '-')) 80 | 81 | def __getattr__(self, name): 82 | value = self._cfg_dict[name] 83 | if isinstance(value, dict): 84 | value = DictAsMember(value) 85 | return value 86 | 87 | def show(self, cfg_dict=None, indent=0): 88 | if cfg_dict is None: 89 | cfg_dict = self._cfg_dict 90 | for key in cfg_dict: 91 | value = cfg_dict[key] 92 | if isinstance(value, dict): 93 | str_list = [' '] * indent 94 | str_list.append(str(key)) 95 | str_list.append(': ') 96 | logger.info(''.join(str_list)) 97 | indent = indent + 1 98 | indent = self.show(value, indent) 99 | else: 100 | str_list = [' '] * indent 101 | str_list.append(str(key)) 102 | str_list.append(': ') 103 | str_list.append(str(value)) 104 | logger.info(''.join(str_list)) 105 | return indent - 1 106 | 107 | if __name__ == "__main__": 108 | 109 | config = Config('inpaint.yml') 110 | print(config.DATASET) 111 | print(config.IMG_SHAPES) 112 | print(config.MASK_MODE == 'irregular') 113 | -------------------------------------------------------------------------------- /src/inpaint_config.yml: -------------------------------------------------------------------------------- 1 | # parameters 2 | CUSTOM_DATASET: True 3 | MASK_MODE: 'irregular' # 'irregular' 4 | DATASET: 'places2' # 'places2', 'celeba_align' 5 | VGG_DIR: 'imagenet-vgg-verydeep-19.mat' 6 | 7 | RANDOM_CROP: False 8 | 9 | LOG_DIR: logs 10 | CHECKPOINT_DIR: checkpoint 11 | MODEL_DIR: '' 12 | SAMPLE_DIR: sample 13 | RESULT_DIR: result 14 | MODEL_RESTORE: 'places2' # 'places2' 'celeba_align', if train from scratch set '' 15 | 16 | 17 | # parameters in other related papers 18 | # 1. Pconv: Ltotal = Lvalid+6Lhole+0.05Lperceptual+120(Lstyleout+Lstylecomp)+0.1Ltv (pool1, pool2, pool3) 19 | # 2. EdgeConnect: l1 = 1, Ladv = 0.1, Lperceptual=0.1, Lstyle=250 (relu_1,...relu_5) 20 | # 3. SC-FEGAN: l1 = 1, Ladv=0.001, Lperceptual=0.05, Lstyle=120, Ltv=0.1 (pool1, pool2, pool3) 21 | 22 | # l1 loss 23 | L1_FORE_ALPHA: 1. # may weight more, such as 1.5 24 | L1_BACK_ALPHA: 1. 25 | 26 | L1_SCALE: 0. # for down scaled image's l1 loss, (Don't used in default. Only regularize on the structures.) 27 | # we found that l1 for pixel and gradients are contradicted in some extent 28 | 29 | # content, style loss 30 | BACKGROUND_LOSS: True # for content and style loss 31 | CONTENT_FORE_ALPHA: 0.1 # layers see in loss.py 32 | CONTENT_BACK_ALPHA: 0.1 33 | STYLE_FORE_ALPHA: 250. # layers see in loss.py 34 | STYLE_BACK_ALPHA: 250. 35 | 36 | # tv loss 37 | TV_ALPHA: 0. 38 | 39 | 40 | # gan loss 41 | GAN_TYPE: 'patch_gan' # 'wgan_gp' 42 | GAN_LOSS_TYPE: 'hinge' # 'hinge', 'gan' 43 | SN: True 44 | PATCH_GAN_ALPHA: 0.4 # weight: best tuned in range [0.1, 0.8] 45 | GP_ALPHA: 0. 46 | 47 | # edge, grad loss 48 | SIGMA: 1.5 # edge info 49 | ALPHA: 1. # weight auxiliary Edge task, taking weight values corresponding to main task 50 | EDGE_ALPHA: 20. # edge weight = EDGE_ALPHA * priority; priority in (0, 1) 51 | # grad reconstruction 52 | GRAD_ALPHA: 0.1 # grad weight 53 | 54 | 55 | # other loss (Don't use in default) 56 | # grad matching 57 | GRAD_MATCHING_ALPHA: 0.0 58 | PATCH_SIZE: 5 59 | STRIDE_SIZE: 3 60 | # image matching 61 | IMG_MATCHING_ALPHA: 0. 62 | 63 | 64 | # training 65 | RANDOM_SEED: False 66 | PADDING: 'REFLECT' # 'REFLECT' 'SAME' 67 | 68 | G_LR: 0.00001 69 | D_LR: 0.00001 70 | 71 | BATCH_SIZE: 4 # batch size 72 | 73 | NUM_GPUS: 1 # number of gpus, support multi-gpu setting 74 | GPU_ID: [1] # list of gpu ids [..] 75 | 76 | EPOCH: 10 # training epochs 77 | PRINT_FREQ: 50 # print training info in steps 78 | SAVE_FREQ: 2000 # saving checkpoint (in steps) 79 | LOG_FREQ: 2000 # logs, viewed in tensorboard (in steps) 80 | VIZ_MAX_OUT: 8 # middle results, viewed in tensorboard 81 | 82 | # validation and test 83 | VAL: False 84 | VAL_NUM: 8 85 | STATIC_VIEW: True 86 | VAL_FREQ: 8000 87 | 88 | TEST_NUM: 6 89 | MAX_TEST_NUM: 1000 # 90 | 91 | # image data dir 92 | DATA_FLIST: 93 | # https://github.com/JiahuiYu/progressive_growing_of_gans_tf 94 | celeba_align: [ 95 | 'data/celeba_align/train_shuffled.flist', 96 | 'data/celeba_align/validation_shuffled.flist' 97 | ] 98 | # http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html, please to use RANDOM_CROP: True 99 | celeba_hq_sample: [ 100 | 'data/celeba_hq_sample/train_shuffled.flist', 101 | 'data/celeba_hq_sample/validation_shuffled.flist' 102 | ] 103 | # http://places2.csail.mit.edu/, please download the high-resolution dataset and use RANDOM_CROP: True 104 | places2: [ 105 | 'data/places2/train_shuffled.flist', 106 | 'data/places2/validation_shuffled.flist' 107 | ] 108 | facade: [ 109 | 'data/facade/train_shuffled.flist', 110 | 'data/facade/validation_shuffled.flist' 111 | ] 112 | dtd: [ 113 | 'data/facade/train_shuffled.flist', 114 | 'data/facade/validation_shuffled.flist' 115 | ] 116 | 117 | 118 | # irregular mask data dir 119 | TRAIN_MASK_FLIST: data/mask-auto/img_mask_train.flist 120 | VAL_MASK_FLIST: data/mask-auto/img_mask_val.flist 121 | TEST_MASK_FLIST: data/mask-auto/img_mask_test.flist 122 | 123 | # regular mask 124 | IMG_SHAPES: [256, 256, 3] 125 | HEIGHT: 128 126 | WIDTH: 128 127 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | """config utilities for yml file""" 2 | import logging 3 | import os 4 | import yaml 5 | import pynvml 6 | 7 | logger = logging.getLogger() 8 | 9 | class LoaderMeta(type): 10 | """Constructor for supporting `!include`.""" 11 | 12 | def __new__(mcs, __name__, __bases__, __dict__): 13 | """Add include constructer to class.""" 14 | # register the include constructor on the class 15 | cls = super().__new__(mcs, __name__, __bases__, __dict__) 16 | cls.add_constructor('!include', cls.construct_include) 17 | return cls 18 | 19 | 20 | class Loader(yaml.Loader, metaclass=LoaderMeta): 21 | """YAML Loader with `!include` constructor.""" 22 | 23 | def __init__(self, stream): 24 | try: 25 | self._root = os.path.split(stream.name)[0] 26 | except AttributeError: 27 | self._root = os.path.curdir 28 | super().__init__(stream) 29 | 30 | def construct_include(self, node): 31 | """Include file referenced at node.""" 32 | filename = os.path.abspath( 33 | os.path.join(self._root, self.construct_scalar(node))) 34 | extension = os.path.splitext(filename)[1].lstrip('.') 35 | with open(filename, 'r') as f: 36 | if extension in ('yaml', 'yml'): 37 | return yaml.load(f, Loader) 38 | else: 39 | return ''.join(f.readlines()) 40 | 41 | 42 | class DictAsMember(dict): 43 | """Dict as member trick.""" 44 | 45 | def __getattr__(self, name): 46 | value = self[name] 47 | if isinstance(value, dict): 48 | value = DictAsMember(value) 49 | return value 50 | 51 | class Config(dict): 52 | """Config with yaml file. 53 | 54 | This class is used to config model hyper-parameters, global constants, and 55 | other settings with yaml file. All settings in yaml file will be 56 | automatically logged into file. 57 | 58 | Args: 59 | filename(str): File name. 60 | 61 | Examples: 62 | 63 | yaml file ``model.yml``:: 64 | 65 | NAME: 'neuralgym' 66 | ALPHA: 1.0 67 | DATASET: '/mnt/data/imagenet' 68 | """ 69 | 70 | def __init__(self, filename=None): 71 | assert os.path.exists(filename), "ERROR: Config File doesn't exist." 72 | try: 73 | with open(filename, 'r') as f: 74 | self._cfg_dict = yaml.load(f, Loader) 75 | # parent of IOError, OSError *and* WindowsError where available 76 | except EnvironmentError: 77 | logger.error('Please check the file with name of "%s"', filename) 78 | logger.info(' APP CONFIG '.center(80, '-')) 79 | self.show() 80 | logger.info(''.center(80, '-')) 81 | 82 | def __getattr__(self, name): 83 | value = self._cfg_dict[name] 84 | if isinstance(value, dict): 85 | value = DictAsMember(value) 86 | return value 87 | 88 | def show(self, cfg_dict=None, indent=0): 89 | if cfg_dict is None: 90 | cfg_dict = self._cfg_dict 91 | for key in cfg_dict: 92 | value = cfg_dict[key] 93 | if isinstance(value, dict): 94 | str_list = [' '] * indent 95 | str_list.append(str(key)) 96 | str_list.append(': ') 97 | logger.info(''.join(str_list)) 98 | indent = indent + 1 99 | indent = self.show(value, indent) 100 | else: 101 | str_list = [' '] * indent 102 | str_list.append(str(key)) 103 | str_list.append(': ') 104 | str_list.append(str(value)) 105 | logger.info(''.join(str_list)) 106 | return indent - 1 107 | 108 | 109 | # GPU 110 | def select_gpu(): 111 | """ 112 | Finding the gpu number with min used memory. 113 | 114 | Args: 115 | None 116 | 117 | Returns: 118 | GPU number with min used memory (or with max free memory). string 119 | """ 120 | import pynvml 121 | pynvml.nvmlInit() 122 | 123 | gpu_count = pynvml.nvmlDeviceGetCount() # number of gpu 124 | gpu_devices = list(range(gpu_count)) # serial number of gpu devices 125 | 126 | # Select GPU with min used memory 127 | max_memo = 24*1024*1024*1024 128 | gpu_selected = gpu_devices[0] 129 | for i in range(len(gpu_devices)): 130 | handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_devices[i]) 131 | meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) 132 | if meminfo.used <= max_memo: 133 | max_memo = meminfo.used 134 | gpu_selected = gpu_devices[i] 135 | 136 | return str(gpu_selected) 137 | 138 | 139 | if __name__ == "__main__": 140 | 141 | config = Config('inpaint.yml') 142 | print(config.DATASET) 143 | print(config.IMG_SHAPES) 144 | print(config.MASK_MODE == 'irregular') 145 | -------------------------------------------------------------------------------- /painter/inpaint_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.framework.python.ops import arg_scope 3 | #from utils_fn import * 4 | from ops import * 5 | import time 6 | 7 | class InpaintModel(): 8 | 9 | def __init__(self, args): 10 | self.model_name = "InpaintModel" # name for checkpoint 11 | self.img_size = args.IMG_SHAPES 12 | 13 | # yj 14 | def build_inpaint_net(self, x, edge, grad, mask, args=None, reuse=False, 15 | training=True, padding='SAME', name='inpaint_net'): 16 | """Inpaint network. 17 | 18 | Args: 19 | x: incomplete image[-1, 1] with shape of (batch_size, h, w, c) 20 | edge: incomplete edge {0, 1} with shape of (batch_size, h, w) 21 | grad map: incomplete grad with shape of (batch_size, h, w, 6) 22 | mask: mask region {0, 1} 23 | Returns: 24 | complete image, grad map, middle result 25 | """ 26 | x = tf.reshape(x, [-1, args.IMG_SHAPES[0], args.IMG_SHAPES[1], args.IMG_SHAPES[2]]) 27 | mask = tf.reshape(mask, [-1, args.IMG_SHAPES[0], args.IMG_SHAPES[1], 1]) 28 | edge = tf.reshape(edge, [-1, args.IMG_SHAPES[0], args.IMG_SHAPES[1], 1]) 29 | # grad = tf.reshape(grad, [-1, args.IMG_SHAPES[0], args.IMG_SHAPES[1], 6]) 30 | 31 | xin = x 32 | ones_x = tf.ones_like(x)[:, :, :, 0:1] 33 | x = tf.concat([x, ones_x * edge, ones_x * mask, grad], axis=3) # add a mask channel,the input channel is 4 34 | # encoder-decoder network: channel 64-128-256-128-64 35 | cnum = 64 # initial channel 36 | # a decorate: arg_scope([op1, op2,..], xx,..) means: 37 | # attributes or parameters xx defined here are the default in op1 and op2,.. 38 | with tf.variable_scope(name, reuse=reuse), \ 39 | arg_scope([gen_conv, gen_deconv], 40 | training=training, padding=padding): 41 | # Encoder 42 | # scale 256 channels activation: relu 43 | x = gen_conv(x, cnum, 7, stride=1, activation=tf.nn.relu, name='en_conv1') # 9 -> 64, ksize=7x7, stride=1 44 | # scale 128 45 | x = gen_conv(x, 2 * cnum, 4, stride=2, activation=tf.nn.relu, name='en_conv2') 46 | # scale 64 47 | x = gen_conv(x, 4 * cnum, 4, stride=2, activation=tf.nn.relu, name='en_conv3') 48 | # res block 49 | x = resnet_blocks(x, 4 * cnum, 3, stride=1, rate=2, block_num=8, activation=tf.nn.relu, name='en_64_8') 50 | 51 | # Decoder 52 | # TODO: output scale 64 Down scale = 2 (origin) pool scale = 2 (origin) 53 | # share attention 54 | x = attention(x, 4 * cnum, down_scale=2, pool_scale=2, name='attention_pooling_64') 55 | 56 | # out of predict grad map 57 | x_64 = gen_conv(x, 4 * cnum, 5, stride=1, activation=tf.nn.relu, name='out64_grad_out') 58 | x_grad_out_64 = gen_conv(x_64, 6, 1, stride=1, activation=None, name='grad64') 59 | x_out_64 = gen_conv(x_64, 3, 1, stride=1, activation=tf.nn.tanh, name='out64') 60 | 61 | # scale 64 - 128 62 | x = tf.concat([x, x_64], axis=3) 63 | x = gen_deconv(x, 2 * cnum, 4, method='deconv', activation=tf.nn.relu, name='de128_conv4_upsample') 64 | 65 | # TODO: output scale 128 66 | # share attention 67 | x = attention(x, 2 * cnum, down_scale=2, pool_scale=2, name='attention_pooling_128') 68 | 69 | # out of predict grad map 70 | x_128 = gen_conv(x, 2 * cnum, 5, stride=1, activation=tf.nn.relu, name='out128_grad_out') 71 | x_grad_out_128 = gen_conv(x_128, 6, 1, stride=1, activation=None, name='grad128') 72 | x_out_128 = gen_conv(x_128, 3, 1, stride=1, activation=tf.nn.tanh, name='out128') 73 | 74 | # scale 128 - 256 75 | x = tf.concat([x, x_128], axis=3) 76 | x = gen_deconv(x, cnum, 4, method='deconv', activation=tf.nn.relu, name='de256_conv5_upsample') 77 | 78 | # TODO: output scale 256 79 | # share attention 80 | x = attention(x, cnum, down_scale=2, pool_scale=2, name='attention_pooling_256') 81 | 82 | # out of predict grad map 83 | x = gen_conv(x, cnum, 5, stride=1, activation=tf.nn.relu, name='out256_grad_out') 84 | x_grad = gen_conv(x, 6, 1, stride=1, activation=None, name='grad256') # grad map no activation 85 | x = gen_conv(x, 3, 1, stride=1, activation=tf.nn.tanh, name='out256') 86 | 87 | return x 88 | 89 | 90 | def evaluate(self, x, edge, mask, args, training=False, reuse=False): 91 | # image, grad map 92 | image = normalize(x) 93 | grad = tf.image.sobel_edges(image) # normalization? 94 | grad = tf.reshape(grad, [1, args.IMG_SHAPES[0], args.IMG_SHAPES[1], 6]) # 6 channel 95 | 96 | # x for image 97 | x = tf.reshape(image, [1, args.IMG_SHAPES[0], args.IMG_SHAPES[1], 98 | args.IMG_SHAPES[2]]) # [1, args.IMG_SHAPES[0], args.IMG_SHAPES[1], args.IMG_SHAPES[2]] 99 | mask = tf.reshape(mask, [-1, args.IMG_SHAPES[0], args.IMG_SHAPES[1], 1]) 100 | edge = tf.reshape(edge, [-1, args.IMG_SHAPES[0], args.IMG_SHAPES[1], 1]) 101 | 102 | # incomplete image 103 | x_incomplete = x * (1. - mask) 104 | 105 | # incomplete edge at full scale 106 | input_edge = 1 - edge 107 | edge_incomplete = input_edge * (1 - mask) + mask # 0 (black) for edge when save and input, 1 (white) for non edge 108 | 109 | # grad 110 | grad_incomplete = (1. - mask) * grad 111 | 112 | out_256 = self.build_inpaint_net(x_incomplete, edge_incomplete, grad_incomplete, 113 | mask, args, reuse=reuse,training=training, padding=args.PADDING) 114 | 115 | raw_x = inverse_transform(x) 116 | raw_x_incomplete = raw_x * (1 - mask) 117 | raw_x_complete = raw_x_incomplete + inverse_transform(out_256) * mask 118 | 119 | return raw_x_complete 120 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning to Incorporate Structure Knowledge for Image Inpainting 2 | Introductions and source code of AAAI 2020 paper *'Learning to Incorporate Structure Knowledge for Image Inpainting'*. You can get the paper in **[AAAI proceedings](https://aaai.org/ojs/index.php/AAAI/article/view/6951) or **[here](https://www.researchgate.net/publication/338984531_Learning_to_Incorporate_Structure_Knowledge_for_Image_Inpainting)**. 3 | 4 | ## Citation 5 | ```html 6 | @inproceedings{jie2020inpainting, 7 | title={Learning to Incorporate Structure Knowledge for Image Inpainting}, 8 | author={Jie Yang, Zhiquan Qi, Yong Shi}, 9 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 10 | volume={34}, 11 | number={7}, 12 | pages={12605-12612}, 13 | year={2020} 14 | } 15 | ``` 16 | 17 | # Introduction 18 | This project develops a multi-task learning framework that attempts to incorporate the image structure knowledge to assist image inpainting, which is not well explored in previous works. The primary idea is to train a shared generator to simultaneously complete the corrupted image and corresponding structures --- edge and gradient, thus implicitly encouraging the generator to exploit relevant structure knowledge while inpainting. In the meantime, we also introduce a structure embedding scheme to explicitly embed the learned structure features into the inpainting process, thus to provide possible preconditions for image completion. Specifically, a novel pyramid structure loss is proposed to supervise structure learning and embedding. Moreover, an attention mechanism is developed to further exploit the recurrent structures and patterns in the image to refine the generated structures and contents. Through multi-task learning, structure embedding besides with attention, our framework takes advantage of the structure knowledge and outperforms several state-of-the-art methods on benchmark datasets quantitatively and qualitatively. 19 | 20 | The overview of our multi-task framework is as in figure below. It leverages the structure knowledge with multi-tasking learning (simultaneous image and structure generation), structure embedding and attention mechanism. 21 | 22 | ![architecture](https://github.com/YoungGod/sturcture-inpainting/blob/master/project-images/architecture.jpg) 23 | 24 | # Pyramid structure loss 25 | We propose a pyramid structure loss to guide the structure generation and embedding, thus incorporating the structure information into the generation process. Here, the gradient and edge which are holded in sobel gradient maps as in figure below are used as the structure information. 26 | 27 |
28 | 29 | 30 |
31 | 32 | The loss function *pyramid_structure_loss(..)* is realized in **structure_loss.py**. 33 | 34 | ```python 35 | def pyramid_structure_loss(image, predicts, edge_alpha, grad_alpha): 36 | _, H, W, _ = image.get_shape().as_list() 37 | loss = 0. 38 | for predict in predicts: 39 | _, h, w, _ = predict.get_shape().as_list() 40 | if h != H: 41 | gt_img = tf.image.resize_nearest_neighbor(image, size=(h, w)) 42 | 43 | # grad 44 | gt_grad = tf.image.sobel_edges(gt_img) 45 | gt_grad = tf.reshape(gt_grad, [-1, h, w, 6]) # 6 channel 46 | grad_error = tf.abs(predict - gt_grad) 47 | 48 | # edge 49 | gt_edge = tf.py_func(canny_edge, [gt_img], tf.float32, stateful=False) 50 | edge_priority = priority_loss_mask(gt_edge, ksize=5, sigma=1, iteration=2) 51 | else: 52 | gt_img = image 53 | 54 | # grad 55 | gt_grad = tf.image.sobel_edges(gt_img) 56 | gt_grad = tf.reshape(gt_grad, [-1, H, W, 6]) # 6 channel 57 | grad_error = tf.abs(predict - gt_grad) 58 | 59 | # edge 60 | gt_edge = tf.py_func(canny_edge, [gt_img], tf.float32, stateful=False) 61 | edge_priority = priority_loss_mask(gt_edge, ksize=5, sigma=1, iteration=2) 62 | 63 | grad_loss = tf.reduce_mean(grad_alpha * grad_error) 64 | edge_weight = edge_alpha * edge_priority 65 | # print("edge_weight", edge_weight.shape) 66 | # print("grad_error", grad_error.shape) 67 | edge_loss = tf.reduce_sum(edge_weight * grad_error) / tf.reduce_sum(edge_weight) / 6. # 6 channel 68 | 69 | loss = loss + grad_loss + edge_loss 70 | 71 | return loss 72 | ``` 73 | 74 | # Attention Layer 75 | Our attention operation is inspired by the non-local mean mechanism which has been used for deionizing and super-resolution. It calculates the response at a position of the output feature map as a weighted sum of the features in the whole input feature map. And the weight or attention score is measured by the feature similarity. And when k=1, it works just like Self-Attention. Through attention, similar features from surroundings can be transferred to the missing regions to refine the generated contents and structures (e.g. smoothing the artifacts and enhancing the details). 76 | 77 |
78 | 79 | 80 |
81 | 82 | # Some qualitative results 83 | ## Qualitative 84 | ![qualitative](https://github.com/YoungGod/sturcture-inpainting/blob/master/project-images/quality-compare-celeba.jpg) 85 | ![qualitative](https://github.com/YoungGod/sturcture-inpainting/blob/master/project-images/quality-compare-place.jpg) 86 | 87 | ## Ablation 88 | ![ablation](https://github.com/YoungGod/sturcture-inpainting/blob/master/project-images/ablation.jpg) 89 | 90 | ## Real life object removal 91 |
92 | 93 | 94 |
95 | 96 | # Code 97 | ## Painter 98 | To evaluate the generalization ability of our inpainting models, we carry out object removal experiments in user scenarios. We develop a interactive image removal and completion tool with Opencv. You may download the checkpoint of the inpainting model pretrained on Places2 training and validation data from **[here](https://pan.baidu.com/s/1SBbfR94KWG5UMm_FClmdMQ)** with pass code: **uiqn**. 99 | 100 | Or [google drive](https://drive.google.com/drive/folders/1ReSArrra8NOQv8dlU2QK0DE0P5qoalCT?usp=sharing) 101 | 102 | Run the paint.py in command line (We implement our model using tensorflow 1.15.2, python 3.7): 103 | > python painter.py --checkpoint checkpoint/places2 --save_path imgs 104 | 105 | Do object removal experiments, it will work like: 106 |
107 | 108 | 109 | 110 |
111 | 112 | ## Citation 113 | ```html 114 | @inproceedings{jie2020inpainting, 115 | title={Learning to Incorporate Structure Knowledge for Image Inpainting}, 116 | author={Jie Yang, Zhiquan Qi, Yong Shi}, 117 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 118 | volume={34}, 119 | number={7}, 120 | pages={12605-12612}, 121 | year={2020} 122 | } 123 | ``` 124 | ## License 125 | CC 4.0 Attribution-NonCommercial International. The software is for educaitonal and academic research purpose only. 126 | -------------------------------------------------------------------------------- /src/val_inpaint_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | import tensorflow as tf 5 | import numpy as np 6 | import time 7 | import pandas as pd 8 | import re 9 | 10 | from math import ceil 11 | from scipy.misc import imsave 12 | from inpaint_model import InpaintModel 13 | from config import Config, select_gpu 14 | from utils_fn import show_all_variables, load_test_data, load_test_mask, create_test_mask, dataset_len, load_test_img_edge 15 | 16 | from frechet_inception_distance import calculate_fid_given_paths 17 | from metrics import uqi_vif 18 | 19 | # For reproducible result 20 | np.random.seed(0) 21 | tf.set_random_seed(0) 22 | 23 | # with tf.device('/cpu:0'): 24 | """ 25 | Testing 26 | """ 27 | # Load config file for run an inpainting model 28 | args = Config('inpaint_config.yml') 29 | 30 | # GPU config 31 | # os.environ["CUDA_VISIBLE_DEVICES"] = str(args.GPU_ID) 32 | os.environ["CUDA_VISIBLE_DEVICES"] = select_gpu() 33 | config_gpu = tf.ConfigProto() 34 | config_gpu.gpu_options.allow_growth = True # allow memory grow 35 | 36 | # log setting 37 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 38 | logger = logging.getLogger("YOUNG") 39 | logger.setLevel(level=logging.INFO) 40 | 41 | """ Input Data (images and masks) """ 42 | # images 43 | if args.CUSTOM_DATASET: 44 | images, image_iterator = load_test_img_edge(args) 45 | else: 46 | images = tf.placeholder(tf.float32, [args.BATCH_SIZE, args.IMG_SHAPES[0], args.IMG_SHAPES[1], args.IMG_SHAPES[2]], 47 | name='real_images') 48 | # test masks 49 | if args.MASK_MODE == 'irregular': 50 | masks, mask_iterator = load_test_mask(args) 51 | else: 52 | masks = tf.placeholder(tf.float32, [args.TEST_NUM, args.IMG_SHAPES[0], args.IMG_SHAPES[1], 1], 53 | name='test_regular_masks') 54 | 55 | """ Build Testing Inpaint Model""" 56 | # Testing model 57 | model = InpaintModel(args) 58 | logger.info("Build Testing Inpaint Model") 59 | model.build_test_model(images, masks, args) 60 | 61 | """ Testing Logic""" 62 | with tf.Session(config=config_gpu) as sess: 63 | 64 | # Saver to restore model: to restore variables 65 | # TODO: we can choose variables to store and steps to keep (max_to_keep) 66 | saver = tf.train.Saver() 67 | 68 | # Model dir 69 | # If restore a specific model 70 | args.MODEL_DIR = args.MODEL_RESTORE 71 | 72 | # Result dirs 73 | # (1) result/model_dir/inpainted_images 74 | # (2) result/model_dir/masked_images 75 | # (3) result/model_dir/sample_images 76 | result_dir = os.path.join(args.RESULT_DIR, args.MODEL_DIR) 77 | if not os.path.exists(result_dir): 78 | os.makedirs(result_dir) 79 | 80 | # (1) result/model_dir/inpainted_images 81 | inpainted_dir = os.path.join(result_dir, 'inpainted_images') 82 | if not os.path.exists(inpainted_dir): 83 | os.makedirs(inpainted_dir) 84 | # (2) result/model_dir/maked_images 85 | masked_dir = os.path.join(result_dir, 'masked_images') 86 | if not os.path.exists(masked_dir): 87 | os.makedirs(masked_dir) 88 | # (3) result/model_dir/sample_images 89 | sample_dir = os.path.join(result_dir, 'sample_images') 90 | if not os.path.exists(sample_dir): 91 | os.makedirs(sample_dir) 92 | # (4) result/model_dir/masks 93 | mask_dir = os.path.join(result_dir, 'masks') 94 | if not os.path.exists(mask_dir): 95 | os.makedirs(mask_dir) 96 | # (5) result/model_dir/inpainted_smpales 97 | inpainted_sample_dir = os.path.join(result_dir, 'inpainted_samples') 98 | if not os.path.exists(inpainted_sample_dir): 99 | os.makedirs(inpainted_sample_dir) 100 | 101 | # Model Checkpoint dir 102 | checkpoint_dir = os.path.join(args.CHECKPOINT_DIR, args.MODEL_DIR) 103 | 104 | # Testing data info 105 | with open(args.DATA_FLIST[args.DATASET][1]) as f: 106 | fnames = f.read().splitlines() 107 | 108 | data_len = len(fnames) 109 | max_test_step = ceil(data_len / args.TEST_NUM) # TEST_NUM can be 1 or a batch like 8 110 | max_test_step = min(max_test_step, ceil(args.MAX_TEST_NUM / args.TEST_NUM)) # max test number of images 111 | 112 | # Training data info 113 | max_step = dataset_len(args) // args.BATCH_SIZE # max step for each epoch 114 | last_step = int(args.EPOCH * max_step) # total steps 115 | # Parameters 116 | imgh = args.IMG_SHAPES[0] 117 | imgw = args.IMG_SHAPES[1] 118 | 119 | # Try to restore model 120 | # Initialize all the variables 121 | tf.global_variables_initializer().run() 122 | # Show network architecture 123 | show_all_variables() 124 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) # checkpoint 125 | if ckpt and ckpt.model_checkpoint_path: 126 | # print ckpt name with dir 127 | logger.info("Latest ckpt: {}".format(ckpt.model_checkpoint_path)) 128 | logger.info("All ckpt: {}".format(ckpt.all_model_checkpoint_paths)) 129 | # ckpt base name 130 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 131 | # restore 132 | # saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name)) # restore 133 | vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 134 | assign_ops = [] 135 | for var in vars_list: 136 | vname = var.name 137 | from_name = vname 138 | try: 139 | var_value = tf.contrib.framework.load_variable(os.path.join(checkpoint_dir, ckpt_name), from_name) 140 | assign_ops.append(tf.assign(var, var_value)) 141 | except Exception: 142 | continue 143 | sess.run(assign_ops) 144 | print('Model loaded.') 145 | 146 | counter = int(next(re.finditer("\d+", ckpt_name)).group(0)) 147 | logger.info(" [*] Success to read {}".format(ckpt_name)) 148 | else: 149 | logger.info(" [*] Failed to find a checkpoint") 150 | # Existing Training info 151 | current_epoch = counter // max_step 152 | current_step = counter % max_step 153 | logger.info('Evaluating epoch {}, step {}.'.format(current_epoch, current_step)) 154 | 155 | # Testing start 156 | # For saving evaluation results 157 | if not os.path.exists(os.path.join(result_dir, 'evaluation.csv')): 158 | with open(os.path.join(result_dir, 'evaluation.csv'), mode='a') as f: 159 | f.write("epoch, step, l1, pnsr, ssim, fid, uqi, vif\n") 160 | mask_size = [] 161 | l1_list = [] 162 | psnr_list = [] 163 | ssim_list = [] 164 | 165 | count = 1 166 | sess.run(image_iterator.initializer) 167 | if args.MASK_MODE == 'irregular': 168 | sess.run(mask_iterator.initializer) 169 | for step in range(1, max_test_step+1): 170 | time_start = time.time() 171 | 172 | try: 173 | if args.MASK_MODE == 'irregular': 174 | raw_x, raw_x_incomplete, raw_x_complete, mask, l1, psnr, ssim = sess.run([model.raw_x, model.raw_x_incomplete, 175 | model.raw_x_complete, model.mask, 176 | model.l1, model.psnr, model.ssim]) 177 | else: 178 | mask = create_test_mask(imgw, imgh, imgw // 2, imgh // 2, args) 179 | raw_x, raw_x_incomplete, raw_x_complete, mask, l1, psnr, ssim = sess.run([model.raw_x, model.raw_x_incomplete, 180 | model.raw_x_complete, model.mask, 181 | model.l1, model.psnr, model.ssim], 182 | feed_dict={masks: mask}) 183 | except tf.errors.OutOfRangeError: 184 | break 185 | 186 | # setting hole pixel value = 255 187 | ones_x = np.ones_like(raw_x_incomplete) 188 | raw_x_incomplete = raw_x_incomplete + ones_x*mask*255 189 | 190 | for i in range(args.TEST_NUM): 191 | # save result 192 | imsave(os.path.join(sample_dir, args.DATASET+"{}.png".format(count)), raw_x[i]) 193 | imsave(os.path.join(inpainted_dir, args.DATASET+"{}.png".format(count)), raw_x_complete[i]) 194 | imsave(os.path.join(masked_dir, args.DATASET+"{}.png".format(count)), raw_x_incomplete[i]) 195 | imsave(os.path.join(mask_dir, args.DATASET+"{}.png".format(count)), mask[i, :, :, 0]) # mask is grey image 196 | 197 | # mask size 198 | mask_size.append(mask[i].sum()) 199 | l1_list.append(l1[i]) 200 | psnr_list.append(psnr[i]) 201 | ssim_list.append(ssim[i]) 202 | 203 | if step == 1: 204 | imsave(os.path.join(inpainted_sample_dir, args.DATASET + "{}_{}.png".format(count, current_epoch)), raw_x_complete[i]) 205 | 206 | count += 1 207 | 208 | time_cost = time.time() - time_start 209 | time_remaining = (max_test_step - step) * time_cost 210 | logger.info( 211 | 'step {}/{}, image {}/{}, cost {:.2f}s, remaining {:.2f}s.'.format(step, max_test_step, count, data_len, time_cost, 212 | time_remaining)) 213 | 214 | # Final evaluation 215 | # df_evaluation = pd.DataFrame(data=np.array([l1_list, psnr_list, ssim_list, mask_size]).T, 216 | # columns=["l1", "psnr", "ssim", "mask"]) 217 | # df_evaluation.to_csv(os.path.join(result_dir, 'evaluation.csv'), index=False) 218 | logger.info("Saving Finished.") 219 | 220 | logger.info("Evaluating Results..") 221 | # Evaluation result 222 | # l1, psnr, ssim, fid 223 | 224 | # fid score 225 | logger.info("FID score") 226 | fid_value = calculate_fid_given_paths([sample_dir, inpainted_dir], 'inception') 227 | print("FID: ", fid_value) 228 | 229 | # print(df_evaluation.mean(axis=0)) 230 | # UQI and VIF 231 | uqi, vif = uqi_vif(sample_dir, inpainted_dir) 232 | # uqi, vif = 0., 0. 233 | # df_evaluation = pd.concat(df_evaluation, pd.DataFrame(data={"epoch": current_epoch, "step": current_step, 234 | # "l1": np.array(l1_list).mean(), 235 | # "psnr": np.array(psnr_list).mean(), 236 | # "ssim": np.array(ssim_list).mean(), 237 | # "fid": fid_value}), axis=0) 238 | # df_evaluation.to_csv(os.path.join(result_dir, 'evaluation.csv'), index=False) 239 | with open(os.path.join(result_dir, 'evaluation.csv'), mode='a') as f: 240 | f.write("{}, {}, {:.4f}, {:.4f}, {:.4f}, {:.4f}, {:.4f}, {:.4f}\n".format(current_epoch, current_step, 241 | np.array(l1_list).mean(), 242 | np.array(psnr_list).mean(), 243 | np.array(ssim_list).mean(), 244 | fid_value, 245 | uqi, 246 | vif)) 247 | 248 | logger.info("Evaluation Finished.") -------------------------------------------------------------------------------- /src/frechet_inception_distance.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright 2017 Martin Heusel 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # Adapted from the original implementation by Martin Heusel. 18 | # Source https://github.com/bioinf-jku/TTUR/blob/master/fid.py 19 | 20 | ''' Calculates the Frechet Inception Distance (FID) to evalulate GANs. 21 | 22 | The FID metric calculates the distance between two distributions of images. 23 | Typically, we have summary statistics (mean & covariance matrix) of one 24 | of these distributions, while the 2nd distribution is given by a GAN. 25 | 26 | When run as a stand-alone program, it compares the distribution of 27 | images that are stored as PNG/JPEG at a specified location with a 28 | distribution given by summary statistics (in pickle format). 29 | 30 | The FID is calculated by assuming that X_1 and X_2 are the activations of 31 | the pool_3 layer of the inception net for generated samples and real world 32 | samples respectivly. 33 | 34 | See --help to see further details. 35 | ''' 36 | 37 | from __future__ import absolute_import, division, print_function 38 | import numpy as np 39 | import scipy as sp 40 | import os 41 | import gzip, pickle 42 | import tensorflow as tf 43 | from scipy.misc import imread 44 | import pathlib 45 | import urllib 46 | 47 | 48 | class InvalidFIDException(Exception): 49 | pass 50 | 51 | 52 | def create_inception_graph(pth): 53 | """Creates a graph from saved GraphDef file.""" 54 | # Creates graph from saved graph_def.pb. 55 | with tf.gfile.FastGFile( pth, 'rb') as f: 56 | graph_def = tf.GraphDef() 57 | graph_def.ParseFromString( f.read()) 58 | _ = tf.import_graph_def( graph_def, name='FID_Inception_Net') 59 | #------------------------------------------------------------------------------- 60 | 61 | 62 | # code for handling inception net derived from 63 | # https://github.com/openai/improved-gan/blob/master/inception_score/model.py 64 | def _get_inception_layer(sess): 65 | """Prepares inception net for batched usage and returns pool_3 layer. """ 66 | layername = 'FID_Inception_Net/pool_3:0' 67 | pool3 = sess.graph.get_tensor_by_name(layername) 68 | ops = pool3.graph.get_operations() 69 | for op_idx, op in enumerate(ops): 70 | for o in op.outputs: 71 | shape = o.get_shape() 72 | if shape._dims is not None: 73 | shape = [s.value for s in shape] 74 | new_shape = [] 75 | for j, s in enumerate(shape): 76 | if s == 1 and j == 0: 77 | new_shape.append(None) 78 | else: 79 | new_shape.append(s) 80 | try: 81 | o._shape = tf.TensorShape(new_shape) 82 | except ValueError: 83 | o._shape_val = tf.TensorShape(new_shape) # EDIT: added for compatibility with tensorflow 1.6.0 84 | return pool3 85 | #------------------------------------------------------------------------------- 86 | 87 | 88 | def get_activations(images, sess, batch_size=50, verbose=False): 89 | """Calculates the activations of the pool_3 layer for all images. 90 | 91 | Params: 92 | -- images : Numpy array of dimension (n_images, hi, wi, 3). The values 93 | must lie between 0 and 256. 94 | -- sess : current session 95 | -- batch_size : the images numpy array is split into batches with batch size 96 | batch_size. A reasonable batch size depends on the disposable hardware. 97 | -- verbose : If set to True and parameter out_step is given, the number of calculated 98 | batches is reported. 99 | Returns: 100 | -- A numpy array of dimension (num images, 2048) that contains the 101 | activations of the given tensor when feeding inception with the query tensor. 102 | """ 103 | inception_layer = _get_inception_layer(sess) 104 | d0 = images.shape[0] 105 | if batch_size > d0: 106 | print("warning: batch size is bigger than the data size. setting batch size to data size") 107 | batch_size = d0 108 | n_batches = d0//batch_size 109 | n_used_imgs = n_batches*batch_size 110 | pred_arr = np.empty((n_used_imgs,2048)) 111 | for i in range(n_batches): 112 | if verbose: 113 | print("\rPropagating batch %d/%d" % (i+1, n_batches), end="", flush=True) 114 | start = i*batch_size 115 | end = start + batch_size 116 | batch = images[start:end] 117 | pred = sess.run(inception_layer, {'FID_Inception_Net/ExpandDims:0': batch}) 118 | pred_arr[start:end] = pred.reshape(batch_size,-1) 119 | if verbose: 120 | print(" done") 121 | return pred_arr 122 | #------------------------------------------------------------------------------- 123 | 124 | 125 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2): 126 | """Numpy implementation of the Frechet Distance. 127 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 128 | and X_2 ~ N(mu_2, C_2) is 129 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 130 | 131 | Params: 132 | -- mu1 : Numpy array containing the activations of the pool_3 layer of the 133 | inception net ( like returned by the function 'get_predictions') 134 | -- mu2 : The sample mean over activations of the pool_3 layer, precalcualted 135 | on an representive data set. 136 | -- sigma2: The covariance matrix over activations of the pool_3 layer, 137 | precalcualted on an representive data set. 138 | 139 | Returns: 140 | -- dist : The Frechet Distance. 141 | 142 | Raises: 143 | -- InvalidFIDException if nan occures. 144 | """ 145 | m = np.square(mu1 - mu2).sum() 146 | #s = sp.linalg.sqrtm(np.dot(sigma1, sigma2)) # EDIT: commented out 147 | s, _ = sp.linalg.sqrtm(np.dot(sigma1, sigma2), disp=False) # EDIT: added 148 | dist = m + np.trace(sigma1+sigma2 - 2*s) 149 | #if np.isnan(dist): # EDIT: commented out 150 | # raise InvalidFIDException("nan occured in distance calculation.") # EDIT: commented out 151 | #return dist # EDIT: commented out 152 | return np.real(dist) # EDIT: added 153 | #------------------------------------------------------------------------------- 154 | 155 | 156 | def calculate_activation_statistics(images, sess, batch_size=50, verbose=False): 157 | """Calculation of the statistics used by the FID. 158 | Params: 159 | -- images : Numpy array of dimension (n_images, hi, wi, 3). The values 160 | must lie between 0 and 255. 161 | -- sess : current session 162 | -- batch_size : the images numpy array is split into batches with batch size 163 | batch_size. A reasonable batch size depends on the available hardware. 164 | -- verbose : If set to True and parameter out_step is given, the number of calculated 165 | batches is reported. 166 | Returns: 167 | -- mu : The mean over samples of the activations of the pool_3 layer of 168 | the incption model. 169 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 170 | the incption model. 171 | """ 172 | act = get_activations(images, sess, batch_size, verbose) 173 | mu = np.mean(act, axis=0) 174 | sigma = np.cov(act, rowvar=False) 175 | return mu, sigma 176 | #------------------------------------------------------------------------------- 177 | 178 | 179 | #------------------------------------------------------------------------------- 180 | # The following functions aren't needed for calculating the FID 181 | # they're just here to make this module work as a stand-alone script 182 | # for calculating FID scores 183 | #------------------------------------------------------------------------------- 184 | def check_or_download_inception(inception_path): 185 | ''' Checks if the path to the inception file is valid, or downloads 186 | the file if it is not present. ''' 187 | INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 188 | if inception_path is None: 189 | inception_path = '/tmp' 190 | inception_path = pathlib.Path(inception_path) 191 | model_file = inception_path / 'classify_image_graph_def.pb' 192 | if not model_file.exists(): 193 | print("Downloading Inception model") 194 | from urllib import request 195 | import tarfile 196 | fn, _ = request.urlretrieve(INCEPTION_URL) 197 | with tarfile.open(fn, mode='r') as f: 198 | f.extract('classify_image_graph_def.pb', str(model_file.parent)) 199 | return str(model_file) 200 | 201 | 202 | def _handle_path(path, sess): 203 | if path.endswith('.npz'): 204 | f = np.load(path) 205 | m, s = f['mu'][:], f['sigma'][:] 206 | f.close() 207 | else: 208 | path = pathlib.Path(path) 209 | files = list(path.glob('*.jpg')) + list(path.glob('*.png')) 210 | x = np.array([imread(str(fn)).astype(np.float32) for fn in files]) 211 | m, s = calculate_activation_statistics(x, sess) 212 | return m, s 213 | 214 | 215 | def calculate_fid_given_paths(paths, inception_path): 216 | ''' Calculates the FID of two paths. ''' 217 | inception_path = check_or_download_inception(inception_path) 218 | 219 | for p in paths: 220 | if not os.path.exists(p): 221 | raise RuntimeError("Invalid path: %s" % p) 222 | 223 | os.environ["CUDA_VISIBLE_DEVICES"] = '1' 224 | 225 | create_inception_graph(str(inception_path)) 226 | with tf.Session() as sess: 227 | sess.run(tf.global_variables_initializer()) 228 | m1, s1 = _handle_path(paths[0], sess) 229 | m2, s2 = _handle_path(paths[1], sess) 230 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 231 | return fid_value 232 | 233 | 234 | if __name__ == "__main__": 235 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 236 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 237 | parser.add_argument("path", type=str, nargs=2, 238 | help='Path to the generated images or to .npz statistic files') 239 | parser.add_argument("-i", "--inception", type=str, default=None, 240 | help='Path to Inception model (will be downloaded if not provided)') 241 | parser.add_argument("--gpu", default="", type=str, 242 | help='GPU to use (leave blank for CPU only)') 243 | args = parser.parse_args() 244 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 245 | fid_value = calculate_fid_given_paths(args.path, args.inception) 246 | print("FID: ", fid_value) 247 | 248 | #---------------------------------------------------------------------------- 249 | # EDIT: added 250 | 251 | class API: 252 | def __init__(self, num_images, image_shape, image_dtype, minibatch_size): 253 | import config 254 | self.network_dir = os.path.join(config.result_dir, '_inception_fid') 255 | self.network_file = check_or_download_inception(self.network_dir) 256 | self.sess = tf.get_default_session() 257 | create_inception_graph(self.network_file) 258 | 259 | def get_metric_names(self): 260 | return ['FID'] 261 | 262 | def get_metric_formatting(self): 263 | return ['%-10.4f'] 264 | 265 | def begin(self, mode): 266 | assert mode in ['warmup', 'reals', 'fakes'] 267 | self.activations = [] 268 | 269 | def feed(self, mode, minibatch): 270 | act = get_activations(minibatch.transpose(0,2,3,1), self.sess, batch_size=minibatch.shape[0]) 271 | self.activations.append(act) 272 | 273 | def end(self, mode): 274 | act = np.concatenate(self.activations) 275 | mu = np.mean(act, axis=0) 276 | sigma = np.cov(act, rowvar=False) 277 | if mode in ['warmup', 'reals']: 278 | self.mu_real = mu 279 | self.sigma_real = sigma 280 | fid = calculate_frechet_distance(mu, sigma, self.mu_real, self.sigma_real) 281 | return [fid] 282 | 283 | #---------------------------------------------------------------------------- 284 | -------------------------------------------------------------------------------- /painter/painter.py: -------------------------------------------------------------------------------- 1 | from tkinter import * 2 | from PIL import Image, ImageTk, ImageDraw 3 | import tkinter.filedialog as tkFileDialog 4 | import numpy as np 5 | import cv2 6 | import os 7 | import subprocess 8 | import argparse 9 | import tensorflow as tf 10 | from config import Config 11 | from skimage import feature 12 | from skimage.color import rgb2gray 13 | 14 | from inpaint_model import InpaintModel 15 | 16 | # os.environ['CUDA_VISIBLE_DEVICES'] = str(np.argmax([int(x.split()[2]) for x in subprocess.Popen( 17 | # "nvidia-smi -q -d Memory | grep -A4 GPU | grep Free", shell=True, stdout=subprocess.PIPE).stdout.readlines()])) 18 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 19 | 20 | class Paint(object): 21 | MARKER_COLOR = 'white' 22 | 23 | def __init__(self, config): 24 | self.config = config 25 | print("******************************",self.config.CHECKPOINT) 26 | 27 | self.root = Tk() 28 | 29 | self.rect_button = Button(self.root, text='rectangle', command=self.use_rect, width=12, height=3) 30 | self.rect_button.grid(row=0, column=2) 31 | 32 | self.poly_button = Button(self.root, text='stroke', command=self.use_poly, width=12, height=3) 33 | self.poly_button.grid(row=1, column=2) 34 | 35 | self.revoke_button = Button(self.root, text='revoke', command=self.revoke, width=12, height=3) 36 | self.revoke_button.grid(row=2, column=2) 37 | 38 | self.clear_button = Button(self.root, text='clear', command=self.clear, width=12, height=3) 39 | self.clear_button.grid(row=3, column=2) 40 | 41 | self.c = Canvas(self.root, bg='white', width=config.IMG_SHAPES[1]+8, height=config.IMG_SHAPES[0]) 42 | self.c.grid(row=0, column=0, rowspan=8) 43 | 44 | self.out = Canvas(self.root, bg='white', width=config.IMG_SHAPES[1]+8, height=config.IMG_SHAPES[0]) 45 | self.out.grid(row=0, column=1, rowspan=8) 46 | 47 | self.save_button = Button(self.root, text="save", command=self.save, width=12, height=3) 48 | self.save_button.grid(row=6, column=2) 49 | 50 | self.load_button = Button(self.root, text='load', command=self.load, width=12, height=3) 51 | self.load_button.grid(row=5, column=2) 52 | 53 | self.fill_button = Button(self.root, text='fill', command=self.fill, width=12, height=3) 54 | self.fill_button.grid(row=7, column=2) 55 | self.filename = None 56 | 57 | self.setup() 58 | self.root.mainloop() 59 | 60 | def setup(self): 61 | self.old_x = None 62 | self.old_y = None 63 | self.start_x = None 64 | self.start_y = None 65 | self.end_x = None 66 | self.end_y = None 67 | self.eraser_on = False 68 | self.active_button = self.rect_button 69 | self.isPainting = False 70 | self.c.bind('', self.paint) 71 | self.c.bind('', self.reset) 72 | self.c.bind('', self.beginPaint) 73 | self.c.bind('', self.icon2pen) 74 | self.c.bind('', self.icon2mice) 75 | self.mode = 'rect' 76 | self.rect_buf = None 77 | self.line_buf = None 78 | assert self.mode in ['rect', 'poly'] 79 | self.paint_color = self.MARKER_COLOR 80 | self.mask_candidate = [] 81 | self.rect_candidate = [] 82 | self.im_h = None 83 | self.im_w = None 84 | self.mask = None 85 | self.result = None 86 | self.blank = None 87 | self.line_width = 8 88 | 89 | # painter model 90 | self.model = InpaintModel(self.config) 91 | self.reuse = False 92 | sess_config = tf.ConfigProto() 93 | sess_config.gpu_options.allow_growth = False 94 | self.sess = tf.Session(config=sess_config) 95 | 96 | self.input_image_tf = tf.placeholder(dtype=tf.float32, 97 | shape=[1, self.config.IMG_SHAPES[0], self.config.IMG_SHAPES[1], 3]) 98 | self.input_mask_tf = tf.placeholder(dtype=tf.float32, 99 | shape=[1, self.config.IMG_SHAPES[0], self.config.IMG_SHAPES[1], 1]) 100 | self.input_edge_tf = tf.placeholder(dtype=tf.float32, 101 | shape=[1, self.config.IMG_SHAPES[0], self.config.IMG_SHAPES[1], 1]) 102 | output = self.model.evaluate(self.input_image_tf, self.input_edge_tf, self.input_mask_tf, 103 | args=self.config, reuse=self.reuse) 104 | # output = (output + 1) * 127.5 105 | output = tf.minimum(tf.maximum(output[:, :, :, ::-1], 0), 255) 106 | # self.output = tf.cast(output, tf.uint8) 107 | self.output = output 108 | 109 | # load pretrained model 110 | vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 111 | assign_ops = list(map(lambda x: tf.assign(x, tf.contrib.framework.load_variable(self.config.CHECKPOINT, x.name)), 112 | vars_list)) 113 | self.sess.run(assign_ops) 114 | print('Model loaded.') 115 | 116 | def checkResp(self): 117 | assert len(self.mask_candidate) == len(self.rect_candidate) 118 | 119 | def load(self): 120 | self.filename = tkFileDialog.askopenfilename(initialdir='./imgs', 121 | title="Select file", 122 | filetypes=(("png files", "*.png"), ("jpg files", "*.jpg"), 123 | ("all files", "*.*"))) 124 | self.filename_ = self.filename.split('/')[-1][:-4] 125 | self.filepath = '/'.join(self.filename.split('/')[:-1]) 126 | print(self.filename_, self.filepath) 127 | try: 128 | photo = Image.open(self.filename) 129 | self.image = cv2.imread(self.filename) 130 | except: 131 | print('do not load image') 132 | else: 133 | self.im_w, self.im_h = photo.size 134 | self.mask = np.zeros((self.im_h, self.im_w, 1)).astype(np.uint8) 135 | print(photo.size) 136 | self.displayPhoto = photo 137 | self.displayPhoto = self.displayPhoto.resize((self.im_w, self.im_h)) 138 | self.draw = ImageDraw.Draw(self.displayPhoto) 139 | self.photo_tk = ImageTk.PhotoImage(image=self.displayPhoto) 140 | self.c.create_image(0, 0, image=self.photo_tk, anchor=NW) 141 | self.rect_candidate.clear() 142 | self.mask_candidate.clear() 143 | if self.blank is None: 144 | if not os.path.exists('imgs/blank.png'): 145 | self.blank = Image.new(mode='L', size=(1000,1000), color=1) 146 | else: 147 | self.blank = Image.open('imgs/blank.png') 148 | self.blank = self.blank.resize((self.im_w, self.im_h)) 149 | self.blank_tk = ImageTk.PhotoImage(image=self.blank) 150 | self.out.create_image(0, 0, image=self.blank_tk, anchor=NW) 151 | 152 | def save(self): 153 | img = np.array(self.displayPhoto) 154 | cv2.imwrite(os.path.join(self.filepath, 'tmp.png'), img) 155 | 156 | if self.mode == 'rect': 157 | self.mask[:,:,:] = 0 158 | for rect in self.mask_candidate: 159 | self.mask[rect[1]:rect[3], rect[0]:rect[2], :] = 1 160 | 161 | self.save_filename = tkFileDialog.asksaveasfilename(initialdir=self.config.SAVEPATH, 162 | title="Select file", 163 | filetypes=(("png files", "*.png"), ("jpg files", "*.jpg"), 164 | ("all files", "*.*"))) 165 | self.save_filename_ = self.save_filename.split('/')[-1][:-4] 166 | self.save_filepath = '/'.join(self.save_filename.split('/')[:-1]) 167 | 168 | cv2.imwrite(os.path.join(self.save_filepath, self.save_filename_ + '_mask.png'), self.mask * 255) 169 | cv2.imwrite(os.path.join(self.save_filepath, self.save_filename_ + '_result.png'), self.result[0][:, :, ::-1]) 170 | cv2.imwrite(os.path.join(self.save_filepath, self.save_filename_ + '_masked.png'), 171 | self.result[0][:, :, ::-1] * (1 - self.mask) + self.mask * 255) 172 | 173 | def fill(self): 174 | if self.mode == 'rect': 175 | for rect in self.mask_candidate: 176 | self.mask[rect[1]:rect[3], rect[0]:rect[2], :] = 1 177 | image = np.expand_dims(self.image, 0) 178 | mask = np.expand_dims(self.mask, 0) 179 | 180 | img_gray = rgb2gray(self.image) 181 | edge = feature.canny(img_gray, sigma=1.5).astype(np.float32) 182 | edge = np.reshape(edge,(1, 256, 256, 1)) 183 | # print(image.shape) 184 | # print(mask.shape) 185 | # print(edge.shape) 186 | 187 | self.result = self.sess.run(self.output, feed_dict={self.input_image_tf: image * 1.0, 188 | self.input_mask_tf: mask * 1.0, 189 | self.input_edge_tf: edge * 1.0}) 190 | cv2.imwrite('./imgs/tmp.png', self.result[0][:, :, ::-1]) # self.output has batch size = 1, so self.result[0] 191 | 192 | photo = Image.open('./imgs/tmp.png') 193 | self.displayPhotoResult = photo 194 | self.displayPhotoResult = self.displayPhotoResult.resize((self.im_w, self.im_h)) 195 | self.photo_tk_result = ImageTk.PhotoImage(image=self.displayPhotoResult) 196 | self.out.create_image(0, 0, image=self.photo_tk_result, anchor=NW) 197 | return 198 | 199 | def use_rect(self): 200 | self.activate_button(self.rect_button) 201 | self.mode = 'rect' 202 | 203 | def use_poly(self): 204 | self.activate_button(self.poly_button) 205 | self.mode = 'poly' 206 | 207 | def revoke(self): 208 | if len(self.rect_candidate) > 0: 209 | self.c.delete(self.rect_candidate[-1]) 210 | self.rect_candidate.remove(self.rect_candidate[-1]) 211 | self.mask_candidate.remove(self.mask_candidate[-1]) 212 | self.checkResp() 213 | 214 | def clear(self): 215 | self.mask = np.zeros((self.im_h, self.im_w, 1)).astype(np.uint8) 216 | if self.mode == 'poly': 217 | photo = Image.open(self.filename) 218 | self.image = cv2.imread(self.filename) 219 | self.displayPhoto = photo 220 | self.displayPhoto = self.displayPhoto.resize((self.im_w, self.im_h)) 221 | self.draw = ImageDraw.Draw(self.displayPhoto) 222 | self.photo_tk = ImageTk.PhotoImage(image=self.displayPhoto) 223 | self.c.create_image(0, 0, image=self.photo_tk, anchor=NW) 224 | else: 225 | if self.rect_candidate is None or len(self.rect_candidate) == 0: 226 | return 227 | for item in self.rect_candidate: 228 | self.c.delete(item) 229 | self.rect_candidate.clear() 230 | self.mask_candidate.clear() 231 | self.checkResp() 232 | 233 | #TODO: reset canvas 234 | #TODO: undo and redo 235 | #TODO: draw triangle, rectangle, oval, text 236 | 237 | def activate_button(self, some_button, eraser_mode=False): 238 | self.active_button.config(relief=RAISED) 239 | some_button.config(relief=SUNKEN) 240 | self.active_button = some_button 241 | self.eraser_on = eraser_mode 242 | 243 | def beginPaint(self, event): 244 | self.start_x = event.x 245 | self.start_y = event.y 246 | self.isPainting = True 247 | 248 | def paint(self, event): 249 | if self.start_x and self.start_y and self.mode == 'rect': 250 | self.end_x = max(min(event.x, self.im_w), 0) 251 | self.end_y = max(min(event.y, self.im_h), 0) 252 | rect = self.c.create_rectangle(self.start_x, self.start_y, self.end_x, self.end_y, fill=self.paint_color) 253 | if self.rect_buf is not None: 254 | self.c.delete(self.rect_buf) 255 | self.rect_buf = rect 256 | elif self.old_x and self.old_y and self.mode == 'poly': 257 | line = self.c.create_line(self.old_x, self.old_y, event.x, event.y, 258 | width=self.line_width, fill=self.paint_color, capstyle=ROUND, 259 | smooth=True, splinesteps=36) 260 | cv2.line(self.mask, (self.old_x, self.old_y), (event.x, event.y), (1), self.line_width) 261 | self.old_x = event.x 262 | self.old_y = event.y 263 | 264 | def reset(self, event): 265 | self.old_x, self.old_y = None, None 266 | if self.mode == 'rect': 267 | self.isPainting = False 268 | rect = self.c.create_rectangle(self.start_x, self.start_y, self.end_x, self.end_y, 269 | fill=self.paint_color) 270 | if self.rect_buf is not None: 271 | self.c.delete(self.rect_buf) 272 | self.rect_buf = None 273 | self.rect_candidate.append(rect) 274 | 275 | x1, y1, x2, y2 = min(self.start_x, self.end_x), min(self.start_y, self.end_y),\ 276 | max(self.start_x, self.end_x), max(self.start_y, self.end_y) 277 | # up left corner, low right corner 278 | self.mask_candidate.append((x1, y1, x2, y2)) 279 | print(self.mask_candidate[-1]) 280 | 281 | def icon2pen(self, event): 282 | return 283 | 284 | def icon2mice(self, event): 285 | return 286 | 287 | 288 | if __name__ == '__main__': 289 | config = Config('inpaint_config.yml') 290 | config.mode = 'silent' 291 | parser = argparse.ArgumentParser() 292 | parser.add_argument('--checkpoint', type=str, help='path to the model checkpoint') 293 | parser.add_argument('--save_path', type=str, help='path to the model checkpoint') 294 | args = parser.parse_args() 295 | config.CHECKPOINT = args.checkpoint 296 | config.SAVEPATH = args.save_path 297 | # print("@############################################", config.CHECKPIONT) 298 | ge = Paint(config) 299 | -------------------------------------------------------------------------------- /src/train_inpaint_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | import tensorflow as tf 5 | import numpy as np 6 | import time 7 | import re 8 | 9 | from inpaint_model import InpaintModel 10 | from config import Config, select_gpu 11 | from utils_fn import (show_all_variables, load_mask, create_mask, 12 | save_images, load_validation_data, load_validation_mask, create_validation_mask, 13 | dataset_len, load_img_scale_edge, load_val_img_scale_edge) 14 | 15 | # Reproducible result 16 | np.random.seed(0) 17 | tf.set_random_seed(0) 18 | 19 | 20 | def multi_gpu_setting(model, args): 21 | gpu_num = args.NUM_GPUS 22 | batch_size = args.BATCH_SIZE 23 | 24 | with tf.device("/cpu:0"): 25 | """ Input Data (images and masks) """ 26 | # images and edges 27 | if args.CUSTOM_DATASET: 28 | images_edges = load_img_scale_edge(args) 29 | else: 30 | images_edges = tf.placeholder(tf.float32, 31 | [args.BATCH_SIZE * gpu_num, args.IMG_SHAPES[0], args.IMG_SHAPES[1], args.IMG_SHAPES[2]], 32 | name='real_images') 33 | images_, edges_, edges_128_, edges_64_ = images_edges # a tuple 34 | images_ = tf.reshape(images_, 35 | [args.BATCH_SIZE * gpu_num, args.IMG_SHAPES[0], args.IMG_SHAPES[1], args.IMG_SHAPES[2]]) 36 | edges_ = tf.reshape(edges_, [-1, args.IMG_SHAPES[0], args.IMG_SHAPES[1], 1]) 37 | edges_128_ = tf.reshape(edges_128_, [-1, 128, 128, 1]) 38 | edges_64_ = tf.reshape(edges_64_, [-1, 64, 64, 1]) 39 | 40 | # masks 41 | if args.MASK_MODE == 'irregular': 42 | masks = load_mask(args) 43 | else: 44 | masks = tf.placeholder(tf.float32, [1, args.IMG_SHAPES[0], args.IMG_SHAPES[1], 1], 45 | name='regular_masks') 46 | _masks = tf.reshape(masks, [-1, args.IMG_SHAPES[0], args.IMG_SHAPES[1], 1]) 47 | 48 | # opt 49 | g_optimizer = tf.train.AdamOptimizer(learning_rate=args.G_LR, beta1=0., beta2=0.9) 50 | d_optimizer = tf.train.AdamOptimizer(learning_rate=args.D_LR, beta1=0., beta2=0.9) 51 | 52 | # update grad 53 | tower_g_grads = [] 54 | tower_d_grads = [] 55 | 56 | with tf.variable_scope(tf.get_variable_scope()): 57 | for i in range(gpu_num): # GPU IDs 58 | with tf.device("/gpu:%d" % i): 59 | with tf.name_scope("tower_%d" % i): 60 | _images = images_[i * batch_size: (i + 1) * batch_size] 61 | _edges = edges_[i * batch_size: (i + 1) * batch_size] 62 | _edges_128 = edges_128_[i * batch_size: (i + 1) * batch_size] 63 | _edges_64 = edges_64_[i * batch_size: (i + 1) * batch_size] 64 | print(_images.shape) 65 | print(_masks.shape) 66 | print(_edges.shape) 67 | print(_edges_64) 68 | model.build_graph_with_losses(_images, _masks, _edges, _edges_128, _edges_64, args, reuse=tf.AUTO_REUSE) 69 | tf.get_variable_scope().reuse_variables() 70 | # scale 256 71 | _g256_grads = g_optimizer.compute_gradients(model.g_loss, var_list=model.total_g_vars) 72 | _d256_grads = d_optimizer.compute_gradients(model.d_loss, var_list=model.total_d_vars) 73 | tower_g_grads.append(_g256_grads) 74 | with open("tower_{}_g.txt".format(i), 'w') as f: 75 | for g in tower_g_grads[0]: 76 | f.write("g:"+str(g)+'\n') 77 | tower_d_grads.append(_d256_grads) 78 | with open("tower_{}_d.txt".format(i), 'w') as f: 79 | for g in tower_g_grads[0]: 80 | f.write("d:"+str(g)+'\n') 81 | 82 | 83 | # average grads 84 | g_grads = average_gradients(tower_g_grads) 85 | d_grads = average_gradients(tower_d_grads) 86 | 87 | # train op 88 | g_train_op = g_optimizer.apply_gradients(g_grads) 89 | d_train_op = d_optimizer.apply_gradients(d_grads) 90 | 91 | # summary model in the last gpu device 92 | all_sum_256 = model.all_sum # only keep the final summary 93 | 94 | # return train ops and inputs 95 | return g_train_op, d_train_op, images_edges, masks, all_sum_256 96 | 97 | 98 | def average_gradients(tower_grads): 99 | average_grads = [] 100 | for grad_and_vars in zip(*tower_grads): 101 | grads = [] 102 | for g, _ in grad_and_vars: 103 | expend_g = tf.expand_dims(g, 0) 104 | grads.append(expend_g) 105 | grad = tf.concat(grads, 0) 106 | grad = tf.reduce_mean(grad, 0) 107 | v = grad_and_vars[0][1] 108 | grad_and_var = (grad, v) 109 | average_grads.append(grad_and_var) 110 | return average_grads 111 | 112 | 113 | def single_gpu_setting(model, args): 114 | gpu_num = args.NUM_GPUS 115 | assert(gpu_num == 1) 116 | 117 | """ Input Data (images and masks) """ 118 | # images and edges 119 | if args.CUSTOM_DATASET: 120 | images_edges = load_img_scale_edge(args) 121 | else: 122 | images_edges = tf.placeholder(tf.float32, 123 | [args.BATCH_SIZE, args.IMG_SHAPES[0], args.IMG_SHAPES[1], args.IMG_SHAPES[2]], 124 | name='real_images') 125 | images_, edges_, edges_128_, edges_64_ = images_edges # a tuple 126 | images = tf.reshape(images_, 127 | [args.BATCH_SIZE * gpu_num, args.IMG_SHAPES[0], args.IMG_SHAPES[1], args.IMG_SHAPES[2]]) 128 | edges = tf.reshape(edges_, [-1, args.IMG_SHAPES[0], args.IMG_SHAPES[1], 1]) 129 | edges_128 = tf.reshape(edges_128_, [-1, 128, 128, 1]) 130 | edges_64 = tf.reshape(edges_64_, [-1, 64, 64, 1]) 131 | 132 | # masks 133 | if args.MASK_MODE == 'irregular': 134 | masks = load_mask(args) 135 | else: 136 | masks = tf.placeholder(tf.float32, [1, args.IMG_SHAPES[0], args.IMG_SHAPES[1], 1], 137 | name='regular_masks') 138 | masks = tf.reshape(masks, [-1, args.IMG_SHAPES[0], args.IMG_SHAPES[1], 1]) 139 | 140 | # build model with losses 141 | model.build_graph_with_losses(images, masks, edges, edges_128, edges_64, args, reuse=False) 142 | 143 | # train op 144 | g_train_op = tf.train.AdamOptimizer(learning_rate=args.G_LR, beta1=0., beta2=0.9).minimize( 145 | model.g_loss, var_list=model.total_g_vars) 146 | d_train_op = tf.train.AdamOptimizer(learning_rate=args.D_LR, beta1=0., beta2=0.9).minimize( 147 | model.d_loss, var_list=model.total_d_vars) 148 | 149 | # summary 150 | all_sum_256 = model.all_sum 151 | 152 | # return train ops and inputs 153 | return g_train_op, d_train_op, images_edges, masks, all_sum_256 154 | 155 | 156 | def main(): 157 | """ 158 | Training 159 | """ 160 | # Load config file for run an inpainting model 161 | args = Config('inpaint_config.yml') 162 | 163 | # GPU config 164 | gpu_ids = args.GPU_ID 165 | os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(gpu) for gpu in gpu_ids]) # default "2" 166 | # os.environ["CUDA_VISIBLE_DEVICES"] = select_gpu() 167 | config_gpu = tf.ConfigProto() 168 | config_gpu.gpu_options.allow_growth = True # allow memory grow 169 | config_gpu.allow_soft_placement = True 170 | # config_gpu.log_device_placement = True 171 | 172 | # log setting 173 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 174 | logger = logging.getLogger("YOUNG") 175 | logger.setLevel(level=logging.INFO) 176 | 177 | """ Build Inpaint Model with Loss and Optimizer""" 178 | # Model and training setting 179 | model = InpaintModel(args) 180 | if args.NUM_GPUS > 1 or len(args.GPU_ID) > 1: # multi-gpu 181 | logger.info("Build Inpaint Model with Loss and Optimizer in Multi-GPU setting.") 182 | g_train256_op, d_train256_op, images_edges, masks, all_sum_256 = multi_gpu_setting(model, args) 183 | else: # cpu or single gpu 184 | logger.info("Build Inpaint Model with Loss and Optimizer in Single-GPU or CPU setting.") 185 | g_train256_op, d_train256_op, images_edges, masks, all_sum_256 = single_gpu_setting(model, args) 186 | 187 | # If validation? 188 | if args.VAL: 189 | logger.info("Build Validation Model.") 190 | with tf.device('/cpu:0'): 191 | # images 192 | images_edges_val, img_iterator_val = load_val_img_scale_edge(args) 193 | # masks 194 | if args.MASK_MODE == 'irregular': 195 | masks_val, mask_iterator_val = load_validation_mask(args) 196 | else: 197 | masks_val = tf.placeholder(tf.float32, [args.VAL_NUM, args.IMG_SHAPES[0], args.IMG_SHAPES[1], 1], 198 | name='val_regular_masks') 199 | model.build_validation_model(images_edges_val, masks_val, args) 200 | 201 | """ Train Logic""" 202 | with tf.Session(config=config_gpu) as sess: 203 | 204 | # Model dir 205 | # If restore a specific model 206 | if args.MODEL_RESTORE == '': 207 | args.MODEL_DIR = '-'.join(time.asctime().split()) + "_GPU" + '-'.join([str(gpu) for gpu in gpu_ids]) + \ 208 | "_" + args.DATASET + "_" + args.GAN_TYPE + \ 209 | '_' + str(args.GAN_LOSS_TYPE) + str(args.PATCH_GAN_ALPHA) + \ 210 | "_" + "L1" + str(args.L1_FORE_ALPHA) + "_" + str(args.L1_BACK_ALPHA) + \ 211 | "_" + "C" + str(args.CONTENT_FORE_ALPHA) + "_" + "S" + str(args.STYLE_FORE_ALPHA) +\ 212 | "_" + "T" + str(args.TV_ALPHA) + "_" + args.PADDING + '_Deep_MT' +\ 213 | "_" + str(args.ALPHA) 214 | else: 215 | args.MODEL_DIR = args.MODEL_RESTORE 216 | 217 | # Checkpoint dir 218 | checkpoint_dir = os.path.join(args.CHECKPOINT_DIR, args.MODEL_DIR) 219 | if not os.path.exists(checkpoint_dir): 220 | os.makedirs(checkpoint_dir) 221 | 222 | # Sample dir 223 | sample_dir = os.path.join(args.SAMPLE_DIR, args.MODEL_DIR) 224 | if not os.path.exists(sample_dir): 225 | os.makedirs(sample_dir) 226 | 227 | # Summary writer 228 | writer = tf.summary.FileWriter(args.LOG_DIR + '/' + args.MODEL_DIR, sess.graph) 229 | 230 | # Saver to save model: to save variables 231 | # TODO: we can choose variables to store and steps to keep (max_to_keep) 232 | saver = tf.train.Saver() 233 | 234 | # Initialize all the variables 235 | tf.global_variables_initializer().run() 236 | # Show network architecture 237 | show_all_variables() 238 | 239 | # Try to restore model 240 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) # get checkpoint and restore training 241 | if ckpt and ckpt.model_checkpoint_path: 242 | # print ckpt name with dir 243 | logger.info("Latest ckpt: {}".format(ckpt.model_checkpoint_path)) 244 | logger.info("All ckpt: {}".format(ckpt.all_model_checkpoint_paths)) 245 | # ckpt base name 246 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 247 | # restore 248 | # saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name)) # restore 249 | vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 250 | 251 | assign_ops = [] 252 | for var in vars_list: 253 | vname = var.name 254 | from_name = vname 255 | try: 256 | var_value = tf.contrib.framework.load_variable(os.path.join(checkpoint_dir, ckpt_name), from_name) 257 | assign_ops.append(tf.assign(var, var_value)) 258 | except Exception: 259 | continue 260 | sess.run(assign_ops) 261 | print('Model loaded.') 262 | 263 | counter = int(next(re.finditer("\d+", ckpt_name)).group(0)) 264 | logger.info(" [*] Success to read {}".format(ckpt_name)) 265 | else: 266 | counter = 0 267 | logger.info(" [*] Failed to find a checkpoint") 268 | 269 | # Parameters 270 | imgh = args.IMG_SHAPES[0] 271 | imgw = args.IMG_SHAPES[1] 272 | 273 | max_step = dataset_len(args) // (args.BATCH_SIZE * args.NUM_GPUS) # max step for each epoch 274 | last_step = int(args.EPOCH * max_step) # total steps 275 | max_iter = last_step * args.BATCH_SIZE * args.NUM_GPUS # max iteration when batch size is 1 276 | 277 | # continue to train 278 | if counter < last_step: 279 | current_epoch = counter // max_step 280 | current_step = counter % max_step + 1 # TODO: may not right here? 281 | logger.info("Start Training...") 282 | logger.info( 283 | "Total Epoch {}, Iteration per Epoch {}, Max Iteration {}, Max Iteration (batch_size=1) {}.".format( 284 | args.EPOCH, max_step, last_step, max_iter)) 285 | logger.info("Epoch Start {} at step {}".format(current_epoch, current_step)) 286 | 287 | # not continue to train 288 | else: 289 | current_step = 0 290 | current_epoch = args.EPOCH 291 | 292 | count = 1 + counter 293 | for epoch in range(current_epoch, args.EPOCH): 294 | logger.info("Epoch {}:".format(epoch)) 295 | time_start = time.time() 296 | time_s = time_start 297 | for step in range(current_step, max_step+1): 298 | 299 | # save 300 | if count % args.SAVE_FREQ == 0 or count == last_step: 301 | saver.save(sess, os.path.join(checkpoint_dir, model.model_name + '.model'), global_step=count,write_meta_graph=False) 302 | 303 | if args.MASK_MODE == 'irregular': 304 | # logs 305 | if count % args.LOG_FREQ == 0 or count == last_step: 306 | all_sum = sess.run(model.all_sum) 307 | writer.add_summary(all_sum, count) 308 | # train step 309 | sess.run([d_train256_op, g_train256_op]) 310 | else: 311 | mask = create_mask(imgw, imgh, imgw // 2, imgh // 2, delta=0) # random block with hole size (imgw // 2, imgh // 2) 312 | # logs 313 | if count % args.LOG_FREQ == 0 or count == last_step: 314 | all_sum = sess.run(model.all_sum, feed_dict={masks: mask}) 315 | writer.add_summary(all_sum, count) 316 | # train step 317 | sess.run([d_train256_op, g_train256_op], feed_dict={masks: mask}) 318 | 319 | # validation 320 | if args.VAL: 321 | if count % args.VAL_FREQ == 0 or count == last_step: 322 | sess.run(img_iterator_val.initializer) 323 | 324 | if args.MASK_MODE == 'irregular': 325 | sess.run(mask_iterator_val.initializer) 326 | try: 327 | val_all_sum = sess.run(model.val_all_sum_256) 328 | 329 | writer.add_summary(val_all_sum, count) 330 | except tf.errors.OutOfRangeError: 331 | break 332 | else: 333 | try: 334 | if args.STATIC_VIEW: 335 | mask = create_validation_mask(imgw, imgh, imgw // 2, imgh // 2, args, imgw // 4, imgh // 4) 336 | else: 337 | mask = create_validation_mask(imgw, imgh, imgw // 2, imgh // 2, args, delta=0) 338 | val_all_sum = sess.run(model.val_all_sum_256, feed_dict={masks_val: mask}) 339 | 340 | writer.add_summary(val_all_sum, count) 341 | except tf.errors.OutOfRangeError: 342 | break 343 | 344 | # logger info 345 | if count % args.PRINT_FREQ == 0 or count == last_step: 346 | time_cost = (time.time() - time_start) / args.PRINT_FREQ 347 | time_remaining = (last_step - count) * time_cost / 3600. 348 | logger.info('epoch {}/{}, step {}/{}, cost {:.2f}s, remaining {:.2f}h.'.format(epoch, args.EPOCH, step, max_step, time_cost,time_remaining)) 349 | time_start = time.time() 350 | 351 | current_step = 0 352 | count += 1 353 | 354 | logger.info('epoch {}/{}, cost {:.2f}min.'.format(epoch, args.EPOCH, (time.time() - time_s)/60)) 355 | 356 | logger.info("Finish.") 357 | 358 | 359 | 360 | if __name__ == "__main__": 361 | 362 | main() 363 | -------------------------------------------------------------------------------- /src/ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.framework.python.ops import add_arg_scope 3 | 4 | 5 | weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02) 6 | weight_regularizer = None 7 | 8 | @add_arg_scope 9 | def gen_conv(x, cnum, ksize, stride=1, rate=1, name='conv', IN=True, reuse=False, 10 | padding='SAME', activation=tf.nn.elu, use_bias=True, training=True, sn=False): 11 | """Define conv for generator. 12 | 13 | Args: 14 | x: Input. 15 | cnum: Channel number. 16 | ksize: Kernel size. 17 | Stride: Convolution stride. 18 | Rate: Rate for or dilated conv. 19 | name: Name of layers. 20 | padding: Default to SYMMETRIC. 21 | activation: Activation function after convolution. 22 | training: If current graph is for training or inference, used for bn. 23 | 24 | Returns: 25 | tf.Tensor: output 26 | 27 | """ 28 | assert padding in ['SYMMETRIC', 'SAME', 'REFLECT'] 29 | if padding == 'SYMMETRIC' or padding == 'REFLECT': 30 | """ 31 | Padding layer. 32 | Dilated kernel size: k_r = ksize + (rate - 1)*(ksize - 1) 33 | Padding size: o = i + 2p - k_r and o = i, so p = rate * (ksize - 1) / 2 (when i and o has the same image shape) 34 | """ 35 | p = int(rate*(ksize-1)/2) 36 | x = tf.pad(x, [[0,0], [p, p], [p, p], [0,0]], mode=padding) 37 | padding = 'VALID' 38 | 39 | # if spectrum normalization 40 | if sn: 41 | with tf.variable_scope(name, reuse=reuse): 42 | w = tf.get_variable("kernel", shape=[ksize, ksize, x.get_shape()[-1], cnum], initializer=weight_init, 43 | regularizer=weight_regularizer) 44 | 45 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w), 46 | strides=[1, stride, stride, 1], padding=padding, dilations=[1, rate, rate, 1]) 47 | if use_bias: 48 | bias = tf.get_variable("bias", [cnum], initializer=tf.constant_initializer(0.0)) 49 | x = tf.nn.bias_add(x, bias) 50 | else: 51 | x = tf.layers.conv2d(inputs=x, filters=cnum, activation=None, 52 | kernel_size=ksize, strides=stride, 53 | dilation_rate=rate, padding=padding, 54 | kernel_initializer=None, 55 | kernel_regularizer=weight_regularizer, 56 | use_bias=use_bias) 57 | if IN: 58 | x = tf.contrib.layers.instance_norm(x) # if instance norm? before non-linear activation!!! 59 | if activation is not None: 60 | x = activation(x) 61 | return x 62 | 63 | @add_arg_scope 64 | def gen_deconv(x, cnum, ksize=4, stride=2, rate=1, method='deconv',IN=True, 65 | activation=tf.nn.relu, name='upsample', padding='SAME', sn=False, training=True, reuse=False): 66 | """Define deconv for generator. 67 | The deconv is defined to be a x2 resize_nearest_neighbor operation with 68 | additional gen_conv operation. 69 | 70 | Args: 71 | x: Input. 72 | cnum: Channel number. 73 | name: Name of layers. 74 | training: If current graph is for training or inference, used for bn. 75 | 76 | Returns: 77 | tf.Tensor: output 78 | 79 | """ 80 | with tf.variable_scope(name, reuse=reuse): 81 | if method == 'nearest': 82 | x = resize(x, func=tf.image.resize_nearest_neighbor) # tf.image.resize_bilinear ? 83 | x = gen_conv( 84 | x, cnum, 3, 1, name=name+'_conv', padding=padding, 85 | training=training, IN=IN) 86 | elif method == 'bilinear': 87 | x = resize(x, func=tf.image.resize_bilinear) 88 | x = gen_conv( 89 | x, cnum, 3, 1, name=name + '_conv', padding=padding, 90 | training=training, IN=IN) 91 | elif method == 'bicubic': 92 | x = resize(x, func=tf.image.resize_bicubic) 93 | x = gen_conv( 94 | x, cnum, 3, 1, name=name + '_conv', padding=padding, 95 | training=training, IN=IN) # default instance normalization, see function gen_conv() 96 | else: 97 | # assert padding in ['SYMMETRIC', 'SAME', 'REFLECT'] 98 | # if padding == 'SYMMETRIC' or padding == 'REFLECT': 99 | # p = int(rate * (ksize - 1) / 2) 100 | # p = 0 101 | # x = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]], mode=padding) 102 | padding = 'SAME' 103 | x = tf.layers.conv2d_transpose(x, cnum, kernel_size=ksize, strides=stride, 104 | activation=None, padding=padding) 105 | if IN: 106 | x = tf.contrib.layers.instance_norm(x) # if instance norm? 107 | if activation is not None: 108 | x = activation(x) 109 | return x 110 | 111 | def resize(x, scale=2, to_shape=None, align_corners=True, dynamic=False, 112 | func=tf.image.resize_bilinear, name='resize'): 113 | if dynamic: 114 | xs = tf.cast(tf.shape(x), tf.float32) 115 | new_xs = [tf.cast(xs[1]*scale, tf.int32), 116 | tf.cast(xs[2]*scale, tf.int32)] 117 | else: 118 | xs = x.get_shape().as_list() 119 | new_xs = [int(xs[1]*scale), int(xs[2]*scale)] 120 | with tf.variable_scope(name): 121 | if to_shape is None: 122 | x = func(x, new_xs, align_corners=align_corners) 123 | else: 124 | x = func(x, [to_shape[0], to_shape[1]], 125 | align_corners=align_corners) 126 | return x 127 | 128 | # yj 129 | @add_arg_scope 130 | def resnet_blocks(x, cnum, ksize, stride, rate, block_num, name, IN=True, 131 | padding='REFLECT', activation=tf.nn.elu, training=True): 132 | for block in range(block_num): 133 | # x = resnet_block12(x, cnum, ksize, stride, rate, name+"_"+str(block), padding, activation, training=training) 134 | x = resnet_block21(x, cnum, ksize, stride, rate, name + "_" + str(block), padding=padding, 135 | activation=activation, training=training) 136 | return x 137 | 138 | # yj 139 | def resnet_block21(x, cnum, ksize, stride, rate, name, IN=True, 140 | padding='SAME', activation=tf.nn.relu, training=True): 141 | xin = x 142 | assert padding in ['SYMMETRIC', 'SAME', 'REFLECT'] 143 | if padding == 'SYMMETRIC' or padding == 'REFLECT': 144 | p = int(rate*(ksize-1)/2) 145 | x = tf.pad(x, [[0,0], [p, p], [p, p], [0,0]], mode=padding) 146 | padding1 = 'VALID' 147 | else: 148 | padding1 = padding 149 | x = tf.layers.conv2d( 150 | x, cnum, ksize, stride, dilation_rate=rate, 151 | activation=None, padding=padding1, name=name+"0") 152 | if IN: 153 | x = tf.contrib.layers.instance_norm(x) # if instance norm? 154 | if activation is not None: 155 | x = activation(x) 156 | 157 | rate = 1 158 | if padding == 'SYMMETRIC' or padding == 'REFLECT': 159 | p = int(rate*(ksize-1)/2) 160 | x = tf.pad(x, [[0,0], [p, p], [p, p], [0,0]], mode=padding) 161 | padding2 = 'VALID' 162 | else: 163 | padding2 = padding 164 | x = tf.layers.conv2d( 165 | x, cnum, ksize, stride, dilation_rate=rate, 166 | activation=None, padding=padding2, name=name+"1") 167 | if IN: 168 | x = tf.contrib.layers.instance_norm(x) # if instance norm? 169 | return xin + x 170 | 171 | # yj 172 | def resnet_block12(x, cnum, ksize, stride, rate, name, IN=True, 173 | padding='REFLECT', activation=tf.nn.elu, training=True): 174 | xin = x 175 | rate = 1 176 | assert padding in ['SYMMETRIC', 'SAME', 'REFLECT'] 177 | if padding == 'SYMMETRIC' or padding == 'REFLECT': 178 | p = int(rate*(ksize-1)/2) 179 | x = tf.pad(x, [[0,0], [p, p], [p, p], [0,0]], mode=padding) 180 | padding1 = 'VALID' 181 | else: 182 | padding1 = padding 183 | x = tf.layers.conv2d( 184 | x, cnum, ksize, stride, dilation_rate=rate, 185 | activation=None, padding=padding1, name=name+"0") 186 | if IN: 187 | x = tf.contrib.layers.instance_norm(x) # if instance norm? 188 | if activation is not None: 189 | x = activation(x) 190 | 191 | rate = 2 192 | if padding == 'SYMMETRIC' or padding == 'REFLECT': 193 | p = int(rate*(ksize-1)/2) 194 | x = tf.pad(x, [[0,0], [p, p], [p, p], [0,0]], mode=padding) 195 | padding2 = 'VALID' 196 | else: 197 | padding2 = padding 198 | x = tf.layers.conv2d( 199 | x, cnum, ksize, stride, dilation_rate=rate, 200 | activation=None, padding=padding2, name=name+"1") 201 | if IN: 202 | x = tf.contrib.layers.instance_norm(x) # if instance norm? 203 | 204 | return xin + x 205 | 206 | 207 | def torgb(x, cnum, ksize, stride, rate, name, activation=tf.nn.tanh, padding="SAME"): 208 | x = tf.layers.conv2d( 209 | x, cnum, ksize, stride, dilation_rate=rate, 210 | activation=activation, padding=padding, name=name) 211 | # x = tf.clip_by_value(x, -1., 1.) 212 | return x 213 | 214 | 215 | def dis_conv(x, cnum, ksize=5, stride=2, rate=1, activation=tf.nn.leaky_relu, name='conv', 216 | padding='SAME', use_bias=True, sn=True, training=True, reuse=False): 217 | """Define conv for discriminator. 218 | Activation is set to leaky_relu. 219 | 220 | Args: 221 | x: Input. 222 | cnum: Channel number. 223 | ksize: Kernel size. 224 | stride: Convolution stride. 225 | name: Name of layers. 226 | training: If current graph is for training or inference, used for bn. 227 | 228 | Returns: 229 | tf.Tensor: output 230 | 231 | """ 232 | # if spectrum normalization 233 | if sn: 234 | with tf.variable_scope(name, reuse=reuse): 235 | w = tf.get_variable("kernel", shape=[ksize, ksize, x.get_shape()[-1], cnum], initializer=weight_init, 236 | regularizer=weight_regularizer) 237 | 238 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w), 239 | strides=[1, stride, stride, 1], padding=padding, dilations=[1, rate, rate, 1]) 240 | if use_bias: 241 | bias = tf.get_variable("bias", [cnum], initializer=tf.constant_initializer(0.0)) 242 | x = tf.nn.bias_add(x, bias) 243 | if activation is not None: 244 | x = activation(x) 245 | else: 246 | x = tf.layers.conv2d(inputs=x, filters=cnum, activation=activation, 247 | kernel_size=ksize, strides=stride, 248 | dilation_rate=rate, padding=padding, 249 | kernel_initializer=None, 250 | kernel_regularizer=None, 251 | use_bias=use_bias, 252 | reuse=reuse) 253 | return x 254 | 255 | def flatten(x, name='flatten'): 256 | """Flatten wrapper. 257 | """ 258 | with tf.variable_scope(name): 259 | return tf.contrib.layers.flatten(x) 260 | 261 | def out_complete(out, x_incomplete, mask, res): 262 | mask = tf.image.resize_images(mask, (res, res)) 263 | x_incomplete = tf.image.resize_images(x_incomplete, (res, res)) 264 | x_complete = out * mask + x_incomplete * (1. - mask) 265 | return x_complete 266 | 267 | 268 | # linear embedding 269 | @add_arg_scope 270 | def conv(x, channels, kernel=3, stride=1, pad=0, pad_type='REFLECT', use_bias=True, sn=False, scope='conv_0', 271 | reuse=False, training=False, padding=None): 272 | with tf.variable_scope(scope, reuse=reuse): 273 | if pad_type == 'zero' : 274 | x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]]) 275 | if pad_type == 'reflect' : 276 | x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]], mode='REFLECT') 277 | 278 | if sn : 279 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init, 280 | regularizer=weight_regularizer) 281 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w), 282 | strides=[1, stride, stride, 1], padding='VALID') 283 | if use_bias : 284 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0)) 285 | x = tf.nn.bias_add(x, bias) 286 | 287 | else : 288 | x = tf.layers.conv2d(inputs=x, filters=channels, 289 | kernel_size=kernel, kernel_initializer=weight_init, 290 | kernel_regularizer=weight_regularizer, 291 | strides=stride, use_bias=use_bias, reuse=reuse) 292 | return x 293 | 294 | def spectral_norm(w, iteration=1): 295 | w_shape = w.shape.as_list() 296 | w = tf.reshape(w, [-1, w_shape[-1]]) 297 | 298 | u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False) 299 | 300 | u_hat = u 301 | v_hat = None 302 | for i in range(iteration): 303 | """ 304 | power iteration 305 | Usually iteration = 1 will be enough 306 | """ 307 | v_ = tf.matmul(u_hat, tf.transpose(w)) 308 | v_hat = l2_norm(v_) 309 | 310 | u_ = tf.matmul(v_hat, w) 311 | u_hat = l2_norm(u_) 312 | 313 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat)) 314 | w_norm = w / sigma 315 | 316 | with tf.control_dependencies([u.assign(u_hat)]): 317 | w_norm = tf.reshape(w_norm, w_shape) 318 | 319 | return w_norm 320 | 321 | def l2_norm(v, eps=1e-12): 322 | return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps) 323 | 324 | def hw_flatten(x) : 325 | return tf.reshape(x, shape=[x.shape[0], -1, x.shape[-1]]) 326 | 327 | def max_pooling(x, pool_size=2): 328 | x = tf.layers.max_pooling2d(x, pool_size=pool_size, strides=pool_size, padding='SAME') 329 | return x 330 | 331 | 332 | def avg_pooling(x, pool_size=2): 333 | x = tf.layers.average_pooling2d(x, pool_size=pool_size, strides=pool_size, padding='SAME') 334 | return x 335 | 336 | 337 | 338 | """##### attention #####""" 339 | def attention(x, channels, neighbors=1, use_bias=True, sn=False, down_scale = 2, pool_scale=2, 340 | name='attention_pooling', training=True, padding='REFLECT', reuse=False): 341 | if neighbors > 1: 342 | x = attention_with_neighbors(x, channels, down_scale=down_scale, pool_scale=pool_scale, name=name) 343 | else: 344 | x = attention_with_pooling(x, channels, down_scale=down_scale, pool_scale=pool_scale, name=name) 345 | return x 346 | 347 | @add_arg_scope 348 | def attention_with_pooling(x, channels, ksize=4, use_bias=True, sn=False, down_scale = 2, pool_scale=2, 349 | name='attention_pooling', training=True, padding='REFLECT', reuse=False): 350 | with tf.variable_scope(name, reuse=reuse): 351 | x_origin = x 352 | 353 | # down sampling 354 | if down_scale > 1: 355 | x = gen_conv(x, channels, ksize, stride=down_scale, activation=tf.nn.relu, name='attention_down_sample',reuse=reuse) 356 | 357 | # attention 358 | f = conv(x, channels // 16, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='f_conv', reuse=reuse) # [bs, h, w, c'] 359 | f = max_pooling(f, pool_scale) 360 | # f = avg_pooling(f) 361 | 362 | g = conv(x, channels // 16, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='g_conv',reuse=reuse) # [bs, h, w, c'] 363 | 364 | h = conv(x, channels // 16, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='h_conv',reuse=reuse) # [bs, h, w, c] 365 | h = max_pooling(h, pool_scale) 366 | # h = avg_pooling(h) [4,65536,4096] 367 | 368 | # N = h * w 369 | s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N] 370 | 371 | beta = tf.nn.softmax(s) # attention map 372 | 373 | o = tf.matmul(beta, hw_flatten(h)) # [bs, N, C] 374 | gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0)) 375 | 376 | o = tf.reshape(o, shape=[x.shape[0], x.shape[1], x.shape[2], channels // 16]) # [bs, h, w, C] 377 | # o = conv(o, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='attn_conv_up') # from bottleneck 378 | 379 | # up sampling 380 | if down_scale > 1: 381 | o = gen_deconv(o, channels, ksize, method='deconv', stride=down_scale, activation=tf.nn.relu, name='attention_down_upsample',reuse=reuse) 382 | 383 | x = gamma * o + x_origin 384 | 385 | return x 386 | 387 | # attention consider neighbors 388 | @add_arg_scope 389 | def attention_with_neighbors(x, channels, ksize=3, use_bias=True, sn=False, stride=2, 390 | down_scale = 2, pool_scale=2, name='attention_pooling', 391 | training=True, padding='REFLECT', reuse=False): 392 | with tf.variable_scope(name, reuse=reuse): 393 | x1 = x 394 | 395 | # downsample input feature maps if needed due to limited GPU memory 396 | # down sampling 397 | if down_scale > 1: 398 | x1 = gen_conv(x1, channels, ksize, stride=down_scale, activation=tf.nn.relu, name='attention_down_sample', 399 | reuse=reuse) 400 | # get shapes 401 | int_x1s = x1.get_shape().as_list() 402 | # extract patches from high-level feature maps for matching and attending 403 | x1_groups = tf.split(x1, int_x1s[0], axis=0) 404 | w = tf.extract_image_patches( 405 | x1, [1, ksize, ksize, 1], [1, stride, stride, 1], [1, 1, 1, 1], padding='SAME') 406 | w = tf.reshape(w, [int_x1s[0], -1, ksize, ksize, int_x1s[3]]) 407 | w = tf.transpose(w, [0, 2, 3, 4, 1]) # transpose to [b, ksize, ksize, c, hw/4] # need transpose?? -- 480 408 | w_groups = tf.split(w, int_x1s[0], axis=0) 409 | 410 | # matching and attending hole and non-hole patches 411 | y = [] 412 | scale = 10. 413 | # high level patches: w_groups, low level patches: raw_w_groups, x2_groups: high level feature map 414 | for xi, wi in zip(x1_groups, w_groups): 415 | # matching on high-level feature maps 416 | wi = wi[0] 417 | wi_normed = wi / tf.maximum(tf.sqrt(tf.reduce_sum(tf.square(wi), axis=[0, 1, 2])), 1e-4) 418 | yi = tf.nn.conv2d(xi, wi_normed, strides=[1, 1, 1, 1], padding="SAME") 419 | yi = tf.reshape(yi, [1, int_x1s[1], int_x1s[2], (int_x1s[1] // stride) * (int_x1s[2] // stride)]) 420 | yi = tf.nn.softmax(yi * scale, 3) 421 | # non local mean 422 | wi_center = tf.transpose(wi, [0, 1, 3, 2]) 423 | yi = tf.nn.conv2d(yi, wi_center, strides=[1, 1, 1, 1], padding="SAME") / 4. 424 | 425 | # filter: [height, width, output_channels, in_channels] 426 | y.append(yi) 427 | y = tf.concat(y, axis=0) 428 | y.set_shape(int_x1s) 429 | # up sampling 430 | if down_scale > 1: 431 | y = gen_deconv(y, channels, ksize, method='deconv', stride=down_scale, activation=tf.nn.relu, 432 | name='attention_down_upsample', reuse=reuse) 433 | 434 | gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0)) 435 | x = gamma * y + x 436 | x = tf.layers.conv2d(x, channels, 3, 1, dilation_rate=1, activation=tf.nn.relu, padding='SAME') 437 | return x 438 | 439 | def normalize(x) : 440 | return x/127.5 - 1 441 | 442 | def imsave(images, size, path): 443 | return scipy.misc.imsave(path, merge(images, size)) 444 | 445 | def inverse_transform(images): 446 | return (images+1.)*127.5 -------------------------------------------------------------------------------- /painter/ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.framework.python.ops import add_arg_scope 3 | 4 | weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02) 5 | weight_regularizer = None 6 | 7 | @add_arg_scope 8 | def gen_conv(x, cnum, ksize, stride=1, rate=1, name='conv', IN=True, reuse=False, 9 | padding='SAME', activation=tf.nn.elu, use_bias=True, training=True, sn=False): 10 | """Define conv for generator. 11 | 12 | Args: 13 | x: Input. 14 | cnum: Channel number. 15 | ksize: Kernel size. 16 | Stride: Convolution stride. 17 | Rate: Rate for or dilated conv. 18 | name: Name of layers. 19 | padding: Default to SYMMETRIC. 20 | activation: Activation function after convolution. 21 | training: If current graph is for training or inference, used for bn. 22 | 23 | Returns: 24 | tf.Tensor: output 25 | 26 | """ 27 | assert padding in ['SYMMETRIC', 'SAME', 'REFLECT'] 28 | if padding == 'SYMMETRIC' or padding == 'REFLECT': 29 | """ 30 | Padding layer. 31 | Dilated kernel size: k_r = ksize + (rate - 1)*(ksize - 1) 32 | Padding size: o = i + 2p - k_r and o = i, so p = rate * (ksize - 1) / 2 (when i and o has the same image shape) 33 | """ 34 | p = int(rate*(ksize-1)/2) 35 | x = tf.pad(x, [[0,0], [p, p], [p, p], [0,0]], mode=padding) 36 | padding = 'VALID' 37 | 38 | # if spectrum normalization 39 | if sn: 40 | with tf.variable_scope(name, reuse=reuse): 41 | w = tf.get_variable("kernel", shape=[ksize, ksize, x.get_shape()[-1], cnum], initializer=weight_init, 42 | regularizer=weight_regularizer) 43 | 44 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w), 45 | strides=[1, stride, stride, 1], padding=padding, dilations=[1, rate, rate, 1]) 46 | if use_bias: 47 | bias = tf.get_variable("bias", [cnum], initializer=tf.constant_initializer(0.0)) 48 | x = tf.nn.bias_add(x, bias) 49 | else: 50 | x = tf.layers.conv2d(inputs=x, filters=cnum, activation=None, 51 | kernel_size=ksize, strides=stride, 52 | dilation_rate=rate, padding=padding, 53 | kernel_initializer=None, 54 | kernel_regularizer=weight_regularizer, 55 | use_bias=use_bias) 56 | if IN: 57 | x = tf.contrib.layers.instance_norm(x) # if instance norm? before non-linear activation!!! 58 | if activation is not None: 59 | x = activation(x) 60 | return x 61 | 62 | @add_arg_scope 63 | def gen_deconv(x, cnum, ksize=4, stride=2, rate=1, method='deconv',IN=True, 64 | activation=tf.nn.relu, name='upsample', padding='SAME', sn=False, training=True, reuse=False): 65 | """Define deconv for generator. 66 | The deconv is defined to be a x2 resize_nearest_neighbor operation with 67 | additional gen_conv operation. 68 | 69 | Args: 70 | x: Input. 71 | cnum: Channel number. 72 | name: Name of layers. 73 | training: If current graph is for training or inference, used for bn. 74 | 75 | Returns: 76 | tf.Tensor: output 77 | 78 | """ 79 | with tf.variable_scope(name, reuse=reuse): 80 | if method == 'nearest': 81 | x = resize(x, func=tf.image.resize_nearest_neighbor) # tf.image.resize_bilinear ? 82 | x = gen_conv( 83 | x, cnum, 3, 1, name=name+'_conv', padding=padding, 84 | training=training, IN=IN) 85 | elif method == 'bilinear': 86 | x = resize(x, func=tf.image.resize_bilinear) 87 | x = gen_conv( 88 | x, cnum, 3, 1, name=name + '_conv', padding=padding, 89 | training=training, IN=IN) 90 | elif method == 'bicubic': 91 | x = resize(x, func=tf.image.resize_bicubic) 92 | x = gen_conv( 93 | x, cnum, 3, 1, name=name + '_conv', padding=padding, 94 | training=training, IN=IN) # default instance normalization, see function gen_conv() 95 | else: 96 | # assert padding in ['SYMMETRIC', 'SAME', 'REFLECT'] 97 | # if padding == 'SYMMETRIC' or padding == 'REFLECT': 98 | # p = int(rate * (ksize - 1) / 2) 99 | # p = 0 100 | # x = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]], mode=padding) 101 | padding = 'SAME' 102 | x = tf.layers.conv2d_transpose(x, cnum, kernel_size=ksize, strides=stride, 103 | activation=None, padding=padding) 104 | if IN: 105 | x = tf.contrib.layers.instance_norm(x) # if instance norm? 106 | if activation is not None: 107 | x = activation(x) 108 | return x 109 | 110 | def resize(x, scale=2, to_shape=None, align_corners=True, dynamic=False, 111 | func=tf.image.resize_bilinear, name='resize'): 112 | if dynamic: 113 | xs = tf.cast(tf.shape(x), tf.float32) 114 | new_xs = [tf.cast(xs[1]*scale, tf.int32), 115 | tf.cast(xs[2]*scale, tf.int32)] 116 | else: 117 | xs = x.get_shape().as_list() 118 | new_xs = [int(xs[1]*scale), int(xs[2]*scale)] 119 | with tf.variable_scope(name): 120 | if to_shape is None: 121 | x = func(x, new_xs, align_corners=align_corners) 122 | else: 123 | x = func(x, [to_shape[0], to_shape[1]], 124 | align_corners=align_corners) 125 | return x 126 | 127 | # yj 128 | @add_arg_scope 129 | def resnet_blocks(x, cnum, ksize, stride, rate, block_num, name, IN=True, 130 | padding='REFLECT', activation=tf.nn.elu, training=True): 131 | for block in range(block_num): 132 | # x = resnet_block12(x, cnum, ksize, stride, rate, name+"_"+str(block), padding, activation, training=training) 133 | x = resnet_block21(x, cnum, ksize, stride, rate, name + "_" + str(block), padding=padding, 134 | activation=activation, training=training) 135 | return x 136 | 137 | # yj 138 | def resnet_block21(x, cnum, ksize, stride, rate, name, IN=True, 139 | padding='SAME', activation=tf.nn.relu, training=True): 140 | xin = x 141 | assert padding in ['SYMMETRIC', 'SAME', 'REFLECT'] 142 | if padding == 'SYMMETRIC' or padding == 'REFLECT': 143 | p = int(rate*(ksize-1)/2) 144 | x = tf.pad(x, [[0,0], [p, p], [p, p], [0,0]], mode=padding) 145 | padding1 = 'VALID' 146 | else: 147 | padding1 = padding 148 | x = tf.layers.conv2d( 149 | x, cnum, ksize, stride, dilation_rate=rate, 150 | activation=None, padding=padding1, name=name+"0") 151 | if IN: 152 | x = tf.contrib.layers.instance_norm(x) # if instance norm? 153 | if activation is not None: 154 | x = activation(x) 155 | 156 | rate = 1 157 | if padding == 'SYMMETRIC' or padding == 'REFLECT': 158 | p = int(rate*(ksize-1)/2) 159 | x = tf.pad(x, [[0,0], [p, p], [p, p], [0,0]], mode=padding) 160 | padding2 = 'VALID' 161 | else: 162 | padding2 = padding 163 | x = tf.layers.conv2d( 164 | x, cnum, ksize, stride, dilation_rate=rate, 165 | activation=None, padding=padding2, name=name+"1") 166 | if IN: 167 | x = tf.contrib.layers.instance_norm(x) # if instance norm? 168 | return xin + x 169 | 170 | # yj 171 | def resnet_block12(x, cnum, ksize, stride, rate, name, IN=True, 172 | padding='REFLECT', activation=tf.nn.elu, training=True): 173 | xin = x 174 | rate = 1 175 | assert padding in ['SYMMETRIC', 'SAME', 'REFLECT'] 176 | if padding == 'SYMMETRIC' or padding == 'REFLECT': 177 | p = int(rate*(ksize-1)/2) 178 | x = tf.pad(x, [[0,0], [p, p], [p, p], [0,0]], mode=padding) 179 | padding1 = 'VALID' 180 | else: 181 | padding1 = padding 182 | x = tf.layers.conv2d( 183 | x, cnum, ksize, stride, dilation_rate=rate, 184 | activation=None, padding=padding1, name=name+"0") 185 | if IN: 186 | x = tf.contrib.layers.instance_norm(x) # if instance norm? 187 | if activation is not None: 188 | x = activation(x) 189 | 190 | rate = 2 191 | if padding == 'SYMMETRIC' or padding == 'REFLECT': 192 | p = int(rate*(ksize-1)/2) 193 | x = tf.pad(x, [[0,0], [p, p], [p, p], [0,0]], mode=padding) 194 | padding2 = 'VALID' 195 | else: 196 | padding2 = padding 197 | x = tf.layers.conv2d( 198 | x, cnum, ksize, stride, dilation_rate=rate, 199 | activation=None, padding=padding2, name=name+"1") 200 | if IN: 201 | x = tf.contrib.layers.instance_norm(x) # if instance norm? 202 | 203 | return xin + x 204 | 205 | 206 | # TODO:torgb, only with conv 1x1 and bias are enough? 线性输出 vs 使用tanh激活函数 207 | def torgb(x, cnum, ksize, stride, rate, name, activation=tf.nn.tanh, padding="SAME"): 208 | x = tf.layers.conv2d( 209 | x, cnum, ksize, stride, dilation_rate=rate, 210 | activation=activation, padding=padding, name=name) 211 | # x = tf.clip_by_value(x, -1., 1.) 212 | return x 213 | 214 | 215 | def dis_conv(x, cnum, ksize=5, stride=2, rate=1, activation=tf.nn.leaky_relu, name='conv', 216 | padding='SAME', use_bias=True, sn=True, training=True, reuse=False): 217 | """Define conv for discriminator. 218 | Activation is set to leaky_relu. 219 | 220 | Args: 221 | x: Input. 222 | cnum: Channel number. 223 | ksize: Kernel size. 224 | stride: Convolution stride. 225 | name: Name of layers. 226 | training: If current graph is for training or inference, used for bn. 227 | 228 | Returns: 229 | tf.Tensor: output 230 | 231 | """ 232 | # if spectrum normalization 233 | if sn: 234 | with tf.variable_scope(name, reuse=reuse): 235 | w = tf.get_variable("kernel", shape=[ksize, ksize, x.get_shape()[-1], cnum], initializer=weight_init, 236 | regularizer=weight_regularizer) 237 | 238 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w), 239 | strides=[1, stride, stride, 1], padding=padding, dilations=[1, rate, rate, 1]) 240 | if use_bias: 241 | bias = tf.get_variable("bias", [cnum], initializer=tf.constant_initializer(0.0)) 242 | x = tf.nn.bias_add(x, bias) 243 | if activation is not None: 244 | x = activation(x) 245 | else: 246 | x = tf.layers.conv2d(inputs=x, filters=cnum, activation=activation, 247 | kernel_size=ksize, strides=stride, 248 | dilation_rate=rate, padding=padding, 249 | kernel_initializer=None, 250 | kernel_regularizer=None, 251 | use_bias=use_bias, 252 | reuse=reuse) 253 | return x 254 | 255 | def flatten(x, name='flatten'): 256 | """Flatten wrapper. 257 | """ 258 | with tf.variable_scope(name): 259 | return tf.contrib.layers.flatten(x) 260 | 261 | def out_complete(out, x_incomplete, mask, res): 262 | mask = tf.image.resize_images(mask, (res, res)) 263 | x_incomplete = tf.image.resize_images(x_incomplete, (res, res)) 264 | x_complete = out * mask + x_incomplete * (1. - mask) 265 | return x_complete 266 | 267 | 268 | # linear embedding 269 | @add_arg_scope 270 | def conv(x, channels, kernel=3, stride=1, pad=0, pad_type='REFLECT', use_bias=True, sn=False, scope='conv_0', reuse=False, training=False, padding=None): 271 | with tf.variable_scope(scope, reuse=reuse): 272 | if pad_type == 'zero' : 273 | x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]]) 274 | if pad_type == 'reflect' : 275 | x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]], mode='REFLECT') 276 | 277 | if sn : 278 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init, 279 | regularizer=weight_regularizer) 280 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w), 281 | strides=[1, stride, stride, 1], padding='VALID') 282 | if use_bias : 283 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0)) 284 | x = tf.nn.bias_add(x, bias) 285 | 286 | else : 287 | x = tf.layers.conv2d(inputs=x, filters=channels, 288 | kernel_size=kernel, kernel_initializer=weight_init, 289 | kernel_regularizer=weight_regularizer, 290 | strides=stride, use_bias=use_bias, reuse=reuse) 291 | return x 292 | 293 | def spectral_norm(w, iteration=1): 294 | w_shape = w.shape.as_list() 295 | w = tf.reshape(w, [-1, w_shape[-1]]) 296 | 297 | u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False) 298 | 299 | u_hat = u 300 | v_hat = None 301 | for i in range(iteration): 302 | """ 303 | power iteration 304 | Usually iteration = 1 will be enough 305 | """ 306 | v_ = tf.matmul(u_hat, tf.transpose(w)) 307 | v_hat = l2_norm(v_) 308 | 309 | u_ = tf.matmul(v_hat, w) 310 | u_hat = l2_norm(u_) 311 | 312 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat)) 313 | w_norm = w / sigma 314 | 315 | with tf.control_dependencies([u.assign(u_hat)]): 316 | w_norm = tf.reshape(w_norm, w_shape) 317 | 318 | return w_norm 319 | 320 | def l2_norm(v, eps=1e-12): 321 | return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps) 322 | 323 | def hw_flatten(x) : 324 | return tf.reshape(x, shape=[x.shape[0], -1, x.shape[-1]]) 325 | 326 | def max_pooling(x, pool_size=2): 327 | x = tf.layers.max_pooling2d(x, pool_size=pool_size, strides=pool_size, padding='SAME') 328 | return x 329 | 330 | 331 | def avg_pooling(x, pool_size=2): 332 | x = tf.layers.average_pooling2d(x, pool_size=pool_size, strides=pool_size, padding='SAME') 333 | return x 334 | 335 | # ATN layer 336 | import tensorflow as tf 337 | from tensorflow.contrib.framework.python.ops import add_arg_scope 338 | 339 | @add_arg_scope 340 | def AtnConv(x1, x2, mask=None, ksize=3, stride=1, rate=2, 341 | softmax_scale=10., training=True, rescale=False): 342 | r""" Attention transfer networks implementation in tensorflow 343 | 344 | Attention transfer networks is introduced in publication: 345 | Learning Pyramid-Context Encoder Networks for High-Quality Image Inpainting, Zeng et al. 346 | https://arxiv.org/pdf/1904.07475.pdf 347 | https://github.com/researchmm/PEN-Net-for-Inpainting 348 | inspired by: 349 | Generative Image Inpainting with Contextual Attention, Yu et al. 350 | https://github.com/JiahuiYu/generative_inpainting/blob/master/inpaint_ops.py 351 | https://arxiv.org/abs/1801.07892 352 | Args: 353 | x1: low-level feature map with larger size [b, h, w, c]. 354 | x2: high-level feature map with smaller size [b, h/2, w/2, c]. 355 | mask: Input mask, 1 for missing regions 0 for known regions. 356 | ksize: Kernel size for attention transfer networks. 357 | stride: Stride for extracting patches from feature map. 358 | rate: Dilation for matching. 359 | softmax_scale: Scaled softmax for attention. 360 | training: Indicating if current graph is training or inference. 361 | rescale: Indicating if input feature maps need to be downsample 362 | Returns: 363 | tf.Tensor: reconstructed feature map 364 | """ 365 | # downsample input feature maps if needed due to limited GPU memory 366 | if rescale: 367 | x1 = resize(x1, scale=1. / 2, func=tf.image.resize_nearest_neighbor) 368 | x2 = resize(x2, scale=1. / 2, func=tf.image.resize_nearest_neighbor) 369 | # get shapes 370 | raw_x1s = tf.shape(x1) 371 | int_x1s = x1.get_shape().as_list() 372 | int_x2s = x2.get_shape().as_list() 373 | 374 | # extract patches from low-level feature maps for reconstruction 375 | kernel = 2 * rate 376 | raw_w = tf.extract_image_patches( 377 | x1, [1, kernel, kernel, 1], [1, rate * stride, rate * stride, 1], [1, 1, 1, 1], padding='SAME') 378 | raw_w = tf.reshape(raw_w, [int_x1s[0], -1, kernel, kernel, int_x1s[3]]) 379 | raw_w = tf.transpose(raw_w, [0, 2, 3, 4, 1]) # transpose to [b, kernel, kernel, c, hw] 380 | raw_w_groups = tf.split(raw_w, int_x1s[0], axis=0) 381 | 382 | # extract patches from high-level feature maps for matching and attending 383 | x2_groups = tf.split(x2, int_x2s[0], axis=0) 384 | w = tf.extract_image_patches( 385 | x2, [1, ksize, ksize, 1], [1, stride, stride, 1], [1, 1, 1, 1], padding='SAME') 386 | w = tf.reshape(w, [int_x2s[0], -1, ksize, ksize, int_x2s[3]]) 387 | w = tf.transpose(w, [0, 2, 3, 4, 1]) # transpose to [b, ksize, ksize, c, hw/4] # need transpose?? -- 480 388 | w_groups = tf.split(w, int_x2s[0], axis=0) 389 | 390 | # resize and extract patches from masks 391 | mask = resize(mask, to_shape=int_x2s[1:3], func=tf.image.resize_nearest_neighbor) 392 | m = tf.extract_image_patches( 393 | mask, [1, ksize, ksize, 1], [1, stride, stride, 1], [1, 1, 1, 1], padding='SAME') 394 | m = tf.reshape(m, [1, -1, ksize, ksize, 1]) 395 | m = tf.transpose(m, [0, 2, 3, 4, 1]) # transpose to [1, ksize, ksize, 1, hw/4] 396 | m = m[0] 397 | mm = tf.cast(tf.equal(tf.reduce_mean(m, axis=[0, 1, 2], keep_dims=True), 0.), tf.float32) 398 | 399 | # matching and attending hole and non-hole patches 400 | y = [] 401 | scale = softmax_scale 402 | # high level patches: w_groups, low level patches: raw_w_groups, x2_groups: high level feature map 403 | for xi, wi, raw_wi in zip(x2_groups, w_groups, raw_w_groups): 404 | # matching on high-level feature maps 405 | wi = wi[0] 406 | wi_normed = wi / tf.maximum(tf.sqrt(tf.reduce_sum(tf.square(wi), axis=[0, 1, 2])), 1e-4) 407 | yi = tf.nn.conv2d(xi, wi_normed, strides=[1, 1, 1, 1], padding="SAME") 408 | yi = tf.reshape(yi, [1, int_x2s[1], int_x2s[2], (int_x2s[1] // stride) * (int_x2s[2] // stride)]) 409 | # apply softmax to obtain attention score 410 | yi *= mm # mask 411 | yi = tf.nn.softmax(yi * scale, 3) 412 | yi *= mm # mask yi: score maps, score maps for non-hole regions are zeros through masks 413 | # transfer non-hole features into holes according to the atttention score 414 | wi_center = raw_wi[0] 415 | yi = tf.nn.conv2d_transpose(yi, wi_center, tf.concat([[1], raw_x1s[1:]], axis=0), 416 | strides=[1, rate * stride, rate * stride, 1]) / 4. # filter: [height, width, output_channels, in_channels] 417 | y.append(yi) 418 | y = tf.concat(y, axis=0) 419 | y.set_shape(int_x1s) 420 | # refine filled feature map after matching and attending 421 | y1 = tf.layers.conv2d(y, int_x1s[-1] // 4, 3, 1, dilation_rate=1, activation=tf.nn.relu, padding='SAME') 422 | y2 = tf.layers.conv2d(y, int_x1s[-1] // 4, 3, 1, dilation_rate=2, activation=tf.nn.relu, padding='SAME') 423 | y3 = tf.layers.conv2d(y, int_x1s[-1] // 4, 3, 1, dilation_rate=4, activation=tf.nn.relu, padding='SAME') 424 | y4 = tf.layers.conv2d(y, int_x1s[-1] // 4, 3, 1, dilation_rate=8, activation=tf.nn.relu, padding='SAME') 425 | y = tf.concat([y1, y2, y3, y4], axis=3) 426 | if rescale: 427 | y = resize(y, scale=2., func=tf.image.resize_nearest_neighbor) 428 | return y 429 | 430 | 431 | """##### our-attention #####""" 432 | def attention(x, channels, neighbors=1, use_bias=True, sn=False, down_scale = 2, pool_scale=2, 433 | name='attention_pooling', training=True, padding='REFLECT', reuse=False): 434 | if neighbors > 1: 435 | x = attention_with_neighbors(x, channels, down_scale=down_scale, pool_scale=pool_scale, name=name) 436 | else: 437 | x = attention_with_pooling(x, channels, down_scale=down_scale, pool_scale=pool_scale, name=name) 438 | return x 439 | 440 | @add_arg_scope 441 | def attention_with_pooling(x, channels, ksize=4, use_bias=True, sn=False, down_scale = 2, pool_scale=2, name='attention_pooling', training=True, padding='REFLECT', reuse=False): 442 | with tf.variable_scope(name, reuse=reuse): 443 | x_origin = x 444 | 445 | # down sampling 446 | if down_scale > 1: 447 | x = gen_conv(x, channels, ksize, stride=down_scale, activation=tf.nn.relu, name='attention_down_sample',reuse=reuse) 448 | 449 | # attention 450 | f = conv(x, channels // 16, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='f_conv', reuse=reuse) # [bs, h, w, c'] 451 | f = max_pooling(f, pool_scale) 452 | # f = avg_pooling(f) 453 | 454 | g = conv(x, channels // 16, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='g_conv',reuse=reuse) # [bs, h, w, c'] 455 | 456 | h = conv(x, channels // 16, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='h_conv',reuse=reuse) # [bs, h, w, c] 457 | h = max_pooling(h, pool_scale) 458 | # h = avg_pooling(h) [4,65536,4096] 459 | 460 | # N = h * w 461 | s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N] 462 | 463 | beta = tf.nn.softmax(s) # attention map 464 | 465 | o = tf.matmul(beta, hw_flatten(h)) # [bs, N, C] 466 | gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0)) 467 | 468 | o = tf.reshape(o, shape=[x.shape[0], x.shape[1], x.shape[2], channels // 16]) # [bs, h, w, C] 469 | # o = conv(o, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='attn_conv_up') # from bottleneck 470 | 471 | # up sampling 472 | if down_scale > 1: 473 | o = gen_deconv(o, channels, ksize, method='deconv', stride=down_scale, activation=tf.nn.relu, name='attention_down_upsample',reuse=reuse) 474 | 475 | x = gamma * o + x_origin 476 | 477 | return x 478 | 479 | # attention consider neighbors 480 | @add_arg_scope 481 | def attention_with_neighbors(x, channels, ksize=3, use_bias=True, sn=False, stride=2, 482 | down_scale = 2, pool_scale=2, name='attention_pooling', 483 | training=True, padding='REFLECT', reuse=False): 484 | with tf.variable_scope(name, reuse=reuse): 485 | x1 = x 486 | 487 | # downsample input feature maps if needed due to limited GPU memory 488 | # down sampling 489 | if down_scale > 1: 490 | x1 = gen_conv(x1, channels, ksize, stride=down_scale, activation=tf.nn.relu, name='attention_down_sample', 491 | reuse=reuse) 492 | # get shapes 493 | int_x1s = x1.get_shape().as_list() 494 | # extract patches from high-level feature maps for matching and attending 495 | x1_groups = tf.split(x1, int_x1s[0], axis=0) 496 | w = tf.extract_image_patches( 497 | x1, [1, ksize, ksize, 1], [1, stride, stride, 1], [1, 1, 1, 1], padding='SAME') 498 | w = tf.reshape(w, [int_x1s[0], -1, ksize, ksize, int_x1s[3]]) 499 | w = tf.transpose(w, [0, 2, 3, 4, 1]) # transpose to [b, ksize, ksize, c, hw/4] # need transpose?? -- 480 500 | w_groups = tf.split(w, int_x1s[0], axis=0) 501 | 502 | # matching and attending hole and non-hole patches 503 | y = [] 504 | scale = 10. 505 | # high level patches: w_groups, low level patches: raw_w_groups, x2_groups: high level feature map 506 | for xi, wi in zip(x1_groups, w_groups): 507 | # matching on high-level feature maps 508 | wi = wi[0] 509 | wi_normed = wi / tf.maximum(tf.sqrt(tf.reduce_sum(tf.square(wi), axis=[0, 1, 2])), 1e-4) 510 | yi = tf.nn.conv2d(xi, wi_normed, strides=[1, 1, 1, 1], padding="SAME") 511 | yi = tf.reshape(yi, [1, int_x1s[1], int_x1s[2], (int_x1s[1] // stride) * (int_x1s[2] // stride)]) 512 | yi = tf.nn.softmax(yi * scale, 3) 513 | # non local mean 514 | wi_center = tf.transpose(wi, [0, 1, 3, 2]) 515 | yi = tf.nn.conv2d(yi, wi_center, strides=[1, 1, 1, 1], padding="SAME") / 4. 516 | 517 | # filter: [height, width, output_channels, in_channels] 518 | y.append(yi) 519 | y = tf.concat(y, axis=0) 520 | y.set_shape(int_x1s) 521 | # up sampling 522 | if down_scale > 1: 523 | y = gen_deconv(y, channels, ksize, method='deconv', stride=down_scale, activation=tf.nn.relu, 524 | name='attention_down_upsample', reuse=reuse) 525 | 526 | gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0)) 527 | x = gamma * y + x 528 | x = tf.layers.conv2d(x, channels, 3, 1, dilation_rate=1, activation=tf.nn.relu, padding='SAME') 529 | return x 530 | 531 | def normalize(x) : 532 | return x/127.5 - 1 533 | 534 | def imsave(images, size, path): 535 | return scipy.misc.imsave(path, merge(images, size)) 536 | 537 | def inverse_transform(images): 538 | return (images+1.)*127.5 -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | import vgg_network 2 | from logging import exception 3 | import tensorflow as tf 4 | import numpy as np 5 | from easydict import EasyDict as edict 6 | 7 | from sys import stdout 8 | from functools import reduce 9 | from vgg_network import VGG 10 | 11 | 12 | # loss config 13 | config = edict() 14 | config.W = edict() 15 | 16 | # TODO: content 17 | # weights 18 | config.W.Content = 1. 19 | 20 | config.Content = edict() 21 | config.Content.feat_layers = {'relu1_1': 0.2, 'relu2_1': 0.2,'relu3_1': 0.2,'relu4_1': 0.2,'relu5_1': 0.2} 22 | 23 | # TODO: style 24 | config.W.Style = 1. 25 | config.Style = edict() 26 | config.Style.feat_layers = {'relu1_1': 0.2, 'relu2_1': 0.2,'relu3_1': 0.2,'relu4_1': 0.2,'relu5_1': 0.2} 27 | 28 | 29 | class LossCalculator: 30 | 31 | def __init__(self, vgg_dir, real_image): 32 | self.vgg_model = VGG(vgg_dir) 33 | self.vgg_real = self.vgg_model.net(real_image) 34 | 35 | def content_loss(self, content_fake, layers=None): 36 | # compute content loss 37 | vgg_fake = self.vgg_model.net(content_fake) # dict: net[name] = current_layer 38 | if config.W.Content > 0: 39 | if layers is not None: 40 | config.Content.feat_layers = layers 41 | content_loss_list = [w * self._content_loss_helper(self.vgg_real[layer], vgg_fake[layer]) 42 | for layer, w in config.Content.feat_layers.items()] 43 | content_loss = tf.reduce_sum(content_loss_list) 44 | else: 45 | zero_tensor = tf.constant(0.0, dtype=tf.float32) 46 | content_loss = zero_tensor 47 | return content_loss 48 | 49 | def style_loss(self, style_fake, layers=None): 50 | vgg_fake = self.vgg_model.net(style_fake) # dict: net[name] = current_layer 51 | # image = tf.placeholder('float32', shape=style.shape) 52 | # style_net = self.vgg.net(image) 53 | 54 | if config.W.Style > 0: 55 | if layers is not None: 56 | config.Style.feat_layers = layers 57 | style_loss_list = [w * self._style_loss_helper(self.vgg_real[layer], vgg_fake[layer]) 58 | for layer, w in config.Style.feat_layers.items()] 59 | style_loss = tf.reduce_sum(style_loss_list) 60 | else: 61 | zero_tensor = tf.constant(0.0, dtype=tf.float32) 62 | style_loss = zero_tensor 63 | return style_loss 64 | 65 | # def _calculate_input_gram_matrix_for(self, layer): 66 | # image_feature = self.network[layer] 67 | # _, height, width, number = map(lambda i: i.value, image_feature.get_shape()) 68 | # size = height * width * number 69 | # image_feature = tf.reshape(image_feature, (-1, number)) 70 | # return tf.matmul(tf.transpose(image_feature), image_feature) / size 71 | 72 | 73 | def _content_loss_helper(self, vgg_A, vgg_B): 74 | N, fH, fW, fC = vgg_A.shape.as_list() 75 | feature_size = N * fH * fW *fC 76 | content_loss = 2 * tf.nn.l2_loss(vgg_A - vgg_B) / feature_size 77 | return content_loss 78 | 79 | def _style_loss_helper(self, vgg_A, vgg_B): 80 | N, fH, fW, fC = vgg_A.shape.as_list() 81 | feature_size = N * fH * fW *fC 82 | gram_A = self._compute_gram(vgg_A) 83 | gram_B = self._compute_gram(vgg_B) 84 | style_loss = 2 * tf.nn.l2_loss(gram_A - gram_B) / feature_size 85 | return style_loss 86 | 87 | def _compute_gram(self, feature): 88 | # https://github.com/fullfanta/real_time_style_transfer/blob/master/train.py 89 | shape = tf.shape(feature) 90 | psi = tf.reshape(feature, [shape[0], shape[1] * shape[2], shape[3]]) 91 | # psi_t = tf.transpose(psi, perm=[0, 2, 1]) 92 | gram = tf.matmul(psi, psi, transpose_a=True) 93 | gram = tf.div(gram, tf.cast(shape[1] * shape[2] * shape[3], tf.float32)) 94 | return gram 95 | 96 | def tv_loss(self, image): 97 | # total variation denoising 98 | tv_y_size = _tensor_size(image[:,1:,:,:]) 99 | tv_x_size = _tensor_size(image[:,:,1:,:]) 100 | shape = image.shape.as_list() 101 | tv_loss = 2 * ( 102 | (tf.nn.l2_loss(image[:,1:,:,:] - image[:,:shape[1]-1,:,:]) / 103 | tv_y_size) + 104 | (tf.nn.l2_loss(image[:,:,1:,:] - image[:,:,:shape[2]-1,:]) / 105 | tv_x_size)) 106 | 107 | return tv_loss 108 | 109 | # TODO: l1_loss(x, x_complete_256) 110 | def l1_loss(self, image, predict, mask, type='foreground'): 111 | error = tf.abs(predict - image) 112 | if type == 'foreground': 113 | loss = tf.reduce_sum(mask * error) / tf.reduce_sum(mask) # * tf.reduce_sum(1. - mask) for balance? 114 | elif type == 'background': 115 | loss = tf.reduce_sum((1. - mask) * error) / tf.reduce_sum(1. - mask) 116 | else: 117 | loss = tf.reduce_sum(mask * tf.abs(predict - image)) / tf.reduce_sum(mask) 118 | return loss 119 | 120 | # TODO: 121 | def adversarial_loss(self): 122 | pass 123 | 124 | def _tensor_size(tensor): 125 | from operator import mul 126 | return reduce(mul, (d.value for d in tensor.get_shape()), 1) 127 | 128 | def gan_wgan_loss(pos, neg, name='gan_wgan_loss'): 129 | """ 130 | wgan loss function for GANs. 131 | 132 | - Wasserstein GAN: https://arxiv.org/abs/1701.07875 133 | """ 134 | with tf.variable_scope(name): 135 | d_loss = tf.reduce_mean(neg-pos) 136 | g_loss = -tf.reduce_mean(neg) 137 | # scalar_summary('d_loss', d_loss) 138 | # scalar_summary('g_loss', g_loss) 139 | # scalar_summary('pos_value_avg', tf.reduce_mean(pos)) 140 | # scalar_summary('neg_value_avg', tf.reduce_mean(neg)) 141 | return g_loss, d_loss 142 | 143 | def patch_gan_loss(pos, neg, name='patch_gan_loss', loss_type='gan'): 144 | """ 145 | patch gan loss 146 | """ 147 | with tf.variable_scope(name): 148 | if loss_type =='gan': 149 | g_loss = tf.reduce_mean( 150 | tf.nn.sigmoid_cross_entropy_with_logits(logits=neg, labels=tf.ones_like(neg))) # 生成器loss 151 | 152 | d_loss_fake = tf.reduce_mean( 153 | tf.nn.sigmoid_cross_entropy_with_logits(logits=neg, labels=tf.zeros_like(neg))) 154 | d_loss_real = tf.reduce_mean( 155 | tf.nn.sigmoid_cross_entropy_with_logits(logits=pos, labels=tf.ones_like(pos))) 156 | d_loss = d_loss_fake + d_loss_real # 判别器loss 157 | 158 | if loss_type == 'hinge': 159 | d_loss_real = tf.reduce_mean(tf.nn.relu(1.0 - pos)) 160 | d_loss_fake = tf.reduce_mean(tf.nn.relu(1.0 + neg)) 161 | d_loss = d_loss_real + d_loss_fake 162 | 163 | g_loss = -tf.reduce_mean(neg) 164 | 165 | return g_loss, d_loss, d_loss_real, d_loss_fake 166 | 167 | def random_interpolates(x, y, alpha=None): 168 | """ 169 | x: first dimension as batch_size 170 | y: first dimension as batch_size 171 | alpha: [BATCH_SIZE, 1] 172 | """ 173 | shape = x.get_shape().as_list() 174 | x = tf.reshape(x, [shape[0], -1]) 175 | y = tf.reshape(y, [shape[0], -1]) 176 | if alpha is None: 177 | alpha = tf.random_uniform(shape=[shape[0], 1]) 178 | interpolates = x + alpha*(y - x) 179 | return tf.reshape(interpolates, shape) 180 | 181 | 182 | def gradients_penalty(x, y, mask=None, norm=1.): 183 | """Improved Training of Wasserstein GANs 184 | 185 | - https://arxiv.org/abs/1704.00028 186 | """ 187 | gradients = tf.gradients(y, x)[0] 188 | if mask is None: 189 | mask = tf.ones_like(gradients) 190 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients) * mask, axis=[1, 2, 3])) 191 | return tf.reduce_mean(tf.square(slopes - norm)) 192 | 193 | from tensorflow.python.ops import array_ops 194 | def focal_loss(prediction_tensor, target_tensor, weights=None, alpha=0.25, gamma=2): 195 | r"""Compute focal loss for predictions. 196 | Multi-labels Focal loss formula: 197 | FL = -alpha * (z-p)^gamma * log(p) -(1-alpha) * p^gamma * log(1-p) 198 | ,which alpha = 0.25, gamma = 2, p = sigmoid(x), z = target_tensor, z = 1. 199 | ref: https://github.com/ailias/Focal-Loss-implement-on-Tensorflow 200 | if z == 1, J = -a * (1 – p) * log(p) 201 | if z != 1, J = -(1 – a) * p * log(1 –p) 202 | Args: 203 | prediction_tensor: A float tensor of shape [batch_size, num_anchors, 204 | num_classes] representing the predicted logits for each class 205 | target_tensor: A float tensor of shape [batch_size, num_anchors, 206 | num_classes] representing one-hot encoded classification targets 207 | weights: A float tensor of shape [batch_size, num_anchors] 208 | alpha: A scalar tensor for focal loss alpha hyper-parameter 209 | gamma: A scalar tensor for focal loss gamma hyper-parameter 210 | Returns: 211 | loss: A (scalar) tensor representing the value of the loss function 212 | """ 213 | sigmoid_p = tf.nn.sigmoid(prediction_tensor) 214 | zeros = array_ops.zeros_like(sigmoid_p, dtype=sigmoid_p.dtype) 215 | 216 | # For poitive prediction, only need consider front part loss, back part is 0; 217 | # target_tensor > zeros <=> z=1, so poitive coefficient = z - p. 218 | pos_p_sub = array_ops.where(target_tensor > zeros, target_tensor - sigmoid_p, zeros) 219 | 220 | # For negative prediction, only need consider back part loss, front part is 0; 221 | # target_tensor > zeros <=> z=1, so negative coefficient = 0. 222 | neg_p_sub = array_ops.where(target_tensor > zeros, zeros, sigmoid_p) 223 | per_entry_cross_ent = - alpha * (pos_p_sub ** gamma) * tf.log(tf.clip_by_value(sigmoid_p, 1e-8, 1.0)) \ 224 | - (1 - alpha) * (neg_p_sub ** gamma) * tf.log(tf.clip_by_value(1.0 - sigmoid_p, 1e-8, 1.0)) 225 | return tf.reduce_mean(per_entry_cross_ent) 226 | 227 | def sigmoid_cross_entropy_balanced_fore(logits, label, mask, name='cross_entropy_loss'): 228 | """ 229 | Implements Equation [2] in https://arxiv.org/pdf/1504.06375.pdf 230 | Compute edge pixels for each training sample and set as pos_weights to 231 | tf.nn.weighted_cross_entropy_with_logits 232 | """ 233 | y = tf.cast(label, tf.float32) 234 | 235 | count_neg = tf.reduce_sum(mask * (1. - y)) 236 | count_pos = tf.reduce_sum(mask * y) 237 | 238 | # Equation [2] 239 | beta = count_neg / (count_neg + count_pos) 240 | 241 | # Equation [2] divide by 1 - beta 242 | pos_weight = beta / (1 - beta) 243 | 244 | cost = tf.nn.weighted_cross_entropy_with_logits(logits=logits, targets=y, pos_weight=pos_weight) 245 | 246 | # Multiply by 1 - beta 247 | # cost = tf.reduce_mean(cost * (1 - beta)) 248 | # N, H, W, C = logits.get_shape().as_list() 249 | size = count_neg + count_neg 250 | cost = tf.reduce_sum(cost * (1 - beta)) / size 251 | 252 | # check if image has no edge pixels return 0 else return complete error function 253 | return tf.where(tf.equal(count_pos, 0.0), 0.0, cost, name=name) 254 | 255 | def sigmoid_cross_entropy_balanced_back(logits, label, name='cross_entropy_loss'): 256 | """ 257 | Implements Equation [2] in https://arxiv.org/pdf/1504.06375.pdf 258 | Compute edge pixels for each training sample and set as pos_weights to 259 | tf.nn.weighted_cross_entropy_with_logits 260 | """ 261 | y = tf.cast(label, tf.float32) 262 | 263 | count_neg = tf.reduce_sum(1. - y) 264 | count_pos = tf.reduce_sum(y) 265 | 266 | # Equation [2] 267 | beta = count_neg / (count_neg + count_pos) 268 | 269 | # Equation [2] divide by 1 - beta 270 | pos_weight = beta / (1 - beta) 271 | 272 | cost = tf.nn.weighted_cross_entropy_with_logits(logits=logits, targets=y, pos_weight=pos_weight) 273 | 274 | # Multiply by 1 - beta 275 | cost = tf.reduce_mean(cost * (1 - beta)) 276 | 277 | # check if image has no edge pixels return 0 else return complete error function 278 | return tf.where(tf.equal(count_pos, 0.0), 0.0, cost, name=name) 279 | 280 | 281 | """ 282 | id-mrf 283 | """ 284 | from enum import Enum 285 | 286 | class Distance(Enum): 287 | L2 = 0 288 | DotProduct = 1 289 | 290 | class CSFlow: 291 | def __init__(self, sigma=float(0.1), b=float(1.0)): 292 | self.b = b 293 | self.sigma = sigma 294 | 295 | def __calculate_CS(self, scaled_distances, axis_for_normalization=3): 296 | self.scaled_distances = scaled_distances 297 | self.cs_weights_before_normalization = tf.exp((self.b - scaled_distances) / self.sigma, name='weights_before_normalization') 298 | self.cs_NHWC = CSFlow.sum_normalize(self.cs_weights_before_normalization, axis_for_normalization) 299 | 300 | def reversed_direction_CS(self): 301 | cs_flow_opposite = CSFlow(self.sigma, self.b) 302 | cs_flow_opposite.raw_distances = self.raw_distances 303 | work_axis = [1, 2] 304 | relative_dist = cs_flow_opposite.calc_relative_distances(axis=work_axis) 305 | cs_flow_opposite.__calculate_CS(relative_dist, work_axis) 306 | return cs_flow_opposite 307 | 308 | # -- 309 | @staticmethod 310 | def create_using_L2(I_features, T_features, sigma=float(0.1), b=float(1.0)): 311 | cs_flow = CSFlow(sigma, b) 312 | with tf.name_scope('CS'): 313 | sT = T_features.shape.as_list() 314 | sI = I_features.shape.as_list() 315 | 316 | Ivecs = tf.reshape(I_features, (sI[0], -1, sI[3])) 317 | Tvecs = tf.reshape(T_features, (sI[0], -1, sT[3])) 318 | r_Ts = tf.reduce_sum(Tvecs * Tvecs, 2) 319 | r_Is = tf.reduce_sum(Ivecs * Ivecs, 2) 320 | raw_distances_list = [] 321 | 322 | N, _, _, _ = T_features.shape.as_list() 323 | for i in range(N): 324 | Ivec, Tvec, r_T, r_I = Ivecs[i], Tvecs[i], r_Ts[i], r_Is[i] 325 | A = tf.matmul(Tvec,tf.transpose(Ivec)) 326 | cs_flow.A = A 327 | # A = tf.matmul(Tvec, tf.transpose(Ivec)) 328 | r_T = tf.reshape(r_T, [-1, 1]) # turn to column vector 329 | dist = r_T - 2 * A + r_I 330 | cs_shape = sI[:3] + [dist.shape[0].value] 331 | cs_shape[0] = 1 332 | dist = tf.reshape(tf.transpose(dist), cs_shape) 333 | # protecting against numerical problems, dist should be positive 334 | dist = tf.maximum(float(0.0), dist) 335 | # dist = tf.sqrt(dist) 336 | raw_distances_list += [dist] 337 | 338 | cs_flow.raw_distances = tf.convert_to_tensor([tf.squeeze(raw_dist, axis=0) for raw_dist in raw_distances_list]) 339 | 340 | relative_dist = cs_flow.calc_relative_distances() 341 | cs_flow.__calculate_CS(relative_dist) 342 | return cs_flow 343 | 344 | #-- 345 | @staticmethod 346 | def create_using_dotP(I_features, T_features, sigma=float(1.0), b=float(1.0), args=None): 347 | cs_flow = CSFlow(sigma, b) 348 | with tf.name_scope('CS'): 349 | # prepare feature before calculating cosine distance 350 | T_features, I_features = cs_flow.center_by_T(T_features, I_features) 351 | with tf.name_scope('TFeatures'): 352 | T_features = CSFlow.l2_normalize_channelwise(T_features) 353 | with tf.name_scope('IFeatures'): 354 | I_features = CSFlow.l2_normalize_channelwise(I_features) 355 | # work seperatly for each example in dim 1 356 | cosine_dist_l = [] 357 | N, _, _, _ = T_features.shape.as_list() 358 | for i in range(N): 359 | T_features_i = tf.expand_dims(T_features[i, :, :, :], 0) 360 | I_features_i = tf.expand_dims(I_features[i, :, :, :], 0) 361 | patches_i = cs_flow.patch_decomposition(T_features_i, args) 362 | # every patch in patches_i as a kernel to conv I_features, obtain dis between each patch in patches_i 363 | # and I_features. (GPU is OK?) 364 | cosine_dist_i = tf.nn.conv2d(I_features_i, patches_i, strides=[1, 1, 1, 1], 365 | padding='VALID', use_cudnn_on_gpu=True, name='cosine_dist') 366 | cosine_dist_l.append(cosine_dist_i) 367 | 368 | cs_flow.cosine_dist = tf.concat(cosine_dist_l, axis = 0) 369 | 370 | cosine_dist_zero_to_one = -(cs_flow.cosine_dist - 1) / 2 371 | cs_flow.raw_distances = cosine_dist_zero_to_one 372 | 373 | relative_dist = cs_flow.calc_relative_distances() 374 | cs_flow.__calculate_CS(relative_dist) 375 | return cs_flow 376 | 377 | def calc_relative_distances(self, axis=3): 378 | epsilon = 1e-5 379 | div = tf.reduce_min(self.raw_distances, axis=axis, keep_dims=True) 380 | # div = tf.reduce_mean(self.raw_distances, axis=axis, keep_dims=True) 381 | relative_dist = self.raw_distances / (div + epsilon) 382 | return relative_dist 383 | 384 | def weighted_average_dist(self, axis=3): 385 | if not hasattr(self, 'raw_distances'): 386 | raise exception('raw_distances property does not exists. cant calculate weighted average l2') 387 | 388 | multiply = self.raw_distances * self.cs_NHWC 389 | return tf.reduce_sum(multiply, axis=axis, name='weightedDistPerPatch') 390 | 391 | # -- 392 | @staticmethod 393 | def create(I_features, T_features, distance : Distance, nnsigma=float(1.0), b=float(1.0), args=None): 394 | if distance.value == Distance.DotProduct.value: 395 | cs_flow = CSFlow.create_using_dotP(I_features, T_features, nnsigma, b, args) 396 | elif distance.value == Distance.L2.value: 397 | cs_flow = CSFlow.create_using_L2(I_features, T_features, nnsigma, b) 398 | else: 399 | raise "not supported distance " + distance.__str__() 400 | return cs_flow 401 | 402 | @staticmethod 403 | def sum_normalize(cs, axis=3): 404 | reduce_sum = tf.reduce_sum(cs, axis, keep_dims=True, name='sum') 405 | return tf.divide(cs, reduce_sum, name='sumNormalized') 406 | 407 | def center_by_T(self, T_features, I_features): 408 | # assuming both input are of the same size 409 | 410 | # calculate stas over [batch, height, width], expecting 1x1xDepth tensor 411 | axes = [0, 1, 2] 412 | self.meanT, self.varT = tf.nn.moments( 413 | T_features, axes, name='TFeatures/moments') 414 | # we do not divide by std since its causing the histogram 415 | # for the final cs to be very thin, so the NN weights 416 | # are not distinctive, giving similar values for all patches. 417 | # stdT = tf.sqrt(varT, "stdT") 418 | # correct places with std zero 419 | # stdT[tf.less(stdT, tf.constant(0.001))] = tf.constant(1) 420 | with tf.name_scope('TFeatures/centering'): 421 | self.T_features_centered = T_features - self.meanT 422 | with tf.name_scope('IFeatures/centering'): 423 | self.I_features_centered = I_features - self.meanT 424 | 425 | return self.T_features_centered, self.I_features_centered 426 | 427 | @staticmethod 428 | def l2_normalize_channelwise(features): 429 | norms = tf.norm(features, ord='euclidean', axis=3, name='norm') 430 | # expanding the norms tensor to support broadcast division 431 | norms_expanded = tf.expand_dims(norms, 3) 432 | features = tf.divide(features, norms_expanded, name='normalized') 433 | return features 434 | 435 | def patch_decomposition(self, T_features, args=None): 436 | # patch decomposition 437 | if args is None: 438 | patch_size = 1 439 | stride_size = 1 440 | else: 441 | patch_size = args.PATCH_SIZE 442 | stride_size = args.STRIDE_SIZE 443 | patches_as_depth_vectors = tf.extract_image_patches( 444 | images=T_features, ksizes=[1, patch_size, patch_size, 1], 445 | strides=[1, stride_size, stride_size, 1], rates=[1, 1, 1, 1], padding='VALID', 446 | name='patches_as_depth_vectors') 447 | 448 | out_channels = int(patches_as_depth_vectors.shape[3].value / patch_size / patch_size) 449 | self.patches_NHWC = tf.reshape( 450 | patches_as_depth_vectors, 451 | shape=[-1, patch_size, patch_size, out_channels], 452 | name='patches_PHWC') # patches_as_depth_vectors.shape[3].value / patch_size / patch_size; because here path_size=1,so it's right 453 | 454 | self.patches_HWCN = tf.transpose( 455 | self.patches_NHWC, 456 | perm=[1, 2, 3, 0], 457 | name='patches_HWCP') # tf.conv2 ready format (every patch as a kernel) 458 | 459 | return self.patches_HWCN 460 | 461 | 462 | def mrf_loss(T_features, I_features, distance=Distance.DotProduct, nnsigma=float(1.0), args=None): 463 | T_features = tf.convert_to_tensor(T_features, dtype=tf.float32) 464 | I_features = tf.convert_to_tensor(I_features, dtype=tf.float32) 465 | 466 | with tf.name_scope('cx'): 467 | cs_flow = CSFlow.create(I_features, T_features, distance, nnsigma) 468 | # sum_normalize: 469 | height_width_axis = [1, 2] 470 | # To: 471 | cs = cs_flow.cs_NHWC 472 | k_max_NC = tf.reduce_max(cs, axis=height_width_axis) 473 | CS = tf.reduce_mean(k_max_NC, axis=[1]) 474 | CS_as_loss = 1 - CS 475 | CS_loss = -tf.log(1 - CS_as_loss) 476 | CS_loss = tf.reduce_mean(CS_loss) 477 | return CS_loss 478 | 479 | 480 | def random_sampling(tensor_in, n, indices=None): 481 | N, H, W, C = tf.convert_to_tensor(tensor_in).shape.as_list() 482 | S = H * W 483 | tensor_NSC = tf.reshape(tensor_in, [N, S, C]) 484 | all_indices = list(range(S)) 485 | shuffled_indices = tf.random_shuffle(all_indices) 486 | indices = tf.gather(shuffled_indices, list(range(n)), axis=0) if indices is None else indices 487 | res = tf.gather(tensor_NSC, indices, axis=1) 488 | return res, indices 489 | 490 | 491 | def random_pooling(feats, output_1d_size=100): 492 | is_input_tensor = type(feats) is tf.Tensor 493 | 494 | if is_input_tensor: 495 | feats = [feats] 496 | 497 | # convert all inputs to tensors 498 | feats = [tf.convert_to_tensor(feats_i) for feats_i in feats] 499 | 500 | N, H, W, C = feats[0].shape.as_list() 501 | feats_sampled_0, indices = random_sampling(feats[0], output_1d_size ** 2) 502 | res = [feats_sampled_0] 503 | for i in range(1, len(feats)): 504 | feats_sampled_i, _ = random_sampling(feats[i], -1, indices) 505 | res.append(feats_sampled_i) 506 | 507 | res = [tf.reshape(feats_sampled_i, [N, output_1d_size, output_1d_size, C]) for feats_sampled_i in res] 508 | if is_input_tensor: 509 | return res[0] 510 | return res 511 | 512 | 513 | def crop_quarters(feature_tensor): 514 | N, fH, fW, fC = feature_tensor.shape.as_list() 515 | quarters_list = [] 516 | quarter_size = [N, round(fH / 2), round(fW / 2), fC] 517 | quarters_list.append(tf.slice(feature_tensor, [0, 0, 0, 0], quarter_size)) 518 | quarters_list.append(tf.slice(feature_tensor, [0, round(fH / 2), 0, 0], quarter_size)) 519 | quarters_list.append(tf.slice(feature_tensor, [0, 0, round(fW / 2), 0], quarter_size)) 520 | quarters_list.append(tf.slice(feature_tensor, [0, round(fH / 2), round(fW / 2), 0], quarter_size)) 521 | feature_tensor = tf.concat(quarters_list, axis=0) 522 | return feature_tensor 523 | 524 | 525 | def id_mrf_reg_feat(feat_A, feat_B, config, args): 526 | if config.crop_quarters is True: 527 | feat_A = crop_quarters(feat_A) 528 | feat_B = crop_quarters(feat_B) 529 | 530 | N, fH, fW, fC = feat_A.shape.as_list() 531 | if fH * fW <= config.max_sampling_1d_size ** 2: 532 | print(' #### Skipping pooling ....') 533 | else: 534 | print(' #### pooling %d**2 out of %dx%d' % (config.max_sampling_1d_size, fH, fW)) 535 | feat_A, feat_B = random_pooling([feat_A, feat_B], output_1d_size=config.max_sampling_1d_size) 536 | 537 | return mrf_loss(feat_A, feat_B, distance=config.Dist, nnsigma=config.nn_stretch_sigma, args=args) 538 | 539 | 540 | from easydict import EasyDict as edict 541 | # scale of im_src and im_dst: [-1, 1] 542 | def grad_matching_loss(im_src, im_dst, config): 543 | 544 | match_config = edict() 545 | match_config.crop_quarters = False 546 | match_config.max_sampling_1d_size = 65 547 | match_config.Dist = Distance.DotProduct 548 | match_config.nn_stretch_sigma = 0.5 # 0.1 549 | 550 | match_loss = id_mrf_reg_feat(im_src, im_dst, match_config, config) 551 | 552 | match_loss = tf.reduce_sum(match_loss) 553 | 554 | return match_loss 555 | 556 | 557 | """ 558 | Salient Edge 559 | """ 560 | import cv2 561 | def gaussian_kernel_2d_opencv(kernel_size = 3,sigma = 0): 562 | """ 563 | ref: https://blog.csdn.net/qq_16013649/article/details/78784791 564 | ref: tensorflow 565 | (1) https://stackoverflow.com/questions/52012657/how-to-make-a-2d-gaussian-filter-in-tensorflow 566 | (2) https://github.com/tensorflow/tensorflow/issues/2826 567 | """ 568 | kx = cv2.getGaussianKernel(kernel_size,sigma) 569 | ky = cv2.getGaussianKernel(kernel_size,sigma) 570 | return np.multiply(kx,np.transpose(ky)) 571 | 572 | def priority_loss_mask(mask, ksize=5, sigma=1, iteration=2): 573 | gaussian_kernel = gaussian_kernel_2d_opencv(kernel_size=ksize, sigma=sigma) 574 | gaussian_kernel = np.reshape(gaussian_kernel, (ksize, ksize, 1, 1)) 575 | mask_priority = tf.convert_to_tensor(mask, dtype=tf.float32) 576 | for i in range(iteration): 577 | mask_priority = tf.nn.conv2d(mask_priority, gaussian_kernel, strides=[1,1,1,1], padding='SAME') 578 | 579 | return mask_priority 580 | 581 | 582 | # structure loss 583 | from skimage import feature 584 | from skimage.color import rgb2gray 585 | 586 | """ 587 | Structure loss 588 | """ 589 | import cv2 590 | 591 | def canny_edge(images, sigma=1.5): 592 | """ 593 | Extract edges in tensorflow. 594 | example: 595 | input = tf.placeholder(dtype=tf.float32, shape=[None, 900, 900, 3]) 596 | output = tf.py_func(canny_edge, [input], tf.float32, stateful=False) 597 | 598 | :param images: 599 | :param sigma: 600 | :return: 601 | """ 602 | edges = [] 603 | for i in range(len(images)): 604 | grey_img = rgb2gray(images[i]) 605 | edge = feature.canny(grey_img, sigma=sigma) 606 | edges.append(np.expand_dims(edge, axis=0)) 607 | edges = np.concatenate(edges, axis=0) 608 | return np.expand_dims(edges, axis=3).astype(np.float32) 609 | 610 | 611 | def pyramid_structure_loss(image, predicts, edge_alpha, grad_alpha): 612 | _, H, W, _ = image.get_shape().as_list() 613 | loss = 0. 614 | for predict in predicts: 615 | _, h, w, _ = predict.get_shape().as_list() 616 | if h != H: 617 | gt_img = tf.image.resize_nearest_neighbor(image, size=(h, w)) 618 | # gt_mask = tf.image.resize_nearest_neighbor(mask, size=(h, w)) 619 | 620 | # grad 621 | gt_grad = tf.image.sobel_edges(gt_img) 622 | gt_grad = tf.reshape(gt_grad, [-1, h, w, 6]) # 6 channel 623 | grad_error = tf.abs(predict - gt_grad) 624 | 625 | # edge 626 | gt_edge = tf.py_func(canny_edge, [gt_img], tf.float32, stateful=False) 627 | edge_priority = priority_loss_mask(gt_edge, ksize=5, sigma=1, iteration=2) 628 | else: 629 | gt_img = image 630 | # gt_mask = mask 631 | 632 | # grad 633 | gt_grad = tf.image.sobel_edges(gt_img) 634 | gt_grad = tf.reshape(gt_grad, [-1, H, W, 6]) # 6 channel 635 | grad_error = tf.abs(predict - gt_grad) 636 | 637 | # edge 638 | gt_edge = tf.py_func(canny_edge, [gt_img], tf.float32, stateful=False) 639 | edge_priority = priority_loss_mask(gt_edge, ksize=5, sigma=1, iteration=2) 640 | 641 | grad_loss = tf.reduce_mean(grad_alpha * grad_error) 642 | edge_weight = edge_alpha * edge_priority 643 | # print("edge_weight", edge_weight.shape) 644 | # print("grad_error", grad_error.shape) 645 | edge_loss = tf.reduce_sum(edge_weight * grad_error) / tf.reduce_sum(edge_weight) / 6. # 6 channel 646 | 647 | loss = loss + grad_loss + edge_loss 648 | 649 | return loss -------------------------------------------------------------------------------- /src/utils_fn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | import scipy 5 | from scipy.misc import imread 6 | from scipy import ndimage 7 | from scipy.misc import imresize 8 | 9 | import skimage 10 | from skimage import feature 11 | from skimage.color import rgb2gray 12 | 13 | import tensorflow as tf 14 | import tensorflow.contrib.slim as slim 15 | from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch 16 | 17 | import cv2 18 | # free form mask (generated by algorithm) 19 | def np_free_form_mask(maxVertex, maxLength, maxBrushWidth, maxAngle, h, w): 20 | mask = np.zeros((h, w, 1), np.float32) 21 | numVertex = np.random.randint(maxVertex + 1) 22 | startY = np.random.randint(h) 23 | startX = np.random.randint(w) 24 | brushWidth = 0 25 | for i in range(numVertex): 26 | angle = np.random.randint(maxAngle + 1) 27 | angle = angle / 360.0 * 2 * np.pi 28 | if i % 2 == 0: 29 | angle = 2 * np.pi - angle 30 | length = np.random.randint(maxLength + 1) 31 | brushWidth = np.random.randint(10, maxBrushWidth + 1) // 2 * 2 32 | nextY = startY + length * np.cos(angle) 33 | nextX = startX + length * np.sin(angle) 34 | 35 | nextY = np.maximum(np.minimum(nextY, h - 1), 0).astype(np.int) 36 | nextX = np.maximum(np.minimum(nextX, w - 1), 0).astype(np.int) 37 | 38 | cv2.line(mask, (startY, startX), (nextY, nextX), 1, brushWidth) 39 | cv2.circle(mask, (startY, startX), brushWidth // 2, 2) 40 | 41 | startY, startX = nextY, nextX 42 | cv2.circle(mask, (startY, startX), brushWidth // 2, 2) 43 | return mask 44 | 45 | 46 | def free_form_mask_tf(parts, maxVertex=16, maxLength=60, maxBrushWidth=14, maxAngle=360, im_size=(256, 256), name='fmask'): 47 | """ 48 | Free form mask 49 | rf: NIPS multi-column conv 50 | """ 51 | # mask = np.zeros((im_size[0], im_size[1], 1), dtype=np.float32) 52 | with tf.variable_scope(name): 53 | mask = tf.Variable(tf.zeros([1, im_size[0], im_size[1], 1]), name='free_mask') 54 | maxVertex = tf.constant(maxVertex, dtype=tf.int32) 55 | maxLength = tf.constant(maxLength, dtype=tf.int32) 56 | maxBrushWidth = tf.constant(maxBrushWidth, dtype=tf.int32) 57 | maxAngle = tf.constant(maxAngle, dtype=tf.int32) 58 | h = tf.constant(im_size[0], dtype=tf.int32) 59 | w = tf.constant(im_size[1], dtype=tf.int32) 60 | for i in range(parts): 61 | p = tf.py_func(np_free_form_mask, [maxVertex, maxLength, maxBrushWidth, maxAngle, h, w], tf.float32) 62 | p = tf.reshape(p, [1, im_size[0], im_size[1], 1]) 63 | mask = mask + p 64 | mask = tf.minimum(mask, 1.0) 65 | return mask 66 | 67 | def free_form_mask(parts, maxVertex=16, maxLength=60, maxBrushWidth=14, maxAngle=360, im_size=(256, 256)): 68 | h, w = im_size[0], im_size[1] 69 | mask = np.zeros((h, w, 1), dtype=np.float32) 70 | for i in range(parts): 71 | p = np_free_form_mask(maxVertex, maxLength, maxBrushWidth, maxAngle, h, w) 72 | p = np.reshape(p, [1, h, w, 1]) 73 | mask = mask + p 74 | mask = np.minimum(mask, 1.0) 75 | return mask 76 | 77 | class ImageData: 78 | 79 | def __init__(self, args=None): 80 | """ 81 | image size 82 | """ 83 | self.img_size = args.IMG_SHAPES[0] 84 | self.channels = args.IMG_SHAPES[2] 85 | self.sigma = args.SIGMA 86 | # self.level = args.DOWN_LEVEL 87 | self.mode = 'rect' 88 | 89 | # TODO: different images with different preprocessing method 90 | def image_processing(self, filename): 91 | """ 92 | """ 93 | x = tf.read_file(filename,mode='RGB') # read filename 94 | img = tf.image.decode_jpeg(x, channels=self.channels) # read image and decode it. tf.image.decode_image 95 | img = tf.image.resize_images(img, [self.img_size, self.img_size]) 96 | img = tf.cast(img, tf.float32) / 127.5 - 1 # scale to [-1, 1] 97 | return img 98 | 99 | def image_processing2(self, filename): 100 | img = imread(filename,mode='RGB') 101 | imgh, imgw = img.shape[0:2] 102 | if imgh != imgw: 103 | # center crop 104 | side = np.minimum(imgh, imgw) 105 | j = (imgh - side) // 2 106 | i = (imgw - side) // 2 107 | img = img[j:j + side, i:i + side, ...] 108 | 109 | img = scipy.misc.imresize(img, [self.img_size, self.img_size]) 110 | img = scipy.misc.imresize(img, [self.img_size, self.img_size]) 111 | img = img.astype(np.float32) / 127.5 - 1 # scale to [-1, 1] 112 | return img 113 | 114 | def image_edge_processing(self, filename): 115 | img = imread(filename,mode='RGB') 116 | imgh, imgw = img.shape[0:2] 117 | if imgh != imgw: 118 | # center crop 119 | side = np.minimum(imgh, imgw) 120 | j = (imgh - side) // 2 121 | i = (imgw - side) // 2 122 | img = img[j:j + side, i:i + side, ...] 123 | 124 | img = scipy.misc.imresize(img, [self.img_size, self.img_size]) 125 | img = scipy.misc.imresize(img, [self.img_size, self.img_size]) 126 | 127 | # edge 128 | img_gray = rgb2gray(img) # with the channel dimension removed 129 | edge = feature.canny(img_gray, sigma=self.sigma).astype(np.float32) 130 | 131 | img = img.astype(np.float32) / 127.5 - 1 # scale to [-1, 1] 132 | return img, edge 133 | 134 | def image_edge_scale_processing(self, filename): 135 | img = imread(filename,mode='RGB') 136 | imgh, imgw = img.shape[0:2] 137 | if imgh != imgw: 138 | # center crop 139 | side = np.minimum(imgh, imgw) 140 | j = (imgh - side) // 2 141 | i = (imgw - side) // 2 142 | img = img[j:j + side, i:i + side, ...] 143 | 144 | img = scipy.misc.imresize(img, [self.img_size, self.img_size]) 145 | img = scipy.misc.imresize(img, [self.img_size, self.img_size]) 146 | 147 | # edge 148 | img_gray = rgb2gray(img) # with the channel dimension removed 149 | edge_256 = feature.canny(img_gray, sigma=self.sigma).astype(np.float32) 150 | img_gray = rgb2gray(imresize(img, [128, 128], interp='nearest')) 151 | edge_128 = feature.canny(img_gray, sigma=self.sigma).astype(np.float32) 152 | img_gray = rgb2gray(imresize(img, [64, 64], interp='nearest')) 153 | edge_64 = feature.canny(img_gray, sigma=self.sigma).astype(np.float32) 154 | # img_gray = rgb2gray(imresize(img, [32, 32]), interp='nearest') 155 | # edge_32 = feature.canny(img_gray, sigma=self.sigma).astype(np.float32) 156 | 157 | img = img.astype(np.float32) / 127.5 - 1 # scale to [-1, 1] 158 | return img, edge_256, edge_128, edge_64 159 | 160 | def mask_processing(self, filename): 161 | x = tf.read_file(filename) # read mask filename 162 | mask = tf.image.decode_png(x, channels=1) # read image and decode it. tf.image.decode_image 163 | mask = tf.image.resize_images(mask, [self.img_size, self.img_size]) 164 | return mask 165 | 166 | def mask_processing2(self, filename): 167 | """ 168 | For training 169 | """ 170 | mask = imread(filename) 171 | 172 | # mask: hole = 1, data augmentation 173 | # mask = (mask > 0).astype(np.float32) 174 | # print(mask.max()) 175 | # print(mask.min()) 176 | mask[mask <= 127] = 0 177 | mask[mask > 127] = 1 178 | 179 | # print(mask.max()) 180 | # print(mask.min()) 181 | # resize 182 | #mask = scipy.misc.imresize(mask, (self.img_size, self.img_size)) 183 | 184 | # random dilation (25%), we augmentation the mask in external way 185 | if np.random.randint(0, 4) == 0: 186 | mask = ndimage.binary_dilation(mask, iterations=np.random.randint(1,6)).astype(np.float32) 187 | mask = mask[np.newaxis, :, :, np.newaxis] 188 | 189 | # 5% prob generate fixed mask 190 | if np.random.randint(0, 20) == 0: 191 | mask = create_mask(256, 256, 256 // 2, 256 // 2, delta=0) 192 | 193 | # 10% prob generate free-form mask (ref: 2018NIPS-multi-column) 194 | if np.random.randint(0, 10) == 0: 195 | mask = free_form_mask(parts=8, im_size=(self.img_size, self.img_size), 196 | maxBrushWidth=20, maxLength=80, maxVertex=16) 197 | return mask.astype(np.float32) 198 | 199 | def mask_processing3(self, filename): 200 | """ 201 | For validation and test 202 | """ 203 | mask = imread(filename) 204 | # mask = skimage.io.imread(filename) 205 | 206 | # mask: hole = 1 207 | # mask = (mask > 0).astype(np.float32) 208 | mask[mask <= 127] = 0 209 | mask[mask > 127] = 1 210 | 211 | # resize 212 | # mask = scipy.misc.imresize(mask, (self.img_size, self.img_size)) 213 | 214 | mask = mask[np.newaxis, :, :, np.newaxis] 215 | 216 | return mask.astype(np.float32) 217 | 218 | 219 | def load_data(args): 220 | """ 221 | Load image data 222 | """ 223 | # training data: 0, as file list 224 | # image files 225 | with open(args.DATA_FLIST[args.DATASET][0]) as f: 226 | fnames = f.read().splitlines() 227 | 228 | # TODO: create input dataset (images and masks) 229 | inputs = tf.data.Dataset.from_tensor_slices(fnames) # a tf dataset object (op) 230 | if args.NUM_GPUS == 1: 231 | device = '/gpu:0' # to which gpu. prefetch_to_device(device, batch_size) 232 | else: 233 | device = '/cpu:0' 234 | dataset_num = len(fnames) 235 | # TODO: dataset with preprocessing (images and masks) 236 | Image_Data_Class = ImageData(args=args) 237 | 238 | # inputs = inputs.apply(shuffle_and_repeat(dataset_num)).apply( 239 | # map_and_batch(Image_Data_Class.image_processing, args.BATCH_SIZE, num_parallel_batches=16, 240 | # drop_remainder=True)).apply(prefetch_to_device(gpu_device, args.BATCH_SIZE)) 241 | inputs = inputs.apply(shuffle_and_repeat(dataset_num)).map(lambda filename: tf.py_func(Image_Data_Class.image_processing2, [filename], [tf.float32]), num_parallel_calls=3) 242 | inputs = inputs.batch(args.BATCH_SIZE*args.NUM_GPUS, drop_remainder=True).apply(prefetch_to_device(device, args.BATCH_SIZE)) 243 | inputs_iterator = inputs.make_one_shot_iterator() # iterator, 一次访问新的数据集的一个元素(batch) 244 | 245 | images = inputs_iterator.get_next() # an iteration get a batch of data 246 | 247 | return images 248 | 249 | def load_mask(args): 250 | # mask files 251 | with open(args.TRAIN_MASK_FLIST) as f: 252 | fnames = f.read().splitlines() 253 | 254 | # TODO: create input dataset (masks) 255 | inputs = tf.data.Dataset.from_tensor_slices(fnames) # a tf dataset object (op) 256 | 257 | if args.NUM_GPUS == 1: 258 | device = '/gpu:0' # to which gpu. prefetch_to_device(device, batch_size) 259 | else: 260 | device = '/cpu:0' 261 | 262 | dataset_num = len(fnames) 263 | # TODO: dataset with preprocessing (masks) 264 | Image_Data_Class = ImageData(args=args) 265 | 266 | # inputs = inputs.apply(shuffle_and_repeat(dataset_num)).apply( 267 | # map_and_batch(Image_Data_Class.image_processing, args.BATCH_SIZE, num_parallel_batches=16, 268 | # drop_remainder=True)).apply(prefetch_to_device(gpu_device, args.BATCH_SIZE)) 269 | inputs = inputs.apply(shuffle_and_repeat(dataset_num)).map(lambda filename: tf.py_func( 270 | Image_Data_Class.mask_processing2, [filename], [tf.float32]), num_parallel_calls=3) 271 | inputs = inputs.batch(1,drop_remainder=True).apply(prefetch_to_device(device, 1)) 272 | # inputs = inputs.apply(prefetch_to_device(device)) 273 | inputs_iterator = inputs.make_one_shot_iterator() # iterator 274 | 275 | masks = inputs_iterator.get_next() # an iteration get a batch of data 276 | 277 | return masks 278 | 279 | def create_mask(width, height, mask_width, mask_height, x=None, y=None, delta=0): 280 | """ 281 | create_mask(imgw, imgh, imgw // 2, imgh, 0 if random.random() < 0.5 else imgw // 2, 0) 282 | delta: margin between mask and image boundary 283 | """ 284 | mask = np.zeros((height, width)) 285 | mask_x = x if x is not None else np.random.randint(delta, width - mask_width - delta) 286 | mask_y = y if y is not None else np.random.randint(delta, height - mask_height - delta) 287 | mask[mask_y:mask_y + mask_height, mask_x:mask_x + mask_width] = 1 288 | mask = mask[np.newaxis, :, :, np.newaxis] 289 | return mask 290 | 291 | def load_validation_data(args): 292 | """ 293 | Load image data 294 | """ 295 | # validation data: 1, as file list 296 | # image files 297 | with open(args.DATA_FLIST[args.DATASET][1]) as f: 298 | fnames = f.read().splitlines() 299 | 300 | # TODO: create input dataset (images) 301 | inputs = tf.data.Dataset.from_tensor_slices(fnames) # a tf dataset object (op) 302 | 303 | gpu_device = '/gpu:0' # to which gpu. prefetch_to_device(device, batch_size) 304 | # gpu_device = '/gpu:{}'.format(args.GPU_ID) 305 | 306 | dataset_num = len(fnames) 307 | # TODO: dataset with preprocessing (images) 308 | Image_Data_Class = ImageData(args) 309 | inputs = inputs.map(lambda filename: tf.py_func(Image_Data_Class.image_processing2, [filename], [tf.float32]), num_parallel_calls=3) 310 | inputs = inputs.batch(args.VAL_NUM,drop_remainder=True).apply(prefetch_to_device(gpu_device,1)) 311 | inputs_iterator = inputs.make_initializable_iterator() # iterator, need to be initialized 312 | 313 | images = inputs_iterator.get_next() # an iteration get a batch of data 314 | 315 | return images, inputs_iterator 316 | 317 | def load_validation_mask(args): 318 | # mask files 319 | with open(args.VAL_MASK_FLIST) as f: 320 | fnames = f.read().splitlines() 321 | 322 | # TODO: create input dataset (masks) 323 | inputs = tf.data.Dataset.from_tensor_slices(fnames) # a tf dataset object (op) 324 | 325 | gpu_device = '/gpu:0' # to which gpu. prefetch_to_device(device, batch_size) 326 | # gpu_device = '/gpu:{}'.format(args.GPU_ID) 327 | 328 | dataset_num = len(fnames) 329 | # TODO: dataset with preprocessing (masks) 330 | Image_Data_Class = ImageData(args=args) 331 | 332 | inputs = inputs.map(lambda filename: tf.py_func(Image_Data_Class.mask_processing3, [filename], [tf.float32]), num_parallel_calls=3) 333 | inputs = inputs.batch(args.VAL_NUM,drop_remainder=True).apply(prefetch_to_device(gpu_device, 1)) 334 | # inputs = inputs.apply(prefetch_to_device(gpu_device)) 335 | inputs_iterator = inputs.make_initializable_iterator() 336 | 337 | masks = inputs_iterator.get_next() # an iteration get a batch of data 338 | 339 | return masks, inputs_iterator 340 | 341 | def create_validation_mask(width, height, mask_width, mask_height, args, x=None, y=None, delta=0): 342 | """ 343 | create_mask(imgw, imgh, imgw // 2, imgh, 0 if random.random() < 0.5 else imgw // 2, 0) 344 | """ 345 | masks = np.zeros((args.VAL_NUM, height, width)) 346 | for i in range(args.VAL_NUM): 347 | mask_x = x if x is not None else np.random.randint(delta, width - mask_width - delta) 348 | mask_y = y if y is not None else np.random.randint(delta, height - mask_height - delta) 349 | masks[i,mask_y:mask_y + mask_height, mask_x:mask_x + mask_width] = 1 350 | masks = masks[:, :, :, np.newaxis] 351 | return masks 352 | 353 | def load_test_data(args): 354 | """ 355 | Load image data 356 | """ 357 | # test data: 2, as file list 358 | # image files 359 | with open(args.DATA_FLIST[args.DATASET][1]) as f: 360 | fnames = f.read().splitlines() 361 | 362 | # TODO: create input dataset (images) 363 | inputs = tf.data.Dataset.from_tensor_slices(fnames) # a tf dataset object (op) 364 | 365 | gpu_device = '/gpu:0' # to which gpu. prefetch_to_device(device, batch_size) 366 | # gpu_device = '/gpu:{}'.format(args.GPU_ID) 367 | 368 | dataset_num = len(fnames) 369 | # TODO: dataset with preprocessing (images) 370 | Image_Data_Class = ImageData(args=args) 371 | inputs = inputs.map(lambda filename: tf.py_func(Image_Data_Class.image_processing2, [filename], [tf.float32]), num_parallel_calls=3) 372 | inputs = inputs.batch(args.TEST_NUM,drop_remainder=True).apply(prefetch_to_device(gpu_device)) 373 | inputs_iterator = inputs.make_initializable_iterator() # iterator, need to be initialized 374 | 375 | images = inputs_iterator.get_next() # an iteration get a batch of data 376 | 377 | return images, inputs_iterator 378 | 379 | def load_test_mask(args): 380 | # mask files 381 | with open(args.TEST_MASK_FLIST) as f: 382 | fnames = f.read().splitlines() 383 | 384 | # TODO: create input dataset (masks) 385 | inputs = tf.data.Dataset.from_tensor_slices(fnames) # a tf dataset object (op) 386 | 387 | gpu_device = '/gpu:0' # to which gpu. prefetch_to_device(device, batch_size) 388 | # gpu_device = '/gpu:{}'.format(args.GPU_ID) 389 | 390 | dataset_num = len(fnames) 391 | # TODO: dataset with preprocessing (masks) 392 | Image_Data_Class = ImageData(args=args) 393 | 394 | inputs = inputs.map(lambda filename: tf.py_func(Image_Data_Class.mask_processing3, [filename], [tf.float32]), num_parallel_calls=3) 395 | inputs = inputs.batch(args.TEST_NUM,drop_remainder=True).apply(prefetch_to_device(gpu_device, 1)) 396 | # inputs = inputs.apply(prefetch_to_device(gpu_device)) 397 | inputs_iterator = inputs.make_initializable_iterator() # iterator 398 | 399 | masks = inputs_iterator.get_next() # an iteration get a batch of data 400 | 401 | return masks, inputs_iterator 402 | 403 | def create_test_mask(width, height, mask_width, mask_height, args, x=None, y=None, delta=0): 404 | """ 405 | create_mask(imgw, imgh, imgw // 2, imgh, 0 if random.random() < 0.5 else imgw // 2, 0) 406 | """ 407 | masks = np.zeros((args.TEST_NUM, height, width)) 408 | for i in range(args.TEST_NUM): 409 | mask_x = x if x is not None else np.random.randint(delta, width - mask_width - delta) 410 | mask_y = y if y is not None else np.random.randint(delta, height - mask_height - delta) 411 | masks[i,mask_y:mask_y + mask_height, mask_x:mask_x + mask_width] = 1 412 | masks = masks[:, :, :, np.newaxis] 413 | return masks 414 | 415 | def dataset_len(args): 416 | with open(args.DATA_FLIST[args.DATASET][0]) as f: 417 | fnames = f.read().splitlines() 418 | return len(fnames) 419 | 420 | def show_all_variables(): 421 | """ 422 | Show all the variables of an tf model. 423 | """ 424 | model_vars = tf.trainable_variables() 425 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 426 | 427 | def normalize(x) : 428 | return x/127.5 - 1 429 | 430 | def save_images(images, size, image_path): 431 | return imsave(inverse_transform(images), size, image_path) 432 | 433 | def merge(images, size): 434 | h, w = images.shape[1], images.shape[2] 435 | if (images.shape[3] in (3,4)): 436 | c = images.shape[3] 437 | img = np.zeros((h * size[0], w * size[1], c)) 438 | for idx, image in enumerate(images): 439 | i = idx % size[1] 440 | j = idx // size[1] 441 | img[j * h:j * h + h, i * w:i * w + w, :] = image 442 | return img 443 | elif images.shape[3]==1: 444 | img = np.zeros((h * size[0], w * size[1])) 445 | for idx, image in enumerate(images): 446 | i = idx % size[1] 447 | j = idx // size[1] 448 | img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0] 449 | return img 450 | else: 451 | raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4') 452 | 453 | def imsave(images, size, path): 454 | return scipy.misc.imsave(path, merge(images, size)) 455 | 456 | def inverse_transform(images): 457 | return (images+1.)*127.5 458 | 459 | def load_img_edge(args): 460 | """ 461 | Load image data 462 | """ 463 | # training data: 0, as file list 464 | # image files 465 | with open(args.DATA_FLIST[args.DATASET][0]) as f: 466 | fnames = f.read().splitlines() 467 | 468 | # TODO: create input dataset (images and masks) 469 | inputs = tf.data.Dataset.from_tensor_slices(fnames) # a tf dataset object (op) 470 | if args.NUM_GPUS == 1: 471 | device = '/gpu:0' # to which gpu. prefetch_to_device(device, batch_size) 472 | # gpu_device = '/gpu:{}'.format(args.GPU_ID) 473 | else: 474 | device = '/cpu:0' 475 | dataset_num = len(fnames) 476 | # TODO: dataset with preprocessing (images and masks) 477 | Image_Data_Class = ImageData(args=args) 478 | # inputs = inputs.apply(shuffle_and_repeat(dataset_num)).apply( 479 | # map_and_batch(Image_Data_Class.image_processing, args.BATCH_SIZE, num_parallel_batches=16, 480 | # drop_remainder=True)).apply(prefetch_to_device(gpu_device, args.BATCH_SIZE)) 481 | inputs = inputs.apply(shuffle_and_repeat(dataset_num)).map(lambda filename: tf.py_func( 482 | Image_Data_Class.image_edge_processing, [filename], [tf.float32, tf.float32]), num_parallel_calls=3) 483 | inputs = inputs.batch(args.BATCH_SIZE*args.NUM_GPUS, drop_remainder=True).apply(prefetch_to_device(device, args.BATCH_SIZE)) 484 | inputs_iterator = inputs.make_one_shot_iterator() # iterator 485 | 486 | images_edges = inputs_iterator.get_next() # an iteration get a batch of data 487 | 488 | return images_edges 489 | 490 | def load_val_img_edge(args): 491 | 492 | """ 493 | Load image data 494 | """ 495 | # validation data: 1, as file list 496 | # image files 497 | with open(args.DATA_FLIST[args.DATASET][1]) as f: 498 | fnames = f.read().splitlines() 499 | 500 | # TODO: create input dataset (images) 501 | inputs = tf.data.Dataset.from_tensor_slices(fnames) # a tf dataset object (op) 502 | 503 | gpu_device = '/gpu:0' # to which gpu. prefetch_to_device(device, batch_size) 504 | # gpu_device = '/gpu:{}'.format(args.GPU_ID) 505 | 506 | dataset_num = len(fnames) 507 | # TODO: dataset with preprocessing (images) 508 | Image_Data_Class = ImageData(args) 509 | inputs = inputs.map(lambda filename: tf.py_func(Image_Data_Class.image_edge_processing, [filename], [tf.float32, tf.float32]), 510 | num_parallel_calls=3) 511 | inputs = inputs.batch(args.VAL_NUM, drop_remainder=True).apply(prefetch_to_device(gpu_device)) 512 | inputs_iterator = inputs.make_initializable_iterator() # iterator, need to be initialized 513 | 514 | images_edges = inputs_iterator.get_next() # an iteration get a batch of data 515 | 516 | return images_edges, inputs_iterator 517 | 518 | def load_test_img_edge(args): 519 | 520 | """ 521 | Load image data 522 | """ 523 | # validation data: 1, as file list 524 | # image files 525 | with open(args.DATA_FLIST[args.DATASET][1]) as f: 526 | fnames = f.read().splitlines() 527 | 528 | # TODO: create input dataset (images) 529 | inputs = tf.data.Dataset.from_tensor_slices(fnames) # a tf dataset object (op) 530 | 531 | gpu_device = '/gpu:0' # to which gpu. prefetch_to_device(device, batch_size) 532 | # gpu_device = '/gpu:{}'.format(args.GPU_ID) 533 | 534 | dataset_num = len(fnames) 535 | # TODO: dataset with preprocessing (images) 536 | Image_Data_Class = ImageData(args) 537 | inputs = inputs.map(lambda filename: tf.py_func(Image_Data_Class.image_edge_processing, [filename], [tf.float32, tf.float32]), 538 | num_parallel_calls=3) 539 | inputs = inputs.batch(args.TEST_NUM, drop_remainder=True).apply(prefetch_to_device(gpu_device)) 540 | inputs_iterator = inputs.make_initializable_iterator() # iterator, need to be initialized 541 | 542 | images_edges = inputs_iterator.get_next() # an iteration get a batch of data 543 | 544 | return images_edges, inputs_iterator 545 | 546 | def load_img_scale_edge(args): 547 | """ 548 | Load image data 549 | """ 550 | # training data: 0, as file list 551 | # image files 552 | with open(args.DATA_FLIST[args.DATASET][0]) as f: 553 | fnames = f.read().splitlines() 554 | 555 | # TODO: create input dataset (images and masks) 556 | inputs = tf.data.Dataset.from_tensor_slices(fnames) # a tf dataset object (op) 557 | if args.NUM_GPUS == 1: 558 | device = '/gpu:0' # to which gpu. prefetch_to_device(device, batch_size) 559 | # gpu_device = '/gpu:{}'.format(args.GPU_ID) 560 | else: 561 | device = '/cpu:0' 562 | dataset_num = len(fnames) 563 | # TODO: dataset with preprocessing (images and masks) 564 | Image_Data_Class = ImageData(args=args) 565 | 566 | # inputs = inputs.apply(shuffle_and_repeat(dataset_num)).apply( 567 | # map_and_batch(Image_Data_Class.image_processing, args.BATCH_SIZE, num_parallel_batches=16, 568 | # drop_remainder=True)).apply(prefetch_to_device(gpu_device, args.BATCH_SIZE)) 569 | inputs = inputs.apply(shuffle_and_repeat(dataset_num)).map(lambda filename: tf.py_func( 570 | Image_Data_Class.image_edge_scale_processing, [filename], [tf.float32, tf.float32,tf.float32, tf.float32]), num_parallel_calls=3) 571 | inputs = inputs.batch(args.BATCH_SIZE*args.NUM_GPUS, drop_remainder=True).apply(prefetch_to_device(device, args.BATCH_SIZE*args.NUM_GPUS)) 572 | inputs_iterator = inputs.make_one_shot_iterator() # iterator 573 | 574 | images_edges = inputs_iterator.get_next() # an iteration get a batch of data 575 | 576 | return images_edges 577 | 578 | def load_val_img_scale_edge(args): 579 | """ 580 | Load image data 581 | """ 582 | # training data: 0, as file list 583 | # image files 584 | with open(args.DATA_FLIST[args.DATASET][1]) as f: 585 | fnames = f.read().splitlines() 586 | 587 | # TODO: create input dataset (images and masks) 588 | inputs = tf.data.Dataset.from_tensor_slices(fnames) # a tf dataset object (op) 589 | 590 | gpu_device = '/gpu:0' # to which gpu. prefetch_to_device(device, batch_size) 591 | # gpu_device = '/gpu:{}'.format(args.GPU_ID) 592 | 593 | dataset_num = len(fnames) 594 | # TODO: dataset with preprocessing (images and masks) 595 | Image_Data_Class = ImageData(args=args) 596 | 597 | # inputs = inputs.apply(shuffle_and_repeat(dataset_num)).apply( 598 | # map_and_batch(Image_Data_Class.image_processing, args.BATCH_SIZE, num_parallel_batches=16, 599 | # drop_remainder=True)).apply(prefetch_to_device(gpu_device, args.BATCH_SIZE)) 600 | inputs = inputs.apply(shuffle_and_repeat(dataset_num)).map(lambda filename: tf.py_func( 601 | Image_Data_Class.image_edge_scale_processing, [filename], [tf.float32, tf.float32,tf.float32, tf.float32]), num_parallel_calls=3) 602 | inputs = inputs.batch(args.VAL_NUM,drop_remainder=True).apply(prefetch_to_device(gpu_device, args.BATCH_SIZE)) 603 | inputs_iterator = inputs.make_initializable_iterator() # iterator 604 | 605 | images_edges = inputs_iterator.get_next() # an iteration get a batch of data 606 | 607 | return images_edges, inputs_iterator 608 | 609 | 610 | 611 | # random rect mask 612 | def random_bbox(config): 613 | """Generate a random tlhw with configuration. 614 | 615 | Args: 616 | config: Config should have configuration including IMG_SHAPES, 617 | VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH. 618 | 619 | Returns: 620 | tuple: (top, left, height, width) 621 | 622 | """ 623 | img_shape = config.img_shapes 624 | img_height = img_shape[0] 625 | img_width = img_shape[1] 626 | if config.random_mask is True: 627 | maxt = img_height - config.margins[0] - config.mask_shapes[0] 628 | maxl = img_width - config.margins[1] - config.mask_shapes[1] 629 | t = tf.random_uniform( 630 | [], minval=config.margins[0], maxval=maxt, dtype=tf.int32) 631 | l = tf.random_uniform( 632 | [], minval=config.margins[1], maxval=maxl, dtype=tf.int32) 633 | else: 634 | t = config.mask_shapes[0]//2 635 | l = config.mask_shapes[1]//2 636 | h = tf.constant(config.mask_shapes[0]) 637 | w = tf.constant(config.mask_shapes[1]) 638 | return (t, l, h, w) 639 | 640 | 641 | def bbox2mask(bbox, config, name='mask'): 642 | """Generate mask tensor from bbox. 643 | 644 | Args: 645 | bbox: configuration tuple, (top, left, height, width) 646 | config: Config should have configuration including IMG_SHAPES, 647 | MAX_DELTA_HEIGHT, MAX_DELTA_WIDTH. 648 | 649 | Returns: 650 | tf.Tensor: output with shape [1, H, W, 1] 651 | 652 | """ 653 | def npmask(bbox, height, width, delta_h, delta_w): 654 | mask = np.zeros((1, height, width, 1), np.float32) 655 | h = np.random.randint(delta_h//2+1) 656 | w = np.random.randint(delta_w//2+1) 657 | mask[:, bbox[0]+h:bbox[0]+bbox[2]-h, 658 | bbox[1]+w:bbox[1]+bbox[3]-w, :] = 1. 659 | return mask 660 | with tf.variable_scope(name), tf.device('/cpu:0'): 661 | img_shape = config.img_shapes 662 | height = img_shape[0] 663 | width = img_shape[1] 664 | mask = tf.py_func( 665 | npmask, 666 | [bbox, height, width, 667 | config.max_delta_shapes[0], config.max_delta_shapes[1]], 668 | tf.float32, stateful=False) 669 | mask.set_shape([1] + [height, width] + [1]) 670 | return mask 671 | 672 | """ 673 | How to use 674 | # generate mask, 1 represents masked point 675 | if config.mask_type == 'rect': 676 | bbox = random_bbox(config) 677 | mask = bbox2mask(bbox, config, name='mask_c') 678 | else: 679 | mask = free_form_mask_tf(parts=8, im_size=(config.img_shapes[0], config.img_shapes[1]), 680 | maxBrushWidth=20, maxLength=80, maxVertex=16) 681 | """ --------------------------------------------------------------------------------