├── src ├── fog.png └── model.png ├── LICENSE ├── extract.py ├── operations.py ├── legacy ├── utils.py ├── vgg19small.py └── replicate.py ├── vgg19.py ├── main.py ├── issue.md ├── README.md └── model.py /src/fog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thatbrguy/Dehaze-GAN/HEAD/src/fog.png -------------------------------------------------------------------------------- /src/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thatbrguy/Dehaze-GAN/HEAD/src/model.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Bharath Raj 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /extract.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import h5py 4 | import numpy as np 5 | from skimage.transform import resize 6 | 7 | if __name__ == '__main__': 8 | 9 | if not os.path.exists('A'): 10 | os.mkdir('A') 11 | 12 | if not os.path.exists('B'): 13 | os.mkdir('B') 14 | 15 | with h5py.File('data.mat', 'r') as f: 16 | images = np.array(f['images']) 17 | depths = np.array(f['depths']) 18 | 19 | images = images.transpose(0, 1, 3, 2) 20 | depths = depths.transpose(2, 1, 0) 21 | depths = (depths - np.min(depths, axis = (0, 1))) / np.max(depths, axis = (0, 1)) 22 | depths = ((1 - depths) * np.random.uniform(0.2, 0.4, size = (1449, ))).transpose(2, 0, 1) 23 | 24 | for i in range(len(images)): 25 | fog = (images[i] * depths[i]) + (1 - depths[i]) * np.ones_like(depths[i]) * 255 26 | fog = resize(fog.transpose(1, 2, 0), (256, 256, 3), mode = 'reflect') 27 | img = resize(images[i].transpose(1, 2, 0), (256, 256, 3), mode = 'reflect') 28 | img = (img * 255).astype(np.uint8) 29 | 30 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 31 | fog = cv2.cvtColor(fog.astype(np.uint8), cv2.COLOR_RGB2BGR) 32 | 33 | cv2.imwrite(os.path.join('A', str(i).zfill(4) + '.png'), fog) 34 | cv2.imwrite(os.path.join('B', str(i).zfill(4) + '.png'), img) 35 | 36 | print('Extracting image:', i, end = '\r') 37 | 38 | print('Done.') 39 | -------------------------------------------------------------------------------- /operations.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def Conv(input_, kernel_size, stride, output_channels, padding = 'SAME', mode = None): 4 | 5 | with tf.variable_scope("Conv") as scope: 6 | 7 | input_channels = input_.get_shape()[-1] 8 | kernel_shape = [kernel_size, kernel_size, input_channels, output_channels] 9 | 10 | kernel = tf.get_variable("Filter", shape = kernel_shape, dtype = tf.float32, initializer = tf.keras.initializers.he_normal()) 11 | 12 | # Patchwise Discriminator (PatchGAN) requires some modifications. 13 | if mode == 'discriminator': 14 | input_ = tf.pad(input_, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="CONSTANT") 15 | 16 | return tf.nn.conv2d(input_, kernel, strides = [1, stride, stride, 1], padding = padding) 17 | 18 | def TransposeConv(input_, output_channels, kernel_size = 4): 19 | 20 | with tf.variable_scope("TransposeConv") as scope: 21 | 22 | input_height, input_width, input_channels = [int(d) for d in input_.get_shape()[1:]] 23 | batch_size = tf.shape(input_)[0] 24 | 25 | kernel_shape = [kernel_size, kernel_size, output_channels, input_channels] 26 | output_shape = tf.stack([batch_size, input_height*2, input_width*2, output_channels]) 27 | 28 | kernel = tf.get_variable(name = "filter", shape = kernel_shape, dtype=tf.float32, initializer = tf.keras.initializers.he_normal()) 29 | 30 | return tf.nn.conv2d_transpose(input_, kernel, output_shape, [1, 2, 2, 1], padding="SAME") 31 | 32 | def MaxPool(input_): 33 | with tf.variable_scope("MaxPool"): 34 | return tf.nn.max_pool(input_, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 35 | 36 | def AvgPool(input_, k = 2): 37 | with tf.variable_scope("AvgPool"): 38 | return tf.nn.avg_pool(input_, ksize=[1, k, k, 1], strides=[1, k, k, 1], padding='VALID') 39 | 40 | def ReLU(input_): 41 | with tf.variable_scope("ReLU"): 42 | return tf.nn.relu(input_) 43 | 44 | def LeakyReLU(input_, leak = 0.2): 45 | with tf.variable_scope("LeakyReLU"): 46 | return tf.maximum(input_, leak * input_) 47 | 48 | def BatchNorm(input_, isTrain, name='BN', decay = 0.99): 49 | with tf.variable_scope(name) as scope: 50 | return tf.contrib.layers.batch_norm(input_, is_training = isTrain, decay = decay) 51 | 52 | def DropOut(input_, isTrain, rate=0.2, name='drop') : 53 | with tf.variable_scope(name) as scope: 54 | return tf.layers.dropout(inputs=input_, rate=rate, training=isTrain) 55 | -------------------------------------------------------------------------------- /legacy/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def conv(ip, out_channels, stride): 4 | 5 | with tf.variable_scope('Conv') as scope: 6 | ip_channels = ip.get_shape()[-1] 7 | h = tf.get_variable(name = 'filter', shape = [4, 4, ip_channels, out_channels], dtype = tf.float32, initializer = tf.keras.initializers.he_normal()) 8 | padded_input = tf.pad(ip, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="CONSTANT") 9 | conv = tf.nn.conv2d(padded_input, h, [1, stride, stride, 1], padding="VALID") 10 | return conv 11 | 12 | def deconv(ip, out_channels, filters = 4): 13 | 14 | with tf.variable_scope("DeConv") as scope: 15 | in_height, in_width, in_channels = [int(d) for d in ip.get_shape()[1:]] 16 | batch_size = tf.shape(ip)[0] 17 | output_shape = tf.stack([batch_size, in_height*2, in_width*2, out_channels]) 18 | h = tf.get_variable(name = "filter", shape = [filters, filters, out_channels, in_channels], dtype=tf.float32, initializer = tf.keras.initializers.he_normal()) 19 | conv = tf.nn.conv2d_transpose(ip, h, output_shape, [1, 2, 2, 1], padding="SAME") 20 | return conv 21 | 22 | def lrelu(ip, leak = 0.2): 23 | with tf.name_scope("LeakyRelu"): 24 | return tf.maximum(ip, leak*ip) 25 | 26 | def batchnorm(input): 27 | with tf.variable_scope("BatchNorm"): 28 | # this block looks like it has 3 inputs on the graph unless we do this 29 | input = tf.identity(input) 30 | 31 | channels = input.get_shape()[3] 32 | offset = tf.get_variable("offset", [channels], dtype=tf.float32, initializer=tf.zeros_initializer()) 33 | scale = tf.get_variable("scale", [channels], dtype=tf.float32, initializer=tf.random_normal_initializer(1.0, 0.02)) 34 | mean, variance = tf.nn.moments(input, axes=[0, 1, 2], keep_dims=False) 35 | variance_epsilon = 1e-5 36 | normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon) 37 | return normalized 38 | 39 | def linear(ip, output_size, name = 'linear'): 40 | shape = ip.get_shape().as_list() 41 | with tf.variable_scope(name): 42 | w = tf.get_variable("w_"+name, shape = [shape[1], output_size] , dtype = tf.float32, 43 | initializer = tf.keras.initializers.he_normal()) 44 | bias = tf.get_variable("b_"+name, [output_size], initializer=tf.constant_initializer(0.0)) 45 | 46 | return tf.matmul(ip,w) + bias 47 | 48 | def Conv(ip, filter, stride, output_ch, padding = 'SAME'): 49 | input_ch = ip.get_shape()[3] 50 | h = tf.get_variable("Filter", shape = [filter, filter, input_ch, output_ch], dtype = tf.float32, initializer = tf.keras.initializers.he_normal()) 51 | b = tf.get_variable("Bias", shape = [output_ch], dtype = tf.float32, initializer = tf.constant_initializer(0)) 52 | return tf.nn.conv2d(ip, h, strides = [1, stride, stride, 1], padding = padding) 53 | 54 | def MaxPool(ip): 55 | return tf.nn.max_pool(ip, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME') 56 | 57 | def AvgPool(ip, k=2): 58 | return tf.nn.avg_pool(ip, ksize=[1,k,k,1], strides=[1,k,k,1], padding='VALID') 59 | 60 | def Relu(ip): 61 | return tf.nn.relu(ip) 62 | 63 | def BatchNorm(ip, isTrain, decay = 0.99): 64 | return tf.contrib.layers.batch_norm(ip, is_training = isTrain, decay = decay) 65 | 66 | def DropOut(x, rate = 0.2) : 67 | return tf.nn.dropout(x, keep_prob=(1-rate)) -------------------------------------------------------------------------------- /vgg19.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | 4 | import numpy as np 5 | import time 6 | import inspect 7 | 8 | VGG_MEAN = [103.939, 116.779, 123.68] 9 | 10 | 11 | class Vgg19: 12 | def __init__(self, vgg19_npy_path=None): 13 | self.data_dict = np.load('./vgg19.npy', encoding='latin1').item() 14 | 15 | def feature_map(self, rgb): 16 | """ 17 | load variable from npy to build the VGG 18 | 19 | :param rgb: rgb image [batch, height, width, 3] values scaled [0, 1] 20 | """ 21 | 22 | start_time = time.time() 23 | rgb_scaled = rgb * 255.0 24 | 25 | # Convert RGB to BGR 26 | red, green, blue = tf.split(axis=3, num_or_size_splits=3, value=rgb_scaled) 27 | assert red.get_shape().as_list()[1:] == [224, 224, 1] 28 | assert green.get_shape().as_list()[1:] == [224, 224, 1] 29 | assert blue.get_shape().as_list()[1:] == [224, 224, 1] 30 | bgr = tf.concat(axis=3, values=[ 31 | blue - VGG_MEAN[0], 32 | green - VGG_MEAN[1], 33 | red - VGG_MEAN[2], 34 | ]) 35 | assert bgr.get_shape().as_list()[1:] == [224, 224, 3] 36 | 37 | self.conv1_1 = self.conv_layer(bgr, "conv1_1") 38 | self.conv1_2 = self.conv_layer(self.conv1_1, "conv1_2") 39 | self.pool1 = self.max_pool(self.conv1_2, 'pool1') 40 | 41 | self.conv2_1 = self.conv_layer(self.pool1, "conv2_1") 42 | self.conv2_2 = self.conv_layer(self.conv2_1, "conv2_2") 43 | self.pool2 = self.max_pool(self.conv2_2, 'pool2') 44 | 45 | self.conv3_1 = self.conv_layer(self.pool2, "conv3_1") 46 | self.conv3_2 = self.conv_layer(self.conv3_1, "conv3_2") 47 | self.conv3_3 = self.conv_layer(self.conv3_2, "conv3_3") 48 | self.conv3_4 = self.conv_layer(self.conv3_3, "conv3_4") 49 | self.pool3 = self.max_pool(self.conv3_4, 'pool3') 50 | 51 | self.conv4_1 = self.conv_layer(self.pool3, "conv4_1") 52 | self.conv4_2 = self.conv_layer(self.conv4_1, "conv4_2") 53 | self.conv4_3 = self.conv_layer(self.conv4_2, "conv4_3") 54 | self.conv4_4 = self.conv_layer(self.conv4_3, "conv4_4") 55 | self.pool4 = self.max_pool(self.conv4_4, 'pool4') 56 | 57 | output = self.pool4 58 | 59 | self.conv5_1 = self.conv_layer(self.pool4, "conv5_1") 60 | self.conv5_2 = self.conv_layer(self.conv5_1, "conv5_2") 61 | self.conv5_3 = self.conv_layer(self.conv5_2, "conv5_3") 62 | self.conv5_4 = self.conv_layer(self.conv5_3, "conv5_4") 63 | self.pool5 = self.max_pool(self.conv5_4, 'pool5') 64 | 65 | return self.pool4 66 | 67 | def avg_pool(self, bottom, name): 68 | return tf.nn.avg_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name) 69 | 70 | def max_pool(self, bottom, name): 71 | return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name) 72 | 73 | def conv_layer(self, bottom, name): 74 | with tf.variable_scope(name): 75 | filt = self.get_conv_filter(name) 76 | 77 | conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME') 78 | 79 | conv_biases = self.get_bias(name) 80 | bias = tf.nn.bias_add(conv, conv_biases) 81 | 82 | relu = tf.nn.relu(bias) 83 | return relu 84 | 85 | def get_conv_filter(self, name): 86 | return tf.constant(self.data_dict[name][0], name="filter") 87 | 88 | def get_bias(self, name): 89 | return tf.constant(self.data_dict[name][1], name="biases") 90 | -------------------------------------------------------------------------------- /legacy/vgg19small.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | 4 | import numpy as np 5 | import time 6 | import inspect 7 | 8 | VGG_MEAN = [103.939, 116.779, 123.68] 9 | 10 | 11 | class Vgg19: 12 | def __init__(self, vgg19_npy_path=None): 13 | self.data_dict = np.load('./vgg19.npy', encoding='latin1').item() 14 | 15 | def feature_map(self, rgb): 16 | """ 17 | load variable from npy to build the VGG 18 | 19 | :param rgb: rgb image [batch, height, width, 3] values scaled [0, 1] 20 | """ 21 | 22 | start_time = time.time() 23 | rgb_scaled = rgb * 255.0 24 | 25 | # Convert RGB to BGR 26 | red, green, blue = tf.split(axis=3, num_or_size_splits=3, value=rgb_scaled) 27 | assert red.get_shape().as_list()[1:] == [224, 224, 1] 28 | assert green.get_shape().as_list()[1:] == [224, 224, 1] 29 | assert blue.get_shape().as_list()[1:] == [224, 224, 1] 30 | bgr = tf.concat(axis=3, values=[ 31 | blue - VGG_MEAN[0], 32 | green - VGG_MEAN[1], 33 | red - VGG_MEAN[2], 34 | ]) 35 | assert bgr.get_shape().as_list()[1:] == [224, 224, 3] 36 | 37 | self.conv1_1 = self.conv_layer(bgr, "conv1_1") 38 | self.conv1_2 = self.conv_layer(self.conv1_1, "conv1_2") 39 | self.pool1 = self.max_pool(self.conv1_2, 'pool1') 40 | 41 | self.conv2_1 = self.conv_layer(self.pool1, "conv2_1") 42 | self.conv2_2 = self.conv_layer(self.conv2_1, "conv2_2") 43 | self.pool2 = self.max_pool(self.conv2_2, 'pool2') 44 | 45 | self.conv3_1 = self.conv_layer(self.pool2, "conv3_1") 46 | self.conv3_2 = self.conv_layer(self.conv3_1, "conv3_2") 47 | self.conv3_3 = self.conv_layer(self.conv3_2, "conv3_3") 48 | self.conv3_4 = self.conv_layer(self.conv3_3, "conv3_4") 49 | self.pool3 = self.max_pool(self.conv3_4, 'pool3') 50 | 51 | self.conv4_1 = self.conv_layer(self.pool3, "conv4_1") 52 | self.conv4_2 = self.conv_layer(self.conv4_1, "conv4_2") 53 | self.conv4_3 = self.conv_layer(self.conv4_2, "conv4_3") 54 | self.conv4_4 = self.conv_layer(self.conv4_3, "conv4_4") 55 | self.pool4 = self.max_pool(self.conv4_4, 'pool4') 56 | 57 | output = self.pool4 58 | 59 | self.conv5_1 = self.conv_layer(self.pool4, "conv5_1") 60 | self.conv5_2 = self.conv_layer(self.conv5_1, "conv5_2") 61 | self.conv5_3 = self.conv_layer(self.conv5_2, "conv5_3") 62 | self.conv5_4 = self.conv_layer(self.conv5_3, "conv5_4") 63 | self.pool5 = self.max_pool(self.conv5_4, 'pool5') 64 | 65 | return self.pool4 66 | 67 | def avg_pool(self, bottom, name): 68 | return tf.nn.avg_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name) 69 | 70 | def max_pool(self, bottom, name): 71 | return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name) 72 | 73 | def conv_layer(self, bottom, name): 74 | with tf.variable_scope(name): 75 | filt = self.get_conv_filter(name) 76 | 77 | conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME') 78 | 79 | conv_biases = self.get_bias(name) 80 | bias = tf.nn.bias_add(conv, conv_biases) 81 | 82 | relu = tf.nn.relu(bias) 83 | return relu 84 | 85 | def get_conv_filter(self, name): 86 | return tf.constant(self.data_dict[name][0], name="filter") 87 | 88 | def get_bias(self, name): 89 | return tf.constant(self.data_dict[name][1], name="biases") 90 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from model import GAN 4 | 5 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 6 | #os.environ['CUDA_VISIBLE_DEVICES'] = '-1' 7 | 8 | parser = argparse.ArgumentParser() 9 | 10 | parser.add_argument("--lr", help="Learning Rate (Default = 0.001)", 11 | type = float, default = 0.001) 12 | parser.add_argument("--D_filters", help="Number of filters in the 1st conv layer of the discriminator (Default = 64)", 13 | type = int, default = 64) 14 | parser.add_argument("--layers", help="Number of layers per dense block (Default = 4)", 15 | type = int, default = 4) 16 | parser.add_argument("--growth_rate", help="Growth Rate of the dense block (Default = 12) ", 17 | type = int, default = 12) 18 | parser.add_argument("--gan_wt", help="Weight of the GAN loss factor (Default = 2)", 19 | type = float, default = 2) 20 | parser.add_argument("--l1_wt", help="Weight of the L1 loss factor (Default = 100)", 21 | type = float, default = 100) 22 | parser.add_argument("--vgg_wt", help="Weight of the VGG loss factor (Default = 10)", 23 | type = float, default = 10) 24 | parser.add_argument("--restore", help = "Restore checkpoint for training (Default = False)", 25 | type = bool, default = False) 26 | parser.add_argument("--batch_size", help="Set the batch size (Default = 1)", 27 | type = int, default = 1) 28 | parser.add_argument("--decay", help="Batchnorm decay (Default = 0.99)", 29 | type = float, default = 0.99) 30 | parser.add_argument("--epochs", help = "Epochs (Default = 200)", 31 | type = int, default = 200) 32 | parser.add_argument("--model_name", help = "Set a model name", 33 | default = 'model') 34 | parser.add_argument("--save_samples", help = "Generate image samples after validation (Default = False)", 35 | type = bool, default = False) 36 | parser.add_argument("--sample_image_dir", help = "Directory containing sample images (Used only if save_samples is True; Default = samples)", 37 | default = 'samples') 38 | parser.add_argument("--A_dir", help = "Directory containing the input images for training, testing or inference (Default = A)", 39 | default = 'A') 40 | parser.add_argument("--B_dir", help = "Directory containing the target images for training or testing. In inference mode, this is used to store results (Default = B)", 41 | default = 'B') 42 | parser.add_argument("--custom_data", help = "Using your own data as input and target (Default = True)", 43 | type = bool, default = True) 44 | parser.add_argument("--val_fraction", help = "Fraction of dataset to be split for validation (Default = 0.15)", 45 | type = float, default = 0.15) 46 | parser.add_argument("--val_threshold", help = "Number of steps to wait before validation is enabled. (Default = 0)", 47 | type = int, default = 0) 48 | parser.add_argument("--val_frequency", help = "Number of batches to wait before perfoming the next validation run (Default = 20)", 49 | type = int, default = 20) 50 | parser.add_argument("--logger_frequency", help = "Number of batches to wait before logging the next set of loss values (Default = 20)", 51 | type = int, default = 20) 52 | parser.add_argument("--mode", help = "Select between train, test or inference modes", 53 | default = 'train', choices = ['train', 'test', 'inference']) 54 | 55 | if __name__ == '__main__': 56 | 57 | args = parser.parse_args() 58 | net = GAN(args) 59 | if args.mode == 'train': 60 | net.train() 61 | if args.mode == 'test': 62 | net.test(args.A_dir, args.B_dir) 63 | if args.mode == 'inference': 64 | net.inference(args.A_dir, args.B_dir) 65 | -------------------------------------------------------------------------------- /issue.md: -------------------------------------------------------------------------------- 1 | # June 2021 Update 2 | 3 | On a recent review of some code I found a few issues. In this document, I provide an explanation of the issues, their possible impact, and the remedies I have taken to account for the issues. 4 | 5 | This codebase was created a few years ago when I was an undergraduate student and some of the files are cleaned up versions of the ones used for my experiments. Unfortunately since I was less experienced and messy back then, these issues crept in and it was hard to analyze if something went wrong and if so what are the causes and effects. I have documented and analyzed these issues and suggested remedies based on my recent review. I apologize for the inconvenience. 6 | 7 | ## Issues 8 | 9 | ### 1. Colorspace Issue 10 | 11 | In the codebase, the file `model.py` loads images using `cv2.imread`. This function loads the image to a NumPy array with the channels in the BGR format. Hence, the model would be trained and tested on BGR images. For ease of training and testing for my experiments, I used to store the images in NumPy files. This was done so that I can simply load the NumPy files instead of reading the images using OpenCV (you might have seen mentions of `A_train.npy`, `B_train.npy`, `A_test.npy` etc. in the codebase). The NumPy files had images in the BGR format and hence the model would be trained and tested on BGR images. Also, to save the model output, I used `cv2.imwrite`, which expects a NumPy array with the channels in BGR format. 12 | 13 | I recently realized that the file `extract.py` has a colorspace issue while saving the images to files. I observed that I used `cv2.imwrite` in `extract.py`, and that I passed NumPy arrays with channels in the RGB format to that function in `extract.py`. This is a mistake; I should have passed NumPy arrays with channels in BGR format to `cv2.imwrite` so that the files are saved in the correct colorspace. 14 | 15 | I have made a change in `extract.py` to fix this issue. Now, the images are converted to the BGR format in `extract.py` before they are saved. 16 | 17 | This issue **does not** affect `replicate.py` since it directly reads `A_test.npy` and `B_test.npy`. As mentioned before, both NumPy files have images already stored in the BGR format, which is then consumed by the model. 18 | 19 | ### 2. VGG Issue 20 | 21 | As mentioned in `README.md`, I used the pre-trained VGG 19 model from [machrisaa](https://github.com/machrisaa/tensorflow-vgg)'s implementation. This model is used for calculating the perceptual loss. On a recent review of the function `feature_map` in the files `vgg19.py` and `legacy/vgg19small.py`, I noticed that the docstring mentioned the following: 22 | 23 | ``` 24 | load variable from npy to build the VGG 25 | 26 | :param rgb: rgb image [batch, height, width, 3] values scaled [0, 1] 27 | ``` 28 | 29 | Unfortunately, for the final experiments that I had run, I might have forgotten to take into account the above docstring. 30 | 31 | Based on the docstring, it seems like the function needed an input with the following characteristics: 32 | 33 | - Image input in the RGB format with the values in the range `[0, 1]` 34 | 35 | However, the codebase provided the function an input with the following characteristics: 36 | 37 | - Image input in the BGR format with the values in the range `[-1, 1]` 38 | 39 | Hence, the VGG model was not used as it ideally should have been used. 40 | 41 | It is interesting to note that we still achieved good performance even though the input to the VGG model did not meet the requirements of the docstring. To think about why and the potential impact, let us analyze the purpose of using this VGG model. We used the pre-trained VGG model to extract features from images. These features are then used in the computation of the perceptual loss. The VGG is only required for training and is not needed for testing/inference. 42 | 43 | We can now state a hypothesis. Even in the incorrect setting, we are still computing the perceptual loss between the extracted features of a ground truth image and the extracted features of the corresponding generated image. This anyway motivates the model to generate images such that the features extracted from these generated images are similar to the features extracted from the ground truth images. Hence, the perceptual loss even in the incorrect case still likely serves a purpose. After all, we were able to train the model using the incorrect setting, and then also get good performance during testing. 44 | 45 | However, it is also possible that providing input in the correct colorspace format and with values in the correct range could give better results since the features may have more "useful information". But in both the incorrect and correct settings, we are still motivating the model to generate images such that the features extracted from these generated images are similar to the features extracted from the ground truth images. 46 | 47 | However, do note that the above points are just hypotheses. A proper comparision experiment should be done to analyze the impact accurately. 48 | 49 | I apologize for the inconvenience cause by these mistakes. The next section discusses what remedies are done in light of these issues. 50 | 51 | ## Remedies 52 | 53 | In light of these issues, a few changes were made to the codebase. I was quite torn on what was the right way to remedy these issues. On one hand, I wanted to preserve the codebase to be as consistent with the paper as possible. On the other hand, I wanted to fix the issues as well. 54 | 55 | After much thought, I have decided to create **two versions**. Each version will lie on a different **branch**. In the `master` branch, the version that is more consistent with the paper is kept. In the `alternate` branch, the version that has resolved an additional issue is kept. The specifics of both branches are given below: 56 | 57 | - The `master` branch: 58 | - Colorspace issue in `extract.py` is fixed. 59 | - The `alt` branch: 60 | - Colorspace issue in `extract.py` is fixed. 61 | - The function `feature_map` in `vgg.py` is modified such that it can accept BGR images with values in the range `[-1, 1]`. 62 | 63 | The reaons why the `master` branch did not change the `feature_map` function was because the experiments for the paper likely used the function with incorrect input characteristics (see the 2nd issue in the issues section). Hence, to be consistent with the paper, the mistake is left as such without correction in the `master` branch alone. If you would like to use the function **without** the mistake, please use the `alt` branch. However do note that I have **not tested** the code in the `alt` branch and hence cannot provide performance comparisons with the `master` branch. 64 | 65 | The entire `legacy` folder in **both branches** have **no changes made**. This is because the `legacy` folder is used for replicating the paper results and it is being preserved as such (even though it has the same VGG issue). 66 | 67 | ## Step 2 from README Instructions 68 | 69 | Step 2 of the instructions section in the README file asks the user to choose one of the two branches (among `master` or `alt`) for their use case. If you have reached this document from step 2 of the README instructions, please read the rest of this document to understand the differences between the two branches. If you have already read the rest of the document, follow the below instructions to complete step 2 of the instructions section in the README file: 70 | 71 | - Make sure that step 1 of instructions in the README file has been completed (i.e. git cloning the repository). 72 | 73 | - Now, **if** you would like to use the `master` branch: 74 | 75 | - Verify that you are in the `master` branch by running `git branch`. 76 | - Verify that you are **either** using the commit `07bd52840574d0c8e1f3a0482538a674a64618c4` **or** any commit made **after** that by running `git log`. The commit message of commit `07bd52840574d0c8e1f3a0482538a674a64618c4` should say `june common patch` and it should have the month and year as `June 2021`. For simplicity, you can just use the newest commit in the branch (since the newest commit is made after the commit `07bd52840574d0c8e1f3a0482538a674a64618c4`). 77 | - Once both are verified, you can proceed to step 3 of the instructions in the README file. 78 | 79 | - Else, **if** you would like to use the `alt` branch: 80 | 81 | - First, execute the below commands: 82 | ``` 83 | git fetch origin alt:alt 84 | git checkout alt 85 | ``` 86 | - Verify that you are in the `alt` branch by running `git branch`. 87 | - Verify that you are **either** using the commit `95c4e96635e6954499b9c27693a3e11f45019995` **or** any commit made **after** that by running `git log`. The commit message of commit `95c4e96635e6954499b9c27693a3e11f45019995` should say `vgg patch` and it should have the month and year as `June 2021`. For simplicity, you can just use the newest commit in the branch (since the newest commit is made after the commit `95c4e96635e6954499b9c27693a3e11f45019995`). 88 | - Once both are verified, you can proceed to step 3 of the instructions in the README file. 89 | 90 | 91 | -------------------------------------------------------------------------------- /legacy/replicate.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import os 5 | import time 6 | import cv2 7 | from skimage.measure import compare_ssim as ssim 8 | from skimage.measure import compare_psnr 9 | import vgg19small 10 | from utils import * 11 | 12 | class GAN(): 13 | 14 | def __init__(self): 15 | self.gen_filters = 64 16 | self.disc_filters = 64 17 | self.layers = 4 18 | self.growth_rate = 12 19 | self.gan_wt = 2 20 | self.l1_wt = 100 21 | self.vgg_wt = 10 22 | self.num = 14 23 | self.restore = True 24 | self.ckpt_dir = './model/'+str(self.num)+'/checkpoint' 25 | self.batch_sz = 1 26 | self.epochs = 1000 27 | self.lr = 0.001 28 | self.total_image_count = 1550 * 2 #Due to flips 29 | self.score_best = -1 30 | 31 | 32 | def Layer(self, ip): 33 | 34 | with tf.variable_scope("Composite"): 35 | next_layer = batchnorm(ip) 36 | next_layer = Relu(next_layer) 37 | next_layer = Conv(next_layer, filter = 3, stride = 1, output_ch = self.growth_rate) 38 | next_layer = DropOut(next_layer, rate=0.2) 39 | 40 | return next_layer 41 | 42 | def TransitionDown(self, ip, name): 43 | 44 | with tf.variable_scope(name): 45 | 46 | reduction = 0.5 47 | next_layer = batchnorm(ip) 48 | reduced_output_size = int(int(ip.get_shape()[-1]) * reduction) 49 | next_layer = Conv(next_layer, filter = 1, stride =1, output_ch = reduced_output_size) 50 | next_layer = DropOut(next_layer, rate=0.2) 51 | next_layer = AvgPool(next_layer) 52 | 53 | return next_layer 54 | 55 | def TransitionUp(self, ip, output_ch, name): 56 | 57 | with tf.variable_scope(name): 58 | next_layer = deconv(ip, output_ch, filters = 3) 59 | return next_layer 60 | 61 | def DenseBlock(self, ip, name, layers = 4): 62 | 63 | with tf.variable_scope(name): 64 | for i in range(layers): 65 | with tf.variable_scope("Layer" + str(i+1)) as scope: 66 | output = self.Layer(ip) 67 | output = tf.concat([ip, output], axis = 3) 68 | ip = output 69 | 70 | return output 71 | 72 | def tiramisu(self, ip): 73 | 74 | with tf.variable_scope('InputConv') as scope: 75 | ip = Conv(ip, filter = 3, stride = 1, output_ch = self.growth_rate*4) 76 | 77 | collect_conv = [] 78 | 79 | for i in range(1,6): 80 | ip = self.DenseBlock(ip, 'Encoder' + str(i), layers = self.layers) 81 | collect_conv.append(ip) 82 | ip = self.TransitionDown(ip, 'TD' + str(i)) 83 | 84 | ip = self.DenseBlock(ip, 'BottleNeck', layers = 15) 85 | 86 | for i in range(1,6): 87 | ip = self.TransitionUp(ip, self.growth_rate*4, 'TU' + str(6 - i)) 88 | ip = tf.concat([ip, collect_conv[6-i-1]], axis = 3, name = 'Decoder' + str(6-i) + '/Concat') 89 | ip = self.DenseBlock(ip, 'Decoder' + str(6 - i), layers = self.layers) 90 | 91 | with tf.variable_scope('OutputConv') as scope: 92 | output = Conv(ip, filter = 1, stride = 1, output_ch = 3) 93 | 94 | return tf.nn.tanh(output) 95 | 96 | 97 | def discriminator(self, ip, target): 98 | 99 | #Using the PatchGAN as a discriminator 100 | layer_count = 4 101 | stride = 2 102 | ndf = self.disc_filters 103 | ip = tf.concat([ip, target], axis = 3, name = 'Concat') 104 | 105 | layer_specs = ndf * np.array([1,2,4,8]) 106 | 107 | for i, out_ch in enumerate(layer_specs,1): 108 | 109 | with tf.variable_scope('Layer'+str(i)) as scope: 110 | if i != 1: 111 | ip = batchnorm(ip) 112 | ip = lrelu(ip) 113 | if i == layer_count: 114 | stride = 1 115 | ip = conv(ip, out_ch, stride = stride) 116 | 117 | with tf.variable_scope('Final_Layer') as scope: 118 | ip = conv(ip, out_channels = 1, stride = 1) 119 | output = tf.sigmoid(ip) 120 | 121 | return output 122 | 123 | def build_vgg(self, img): 124 | 125 | model = vgg19small.Vgg19() 126 | img = tf.image.resize_images(img, [224,224]) 127 | layer = model.feature_map(img) 128 | return layer 129 | 130 | 131 | def build(self): 132 | 133 | EPS = 10e-12 134 | 135 | with tf.variable_scope('Placeholders') as scope: 136 | self.RealA = tf.placeholder(name = 'A', shape = [None, 256, 256, 3], dtype = tf.float32) 137 | self.RealB = tf.placeholder(name = 'B', shape = [None, 256, 256, 3], dtype = tf.float32) 138 | self.step = tf.train.get_or_create_global_step() 139 | 140 | with tf.variable_scope('Generator') as scope: 141 | self.FakeB = self.tiramisu(self.RealA) 142 | 143 | with tf.name_scope('Real_Discriminator'): 144 | with tf.variable_scope('Discriminator') as scope: 145 | self.predict_real = self.discriminator(self.RealA, self.RealB) 146 | 147 | with tf.name_scope('Fake_Discriminator'): 148 | with tf.variable_scope('Discriminator', reuse = True) as scope: 149 | self.predict_fake = self.discriminator(self.RealA, self.FakeB) 150 | 151 | with tf.name_scope('Real_VGG'): 152 | with tf.variable_scope('VGG') as scope: 153 | self.RealB_VGG = self.build_vgg(self.RealB) 154 | 155 | with tf.name_scope('Fake_VGG'): 156 | with tf.variable_scope('VGG', reuse = True) as scope: 157 | self.FakeB_VGG = self.build_vgg(self.FakeB) 158 | 159 | with tf.name_scope('DiscriminatorLoss'): 160 | self.D_loss = tf.reduce_mean(-( tf.log(self.predict_real + EPS) + tf.log(1 - self.predict_fake + EPS) )) 161 | 162 | with tf.name_scope('GeneratorLoss'): 163 | self.gan_loss = tf.reduce_mean(-tf.log(self.predict_fake + EPS )) 164 | self.l1_loss = tf.reduce_mean(tf.abs(self.RealB - self.FakeB)) 165 | self.vgg_loss = (1e-5) * tf.losses.mean_squared_error(self.RealB_VGG, self.FakeB_VGG) 166 | 167 | self.G_loss = self.gan_wt * self.gan_loss + self.l1_wt * self.l1_loss + self.vgg_wt * self.vgg_loss 168 | 169 | with tf.name_scope('Summary'): 170 | dloss_sum = tf.summary.scalar('Discriminator Loss', self.D_loss) 171 | gloss_sum = tf.summary.scalar('Generator Loss', self.G_loss) 172 | gan_loss_sum = tf.summary.scalar('GAN Loss', self.gan_loss) 173 | l1_loss_sum = tf.summary.scalar('L1 Loss', self.l1_loss) 174 | vgg_loss_sum = tf.summary.scalar('VGG Loss', self.gan_loss) 175 | output_im = tf.summary.image('Output', self.FakeB, max_outputs = 1) 176 | target_im = tf.summary.image('Target', self.RealB, max_outputs = 1) 177 | input_im = tf.summary.image('Input', self.RealA, max_outputs = 1) 178 | 179 | self.image_summary = tf.summary.merge([output_im, target_im, input_im]) 180 | self.g_summary = tf.summary.merge([gan_loss_sum, l1_loss_sum, vgg_loss_sum, gloss_sum]) 181 | self.d_summary = dloss_sum 182 | 183 | with tf.name_scope('Variables'): 184 | self.G_vars = [var for var in tf.trainable_variables() if var.name.startswith("Generator")] 185 | self.D_vars = [var for var in tf.trainable_variables() if var.name.startswith("Discriminator")] 186 | 187 | with tf.name_scope('Save'): 188 | self.saver = tf.train.Saver(max_to_keep = 10) 189 | 190 | with tf.name_scope('Optimizer'): 191 | with tf.name_scope("Discriminator_Train"): 192 | discrim_optim = tf.train.AdamOptimizer(self.lr, beta1 = 0.5) 193 | self.discrim_grads_and_vars = discrim_optim.compute_gradients(self.D_loss, var_list = self.D_vars) 194 | self.discrim_train = discrim_optim.apply_gradients(self.discrim_grads_and_vars, global_step = self.step) 195 | 196 | with tf.name_scope("Generator_Train"): 197 | gen_optim = tf.train.AdamOptimizer(self.lr, beta1 = 0.5) 198 | self.gen_grads_and_vars = gen_optim.compute_gradients(self.G_loss, var_list = self.G_vars) 199 | self.gen_train = gen_optim.apply_gradients(self.gen_grads_and_vars, global_step = self.step) 200 | 201 | 202 | def test(self, ckpt_dir): 203 | 204 | # Weight values as used in the paper. 205 | total_ssim = 0 206 | total_psnr = 0 207 | psnr_weight = 1/20 208 | ssim_weight = 1 209 | 210 | self.A_test = np.load('A_test.npy') #Valset 2 211 | self.B_test = np.load('B_test.npy') 212 | self.A_test = (self.A_test/255)*2 - 1 213 | 214 | print('Building Model') 215 | self.build() 216 | print('Model Built') 217 | 218 | with tf.Session() as sess: 219 | 220 | print('Loading Checkpoint') 221 | self.ckpt = tf.train.latest_checkpoint(ckpt_dir, latest_filename=None) 222 | self.saver.restore(sess, self.ckpt) 223 | print('Checkpoint Loaded') 224 | 225 | for i in range(len(self.A_test)): 226 | 227 | x = np.expand_dims(self.A_test[i], axis = 0) 228 | feed = {self.RealA :x} 229 | img = self.FakeB.eval(feed_dict = feed) 230 | 231 | print('Test image', i, end = '\r') 232 | 233 | A_img = (((img[0] + 1)/2) * 255).astype(np.uint8) 234 | B_img = (self.B_test[i]).astype(np.uint8) 235 | 236 | psnr = compare_psnr(B_img, A_img) 237 | s = ssim(B_img, A_img, multichannel = True) 238 | 239 | total_psnr = total_psnr + psnr 240 | total_ssim = total_ssim + s 241 | 242 | average_psnr = total_psnr / len(self.A_test) 243 | average_ssim = total_ssim / len(self.A_test) 244 | 245 | score = average_psnr * psnr_weight + average_ssim * ssim_weight 246 | 247 | line = 'Score: %.6f, PSNR: %.6f, SSIM: %.6f' %(score, average_psnr, average_ssim) 248 | print(line) 249 | 250 | 251 | if __name__ == '__main__': 252 | 253 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 254 | #os.environ['CUDA_VISIBLE_DEVICES'] = '-1' 255 | obj = GAN() 256 | obj.test(os.path.join(os.getcwd(), 'model', 'checkpoint')) 257 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dehaze-GAN 2 | This repository contains TensorFlow code for the paper titled **Single Image Haze Removal using a Generative Adversarial Network.** [[Demo](https://www.youtube.com/watch?v=ioSL6ese46A)][[Arxiv](http://arxiv.org/abs/1810.09479)] 3 | 4 |

5 | Dehaze-GAN in action 6 | 7 |

8 | 9 | ## Update (June 2021) (important) 10 | On a recent review of some code I found a few issues. In the document [issue.md](issue.md), I have provided an explanation of the issues, their possible impact, and the remedies I have taken to account for the issues. Please read the document before attempting to use the codebase. I apologize for the inconvenience. 11 | 12 | ## Update (August 2020) 13 | This work has been accepted for the 2020 International Conference on Wireless Communications, Signal Processing and Networking (**WiSPNET 2020**). The arxiv version of this work (originally published around 2018) will be updated with the accepted version of the paper. ~~Will attach a document listing out the major updates to the paper soon!~~ 14 | 15 | > (EDIT: Oct 2021) I wanted to upload a document describing the major changes between v1 and v2 of the arxiv papers but I was not able to allocate enough time to do so. The v1 paper was put on arxiv sometime in October 2018. The v2 version is shorter, but is better written than the v1 version and also has made some corrections. The v2 version of the paper is the one that was accepted for the WiSPNET 2020 conference. The interested reader can access both versions from arxiv. However, it is recommended to follow the v2 version of the paper (along with `issue.md` and other docs in this codebase). 16 | 17 | ## Features: 18 | The model has the following components: 19 | - The 56-Layer Tiramisu as the generator. 20 | - A patch-wise discriminator. 21 | - A weighted loss function involving three components, namely: 22 | - GAN loss component. 23 | - Perceptual loss component (aka VGG loss component). 24 | - L1 loss component. 25 | 26 | Please refer to the [paper](http://arxiv.org/abs/1810.09479) for a detailed description. 27 | 28 |

29 | Block diagram of the Dehaze-GAN 30 | 31 |

32 | 33 | ## Notes: 34 | 1. The first version of this project was completed around December 2017. The demo video (dated March 2018) reflects the performance of one of the final versions, however some iterative improvements were made after that. 35 | 2. This repository contains code that can be used for any application, and is not limited to Dehazing. 36 | 3. For recreating the results reported in the paper, use the repository `legacy` (for more details refer below). This repository is the refactored version of the final model, but it uses newer versions of some TensorFlow operations. Those operations are not available in the old saved checkpoints. 37 | 4. ~~The codebase uses OpenCV's `imread` and `imwrite` functions without converting them from BGR to RGB space. However, there might be cases where this type of usage (such as was raised in this issue about `extract.py`) may not be desirable. To maintain reproducibility, the original code is left intact. If you application desires usage of images in the RGB space, you could manually convert them from BGR to RGB.~~ (**UPDATE: June 2021**) Please ignore the content that is stricken through. On a recent analysis I made a few key observations. Please refer to the document [issue.md](issue.md) for a detailed explanation of the colorspace issue. 38 | 39 | ## Requirements: 40 | - TensorFlow (version 1.4+) 41 | - Matplotlib 42 | - Numpy 43 | - Scikit-Image 44 | 45 | ## Instructions: 46 | 1. Clone the repository using: 47 | ``` 48 | git clone https://github.com/thatbrguy/Dehaze-GAN.git 49 | ``` 50 | 51 | 2. The codebase has two branches namely `master` and `alt`. There are some differences between the two branches. The user has to choose one of the two branches to use based on their preferences. Please read [issue.md](issue.md) to understand the differences between the two branches. For instructions on the steps that are needed to be taken to setup the desired branch, go the section `Step 2 from README Instructions` in `issue.md` ([link](issue.md/#step-2-from-readme-instructions)). 52 | 53 | 3. A VGG-19 pretrained on the ImageNet dataset is required to calculate perceptual loss. In this work, we used the weights provided by [machrisaa](https://github.com/machrisaa/tensorflow-vgg)'s implementation. Download the weights from this [link](https://mega.nz/#!xZ8glS6J!MAnE91ND_WyfZ_8mvkuSa2YcA7q-1ehfSm-Q1fxOvvs) and include it in this repository. 54 | > **Note:** You may consider using a different implementation of imagenet pretrained VGG-19 which can automatically download the weights (for example, the implementation from keras). However, in that case you would also have to modify the codebase a lot so that it can correctly use your desired implementation. Hence, please be careful if you want to use a different implementation. 55 | 56 | 4. Download the dataset. 57 | - We used the [NYU Depth Dataset V2](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html) and the [Make 3D](http://make3d.cs.cornell.edu/data.html) dataset for training. The following code will download the NYU Depth Dataset V2 and create the hazy and clear image pairs. The images will be placed in directories `A` and `B` respectively. The formula and values used to create the synthetic haze are stated in our paper. 58 | ``` 59 | wget -O data.mat http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_labeled.mat 60 | python extract.py 61 | ``` 62 | > **Note 1:** The above step is only given for the NYU dataset. If you are interested in creating hazy and clear image pairs for any other dataset which has RGB and depth information, please create your own script. The method mentioned in `extract.py` can be adapted for your custom dataset. 63 | 64 | > **Note 2:** For training the model in the paper (and also for validation and some testing experiments), we also used some data from the Make 3D dataset. Around the time the codebase was released, I think there were some issues with accessing the dataset link and hence `extract.py` was only given for the NYU dataset. If the link is accessible now and you are interested in using the Make 3D dataset, you can create your own script to create hazy and clear image pairs by using the method in `extract.py` as a reference. 65 | 66 | 5. In case you want to use your own dataset, follow these instructions. If not, skip this step. 67 | - Create two directories `A` and `B` in this repository. 68 | - Place the input images into directory `A` and target images into directory `B`. 69 | - Ensure that an input and target image pair has the **same name**, otherwise the program will throw an error (For instance, if `0001.png` is present in `A` it must also be present in `B`). 70 | - Resize all images to be of size `(256, 256, 3)`. 71 | 72 | 6. Train the model by using the following code. 73 | ``` 74 | python main.py \ 75 | --A_dir A \ 76 | --B_dir B \ 77 | --batch_size 2 \ 78 | --epochs 20 79 | ``` 80 | The file `main.py` supports a lot of options, which are listed below: 81 | - `--mode`: Select between `train`, `test` and `inference` modes. For `test` and `inference` modes, please place the checkpoint files at `./model/checkpoint` (you can replace `model` with your setting of the `--model_name` argument). Default value is `train`. 82 | - `--model_name`: Tensorboard, logs, samples and checkpoint files are stored in a folder named `model_name`. This argument allows you to provide the name of that folder. Default value is `model`. 83 | - `--lr`: Sets the learning rate for both the generator and the discriminator. Default value is `0.001`. 84 | - `--epochs`: Sets the number of epochs. Default value is `200`. 85 | - `--batch_size`: Sets the batch_size. Default value is `1`. 86 | - `--restore`: Boolean flag that enables restoring from old checkpoint. Checkpoints must be stored at `model_name/checkpoint`. Default value is `False`. 87 | - `--gan_wt`: Weight factor for the GAN loss. Default value is `2`. 88 | - `--l1_wt`: Weight factor for the L1 loss. Default value is `100`. 89 | - `--vgg_wt`: Weight factor for the Perceptual loss (VGG loss). Default value is `10`. 90 | - `--growth_rate`: Growth rate of the dense block. Refer to the DenseNet paper to learn more. Default value is `12`. 91 | - `--layers`: Number of layers per dense block. Default value is `4`. 92 | - `--decay`: Decay for the batchnorm operation. Default value is `0.99`. 93 | - `--D_filters`: Number of filters in the 1st conv layer of the discriminator. Number of filters is multiplied by 2 for every successive layer. Default value is `64`. 94 | - `--save_samples`: Since GAN convergence is hard to interpret from metrics, you can choose to visualize the output of the generator after each validation run. This boolean flag enables the behavior. Default value is `False`. 95 | - `--sample_image_dir`: If `save_samples` is set to `True`, you must provide sample images placed in a directory. Give the name of that directory to this argument. Default value is `samples`. 96 | - `--custom_data`: Boolean flag that allows you to use your own data for training. Default is `True`. (Note: As of now, I have not linked the data I used for training). 97 | - `--A_dir`: Directory containing the ipnut images. Only used when `custom_data` is set to `True`. Default value is `A`. 98 | - `--B_dir`: Directory containing the target images. Only used when `custom_data` is set to `True`. Default value is `B`. 99 | - `--val_fraction`: Fraction of the data to be used for validation. Only used when `custom_data` is set to `True`. Default value is `0.15`. 100 | - `--val_threshold`: Number of steps to wait before validation is enabled. Usually, the GAN performs suboptimally for quite a while. Hence, disabling validation initially can prevent unnecessary validation and speeds up training. Default value is `0`. 101 | - `--val_frequency`: Number of batches to wait before performing the next validation run. Setting this to `1` will perform validation after one discriminator and generator step. You can set it to a higher value to speed up training. Default value is `20`. 102 | - `--logger_frequency`: Number of batches to wait before logging the next set of loss values. Setting it to a higher value will reduce clutter and slightly increase training speed. Default value is `20`. 103 | 104 | ## Replicating: 105 | The code in `legacy` can be used for replicating results of our model on the test split of the "Custom Dataset" as mentioned in the paper. The following steps explain how to replicate the results: 106 | 107 | 1. Download the model checkpoint and the data used for testing from this [link](https://drive.google.com/file/d/1d2HUyumIu6BwSYiPuGOdnvTiCArQjrCm/view?usp=sharing). Place the tar file inside the `legacy` folder. Extract the contents using the following commands. 108 | ``` 109 | cd legacy 110 | tar -xzvf replicate.tar.gz 111 | ``` 112 | 2. Move weights of the pretrained VGG-19 into the the `legacy` folder. The download [link](https://mega.nz/#!xZ8glS6J!MAnE91ND_WyfZ_8mvkuSa2YcA7q-1ehfSm-Q1fxOvvs) is reproduced here for convenience. 113 | 114 | 3. Run the code from the `legacy` folder using: 115 | ``` 116 | python replicate.py 117 | ``` 118 | 119 | If you would like to replicate the results on the SOTS outdoor dataset as well, you may do so. However, a ready to use script to quickly test it is not provided. However, it should be pretty straightforward to write and add a small function in `replicate.py` to test the same. 120 | 121 | ## Known Issues: 122 | 1. For newer versions of NumPy, you may get an error when the codebase attempts to load the pretrained VGG-19 weights NumPy file (for example, see issue [#7](https://github.com/thatbrguy/Dehaze-GAN/issues/7)). One simple solution to avoid this issue is to use a lower version of NumPy (as mentioned in issue [#7](https://github.com/thatbrguy/Dehaze-GAN/issues/7)). If you prefer to use a higher version of NumPy, you can take a look at PR [#17](https://github.com/thatbrguy/Dehaze-GAN/pull/17) on how to modify the code so that it can work with the higher versions of NumPy. 123 | 124 | 2. This repository may not work with TensorFlow 2.x out of the box. It also may not work with eager mode for TensorFlow 1.x out of the box. The code was created a few years ago so consider using an older version of TensorFlow 1.x (maybe around 1.4 to 1.9) in the graph execution mode (which is the default mode for TensorFlow 1.x). 125 | 126 | ## License: 127 | This repository is open source under the MIT clause. Feel free to use it for academic and educational purposes. 128 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import cv2 4 | import vgg19 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from skimage.measure import compare_psnr 9 | from skimage.measure import compare_ssim 10 | from operations import TransposeConv, DropOut 11 | from operations import Conv, ReLU, LeakyReLU, AvgPool, BatchNorm 12 | 13 | class GAN(): 14 | 15 | def __init__(self, args): 16 | self.num_discriminator_filters = args.D_filters 17 | self.layers = args.layers 18 | self.growth_rate = args.growth_rate 19 | self.gan_wt = args.gan_wt 20 | self.l1_wt = args.l1_wt 21 | self.vgg_wt = args.vgg_wt 22 | self.restore = args.restore 23 | self.batch_size = args.batch_size 24 | self.epochs = args.epochs 25 | self.lr = args.lr 26 | self.model_name = args.model_name 27 | self.decay = args.decay 28 | self.save_samples = args.save_samples 29 | self.sample_image_dir = args.sample_image_dir 30 | self.A_dir = args.A_dir 31 | self.B_dir = args.B_dir 32 | self.custom_data = args.custom_data 33 | self.val_fraction = args.val_fraction 34 | self.val_threshold = args.val_threshold 35 | self.val_frequency = args.val_frequency 36 | self.logger_frequency = args.logger_frequency 37 | 38 | self.EPS = 10e-12 39 | self.score_best = -1 40 | self.ckpt_dir = os.path.join(os.getcwd(), self.model_name, 'checkpoint') 41 | self.tensorboard_dir = os.path.join(os.getcwd(), self.model_name, 'tensorboard') 42 | 43 | def Layer(self, input_): 44 | """ 45 | This function creates the components inside a composite layer 46 | of a Dense Block. 47 | """ 48 | with tf.variable_scope("Composite"): 49 | next_layer = BatchNorm(input_, isTrain = self.isTrain) 50 | next_layer = ReLU(next_layer) 51 | next_layer = Conv(next_layer, kernel_size = 3, stride = 1, output_channels = self.growth_rate) 52 | next_layer = DropOut(next_layer, isTrain = self.isTrain, rate = 0.2) 53 | 54 | return next_layer 55 | 56 | def TransitionDown(self, input_, name): 57 | 58 | with tf.variable_scope(name): 59 | 60 | reduction = 0.5 61 | reduced_output_size = int(int(input_.get_shape()[-1]) * reduction) 62 | 63 | next_layer = BatchNorm(input_, isTrain = self.isTrain, decay = self.decay) 64 | next_layer = Conv(next_layer, kernel_size = 1, stride = 1, output_channels = reduced_output_size) 65 | next_layer = DropOut(next_layer, isTrain = self.isTrain, rate = 0.2) 66 | next_layer = AvgPool(next_layer) 67 | 68 | return next_layer 69 | 70 | def TransitionUp(self, input_, output_channels, name): 71 | 72 | with tf.variable_scope(name): 73 | next_layer = TransposeConv(input_, output_channels = output_channels, kernel_size = 3) 74 | 75 | return next_layer 76 | 77 | def DenseBlock(self, input_, name, layers = 4): 78 | 79 | with tf.variable_scope(name): 80 | for i in range(layers): 81 | with tf.variable_scope("Layer" + str(i + 1)) as scope: 82 | output = self.Layer(input_) 83 | output = tf.concat([input_, output], axis=3) 84 | input_ = output 85 | 86 | return output 87 | 88 | def generator(self, input_): 89 | """ 90 | 54 Layer Tiramisu 91 | """ 92 | with tf.variable_scope('InputConv') as scope: 93 | input_ = Conv(input_, kernel_size = 3, stride=1, output_channels = self.growth_rate * 4) 94 | 95 | collect_conv = [] 96 | 97 | for i in range(1, 6): 98 | input_ = self.DenseBlock(input_, name = 'Encoder' + str(i), layers = self.layers) 99 | collect_conv.append(input_) 100 | input_ = self.TransitionDown(input_, name = 'TD' + str(i)) 101 | 102 | input_ = self.DenseBlock(input_, name = 'BottleNeck', layers = 15) 103 | 104 | for i in range(1, 6): 105 | input_ = self.TransitionUp(input_, output_channels = self.growth_rate * 4, name = 'TU' + str(6 - i)) 106 | input_ = tf.concat([input_, collect_conv[6 - i - 1]], axis = 3, name = 'Decoder' + str(6 - i) + '/Concat') 107 | input_ = self.DenseBlock(input_, name = 'Decoder' + str(6 - i), layers = self.layers) 108 | 109 | with tf.variable_scope('OutputConv') as scope: 110 | output = Conv(input_, kernel_size = 1, stride = 1, output_channels = 3) 111 | 112 | return tf.nn.tanh(output) 113 | 114 | def discriminator(self, input_, target, stride = 2, layer_count = 4): 115 | """ 116 | Using the PatchGAN as a discriminator 117 | """ 118 | input_ = tf.concat([input_, target], axis=3, name='Concat') 119 | layer_specs = self.num_discriminator_filters * np.array([1, 2, 4, 8]) 120 | 121 | for i, output_channels in enumerate(layer_specs, 1): 122 | 123 | with tf.variable_scope('Layer' + str(i)) as scope: 124 | 125 | if i != 1: 126 | input_ = BatchNorm(input_, isTrain = self.isTrain) 127 | 128 | if i == layer_count: 129 | stride = 1 130 | 131 | input_ = LeakyReLU(input_) 132 | input_ = Conv(input_, output_channels = output_channels, kernel_size = 4, stride = stride, padding = 'VALID', mode = 'discriminator') 133 | 134 | with tf.variable_scope('Final_Layer') as scope: 135 | output = Conv(input_, output_channels = 1, kernel_size = 4, stride = 1, padding = 'VALID', mode = 'discriminator') 136 | 137 | return tf.sigmoid(output) 138 | 139 | def build_vgg(self, img): 140 | 141 | model = vgg19.Vgg19() 142 | img = tf.image.resize_images(img, [224, 224]) 143 | layer = model.feature_map(img) 144 | return layer 145 | 146 | def build_model(self): 147 | 148 | with tf.variable_scope('Placeholders') as scope: 149 | self.RealA = tf.placeholder(name='A', shape=[None, 256, 256, 3], dtype=tf.float32) 150 | self.RealB = tf.placeholder(name='B', shape=[None, 256, 256, 3], dtype=tf.float32) 151 | self.isTrain = tf.placeholder(name = "isTrain", shape = None, dtype = tf.bool) 152 | self.step = tf.train.get_or_create_global_step() 153 | 154 | with tf.variable_scope('Generator') as scope: 155 | self.FakeB = self.generator(self.RealA) 156 | 157 | with tf.name_scope('Real_Discriminator'): 158 | with tf.variable_scope('Discriminator') as scope: 159 | self.predict_real = self.discriminator(self.RealA, self.RealB) 160 | 161 | with tf.name_scope('Fake_Discriminator'): 162 | with tf.variable_scope('Discriminator', reuse=True) as scope: 163 | self.predict_fake = self.discriminator(self.RealA, self.FakeB) 164 | 165 | with tf.name_scope('Real_VGG'): 166 | with tf.variable_scope('VGG') as scope: 167 | self.RealB_VGG = self.build_vgg(self.RealB) 168 | 169 | with tf.name_scope('Fake_VGG'): 170 | with tf.variable_scope('VGG', reuse=True) as scope: 171 | self.FakeB_VGG = self.build_vgg(self.FakeB) 172 | 173 | with tf.name_scope('DiscriminatorLoss'): 174 | self.D_loss = tf.reduce_mean(-(tf.log(self.predict_real + self.EPS) + tf.log(1 - self.predict_fake + self.EPS))) 175 | 176 | with tf.name_scope('GeneratorLoss'): 177 | self.gan_loss = tf.reduce_mean(-tf.log(self.predict_fake + self.EPS)) 178 | self.l1_loss = tf.reduce_mean(tf.abs(self.RealB - self.FakeB)) 179 | self.vgg_loss = (1e-5) * tf.losses.mean_squared_error(self.RealB_VGG, self.FakeB_VGG) 180 | 181 | self.G_loss = self.gan_wt * self.gan_loss + self.l1_wt * self.l1_loss + self.vgg_wt * self.vgg_loss 182 | 183 | with tf.name_scope('Summary'): 184 | D_loss_sum = tf.summary.scalar('Discriminator Loss', self.D_loss) 185 | G_loss_sum = tf.summary.scalar('Generator Loss', self.G_loss) 186 | gan_loss_sum = tf.summary.scalar('GAN Loss', self.gan_loss) 187 | l1_loss_sum = tf.summary.scalar('L1 Loss', self.l1_loss) 188 | vgg_loss_sum = tf.summary.scalar('VGG Loss', self.gan_loss) 189 | output_img = tf.summary.image('Output', self.FakeB, max_outputs = 1) 190 | target_img = tf.summary.image('Target', self.RealB, max_outputs = 1) 191 | input_img = tf.summary.image('Input', self.RealA, max_outputs = 1) 192 | 193 | self.image_summary = tf.summary.merge([output_img, target_img, input_img]) 194 | self.G_summary = tf.summary.merge([gan_loss_sum, l1_loss_sum, vgg_loss_sum, G_loss_sum]) 195 | self.D_summary = D_loss_sum 196 | 197 | with tf.name_scope('Variables'): 198 | self.G_vars = [var for var in tf.trainable_variables() if var.name.startswith("Generator")] 199 | self.D_vars = [var for var in tf.trainable_variables() if var.name.startswith("Discriminator")] 200 | 201 | with tf.name_scope('Save'): 202 | self.saver = tf.train.Saver(max_to_keep=3) 203 | 204 | with tf.name_scope('Optimizer'): 205 | 206 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 207 | with tf.control_dependencies(update_ops): 208 | 209 | with tf.name_scope("Discriminator_Train"): 210 | D_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5) 211 | self.D_grads_and_vars = D_optimizer.compute_gradients(self.D_loss, var_list = self.D_vars) 212 | self.D_train = D_optimizer.apply_gradients(self.D_grads_and_vars, global_step = self.step) 213 | 214 | with tf.name_scope("Generator_Train"): 215 | G_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5) 216 | self.G_grads_and_vars = G_optimizer.compute_gradients(self.G_loss, var_list = self.G_vars) 217 | self.G_train = G_optimizer.apply_gradients(self.G_grads_and_vars, global_step = self.step) 218 | 219 | def train(self): 220 | 221 | start_epoch = 0 222 | logger_frequency = self.logger_frequency 223 | val_frequency = self.val_frequency 224 | val_threshold = self.val_threshold 225 | 226 | if not os.path.exists(self.model_name): 227 | os.mkdir(self.model_name) 228 | 229 | print('Loading Model') 230 | self.build_model() 231 | print('Model Loaded') 232 | 233 | print('Loading Data') 234 | 235 | if self.custom_data: 236 | 237 | # Please ensure that the input images and target images have 238 | # the same filename. 239 | 240 | data = sorted(os.listdir(self.A_dir)) 241 | 242 | total_image_count = int(len(data) * (1 - self.val_fraction)) 243 | batches = total_image_count // self.batch_size 244 | 245 | train_data = data[: total_image_count] 246 | val_data = data[total_image_count: ] 247 | val_image_count = len(val_data) 248 | 249 | self.A_train = np.zeros((total_image_count, 256, 256, 3)) 250 | self.B_train = np.zeros((total_image_count, 256, 256, 3)) 251 | self.A_val = np.zeros((val_image_count, 256, 256, 3)) 252 | self.B_val = np.zeros((val_image_count, 256, 256, 3)) 253 | 254 | print(self.A_train.shape, self.A_val.shape) 255 | 256 | for i, file in enumerate(train_data): 257 | self.A_train[i] = cv2.imread(os.path.join(os.getcwd(), self.A_dir, file), 1).astype(np.float32) 258 | self.B_train[i] = cv2.imread(os.path.join(os.getcwd(), self.B_dir, file), 1).astype(np.float32) 259 | 260 | for i, file in enumerate(val_data): 261 | self.A_val[i] = cv2.imread(os.path.join(os.getcwd(), self.A_dir, file), 1).astype(np.float32) 262 | self.B_val[i] = cv2.imread(os.path.join(os.getcwd(), self.B_dir, file), 1).astype(np.float32) 263 | 264 | else: 265 | 266 | self.A_train = np.load('A_train.npy').astype(np.float32) 267 | self.B_train = np.load('B_train.npy').astype(np.float32) 268 | self.A_val = np.load('A_val.npy').astype(np.float32) # Valset 2 269 | self.B_val = np.load('B_val.npy').astype(np.float32) 270 | 271 | total_image_count = len(self.A_train) 272 | val_image_count = len(self.A_val) 273 | batches = total_image_count // self.batch_size 274 | 275 | self.A_val = (self.A_val / 255) * 2 - 1 276 | self.B_val = (self.B_val / 255) * 2 - 1 277 | self.A_train = (self.A_train / 255) * 2 - 1 278 | self.B_train = (self.B_train / 255) * 2 - 1 279 | 280 | print('Data Loaded') 281 | 282 | 283 | with tf.Session() as self.sess: 284 | 285 | init_op = tf.global_variables_initializer() 286 | self.sess.run(init_op) 287 | 288 | if self.restore: 289 | print('Loading Checkpoint') 290 | ckpt = tf.train.latest_checkpoint(self.ckpt_dir) 291 | self.saver.restore(self.sess, ckpt) 292 | self.step = tf.train.get_or_create_global_step() 293 | print('Checkpoint Loaded') 294 | 295 | self.writer = tf.summary.FileWriter(self.tensorboard_dir, tf.get_default_graph()) 296 | total_parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()]) 297 | G_parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables() if v.name.startswith("Generator")]) 298 | D_parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables() if v.name.startswith("Discriminator")]) 299 | loss_operations = [self.D_loss, self.G_loss, self.gan_loss, self.l1_loss, self.vgg_loss] 300 | 301 | counts = self.sess.run([G_parameter_count, D_parameter_count, total_parameter_count]) 302 | 303 | print('Generator parameter count:', counts[0]) 304 | print('Discriminator parameter count:', counts[1]) 305 | print('Total parameter count:', counts[2]) 306 | 307 | # The variable below is divided by 2 since both the Generator 308 | # and the Discriminator increases step count by 1 309 | start = self.step.eval() // (batches * 2) 310 | 311 | for i in range(start, self.epochs): 312 | 313 | print('Epoch:', i) 314 | shuffle = np.random.permutation(total_image_count) 315 | 316 | for j in range(batches): 317 | 318 | if j != batches - 1: 319 | current_batch = shuffle[j * self.batch_size: (j + 1) * self.batch_size] 320 | else: 321 | current_batch = shuffle[j * self.batch_size: ] 322 | 323 | a = self.A_train[current_batch] 324 | b = self.B_train[current_batch] 325 | feed_dict = {self.RealA: a, self.RealB: b, self.isTrain: True} 326 | 327 | begin = time.time() 328 | step = self.step.eval() 329 | 330 | _, D_summary = self.sess.run([self.D_train, self.D_summary], feed_dict = feed_dict) 331 | 332 | self.writer.add_summary(D_summary, step) 333 | 334 | _, G_summary = self.sess.run([self.G_train, self.G_summary], feed_dict = feed_dict) 335 | 336 | self.writer.add_summary(G_summary, step) 337 | 338 | print('Time Per Step: ', format(time.time() - begin, '.3f'), end='\r') 339 | 340 | if j % logger_frequency == 0: 341 | D_loss, G_loss, GAN_loss, L1_loss, VGG_loss = self.sess.run(loss_operations, feed_dict=feed_dict) 342 | 343 | GAN_loss = GAN_loss * self.gan_wt 344 | L1_loss = L1_loss * self.l1_wt 345 | VGG_loss = VGG_loss * self.vgg_wt 346 | 347 | trial_image_idx = np.random.randint(total_image_count) 348 | a = self.A_train[trial_image_idx] 349 | b = self.B_train[trial_image_idx] 350 | 351 | if a.ndim == 3: 352 | a = np.expand_dims(a, axis = 0) 353 | 354 | if b.ndim == 3: 355 | b = np.expand_dims(b, axis = 0) 356 | 357 | feed_dict = {self.RealA: a, self.RealB: b, self.isTrain: False} 358 | img_summary = self.sess.run(self.image_summary, feed_dict=feed_dict) 359 | self.writer.add_summary(img_summary, step) 360 | 361 | line = 'Batch: %d, D_Loss: %.3f, G_Loss: %.3f, GAN: %.3f, L1: %.3f, P: %.3f' % ( 362 | j, D_loss, G_loss, GAN_loss, L1_loss, VGG_loss) 363 | print(line) 364 | 365 | # The variable `step` counts both D and G updates as individual steps. 366 | # The variable `G_D_step` counts one D update followed by a G update 367 | # as a single step. 368 | G_D_step = step // 2 369 | print('GD', G_D_step, 'val', val_threshold) 370 | 371 | if (val_threshold > G_D_step) and (j % val_frequency == 0): 372 | self.validate() 373 | 374 | 375 | def validate(self): 376 | 377 | total_ssim = 0 378 | total_psnr = 0 379 | psnr_weight = 1/20 380 | ssim_weight = 1 381 | val_image_count = len(self.A_val) 382 | 383 | for i in range(val_image_count): 384 | 385 | x = np.expand_dims(self.A_val[i], axis = 0) 386 | feed_dict = {self.RealA: x ,self.isTrain: False} 387 | generated_B = self.FakeB.eval(feed_dict = feed_dict) 388 | 389 | print('Validation Image', i, end = '\r') 390 | 391 | generated_B = (((generated_B[0] + 1)/2) * 255).astype(np.uint8) 392 | real_B = (((self.B_val[i] + 1)/2)*255).astype(np.uint8) 393 | 394 | psnr = compare_psnr(real_B, generated_B) 395 | ssim = compare_ssim(real_B, generated_B, multichannel = True) 396 | 397 | total_psnr = total_psnr + psnr 398 | total_ssim = total_ssim + ssim 399 | 400 | average_psnr = total_psnr / val_image_count 401 | average_ssim = total_ssim / val_image_count 402 | 403 | score = average_psnr * psnr_weight + average_ssim * ssim_weight 404 | 405 | 406 | if(score > self.score_best): 407 | 408 | self.score_best = score 409 | 410 | self.saver.save(self.sess, os.path.join(self.ckpt_dir, 'gan'), global_step = self.step.eval()) 411 | line = 'Better Score: %.6f, PSNR: %.6f, SSIM: %.6f' %(score, average_psnr, average_ssim) 412 | print(line) 413 | 414 | with open(os.path.join(self.ckpt_dir, 'logs.txt'),'a') as f: 415 | line += '\n' 416 | f.write(line) 417 | 418 | if self.save_samples: 419 | 420 | try: 421 | image_list = os.listdir(self.sample_image_dir) 422 | except: 423 | print('Sample images not found. Terminating program') 424 | exit(0) 425 | 426 | for i, file in enumerate(image_list, 1): 427 | 428 | print('Sample Image', i, end = '\r') 429 | 430 | x = cv2.imread(os.path.join(self.sample_image_dir, file), 1) 431 | x = (x/255)*2 - 1 432 | x = np.reshape(x,(1,256,256,3)) 433 | 434 | feed_dict = {self.RealA: x, self.isTrain: False} 435 | img = self.FakeB.eval(feed_dict = feed_dict) 436 | 437 | img = img[0,:,:,:] 438 | img = (((img + 1)/2) * 255).astype(np.uint8) 439 | cv2.imwrite(os.path.join(self.ckpt_dir, file), img) 440 | 441 | 442 | def test(self, input_dir, GT_dir): 443 | 444 | total_ssim = 0 445 | total_psnr = 0 446 | psnr_weight = 1/20 447 | ssim_weight = 1 448 | 449 | GT_list = os.listdir(GT_dir) 450 | input_list = os.listdir(input_dir) 451 | 452 | print('Loading Model') 453 | self.build_model() 454 | print('Model Loaded') 455 | 456 | with tf.Session() as self.sess: 457 | 458 | init_op = tf.global_variables_initializer() 459 | self.sess.run(init_op) 460 | 461 | print('Loading Checkpoint') 462 | ckpt = tf.train.latest_checkpoint(self.ckpt_dir) 463 | self.saver.restore(self.sess, ckpt) 464 | self.step = tf.train.get_or_create_global_step() 465 | print('Checkpoint Loaded') 466 | 467 | for i, (img_file, GT_file) in enumerate(zip(input_list, GT_list), 1): 468 | 469 | img = cv2.imread(os.path.join(input_dir, img_file), 1) 470 | GT = cv2.imread(os.path.join(GT_dir, GT_file), 1).astype(np.uint8) 471 | 472 | print('Test image', i, end = '\r') 473 | 474 | img = ((np.expand_dims(img, axis = 0) / 255) * 2) - 1 475 | feed_dict = {self.RealA: img, self.isTrain: False} 476 | generated_B = self.FakeB.eval(feed_dict = feed_dict) 477 | generated_B = (((generated_B[0] + 1)/2) * 255).astype(np.uint8) 478 | 479 | psnr = compare_psnr(GT, generated_B) 480 | ssim = compare_ssim(GT, generated_B, multichannel = True) 481 | 482 | total_psnr = total_psnr + psnr 483 | total_ssim = total_ssim + ssim 484 | 485 | average_psnr = total_psnr / len(GT_list) 486 | average_ssim = total_ssim / len(GT_list) 487 | 488 | score = average_psnr * psnr_weight + average_ssim * ssim_weight 489 | 490 | line = 'Score: %.6f, PSNR: %.6f, SSIM: %.6f' %(score, average_psnr, average_ssim) 491 | print(line) 492 | 493 | 494 | def inference(self, input_dir, result_dir): 495 | 496 | input_list = os.listdir(input_dir) 497 | 498 | if not os.path.exists(result_dir): 499 | os.mkdir(result_dir) 500 | 501 | print('Loading Model') 502 | self.build_model() 503 | print('Model Loaded') 504 | 505 | with tf.Session() as self.sess: 506 | 507 | init_op = tf.global_variables_initializer() 508 | self.sess.run(init_op) 509 | 510 | print('Loading Checkpoint') 511 | ckpt = tf.train.latest_checkpoint(self.ckpt_dir) 512 | self.saver.restore(self.sess, ckpt) 513 | self.step = tf.train.get_or_create_global_step() 514 | print('Checkpoint Loaded') 515 | 516 | for i, img_file in enumerate(input_list, 1): 517 | 518 | img = cv2.imread(os.path.join(input_dir, img_file), 1) 519 | 520 | print('Processing image', i, end = '\r') 521 | 522 | img = ((np.expand_dims(img, axis = 0) / 255) * 2) - 1 523 | feed_dict = {self.RealA: img, self.isTrain: False} 524 | generated_B = self.FakeB.eval(feed_dict = feed_dict) 525 | generated_B = (((generated_B[0] + 1)/2) * 255).astype(np.uint8) 526 | 527 | cv2.imwrite(os.path.join(result_dir, img_file), generated_B) 528 | 529 | print('Done.') 530 | --------------------------------------------------------------------------------