├── .gitattributes ├── LICENSE ├── README.md ├── dataset.py ├── mobilenetv2.py ├── mobilenetv2_075.npy ├── network.py ├── test.py ├── train.py └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 secret_wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | update 2019.10.24: 2 | 3 | The writer open-sourced official version, so this version could be abondoned. In fact the bilateral filter part cannot get reasonable results(because i don't use C++ and dont's know how to make the official version in HDRNet), and there are also other bugs in my code. 4 | 5 | 6 | In fact, I tried to modify Deep Guided Filter and got pretty good results in color adjusting/HDR task.(https://github.com/wuhuikai/DeepGuidedFilter) Its inplementation only involves official tensorflow ops and does not neet to add self-defined ops. I am sorry that the code could not be published because it's relevant to my job. But i believe it's easy to modify by your self. 7 | 8 | ______________________________________________________________________________________________________________________________ 9 | 10 | 11 | update 2019.07.11: 12 | 13 | I believe codes in this repo cannot get you any reasonale result. It may be the codes themselves (bilateral upsample), or the dataset(they are not strictly aligned, so pixel-wise losses are not suitable). Recently I tried deep guided filter with re-process dataset, and got much better result. I will update my code once I have time, maybe in this repo, or open a new repo. 14 | 15 | ______________________________________________________________________________________________________________________________ 16 | 17 | # deepupe_tensorflow 18 | 19 | This is an unofficial implementation of cvpr2019 paper "Underexposed Photo Enhancement using Deep Illumination Estimation" 20 | ------------- 21 | paper url: http://jiaya.me/papers/photoenhance_cvpr19.pdf 22 | 23 | My testing environment is as below: 24 | 25 | ubuntu 16.04 26 | 27 | tensorflow-gpu 1.12 28 | 29 | cuda 9.0 30 | 31 | cudnn 7.3.1 32 | 33 | pytorch 1.0(i used pytorch dataloader) 34 | 35 | ______________________________________________________________________________________________________________________________ 36 | 37 | 38 | I guess this code can run on recent tensorflow versions without main problems but you may need to modify the dataloader part, as I am using post-processed hdr_burst dataset with photos arranged in my own way. 39 | 40 | There are some differences bewteen my implementation and the original paper: 41 | 42 | In the paper, VGG19 is used as feature extractor but here I use mobilenetv2 the pre-trained model need input data to be normalized to (-1, 1), but according to my experiments, this will result in bad visual quality, so you may consider not use the pre-trained weight or even try to use another network (see in network.py) 43 | 44 | The smoothness loss is also different: I kept getting nan with the logarithmic operation so I deleted that part, the left part could be seen as total-variation loss (tv-loss) 45 | 46 | also, the bilateral slice op is borrowed from this repo: https://github.com/dragonkao730/BilateralGrid/blob/master/bilaterial_grid.py 47 | 48 | I am still training the model, and will upload the pre-trained weight if I got visual satisfying result.(There are some problems with these codes now, i am trying to modify them and will update later) 49 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | from tqdm import tqdm 5 | from torch.utils.data import Dataset, DataLoader 6 | from torchvision import transforms 7 | 8 | 9 | class HDRDataset(Dataset): 10 | 11 | 12 | def __init__(self, data_dir, mode='train', transform=None): 13 | total_len = len(os.listdir(data_dir))//2 14 | self.data_dir = data_dir 15 | if mode == 'train': 16 | self.data_len = 3000 17 | 18 | elif mode == 'test': 19 | self.data_len = total_len - 3000 20 | 21 | self.transform = transform 22 | 23 | 24 | def __len__(self): 25 | return self.data_len 26 | 27 | 28 | def __getitem__(self, idx): 29 | image_path = os.path.join(self.data_dir, '{}.jpg'.format(str(idx).zfill(4))) 30 | label_path = os.path.join(self.data_dir, '{}_gt.jpg'.format(str(idx).zfill(4))) 31 | image = cv2.imread(image_path) 32 | label = cv2.imread(label_path) 33 | h1, w1 = np.shape(image)[:2] 34 | h2, w2 = np.shape(label)[:2] 35 | 36 | if h1 > h2: 37 | dh = (h1-h2)//2 38 | image = image[dh: dh+h2, :, :] 39 | elif h1 < h2: 40 | dh = (h2-h1)//2 41 | label = label[dh: dh+h1, :, :] 42 | 43 | if w1 > w2: 44 | dw = (w1-w2)//2 45 | image = image[:, dw: dw+w2, :] 46 | elif w1 < w2: 47 | dw = (w1-w2)//2 48 | label = label[:, dw: dw+w1, :] 49 | 50 | try: 51 | assert np.shape(image) == np.shape(label) 52 | except: 53 | print(np.shape(image), np.shape(label)) 54 | 55 | 56 | if self.transform is not None: 57 | image, label = self.transform(image, label) 58 | 59 | return image, label 60 | 61 | 62 | 63 | class TrainTransform(): 64 | 65 | def __init__(self, output_size): 66 | assert isinstance(output_size, (int, tuple)) 67 | if isinstance(output_size, int): 68 | self.output_size = (output_size, output_size) 69 | else: 70 | assert len(output_size) == 2 71 | self.output_size = output_size 72 | 73 | 74 | def __call__(self, image, label): 75 | 76 | new_h, new_w = self.output_size 77 | 78 | h, w = np.shape(image)[:2] 79 | offset_h = np.random.randint(0, h - new_h) 80 | offset_w = np.random.randint(0, w - new_w) 81 | 82 | image = image[offset_h: offset_h + new_h, 83 | offset_w: offset_w + new_w] 84 | label = label[offset_h: offset_h + new_h, 85 | offset_w: offset_w + new_w] 86 | 87 | flip_prop = np.random.randint(0, 100) 88 | if flip_prop > 50: 89 | image = cv2.flip(image, 1) 90 | label = cv2.flip(label, 1) 91 | 92 | image = image.astype(np.float32)/255.0 93 | label = label.astype(np.float32)/255.0 94 | 95 | 96 | return image, label 97 | 98 | 99 | 100 | class TestTransform(): 101 | 102 | def __init__(self): 103 | pass 104 | 105 | 106 | def __call__(self, image, label): 107 | 108 | image = image.astype(np.float32)/255.0 109 | #label = label.astype(np.float32)/127.5 - 1 110 | 111 | return image, label 112 | 113 | 114 | 115 | def get_train_loader(image_size, batch_size, data_dir): 116 | transform = TrainTransform(image_size) 117 | dataset = HDRDataset(data_dir, transform=transform) 118 | dataloader = DataLoader(dataset, batch_size=batch_size, 119 | shuffle=True, num_workers=8) 120 | return dataloader 121 | 122 | 123 | 124 | def get_test_loader(data_dir): 125 | transform = TestTransform() 126 | dataset = HDRDataset(data_dir, mode='test', transform=transform) 127 | dataloader = DataLoader(dataset, batch_size=1, 128 | shuffle=False, num_workers=0) 129 | return dataloader 130 | 131 | 132 | 133 | if __name__ == '__main__': 134 | data_dir = '/media/wangxinrui/新加卷/hdr+burst/hdr_burst' 135 | dataloader = get_train_loader(512, 8, data_dir) 136 | for idx, batch in tqdm(enumerate(dataloader), total=len(dataloader)): 137 | print(np.shape(batch[0]), np.shape(batch[1])) 138 | #pass 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | -------------------------------------------------------------------------------- /mobilenetv2.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.slim as slim 3 | import numpy as np 4 | 5 | 6 | def init_conv(x, out_channel, scope='init_conv', is_training=True): 7 | with tf.variable_scope(scope): 8 | with slim.arg_scope([slim.convolution2d], 9 | normalizer_fn=slim.batch_norm, 10 | activation_fn=tf.nn.relu6): 11 | with slim.arg_scope([slim.batch_norm], 12 | is_training=is_training, 13 | center=True, scale=True): 14 | x = slim.convolution2d(x, out_channel, [3, 3], stride=2, padding='same', 15 | biases_initializer=None, biases_regularizer=None) 16 | #x = slim.batch_norm(x, scale=True, is_training=is_training) 17 | return x 18 | 19 | 20 | 21 | def init_resblock(x, out_channel, stride, 22 | scope='init_resblock', is_training=True): 23 | with tf.variable_scope(scope): 24 | 25 | with slim.arg_scope([slim.separable_convolution2d], 26 | normalizer_fn=slim.batch_norm, 27 | activation_fn=tf.nn.relu6): 28 | with slim.arg_scope([slim.batch_norm], 29 | is_training=is_training, 30 | center=True, scale=True): 31 | x = slim.separable_convolution2d(x, None, [3, 3], depth_multiplier=1, stride=stride, 32 | biases_initializer=None, biases_regularizer=None) 33 | 34 | with slim.arg_scope([slim.convolution2d], 35 | normalizer_fn=slim.batch_norm): 36 | with slim.arg_scope([slim.batch_norm], 37 | is_training=is_training, 38 | center=True, scale=True): 39 | x = slim.convolution2d(x, out_channel, [1, 1], 40 | stride=1, padding='same', 41 | biases_initializer=None, 42 | biases_regularizer=None) 43 | 44 | return x 45 | 46 | 47 | def inverte_resblock(x, in_channel, out_channel, stride, expand_radio=6, 48 | res_connect=True, scope='inverte_resblock', is_training=True): 49 | mid_channel = in_channel * expand_radio 50 | with tf.variable_scope(scope): 51 | if res_connect: 52 | short_cut = x 53 | 54 | with slim.arg_scope([slim.convolution2d], 55 | normalizer_fn=slim.batch_norm, 56 | activation_fn=tf.nn.relu6): 57 | with slim.arg_scope([slim.batch_norm], 58 | is_training=is_training, 59 | center=True, scale=True): 60 | x = slim.convolution2d(x, mid_channel, [1, 1], stride=1, padding='same', 61 | biases_initializer=None, biases_regularizer=None) 62 | 63 | with slim.arg_scope([slim.separable_convolution2d], 64 | normalizer_fn=slim.batch_norm, 65 | activation_fn=tf.nn.relu6): 66 | with slim.arg_scope([slim.batch_norm], 67 | is_training=is_training, 68 | center=True, scale=True): 69 | x = slim.separable_convolution2d(x, None, [3, 3], depth_multiplier=1, stride=stride, 70 | biases_initializer=None, biases_regularizer=None) 71 | 72 | with slim.arg_scope([slim.convolution2d], 73 | normalizer_fn=slim.batch_norm): 74 | with slim.arg_scope([slim.batch_norm], 75 | is_training=is_training, 76 | center=True, scale=True): 77 | x = slim.convolution2d(x, out_channel, [1, 1], 78 | stride=1, padding='same', 79 | biases_initializer=None, 80 | biases_regularizer=None) 81 | 82 | if res_connect: 83 | return x + short_cut 84 | else: 85 | return x 86 | 87 | 88 | def backbone(inputs, width=0.75, scope='backbone', is_training=True, reuse=False): 89 | 90 | if width == 1: 91 | channel = np.array([32, 16, 24, 32, 64, 96, 160, 320]) 92 | elif width == 0.75: 93 | channel = np.array([24, 16, 24, 24, 48, 72, 120, 240]) 94 | elif width == 0.5: 95 | channel = np.array([16, 8, 16, 16, 32, 48, 80, 160]) 96 | 97 | with tf.variable_scope(scope, reuse=reuse): 98 | 99 | x = init_conv(inputs, channel[0], is_training=is_training) 100 | x1_out = init_resblock(x, channel[1], stride=1, is_training=is_training) 101 | 102 | x = inverte_resblock(x1_out, in_channel=channel[1], out_channel=channel[2], stride=2, 103 | res_connect=False, scope='invert1_1', is_training=is_training) 104 | x2_out = inverte_resblock(x, in_channel=channel[2], out_channel=channel[2], stride=1, 105 | res_connect=True, scope='invert1_2', is_training=is_training) 106 | 107 | x = inverte_resblock(x2_out, in_channel=channel[2], out_channel=channel[3], stride=2, 108 | res_connect=False, scope='invert2_1', is_training=is_training) 109 | x = inverte_resblock(x, in_channel=channel[3], out_channel=channel[3], stride=1, 110 | res_connect=True, scope='invert2_2', is_training=is_training) 111 | x3_out = inverte_resblock(x, in_channel=channel[3], out_channel=channel[3], stride=1, 112 | res_connect=True, scope='invert2_3', is_training=is_training) 113 | 114 | x = inverte_resblock(x3_out, in_channel=channel[3], out_channel=channel[4], stride=2, 115 | res_connect=False, scope='invert3_1', is_training=is_training) 116 | for i in range(3): 117 | x = inverte_resblock(x, in_channel=channel[4], out_channel=channel[4], stride=1, 118 | res_connect=True, scope='invert3_{}'.format(i+2), 119 | is_training=is_training) 120 | 121 | x = inverte_resblock(x, in_channel=channel[4], out_channel=channel[5], stride=1, 122 | res_connect=False, scope='invert4_1', is_training=is_training) 123 | x = inverte_resblock(x, in_channel=channel[5], out_channel=channel[5], stride=1, 124 | res_connect=True, scope='invert4_2', is_training=is_training) 125 | 126 | x4_out = inverte_resblock(x, in_channel=channel[5], out_channel=channel[5], stride=1, 127 | res_connect=True, scope='invert4_3', is_training=is_training) 128 | 129 | 130 | x = inverte_resblock(x4_out, in_channel=channel[5], out_channel=channel[6], stride=2, 131 | res_connect=False, scope='invert5_1', is_training=is_training) 132 | 133 | for j in range(2): 134 | x = inverte_resblock(x, in_channel=channel[6], out_channel=channel[6], 135 | stride=1, res_connect=True, scope='invert5_{}'.format(j+2), 136 | is_training=is_training) 137 | 138 | x5_out = inverte_resblock(x, in_channel=channel[6], out_channel=channel[7], stride=1, 139 | res_connect=False, scope='invert5_4', is_training=is_training) 140 | 141 | return x5_out 142 | 143 | 144 | if __name__ == '__main__': 145 | 146 | 147 | inputs = tf.placeholder(tf.float32, [1, 256, 256, 3]) 148 | outputs = backbone(inputs, 0.5) 149 | for xxx in outputs: 150 | print(xxx.get_shape().as_list()) 151 | 152 | ''' 153 | sess = tf.Session() 154 | sess.run(tf.global_variables_initializer()) 155 | 156 | converter = tf.lite.TFLiteConverter.from_session( 157 | sess, [inputs], [outputs]) 158 | 159 | converter.default_ranges_stats=(0, 6) 160 | converter.inference_type = tf.lite.constants.QUANTIZED_UINT8 161 | input_arrays = converter.get_input_arrays() 162 | converter.quantized_input_stats = {input_arrays[0] : (0., 1.)} # mean, std_dev 163 | tflite_model = converter.convert() 164 | open("narrow_unet.tflite", "wb").write(tflite_model) 165 | ''' 166 | 167 | ''' 168 | adb push narrow_unet.tflite /data/local/tmp 169 | adb shell /data/local/tmp/benchmark_model --graph=/data/local/tmp/narrow_unet.tflite --num_threads=1 170 | ''' 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | -------------------------------------------------------------------------------- /mobilenetv2_075.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SystemErrorWang/deepupe_tensorflow/ce3406d93d773ecbc6def363b0b7315a9331d3fb/mobilenetv2_075.npy -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tensorflow.contrib.slim as slim 4 | from mobilenetv2 import backbone 5 | from utils import apply_bilateral_grid 6 | from tensorflow.python.framework import graph_util 7 | 8 | 9 | def get_tensor_shape(x): 10 | a = x.get_shape().as_list() 11 | b = [tf.shape(x)[i] for i in range(len(a))] 12 | return [aa if type(aa) is int else bb for aa, bb in zip(a, b)] 13 | 14 | 15 | 16 | def coef_mobilenetv2(inputs, width=0.75, luma_bins=8, is_training='False', name='coefficients'): 17 | with tf.variable_scope(name): 18 | 19 | with slim.arg_scope([slim.separable_convolution2d, slim.fully_connected], 20 | normalizer_fn=slim.batch_norm, activation_fn=tf.nn.relu6): 21 | with slim.arg_scope([slim.batch_norm], 22 | is_training=is_training, center=True, scale=True): 23 | 24 | x = backbone(inputs, width=width, is_training=is_training) 25 | 26 | for _ in range(2): 27 | x = slim.convolution2d(x, 48, [3, 3], stride=1) 28 | 29 | pool = tf.reduce_mean(x, axis=[1, 2], keepdims=False) 30 | 31 | fc1 = slim.fully_connected(pool, 192) 32 | 33 | fc2 = slim.fully_connected(fc1, 96) 34 | 35 | fc3 = slim.fully_connected(fc2, 48) 36 | 37 | feat1 = slim.convolution2d(x, 48, [3, 3], stride=1) 38 | 39 | feat2 = slim.convolution2d(feat1, 48, [3, 3], stride=1, 40 | normalizer_fn=None, activation_fn=None) 41 | 42 | bs, ch = tf.shape(fc3)[0], tf.shape(fc3)[1] 43 | fc_reshape = tf.reshape(fc3, [bs, 1, 1, ch]) 44 | fusion = tf.nn.relu6(feat2 + fc_reshape) 45 | 46 | conv7 = slim.convolution2d(fusion, 24*luma_bins, [1, 1], stride=1, 47 | normalizer_fn=None, activation_fn=None) 48 | 49 | stack1 = tf.stack(tf.split(conv7, 24, axis=3), axis=4) 50 | stack2 = tf.stack(tf.split(stack1, 4, axis=4), axis=5) 51 | #print(stack2.get_shape().as_list()) 52 | # [1, 16, 16, 8, 9, 4] 53 | b, h, w, ch1, ch2, ch3 = get_tensor_shape(stack2) 54 | stack2 = tf.reshape(stack2, [b, h, w, ch1*ch2*ch3]) 55 | return stack2 56 | 57 | 58 | 59 | ''' 60 | def coefficients(inputs, luma_bins=8, is_training='False', name='coefficients'): 61 | with tf.variable_scope(name): 62 | 63 | with slim.arg_scope([slim.separable_convolution2d, slim.fully_connected], 64 | normalizer_fn=slim.batch_norm, activation_fn=tf.nn.relu6): 65 | with slim.arg_scope([slim.batch_norm], 66 | is_training=is_training, center=True, scale=True): 67 | 68 | x = slim.convolution2d(inputs, 32, [3, 3], stride=2) 69 | 70 | x = slim.convolution2d(x, 64, [3, 3], stride=2) 71 | 72 | x = slim.convolution2d(x, 96, [3, 3], stride=2) 73 | 74 | x = slim.convolution2d(x, 128, [3, 3], stride=2) 75 | 76 | 77 | conv4 = x 78 | 79 | for _ in range(2): 80 | x = slim.convolution2d(x, 48, [3, 3], stride=2) 81 | 82 | 83 | pool = tf.reduce_mean(x, axis=[1, 2], keepdims=False) 84 | 85 | fc1 = slim.fully_connected(pool, 192) 86 | 87 | fc2 = slim.fully_connected(fc1, 96) 88 | 89 | fc3 = slim.fully_connected(fc2, 48) 90 | 91 | conv5 = slim.convolution2d(conv4, 48, [3, 3], stride=1) 92 | 93 | conv6 = slim.convolution2d(conv5, 48, [3, 3], stride=1, 94 | normalizer_fn=None, activation_fn=None) 95 | 96 | bs, ch = tf.shape(fc3)[0], tf.shape(fc3)[1] 97 | fc_reshape = tf.reshape(fc3, [bs, 1, 1, ch]) 98 | fusion = tf.nn.relu6(conv6 + fc_reshape) 99 | 100 | conv7 = slim.convolution2d(fusion, 24*luma_bins, [1, 1], stride=1, 101 | normalizer_fn=None, activation_fn=None) 102 | 103 | stack1 = tf.stack(tf.split(conv7, 24, axis=3), axis=4) 104 | stack2 = tf.stack(tf.split(stack1, 4, axis=4), axis=5) 105 | #print(stack2.get_shape().as_list()) 106 | # [1, 16, 16, 8, 9, 4] 107 | b, h, w, ch1, ch2, ch3 = get_tensor_shape(stack2) 108 | stack2 = tf.reshape(stack2, [b, h, w, ch1*ch2*ch3]) 109 | return stack2 110 | ''' 111 | 112 | 113 | def guide(inputs, is_training=False, name='guide'): 114 | with tf.variable_scope(name): 115 | in_ch = inputs.get_shape().as_list()[-1] 116 | idtity = np.identity(in_ch, dtype=np.float32)\ 117 | + np.random.randn(1).astype(np.float32)*1e-4 118 | ccm = tf.get_variable('ccm', dtype=tf.float32, initializer=idtity) 119 | ccm_bias = tf.get_variable('ccm_bias', shape=[in_ch,], dtype=tf.float32, 120 | initializer=tf.constant_initializer(0.0)) 121 | 122 | guidemap = tf.matmul(tf.reshape(inputs, [-1, in_ch]), ccm) 123 | guidemap = tf.nn.bias_add(guidemap, ccm_bias, name='ccm_bias_add') 124 | guidemap = tf.reshape(guidemap, tf.shape(inputs)) 125 | 126 | shifts_ = np.linspace(0, 1, 16, endpoint=False, dtype=np.float32) 127 | shifts_ = shifts_[np.newaxis, np.newaxis, np.newaxis, :] 128 | shifts_ = np.tile(shifts_, (1, 1, in_ch, 1)) 129 | 130 | guidemap = tf.expand_dims(guidemap, 4) 131 | shifts = tf.get_variable('shifts', dtype=tf.float32, initializer=shifts_) 132 | 133 | slopes_ = np.zeros([1, 1, 1, in_ch, 16], dtype=np.float32) 134 | slopes_[:, :, :, :, 0] = 1.0 135 | slopes = tf.get_variable('slopes', dtype=tf.float32, initializer=slopes_) 136 | 137 | guidemap = tf.reduce_sum(slopes*tf.nn.relu6(guidemap-shifts), reduction_indices=[4]) 138 | guidemap = slim.convolution2d(guidemap, 1, [1, 1], activation_fn=None, 139 | weights_initializer=tf.constant_initializer(1.0/in_ch)) 140 | guidemap = tf.clip_by_value(guidemap, 0, 1) 141 | guidemap = tf.squeeze(guidemap, squeeze_dims=[3,]) 142 | 143 | return guidemap 144 | 145 | 146 | 147 | def inference(hr_input, width=0.75, lr_size=(256, 256), is_training=False, name='inference'): 148 | with tf.variable_scope(name): 149 | 150 | lr_input = tf.image.resize_images(hr_input, lr_size, 151 | tf.image.ResizeMethod.BILINEAR) 152 | coeffs = coef_mobilenetv2(lr_input, width=width, is_training=is_training) 153 | #coeffs = coeffient(lr_input, is_training=is_training) 154 | guidemap = guide(hr_input, is_training=is_training) 155 | output = apply_bilateral_grid(coeffs, guidemap, hr_input) 156 | return output 157 | 158 | 159 | 160 | 161 | if __name__ == '__main__': 162 | 163 | hr_input = tf.placeholder(tf.float32, [1, 1024, 1024, 3]) 164 | outputs = inference(hr_input, is_training=False) 165 | print(outputs.get_shape().as_list()) 166 | 167 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import tensorflow as tf 4 | import numpy as np 5 | from network import inference 6 | from dataset import get_test_loader 7 | from tqdm import tqdm 8 | 9 | 10 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 11 | 12 | 13 | def test(): 14 | 15 | 16 | input_image = tf.placeholder(tf.float32, [None, None, None, 3]) 17 | result_image = inference(input_image, name='generator') 18 | 19 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.99) 20 | sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) 21 | saver = tf.train.Saver() 22 | 23 | if not os.path.exists('results'): 24 | os.mkdir('results') 25 | 26 | with tf.device('/device:GPU:0'): 27 | 28 | data_dir = 'data_loacation_in_your_computer' 29 | dataloader = get_test_loader(data_dir) 30 | 31 | sess.run(tf.global_variables_initializer()) 32 | saver.restore(sess, tf.train.latest_checkpoint('saved_models')) 33 | 34 | for idx, batch in tqdm(enumerate(dataloader)): 35 | result = sess.run([result_image], feed_dict={input_image: batch[0]}) 36 | result = np.squeeze(result)*255 37 | result = np.clip(result, 0, 255).astype(np.float32) 38 | ground_truth = np.squeeze(batch[1]) 39 | save_out_path = os.path.join('results', '{}.jpg'.format(str(idx).zfill(4))) 40 | save_gt_path = os.path.join('results', '{}_gt.jpg'.format(str(idx).zfill(4))) 41 | cv2.imwrite(save_out_path, result) 42 | cv2.imwrite(save_gt_path, ground_truth) 43 | 44 | if __name__ == '__main__': 45 | test() 46 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import tensorflow as tf 4 | import numpy as np 5 | from network import inference 6 | from dataset import get_train_loader 7 | from tqdm import tqdm 8 | 9 | 10 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 11 | 12 | 13 | def log10(x): 14 | numerator = tf.log(x) 15 | denominator = tf.log(tf.constant(10, dtype=numerator.dtype)) 16 | return numerator / denominator 17 | 18 | 19 | def color_loss(image, label, len_reg=0): 20 | 21 | vec1 = tf.reshape(image, [-1, 3]) 22 | vec2 = tf.reshape(label, [-1, 3]) 23 | clip_value = 0.999999 24 | norm_vec1 = tf.nn.l2_normalize(vec1, 1) 25 | norm_vec2 = tf.nn.l2_normalize(vec2, 1) 26 | dot = tf.reduce_sum(norm_vec1*norm_vec2, 1) 27 | dot = tf.clip_by_value(dot, -clip_value, clip_value) 28 | angle = tf.acos(dot) * (180/math.pi) 29 | 30 | return tf.reduce_mean(angle) 31 | 32 | 33 | def smoothness_loss(image): 34 | clip_low, clip_high = 0.000001, 0.999999 35 | image = tf.clip_by_value(image, clip_low, clip_high) 36 | image_h, image_w = tf.shape(image)[1], tf.shape(image)[2] 37 | tv_x = tf.reduce_mean((image[:, 1:, :, :]-image[:, :image_h-1, :, :])**2) 38 | tv_y = tf.reduce_mean((image[:, :, 1:, :]-image[:, :, :image_w-1, :])**2) 39 | total_loss = (tv_x + tv_y)/2 40 | ''' 41 | log_image = tf.log(image) 42 | log_tv_x = tf.reduce_mean((log_image[:, 1:, :, :]- 43 | log_image[:, :image_h-1, :, :])**1.2) 44 | log_tv_y = tf.reduce_mean((log_image[:, :, 1:, :]- 45 | log_image[:, :, :image_w-1, :])**1.2) 46 | total_loss = tv_x / (log_tv_x + 1e-4) + tv_y / (log_tv_y + 1e-4) 47 | ''' 48 | return total_loss 49 | 50 | 51 | 52 | def reconstruct_loss(image, label): 53 | l2_loss = tf.reduce_mean(tf.square(label-image)) 54 | return l2_loss 55 | 56 | 57 | def cal_psnr(pred, label): 58 | label_tmp, pred_tmp = label*255, pred*255 59 | mse = tf.reduce_mean(tf.squared_difference(label_tmp, pred_tmp)) 60 | mse = tf.cast(mse, tf.float32) 61 | train_psnr = tf.constant(10, dtype=tf.float32)*\ 62 | log10(tf.constant(255**2, dtype=tf.float32)/mse) 63 | return train_psnr 64 | 65 | 66 | def train(): 67 | total_epoch, total_iter = 100, 0 68 | best_loss, init_lr = 1e10, 5e-5 69 | batch_size, image_h, image_w = 8, 512, 512 70 | 71 | 72 | image = tf.placeholder(tf.float32, [None, image_h, image_w, 3]) 73 | label = tf.placeholder(tf.float32, [None, image_h, image_w, 3]) 74 | lr = tf.placeholder(tf.float32) 75 | 76 | pred = inference(image, width=0.75, is_training=True) 77 | c_loss = color_loss(pred, label) 78 | s_loss = smoothness_loss(pred) 79 | r_loss = reconstruct_loss(pred, label) 80 | total_loss = 1e-2*c_loss + 1e2*s_loss + r_loss 81 | #total_loss = c_loss + r_loss 82 | 83 | all_vars = tf.trainable_variables() 84 | backbone_vars = [var for var in all_vars if 'backbone' in var.name] 85 | train_psnr = cal_psnr(pred, label) 86 | 87 | 88 | tf.summary.scalar('loss', total_loss) 89 | tf.summary.scalar('color_loss', c_loss) 90 | tf.summary.scalar('smoothness_loss', s_loss) 91 | tf.summary.scalar('reconstruct_loss', r_loss) 92 | tf.summary.scalar('psnr', train_psnr) 93 | 94 | 95 | optimizer = tf.train.AdamOptimizer(learning_rate=lr) 96 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 97 | train_op = optimizer.minimize(total_loss) 98 | train_op = tf.group([train_op, update_ops]) 99 | 100 | 101 | config = tf.ConfigProto() 102 | config.gpu_options.allow_growth = True 103 | sess = tf.Session(config=config) 104 | ''' 105 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.75) 106 | sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) 107 | ''' 108 | train_writer = tf.summary.FileWriter('train_log', sess.graph) 109 | summary_op = tf.summary.merge_all() 110 | saver = tf.train.Saver() 111 | 112 | with tf.device('/device:GPU:0'): 113 | 114 | sess.run(tf.global_variables_initializer()) 115 | 116 | weight = np.load('mobilenetv2_075.npy', allow_pickle=True) 117 | assign_ops = [] 118 | for var, para in zip(backbone_vars, weight): 119 | assign_ops.append(var.assign(para)) 120 | sess.run(assign_ops) 121 | 122 | 123 | data_dir = 'data_loacation_in_your_computer' 124 | dataloader = get_train_loader((image_h, image_w), batch_size, data_dir) 125 | 126 | for epoch in range(total_epoch): 127 | for batch in tqdm(dataloader): 128 | total_iter += 1 129 | 130 | _, train_info, loss = sess.run([train_op, summary_op, total_loss], 131 | feed_dict={image: batch[0], 132 | label: batch[1], 133 | lr: init_lr}) 134 | train_writer.add_summary(train_info, total_iter) 135 | 136 | if np.mod(total_iter, 20) == 0: 137 | print('{}th epoch, {}th iter, loss: {}'.format(epoch, total_iter, loss)) 138 | if loss < best_loss: 139 | best_loss = loss 140 | saver.save(sess, 'saved_models/model', global_step=total_iter) 141 | 142 | 143 | 144 | 145 | if __name__ == '__main__': 146 | train() 147 | 148 | #test() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def get_tensor_shape(x): 4 | a = x.get_shape().as_list() 5 | b = [tf.shape(x)[i] for i in range(len(a))] 6 | return [aa if type(aa) is int else bb for aa, bb in zip(a, b)] 7 | 8 | 9 | # [b, n, c] 10 | def sample_1d(image, y_idx): 11 | # img: [b, h, c] 12 | # y_idx: [b, n], 0 <= pos < h, dtpye=int32 13 | b, h, c = get_tensor_shape(image) 14 | b, n = get_tensor_shape(y_idx) 15 | 16 | b_idx = tf.range(b, dtype=tf.int32) # [b] 17 | b_idx = tf.expand_dims(b_idx, -1) # [b, 1] 18 | b_idx = tf.tile(b_idx, [1, n]) # [b, n] 19 | 20 | y_idx = tf.clip_by_value(y_idx, 0, h - 1) # [b, n] 21 | a_idx = tf.stack([b_idx, y_idx], axis=-1) # [b, n, 2] 22 | 23 | output = tf.gather_nd(image, a_idx) 24 | #print('sample 1d shape:', output.get_shape().as_list()) 25 | return output 26 | 27 | # [b, n, c] 28 | def interp_1d(image, y): 29 | # img: [b, h, c] 30 | # y: [b, n], 0 <= pos < h, dtpye=int32 31 | 32 | b, h, c = get_tensor_shape(image) 33 | b, n = get_tensor_shape(y) 34 | 35 | y_0 = tf.floor(y) # [b, n] 36 | y_1 = y_0 + 1 37 | 38 | _sample_func = lambda y_x: sample_1d(image, 39 | tf.cast(y_x, tf.int32)) 40 | y_0_val = _sample_func(y_0) # [b, n, c] 41 | y_1_val = _sample_func(y_1) 42 | 43 | w_0 = y_1 - y # [b, n] 44 | w_1 = y - y_0 45 | 46 | w_0 = tf.expand_dims(w_0, -1) # [b, n, 1] 47 | w_1 = tf.expand_dims(w_1, -1) 48 | 49 | return w_0*y_0_val + w_1*y_1_val 50 | 51 | # [b, h, w, 3] 52 | def apply_bilateral_grid(bilateral_grid, guide, in_image): 53 | 54 | # bilateral_grid :[b, ?, ?, d*3*4] 55 | # guide: [b, h, w], 0 <= guide <= 1 56 | # in_image: [b, h, w, 3] 57 | 58 | 59 | b, _, _, d34, = get_tensor_shape(bilateral_grid) 60 | b, h, w, = get_tensor_shape(guide) 61 | b, h, w, _, = get_tensor_shape(in_image) 62 | 63 | d = d34//3//4 64 | 65 | bilateral_grid = tf.image.resize_images(bilateral_grid, [h, w]) 66 | # [b, h, w, d*3*4] 67 | 68 | coef = interp_1d(tf.reshape(bilateral_grid, [b*h*w, d, 3*4]), 69 | (d - 1)*tf.reshape(guide, [b*h*w, 1])) 70 | # [b*h*w, 1, 3*4] 71 | coef = tf.reshape(coef, [b, h, w, 3, 4]) 72 | # [b, h, w, 3, 4] 73 | 74 | ''' 75 | in_image = tf.reshape(in_image, [b, h, w, 3, 1]) 76 | # [b, h, w, 3, 1] 77 | in_image = tf.pad(in_image, [[0, 0], [0, 0], [0, 0], [0, 1], [0, 0]], 78 | mode='CONSTANT', constant_values=1) 79 | 80 | # [b, h, w, 4, 1] 81 | 82 | out_image = tf.matmul(coef, in_image) # [b, h, w, 3, 1] 83 | out_image = tf.reshape(out_image, [b, h, w, 3]) # [b, h, w, 3] 84 | ''' 85 | 86 | in_image = tf.pad(in_image, [[0, 0], [0, 0], [0, 0], [0, 1]], 87 | mode='CONSTANT', constant_values=1) 88 | out_image = reduce_maltiply(in_image, coef) 89 | 90 | 91 | return tf.clip_by_value(out_image, -1, 1) 92 | 93 | 94 | 95 | def reduce_maltiply(image, coef): 96 | ch = coef.get_shape().as_list()[3] 97 | output = [] 98 | for i in range(ch): 99 | mul_channel = image * tf.squeeze(coef[:, :, :, i, :]) 100 | reduced_channel = tf.reduce_mean(mul_channel, axis=3) 101 | output.append(reduced_channel) 102 | output = tf.stack(output, axis=3) 103 | return output 104 | 105 | 106 | 107 | if __name__ == '__main__': 108 | image = tf.placeholder(tf.float32, [1, 1024, 1024, 3]) 109 | guide = tf.placeholder(tf.float32, [1, 1024, 1024]) 110 | grid = tf.placeholder(tf.float32, [1, 16, 16, 36]) 111 | output = apply_bilateral_grid(grid, guide, image) 112 | print(output.get_shape().as_list()) 113 | --------------------------------------------------------------------------------