├── README.md ├── discriminator ├── utils.py ├── inference.py ├── ImageFogger ├── generator.py ├── export_graph.py ├── reader.py ├── vgg16.py ├── train.py ├── ops.py └── model.py /README.md: -------------------------------------------------------------------------------- 1 | # ES-CCGAN, the implementation of the paper of "Unsupervised Haze Removal for High-Resolution Optical Remote-Sensing Images Based on Improved Generative Adversarial Networks", the link of this paper is "https://www.mdpi.com/2072-4292/12/24/4162". 2 | This is a remote sensing image dehazing code, and this is realized by python. 3 | To run this project you need to set up the environment, download the dataset, run a script to process data, and then you can train and test the network models. 4 | I will show you step by step to run this project and I hope it is clear enough. 5 | 6 | --Prerequisite 7 | I tested my project in Intel Core i9, 64G RAM, GPU RTX 2080 Ti. Because it takes about several days for training, I recommend you using CPU/GPU strong enough and about 24G Video Memory. 8 | 9 | --Dataset 10 | I use a self-made remote sensing image which consists of 52376 haze-free images, 52376 hazy images, and 52376 haze-free images with blurred edges. All the images were 256 × 256 pixels in size. All of the data need to transform to tfrecords. The code of generated haze remote sensing image is in the "ImageFogger.py" 11 | 12 | --Training 13 | To train a generator, run the following command 14 | python train.py 15 | 16 | --Test 17 | First, the model needs to transform to the type of '.pb', run the following command 18 | python export_graph.py 19 | Second, the haze image is dehazed as following: 20 | python inference.py 21 | -------------------------------------------------------------------------------- /discriminator: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import ops 3 | 4 | class Discriminator: 5 | def __init__(self, name, is_training, norm='instance', use_sigmoid=False): 6 | self.name = name 7 | self.is_training = is_training 8 | self.norm = norm 9 | self.reuse = False 10 | self.use_sigmoid = use_sigmoid 11 | 12 | def __call__(self, input): 13 | """ 14 | Args: 15 | input: batch_size x image_size x image_size x 3 16 | Returns: 17 | output: 4D tensor batch_size x out_size x out_size x 1 (default 1x5x5x1) 18 | filled with 0.9 if real, 0.0 if fake 19 | """ 20 | with tf.variable_scope(self.name): 21 | # convolution layers 22 | C64 = ops.Ck(input, 64, reuse=self.reuse, norm=None, 23 | is_training=self.is_training, name='C64') # (?, w/2, h/2, 64) 24 | C128 = ops.Ck(C64, 128, reuse=self.reuse, norm=self.norm, 25 | is_training=self.is_training, name='C128') # (?, w/4, h/4, 128) 26 | C256 = ops.Ck(C128, 256, reuse=self.reuse, norm=self.norm, 27 | is_training=self.is_training, name='C256') # (?, w/8, h/8, 256) 28 | C512 = ops.Ck(C256, 512,reuse=self.reuse, norm=self.norm, 29 | is_training=self.is_training, name='C512') # (?, w/16, h/16, 512) 30 | 31 | # apply a convolution to produce a 1 dimensional output (1 channel?) 32 | # use_sigmoid = False if use_lsgan = True 33 | output = ops.last_conv(C512, reuse=self.reuse, 34 | use_sigmoid=self.use_sigmoid, name='output') # (?, w/16, h/16, 1) 35 | 36 | self.reuse = True 37 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) 38 | 39 | return output 40 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import random 3 | 4 | def convert2int(image): 5 | """ Transfrom from float tensor ([-1.,1.]) to int image ([0,255]) 6 | """ 7 | return tf.image.convert_image_dtype((image+1.0)/2.0, tf.uint8) 8 | 9 | def convert2float(image): 10 | """ Transfrom from int image ([0,255]) to float tensor ([-1.,1.]) 11 | """ 12 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 13 | return (image/127.5) - 1.0 14 | 15 | def batch_convert2int(images): 16 | """ 17 | Args: 18 | images: 4D float tensor (batch_size, image_size, image_size, depth) 19 | Returns: 20 | 4D int tensor 21 | """ 22 | return tf.map_fn(convert2int, images, dtype=tf.uint8) 23 | 24 | def batch_convert2float(images): 25 | """ 26 | Args: 27 | images: 4D int tensor (batch_size, image_size, image_size, depth) 28 | Returns: 29 | 4D float tensor 30 | """ 31 | return tf.map_fn(convert2float, images, dtype=tf.float32) 32 | 33 | class ImagePool: 34 | """ History of generated images 35 | Same logic as https://github.com/junyanz/CycleGAN/blob/master/util/image_pool.lua 36 | """ 37 | def __init__(self, pool_size): 38 | self.pool_size = pool_size 39 | self.images = [] 40 | 41 | def query(self, image): 42 | if self.pool_size == 0: 43 | return image 44 | 45 | if len(self.images) < self.pool_size: 46 | self.images.append(image) 47 | return image 48 | else: 49 | p = random.random() 50 | if p > 0.5: 51 | # use old image 52 | random_id = random.randrange(0, self.pool_size) 53 | tmp = self.images[random_id].copy() 54 | self.images[random_id] = image.copy() 55 | return tmp 56 | else: 57 | return image 58 | 59 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | """Translate an image to another image 2 | An example of command-line usage is: 3 | python export_graph.py --model pretrained/apple2orange.pb \ 4 | --input input_sample.jpg \ 5 | --output output_sample.jpg \ 6 | --image_size 256 7 | """ 8 | 9 | import tensorflow as tf 10 | import os 11 | from model import CycleGAN 12 | import utils 13 | from glob import glob 14 | 15 | FLAGS = tf.flags.FLAGS 16 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 17 | 18 | 19 | tf.flags.DEFINE_string('model',r'pretrained-densenet/fog2unfog-80000.pb', 'model path (.pb)') 20 | tf.flags.DEFINE_integer('image_size', '256', 'image size, default: 256') 21 | 22 | def inference(files): 23 | graph = tf.Graph() 24 | 25 | for sample_file in files: 26 | inpute_path = './new_data/{}'.format(sample_file) 27 | output_path = './results-densenet/8/{}'.format(sample_file) 28 | with graph.as_default(): 29 | with tf.gfile.FastGFile(inpute_path, 'rb') as f: 30 | image_data = f.read() 31 | input_image = tf.image.decode_jpeg(image_data, channels=3) 32 | input_image = tf.image.resize_images(input_image, size=(FLAGS.image_size, FLAGS.image_size)) 33 | input_image = utils.convert2float(input_image) 34 | input_image.set_shape([FLAGS.image_size, FLAGS.image_size, 3]) 35 | with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file: 36 | graph_def = tf.GraphDef() 37 | graph_def.ParseFromString(model_file.read()) 38 | [output_image] = tf.import_graph_def(graph_def, 39 | input_map={'input_image': input_image}, 40 | return_elements=['output_image:0'], 41 | name='output') 42 | with tf.Session(graph=graph) as sess: 43 | generated = output_image.eval() 44 | with open(output_path, 'wb') as f: 45 | f.write(generated) 46 | 47 | def main(unused_argv): 48 | files = os.listdir('./new_data') 49 | inference(files) 50 | 51 | if __name__ == '__main__': 52 | tf.app.run() 53 | -------------------------------------------------------------------------------- /ImageFogger: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import numpy as np 4 | import cv2 5 | from PIL import Image 6 | m_persistence = 0.50 7 | m_octaveNum = 4 8 | m_frequency = 0.025 9 | m_amplitude = 128 10 | 11 | def Noise(x,y): 12 | n = x + y * 57 13 | n = (n<<13) ^ n 14 | return ( 1.0 - ( (n * (n * n * 15731 + 789221) + 1376312589) & 0x7fffffff) / 1073741824.0) 15 | 16 | def SmoothedNoise(x,y): 17 | corners = ( Noise(x-1, y-1)+Noise(x+1, y-1)+Noise(x-1, y+1)+Noise(x+1, y+1) ) / 16 18 | sides = ( Noise(x-1, y) +Noise(x+1, y) +Noise(x, y-1) +Noise(x, y+1) ) / 8 19 | center = (Noise(x, y)) / 4 20 | return corners + sides + center 21 | 22 | def Cosine_Interpolate(a,b,x): 23 | ft = x * 3.1415927 24 | f = (1 - np.cos(ft)) * 0.5 25 | return a*(1-f) + b*f 26 | 27 | def InterpolatedNoise(x,y): 28 | integer_X = x.astype(int) 29 | fractional_X = x - integer_X 30 | integer_Y = y.astype(int) 31 | fractional_Y = y - integer_Y 32 | v1 = SmoothedNoise(integer_X, integer_Y) 33 | v2 = SmoothedNoise(integer_X + 1, integer_Y) 34 | v3 = SmoothedNoise(integer_X, integer_Y + 1) 35 | v4 = SmoothedNoise(integer_X + 1, integer_Y + 1) 36 | i1 = Cosine_Interpolate(v1, v2, fractional_X) 37 | i2 = Cosine_Interpolate(v3, v4, fractional_X) 38 | return Cosine_Interpolate(i1, i2, fractional_Y) 39 | 40 | def PerlinNoise(x,y): 41 | total = np.zeros(x.shape,dtype=float) 42 | p = m_persistence 43 | n = m_octaveNum 44 | for i in range(n): 45 | frequency = float(pow(2,i)) 46 | amplitude = float(pow(p,i)) 47 | total = total + InterpolatedNoise(x * frequency, y * frequency) * amplitude 48 | return total 49 | 50 | def processImage(img): 51 | height,width= img.shape[:2] 52 | x=np.zeros(img.shape[:2],dtype=float) 53 | y=np.zeros(img.shape[:2],dtype=float) 54 | for i in range(height): 55 | for k in range(width): 56 | x[i,k]=k*m_frequency 57 | y[i,k]=i*m_frequency 58 | noise=m_amplitude*PerlinNoise(x,y)+128 59 | img=np.minimum(np.maximum((img*(1-noise.repeat(3, axis=1).reshape(height,width,3)/256)+noise.repeat(3, axis=1).reshape(height,width,3)).astype(int),0),255) 60 | return img 61 | image=cv2.imread("/Users/chensiqiong/Documents/实验室/去云/wu/City01_grid_0844807680.jpg") 62 | cv2.imwrite("/Users/chensiqiong/Documents/实验室/去云/out/City01_grid_0844807680.jpg",processImage(image)) 63 | -------------------------------------------------------------------------------- /generator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import ops 3 | import utils 4 | 5 | class Generator: 6 | def __init__(self, name, is_training, ngf=64, norm='instance', image_size1=128, image_size2=128): 7 | self.name = name 8 | self.reuse = False 9 | self.ngf = ngf 10 | self.norm = norm 11 | self.is_training = is_training 12 | self.image_size1 = image_size1 13 | self.image_size2 = image_size2 14 | 15 | def __call__(self, input): 16 | """ 17 | Args: 18 | input: batch_size x width x height x 3 19 | Returns: 20 | output: same size as input 21 | """ 22 | with tf.variable_scope(self.name): 23 | # conv layers 24 | c7s1_32 = ops.c7s1_k(input, self.ngf, is_training=self.is_training, norm=self.norm, 25 | reuse=self.reuse, name='c7s1_32') # (?, w, h, 32) 26 | d64 = ops.dk(c7s1_32, 2*self.ngf, is_training=self.is_training, norm=self.norm, 27 | reuse=self.reuse, name='d64') # (?, w/2, h/2, 64) 28 | d128 = ops.dk(d64, 4*self.ngf, is_training=self.is_training, norm=self.norm, 29 | reuse=self.reuse, name='d128') # (?, w/4, h/4, 128) 30 | 31 | 32 | if self.image_size1 <= 128: 33 | # use 6 residual blocks for 128x128 images 34 | res_output = ops.n_res_blocks(d128, reuse=self.reuse, n=6) # (?, w/4, h/4, 64) 35 | else: 36 | # 9 blocks for higher resolution 37 | # res_output = ops.n_res_blocks(d128, reuse=self.reuse, n=9) 38 | # (?, w/4, h/4, 128) 39 | res_output = ops.denseRK(d128, is_training=self.is_training, reuse=self.reuse) 40 | 41 | # fractional-strided convolution 42 | u64 = ops.uk(res_output, 2*self.ngf, is_training=self.is_training, norm=self.norm, 43 | reuse=self.reuse, name='u64') # (?, w/2, h/2, 64) 44 | u32 = ops.uk(u64, self.ngf, is_training=self.is_training, norm=self.norm, 45 | reuse=self.reuse, name='u32') # (?, w, h, 32) 46 | 47 | # conv layer 48 | # Note: the paper said that ReLU and _norm were used 49 | # but actually tanh was used and no _norm here 50 | output = ops.c7s1_k(u32, 3, norm=None, 51 | activation='tanh', reuse=self.reuse, name='output') # (?, w, h, 3) 52 | 53 | # set reuse=True for next call 54 | self.reuse = True 55 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) 56 | 57 | return output 58 | 59 | def sample(self, input): 60 | image = utils.batch_convert2int(self.__call__(input)) 61 | image = tf.image.encode_jpeg(tf.squeeze(image, [0])) 62 | return image 63 | 64 | -------------------------------------------------------------------------------- /export_graph.py: -------------------------------------------------------------------------------- 1 | """ Freeze variables and convert 2 generator networks to 2 GraphDef files. 2 | This makes file size smaller and can be used for inference in production. 3 | An example of command-line usage is: 4 | python export_graph.py --checkpoint_dir checkpoints/20170424-1152 \ 5 | --XtoY_model apple2orange.pb \ 6 | --YtoX_model orange2apple.pb \ 7 | --image_size 256 8 | """ 9 | 10 | import tensorflow as tf 11 | import os 12 | from tensorflow.python.tools.freeze_graph import freeze_graph 13 | from model import CycleGAN 14 | import utils 15 | 16 | 17 | FLAGS = tf.flags.FLAGS 18 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 19 | 20 | tf.flags.DEFINE_string('checkpoint_dir',r'checkpoints-densenet/110000', 'checkpoints directory path') 21 | tf.flags.DEFINE_string('XtoY_model', 'fog2unfog-110000.pb', 'XtoY model name, default: apple2orange.pb') 22 | tf.flags.DEFINE_string('YtoX_model', 'unfog2fog-110000.pb', 'YtoX model name, default: orange2apple.pb') 23 | tf.flags.DEFINE_integer('image_size1', '256', 'image size, default: 256') 24 | tf.flags.DEFINE_integer('image_size2', '256', 'image size, default: 256') 25 | tf.flags.DEFINE_integer('ngf', 64, 26 | 'number of gen filters in first conv layer, default: 64') 27 | tf.flags.DEFINE_string('norm', 'instance', 28 | '[instance, batch] use instance norm or batch norm, default: instance') 29 | 30 | def export_graph(model_name, XtoY=True): 31 | graph = tf.Graph() 32 | 33 | with graph.as_default(): 34 | cycle_gan = CycleGAN(ngf=FLAGS.ngf, norm=FLAGS.norm, image_size1=FLAGS.image_size1, image_size2=FLAGS.image_size2) 35 | 36 | input_image = tf.placeholder(tf.float32, shape=[FLAGS.image_size1, FLAGS.image_size2, 3], name='input_image') 37 | cycle_gan.model() 38 | if XtoY: 39 | output_image = cycle_gan.G.sample(tf.expand_dims(input_image, 0)) 40 | else: 41 | output_image = cycle_gan.F.sample(tf.expand_dims(input_image, 0)) 42 | 43 | output_image = tf.identity(output_image, name='output_image') 44 | restore_saver = tf.train.Saver() 45 | export_saver = tf.train.Saver() 46 | 47 | with tf.Session(graph=graph) as sess: 48 | sess.run(tf.global_variables_initializer()) 49 | latest_ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) 50 | restore_saver.restore(sess, latest_ckpt) 51 | output_graph_def = tf.graph_util.convert_variables_to_constants( 52 | sess, graph.as_graph_def(), [output_image.op.name]) 53 | 54 | tf.train.write_graph(output_graph_def, 'pretrained-densenet', model_name, as_text=False) 55 | 56 | def main(unused_argv): 57 | print('Export XtoY model...') 58 | export_graph(FLAGS.XtoY_model, XtoY=True) 59 | print('Export YtoX model...') 60 | export_graph(FLAGS.YtoX_model, XtoY=False) 61 | 62 | if __name__ == '__main__': 63 | tf.app.run() 64 | -------------------------------------------------------------------------------- /reader.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import utils 3 | 4 | class Reader(): 5 | def __init__(self, tfrecords_file, image_size1=256, image_size2=256, min_queue_examples=1000, batch_size=1, num_threads=8, name=''): 6 | """ 7 | Args: 8 | tfrecords_file: string, tfrecords file path 9 | min_queue_examples: integer, minimum number of samples to retain in the queue that provides of batches of examples 10 | batch_size: integer, number of images per batch 11 | num_threads: integer, number of preprocess threads 12 | """ 13 | self.tfrecords_file = tfrecords_file 14 | self.image_size1 = image_size1 15 | self.image_size2 = image_size2 16 | self.min_queue_examples = min_queue_examples 17 | self.batch_size = batch_size 18 | self.num_threads = num_threads 19 | self.reader = tf.TFRecordReader() 20 | self.name = name 21 | 22 | def feed(self): 23 | """ 24 | Returns: 25 | images: 4D tensor [batch_size, image_width, image_height, image_depth] 26 | """ 27 | with tf.name_scope(self.name): 28 | filename_queue = tf.train.string_input_producer([self.tfrecords_file]) 29 | reader = tf.TFRecordReader() 30 | 31 | _, serialized_example = self.reader.read(filename_queue) 32 | features = tf.parse_single_example( 33 | serialized_example, 34 | features={ 35 | 'image/file_name': tf.FixedLenFeature([], tf.string), 36 | 'image/encoded_image': tf.FixedLenFeature([], tf.string), 37 | }) 38 | 39 | image_buffer = features['image/encoded_image'] 40 | image = tf.image.decode_jpeg(image_buffer, channels=3) 41 | image = self._preprocess(image) 42 | images = tf.train.shuffle_batch( 43 | [image], batch_size=self.batch_size, num_threads=self.num_threads, 44 | capacity=self.min_queue_examples + 3*self.batch_size, 45 | min_after_dequeue=self.min_queue_examples 46 | ) 47 | 48 | # tf.summary.image('_input', images) 49 | return images 50 | 51 | def _preprocess(self, image): 52 | image = tf.image.resize_images(image, size=(self.image_size1, self.image_size2)) 53 | image = utils.convert2float(image) 54 | image.set_shape([self.image_size1, self.image_size2, 3]) 55 | return image 56 | 57 | def test_reader(): 58 | TRAIN_FILE_1 = 'data/tfrecords/apple.tfrecords' 59 | TRAIN_FILE_2 = 'data/tfrecords/orange.tfrecords' 60 | 61 | with tf.Graph().as_default(): 62 | reader1 = Reader(TRAIN_FILE_1, batch_size=2) 63 | reader2 = Reader(TRAIN_FILE_2, batch_size=2) 64 | images_op1 = reader1.feed() 65 | images_op2 = reader2.feed() 66 | 67 | sess = tf.Session() 68 | init = tf.global_variables_initializer() 69 | sess.run(init) 70 | 71 | coord = tf.train.Coordinator() 72 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 73 | 74 | try: 75 | step = 0 76 | while not coord.should_stop(): 77 | batch_images1, batch_images2 = sess.run([images_op1, images_op2]) 78 | print("image shape: {}".format(batch_images1)) 79 | print("image shape: {}".format(batch_images2)) 80 | print("="*10) 81 | step += 1 82 | except KeyboardInterrupt: 83 | print('Interrupted') 84 | coord.request_stop() 85 | except Exception as e: 86 | coord.request_stop(e) 87 | finally: 88 | # When done, ask the threads to stop. 89 | coord.request_stop() 90 | coord.join(threads) 91 | 92 | if __name__ == '__main__': 93 | test_reader() 94 | -------------------------------------------------------------------------------- /vgg16.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import os 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | import time 7 | 8 | VGG_MEAN = [103.939, 116.779, 123.68] 9 | 10 | 11 | class Vgg16: 12 | def __init__(self, vgg16_npy_path=None): 13 | if vgg16_npy_path is None: 14 | path = inspect.getfile(Vgg16) 15 | path = os.path.abspath(os.path.join(path, os.pardir)) 16 | path = os.path.join(path, "vgg16.npy") 17 | vgg16_npy_path = path 18 | print(path) 19 | 20 | self.data_dict = np.load(vgg16_npy_path, encoding='latin1').item() 21 | print("npy file loaded") 22 | 23 | def build(self, rgb): 24 | """ 25 | load variable from npy to build the VGG 26 | 27 | :param rgb: rgb image [batch, height, width, 3] values scaled [0, 1] 28 | """ 29 | 30 | start_time = time.time() 31 | print("build model started") 32 | rgb_scaled = rgb * 255.0 33 | 34 | # Convert RGB to BGR 35 | red, green, blue = tf.split(axis=3, num_or_size_splits=3, value=rgb_scaled) 36 | assert red.get_shape().as_list()[1:] == [224, 224, 1] 37 | assert green.get_shape().as_list()[1:] == [224, 224, 1] 38 | assert blue.get_shape().as_list()[1:] == [224, 224, 1] 39 | bgr = tf.concat(axis=3, values=[ 40 | blue - VGG_MEAN[0], 41 | green - VGG_MEAN[1], 42 | red - VGG_MEAN[2], 43 | ]) 44 | assert bgr.get_shape().as_list()[1:] == [224, 224, 3] 45 | 46 | self.conv1_1 = self.conv_layer(bgr, "conv1_1") 47 | self.conv1_2 = self.conv_layer(self.conv1_1, "conv1_2") 48 | self.pool1 = self.max_pool(self.conv1_2, 'pool1') 49 | 50 | self.conv2_1 = self.conv_layer(self.pool1, "conv2_1") 51 | self.conv2_2 = self.conv_layer(self.conv2_1, "conv2_2") 52 | self.pool2 = self.max_pool(self.conv2_2, 'pool2') 53 | 54 | self.conv3_1 = self.conv_layer(self.pool2, "conv3_1") 55 | self.conv3_2 = self.conv_layer(self.conv3_1, "conv3_2") 56 | self.conv3_3 = self.conv_layer(self.conv3_2, "conv3_3") 57 | self.pool3 = self.max_pool(self.conv3_3, 'pool3') 58 | 59 | self.conv4_1 = self.conv_layer(self.pool3, "conv4_1") 60 | self.conv4_2 = self.conv_layer(self.conv4_1, "conv4_2") 61 | self.conv4_3 = self.conv_layer(self.conv4_2, "conv4_3") 62 | self.pool4 = self.max_pool(self.conv4_3, 'pool4') 63 | 64 | self.conv5_1 = self.conv_layer(self.pool4, "conv5_1") 65 | self.conv5_2 = self.conv_layer(self.conv5_1, "conv5_2") 66 | self.conv5_3 = self.conv_layer(self.conv5_2, "conv5_3") 67 | self.pool5 = self.max_pool(self.conv5_3, 'pool5') 68 | 69 | #self.fc6 = self.fc_layer(self.pool5, "fc6") 70 | #assert self.fc6.get_shape().as_list()[1:] == [4096] 71 | #self.relu6 = tf.nn.relu(self.fc6) 72 | 73 | #self.fc7 = self.fc_layer(self.relu6, "fc7") 74 | #self.relu7 = tf.nn.relu(self.fc7) 75 | 76 | #self.fc8 = self.fc_layer(self.relu7, "fc8") 77 | 78 | #self.prob = tf.nn.softmax(self.fc8, name="prob") 79 | 80 | #self.data_dict = None 81 | print(("build model finished: %ds" % (time.time() - start_time))) 82 | return self.pool2, self.pool5 83 | 84 | def avg_pool(self, bottom, name): 85 | return tf.nn.avg_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name) 86 | 87 | def max_pool(self, bottom, name): 88 | return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name) 89 | 90 | def conv_layer(self, bottom, name): 91 | with tf.variable_scope(name): 92 | filt = self.get_conv_filter(name) 93 | 94 | conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME') 95 | 96 | conv_biases = self.get_bias(name) 97 | bias = tf.nn.bias_add(conv, conv_biases) 98 | 99 | relu = tf.nn.relu(bias) 100 | return relu 101 | 102 | def fc_layer(self, bottom, name): 103 | with tf.variable_scope(name): 104 | shape = bottom.get_shape().as_list() 105 | dim = 1 106 | for d in shape[1:]: 107 | dim *= d 108 | x = tf.reshape(bottom, [-1, dim]) 109 | 110 | weights = self.get_fc_weight(name) 111 | biases = self.get_bias(name) 112 | 113 | # Fully connected layer. Note that the '+' operation automatically 114 | # broadcasts the biases. 115 | fc = tf.nn.bias_add(tf.matmul(x, weights), biases) 116 | 117 | return fc 118 | 119 | def get_conv_filter(self, name): 120 | return tf.constant(self.data_dict[name][0], name="filter") 121 | 122 | def get_bias(self, name): 123 | return tf.constant(self.data_dict[name][1], name="biases") 124 | 125 | def get_fc_weight(self, name): 126 | return tf.constant(self.data_dict[name][0], name="weights") 127 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from model import CycleGAN 3 | from reader import Reader 4 | from datetime import datetime 5 | import os 6 | import logging 7 | 8 | import subprocess 9 | from utils import ImagePool 10 | # config = tf.ConfigProto() 11 | # config.gpu_options.allow_growth = True 12 | # session = tf.Session(config=config) 13 | config = tf.ConfigProto() 14 | config.gpu_options.allow_growth = True # TensorFlow按需分配显存 15 | config.gpu_options.per_process_gpu_memory_fraction = 1.0 # 指定显存分配比例 16 | session = tf.Session(config=config) 17 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 18 | 19 | FLAGS = tf.flags.FLAGS 20 | 21 | tf.flags.DEFINE_integer('batch_size', 1, 'batch size, default: 1') 22 | tf.flags.DEFINE_integer('image_size1', 256, 'image size, default: 256') 23 | tf.flags.DEFINE_integer('image_size2', 256, 'image size, default: 256') 24 | tf.flags.DEFINE_bool('use_lsgan', True, 25 | 'use lsgan (mean squared error) or cross entropy loss, default: True') 26 | tf.flags.DEFINE_string('norm', 'instance', 27 | '[instance, batch] use instance norm or batch norm, default: instance') 28 | tf.flags.DEFINE_integer('lambda1', 10, 29 | 'weight for forward cycle loss (X->Y->X), default: 10.0') 30 | tf.flags.DEFINE_integer('lambda2', 10, 31 | 'weight for backward cycle loss (Y->X->Y), default: 10.0') 32 | tf.flags.DEFINE_float('learning_rate', 1e-4, 33 | 'initial learning rate for Adam, default: 0.0002') 34 | tf.flags.DEFINE_float('beta1', 0.5, 35 | 'momentum term of Adam, default: 0.5') 36 | tf.flags.DEFINE_float('pool_size', 50, 37 | 'size of image buffer that stores previously generated images, default: 50') 38 | tf.flags.DEFINE_integer('ngf', 64, 39 | 'number of gen filters in first conv layer, default: 64') 40 | 41 | tf.flags.DEFINE_string('X', 'RS_train/fog.tfrecords', 42 | 'X tfrecords file for training, default: data/tfrecords/apple.tfrecords') 43 | tf.flags.DEFINE_string('Y', 'RS_train/unfog.tfrecords', 44 | 'Y tfrecords file for training, default: data/tfrecords/orange.tfrecords') 45 | tf.flags.DEFINE_string('Y_smooth', 'RS_train/unfog_smooth.tfrecords', 46 | 'Y tfrecords file for training, default: data/tfrecords/orange.tfrecords') 47 | tf.flags.DEFINE_string('load_model', None, 48 | 'folder of saved model that you wish to continue training (e.g. 20170602-1936), default: None') 49 | 50 | 51 | def train(): 52 | if FLAGS.load_model is not None: 53 | checkpoints_dir = FLAGS.load_model 54 | else: 55 | current_time = datetime.now().strftime("%Y%m%d-%H%M") 56 | checkpoints_dir = "RS_train/checkpoints" 57 | try: 58 | os.makedirs(checkpoints_dir) 59 | except os.error: 60 | pass 61 | 62 | graph = tf.Graph() 63 | with graph.as_default(): 64 | cycle_gan = CycleGAN( 65 | X_train_file=FLAGS.X, 66 | Y_train_file=FLAGS.Y, 67 | Y_smooth_train_file=FLAGS.Y_smooth, 68 | batch_size=FLAGS.batch_size, 69 | image_size1=FLAGS.image_size1, 70 | image_size2=FLAGS.image_size2, 71 | use_lsgan=FLAGS.use_lsgan, 72 | norm=FLAGS.norm, 73 | lambda1=FLAGS.lambda1, 74 | lambda2=FLAGS.lambda2, 75 | learning_rate=FLAGS.learning_rate, 76 | beta1=FLAGS.beta1, 77 | ngf=FLAGS.ngf 78 | ) 79 | G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model() 80 | optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss) 81 | 82 | summary_op = tf.summary.merge_all() 83 | train_writer = tf.summary.FileWriter(checkpoints_dir, graph) 84 | saver = tf.train.Saver(max_to_keep = 1000000) 85 | config = tf.ConfigProto() 86 | config.gpu_options.per_process_gpu_memory_fraction = 0.9 87 | with tf.Session(config=config, graph=graph) as sess: 88 | if FLAGS.load_model is not None: 89 | checkpoint = tf.train.latest_checkpoint(checkpoints_dir) 90 | meta_graph_path = str(checkpoint) + '.meta' 91 | print(tf.train.latest_checkpoint(meta_graph_path)) 92 | restore = tf.train.import_meta_graph(meta_graph_path) 93 | restore.restore(sess, checkpoint) 94 | step = 130000 95 | else: 96 | sess.run(tf.global_variables_initializer()) 97 | step = 0 98 | 99 | coord = tf.train.Coordinator() 100 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 101 | 102 | try: 103 | fake_Y_pool = ImagePool(FLAGS.pool_size) 104 | fake_X_pool = ImagePool(FLAGS.pool_size) 105 | 106 | while not coord.should_stop(): 107 | # get previously generated images 108 | fake_y_val, fake_x_val = sess.run([fake_y, fake_x]) 109 | 110 | # train 111 | _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = ( 112 | sess.run( 113 | [optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op], 114 | feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val), 115 | cycle_gan.fake_x: fake_X_pool.query(fake_x_val)})) 116 | train_writer.add_summary(summary, step) 117 | train_writer.flush() 118 | 119 | if step % 100 == 0: 120 | logging.info('-----------Step %d:-------------' % step) 121 | logging.info(' G_loss : {}'.format(G_loss_val)) 122 | logging.info(' D_Y_loss : {}'.format(D_Y_loss_val)) 123 | logging.info(' F_loss : {}'.format(F_loss_val)) 124 | logging.info(' D_X_loss : {}'.format(D_X_loss_val)) 125 | 126 | if step % 10 == 0: 127 | save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) 128 | logging.info("Model saved in file: %s" % save_path) 129 | # subprocess.call(r"./create_model.sh") 130 | step += 1 131 | 132 | except KeyboardInterrupt: 133 | logging.info('Interrupted') 134 | coord.request_stop() 135 | except Exception as e: 136 | coord.request_stop(e) 137 | finally: 138 | save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) 139 | logging.info("Model saved in file: %s" % save_path) 140 | # When done, ask the threads to stop. 141 | coord.request_stop() 142 | coord.join(threads) 143 | 144 | 145 | def main(unused_argv): 146 | train() 147 | 148 | 149 | if __name__ == '__main__': 150 | logging.basicConfig(level=logging.INFO) 151 | tf.app.run() 152 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | ## Layers: follow the naming convention used in the original paper 3 | ### Generator layers 4 | def c7s1_k(input, k, reuse=False, norm='instance', activation='relu', is_training=True, name='c7s1_k'): 5 | """ A 7x7 Convolution-BatchNorm-ReLU layer with k filters and stride 1 6 | Args: 7 | input: 4D tensor 8 | k: integer, number of filters (output depth) 9 | norm: 'instance' or 'batch' or None 10 | activation: 'relu' or 'tanh' 11 | name: string, e.g. 'c7sk-32' 12 | is_training: boolean or BoolTensor 13 | name: string 14 | reuse: boolean 15 | Returns: 16 | 4D tensor 17 | """ 18 | with tf.variable_scope(name, reuse=reuse): 19 | weights = _weights("weights", 20 | shape=[7, 7, input.get_shape()[3], k]) 21 | 22 | padded = tf.pad(input, [[0,0],[3,3],[3,3],[0,0]], 'REFLECT') 23 | conv = tf.nn.conv2d(padded, weights, 24 | strides=[1, 1, 1, 1], padding='VALID') 25 | 26 | normalized = _norm(conv, is_training, norm) 27 | 28 | if activation == 'relu': 29 | output = tf.nn.relu(normalized) 30 | if activation == 'tanh': 31 | output = tf.nn.tanh(normalized) 32 | return output 33 | 34 | def dk(input, k, reuse=False, norm='instance', is_training=True, name=None): 35 | """ A 3x3 Convolution-BatchNorm-ReLU layer with k filters and stride 2 36 | Args: 37 | input: 4D tensor 38 | k: integer, number of filters (output depth) 39 | norm: 'instance' or 'batch' or None 40 | is_training: boolean or BoolTensor 41 | name: string 42 | reuse: boolean 43 | name: string, e.g. 'd64' 44 | Returns: 45 | 4D tensor 46 | """ 47 | with tf.variable_scope(name, reuse=reuse): 48 | weights = _weights("weights", 49 | shape=[3, 3, input.get_shape()[3], k]) 50 | 51 | conv = tf.nn.conv2d(input, weights, 52 | strides=[1, 2, 2, 1], padding='SAME') 53 | normalized = _norm(conv, is_training, norm) 54 | output = tf.nn.relu(normalized) 55 | return output 56 | 57 | 58 | def Rk(input, k, reuse=False, norm='instance', is_training=True, name=None): 59 | """ A residual block that contains two 3x3 convolutional layers 60 | with the same number of filters on both layer 61 | Args: 62 | input: 4D Tensor 63 | k: integer, number of filters (output depth) 64 | reuse: boolean 65 | name: string 66 | Returns: 67 | 4D tensor (same shape as input) 68 | """ 69 | with tf.variable_scope(name, reuse=reuse): 70 | with tf.variable_scope('layer1', reuse=reuse): 71 | weights1 = _weights("weights1", 72 | shape=[3, 3, input.get_shape()[3], k]) 73 | padded1 = tf.pad(input, [[0,0],[1,1],[1,1],[0,0]], 'REFLECT') 74 | conv1 = tf.nn.conv2d(padded1, weights1, 75 | strides=[1, 1, 1, 1], padding='VALID') 76 | normalized1 = _norm(conv1, is_training, norm) 77 | relu1 = tf.nn.relu(normalized1) 78 | 79 | with tf.variable_scope('layer2', reuse=reuse): 80 | weights2 = _weights("weights2", 81 | shape=[3, 3, relu1.get_shape()[3], k]) 82 | 83 | padded2 = tf.pad(relu1, [[0,0],[1,1],[1,1],[0,0]], 'REFLECT') 84 | conv2 = tf.nn.conv2d(padded2, weights2, 85 | strides=[1, 1, 1, 1], padding='VALID') 86 | normalized2 = _norm(conv2, is_training, norm) 87 | output = input+normalized2 88 | return output 89 | 90 | def n_res_blocks(input, reuse, norm='instance', is_training=True, n=6): 91 | depth = input.get_shape()[3] 92 | for i in range(1,n+1): 93 | output = Rk(input, depth, reuse, norm, is_training, 'R{}_{}'.format(depth, i)) 94 | input = output 95 | return output 96 | 97 | def uk(input, k, reuse=False, norm='instance', is_training=True, name=None, output_size1=None, output_size2=None): 98 | """ A 3x3 fractional-strided-Convolution-BatchNorm-ReLU layer 99 | with k filters, stride 1/2 100 | Args: 101 | input: 4D tensor 102 | k: integer, number of filters (output depth) 103 | norm: 'instance' or 'batch' or None 104 | is_training: boolean or BoolTensor 105 | reuse: boolean 106 | name: string, e.g. 'c7sk-32' 107 | output_size: integer, desired output size of layer 108 | Returns: 109 | 4D tensor 110 | """ 111 | with tf.variable_scope(name, reuse=reuse): 112 | input_shape = input.get_shape().as_list() 113 | 114 | weights = _weights("weights", 115 | shape=[3, 3, k, input_shape[3]]) 116 | 117 | if not output_size1: 118 | output_size1 = input_shape[1]*2 119 | output_size2 = input_shape[2]*2 120 | output_shape = [input_shape[0], output_size1, output_size2, k] 121 | fsconv = tf.nn.conv2d_transpose(input, weights, 122 | output_shape=output_shape, 123 | strides=[1, 2, 2, 1], padding='SAME') 124 | normalized = _norm(fsconv, is_training, norm) 125 | output = tf.nn.relu(normalized) 126 | return output 127 | 128 | 129 | ### Discriminator layers 130 | def Ck(input, k, slope=0.2, stride=2, reuse=False, norm='instance', is_training=True, name=None): 131 | """ A 4x4 Convolution-BatchNorm-LeakyReLU layer with k filters and stride 2 132 | Args: 133 | input: 4D tensor 134 | k: integer, number of filters (output depth) 135 | slope: LeakyReLU's slope 136 | stride: integer 137 | norm: 'instance' or 'batch' or None 138 | is_training: boolean or BoolTensor 139 | reuse: boolean 140 | name: string, e.g. 'C64' 141 | Returns: 142 | 4D tensor 143 | """ 144 | with tf.variable_scope(name, reuse=reuse): 145 | weights = _weights("weights", 146 | shape=[4, 4, input.get_shape()[3], k]) 147 | 148 | conv = tf.nn.conv2d(input, weights, 149 | strides=[1, stride, stride, 1], padding='SAME') 150 | 151 | normalized = _norm(conv, is_training, norm) 152 | output = _leaky_relu(normalized, slope) 153 | return output 154 | 155 | 156 | def last_conv(input, reuse=False, use_sigmoid=False, name=None): 157 | """ Last convolutional layer of discriminator network 158 | (1 filter with size 4x4, stride 1) 159 | Args: 160 | input: 4D tensor 161 | reuse: boolean 162 | use_sigmoid: boolean (False if use lsgan) 163 | name: string, e.g. 'C64' 164 | """ 165 | with tf.variable_scope(name, reuse=reuse): 166 | weights = _weights("weights", 167 | shape=[4, 4, input.get_shape()[3], 1]) 168 | biases = _biases("biases", [1]) 169 | 170 | conv = tf.nn.conv2d(input, weights, 171 | strides=[1, 1, 1, 1], padding='SAME') 172 | output = conv + biases 173 | if use_sigmoid: 174 | output = tf.sigmoid(output) 175 | return output 176 | 177 | ### Helpers 178 | def _weights(name, shape, mean=0.0, stddev=0.02): 179 | """ Helper to create an initialized Variable 180 | Args: 181 | name: name of the variable 182 | shape: list of ints 183 | mean: mean of a Gaussian 184 | stddev: standard deviation of a Gaussian 185 | Returns: 186 | A trainable variable 187 | """ 188 | var = tf.get_variable( 189 | name, shape, 190 | initializer=tf.random_normal_initializer( 191 | mean=mean, stddev=stddev, dtype=tf.float32)) 192 | return var 193 | 194 | def _biases(name, shape, constant=0.0): 195 | """ Helper to create an initialized Bias with constant 196 | """ 197 | return tf.get_variable(name, shape, 198 | initializer=tf.constant_initializer(constant)) 199 | 200 | def _leaky_relu(input, slope): 201 | return tf.maximum(slope*input, input) 202 | 203 | def _norm(input, is_training, norm='instance'): 204 | """ Use Instance Normalization or Batch Normalization or None 205 | """ 206 | if norm == 'instance': 207 | return _instance_norm(input) 208 | elif norm == 'batch': 209 | return _batch_norm(input, is_training) 210 | else: 211 | return input 212 | 213 | def _batch_norm(input, is_training): 214 | """ Batch Normalization 215 | """ 216 | with tf.variable_scope("batch_norm"): 217 | return tf.contrib.layers.batch_norm(input, 218 | decay=0.9, 219 | scale=True, 220 | updates_collections=None, 221 | is_training=is_training) 222 | 223 | def _instance_norm(input): 224 | """ Instance Normalization 225 | """ 226 | with tf.variable_scope("instance_norm"): 227 | depth = input.get_shape()[3] 228 | scale = _weights("scale", [depth], mean=1.0) 229 | offset = _biases("offset", [depth]) 230 | mean, variance = tf.nn.moments(input, axes=[1,2], keep_dims=True) 231 | epsilon = 1e-5 232 | inv = tf.rsqrt(variance + epsilon) 233 | normalized = (input-mean)*inv 234 | return scale*normalized + offset 235 | 236 | def safe_log(x, eps=1e-12): 237 | return tf.log(x + eps) 238 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import ops 3 | import subprocess as sb 4 | import utils 5 | from reader import Reader 6 | from discriminator import Discriminator 7 | from generator import Generator 8 | import numpy as np 9 | 10 | import vgg16 11 | 12 | REAL_LABEL = 0.9 13 | 14 | class CycleGAN: 15 | def __init__(self, 16 | X_train_file='', 17 | Y_train_file='', 18 | Y_smooth_train_file='', 19 | batch_size=1, 20 | image_size1=256, 21 | image_size2=256, 22 | use_lsgan=True, 23 | norm='instance', 24 | lambda1=10.0, 25 | lambda2=10.0, 26 | learning_rate=1e-4, 27 | beta1=0.5, 28 | ngf=64 29 | ): 30 | """ 31 | Args: 32 | X_train_file: string, X tfrecords file for training 33 | Y_train_file: string Y tfrecords file for training 34 | batch_size: integer, batch size 35 | image_size: integer, image size 36 | lambda1: integer, weight for forward cycle loss (X->Y->X) 37 | lambda2: integer, weight for backward cycle loss (Y->X->Y) 38 | use_lsgan: boolean 39 | norm: 'instance' or 'batch' 40 | learning_rate: float, initial learning rate for Adam 41 | beta1: float, momentum term of Adam 42 | ngf: number of gen filters in first conv layer 43 | """ 44 | self.lambda1 = lambda1 45 | self.lambda2 = lambda2 46 | self.use_lsgan = use_lsgan 47 | use_sigmoid = not use_lsgan 48 | self.batch_size = batch_size 49 | self.image_size1 = image_size1 50 | self.image_size2 = image_size2 51 | self.learning_rate = learning_rate 52 | self.beta1 = beta1 53 | self.X_train_file = X_train_file 54 | self.Y_train_file = Y_train_file 55 | self.Y_smooth_train_file = Y_smooth_train_file 56 | 57 | self.is_training = tf.placeholder_with_default(True, shape=[], name='is_training') 58 | 59 | 60 | self.G = Generator('G', self.is_training, ngf=ngf, norm=norm, image_size1=image_size1, image_size2=image_size2) 61 | self.D_Y = Discriminator('D_Y', 62 | self.is_training, norm=norm, use_sigmoid=use_sigmoid) 63 | self.F = Generator('F', self.is_training, ngf=ngf, norm=norm, image_size1=image_size1, image_size2=image_size2) 64 | self.D_X = Discriminator('D_X', 65 | self.is_training, norm=norm, use_sigmoid=use_sigmoid) 66 | 67 | self.fake_x = tf.placeholder(tf.float32, 68 | shape=[batch_size, image_size1, image_size2, 3]) 69 | self.fake_y = tf.placeholder(tf.float32, 70 | shape=[batch_size, image_size1, image_size2, 3]) 71 | 72 | self.vgg = vgg16.Vgg16() 73 | 74 | def model(self): 75 | X_reader = Reader(self.X_train_file, name='X', 76 | image_size1=self.image_size1, image_size2=self.image_size2, batch_size=self.batch_size) 77 | Y_reader = Reader(self.Y_train_file, name='Y', 78 | image_size1=self.image_size1, image_size2=self.image_size2, batch_size=self.batch_size) 79 | Y_smooth_reader = Reader(self.Y_smooth_train_file, name='Y_smooth', 80 | image_size1=self.image_size1, image_size2=self.image_size2, batch_size=self.batch_size) 81 | 82 | x = X_reader.feed() 83 | y = Y_reader.feed() 84 | y_smooth = Y_smooth_reader.feed() 85 | 86 | 87 | cycle_loss = self.cycle_consistency_loss(self.G, self.F, x, y) 88 | perceptual_loss = self.perceptual_similarity_loss(self.G, self.F, x, y, self.vgg) 89 | 90 | # X -> Y 91 | fake_y = self.G(x) 92 | G_gan_loss = self.generator_loss(self.D_Y, fake_y, use_lsgan=self.use_lsgan) 93 | G_loss = G_gan_loss + cycle_loss + perceptual_loss #+ pixel_loss 94 | D_Y_loss = self.discriminator_loss_Y(self.D_Y, y,self.fake_y,y_smooth,use_lsgan=self.use_lsgan) 95 | 96 | # Y -> X 97 | fake_x = self.F(y) 98 | F_gan_loss = self.generator_loss(self.D_X, fake_x, use_lsgan=self.use_lsgan) 99 | F_loss = F_gan_loss + cycle_loss + perceptual_loss #+ pixel_loss 100 | D_X_loss = self.discriminator_loss_X(self.D_X, x,self.fake_x,y_smooth, use_lsgan=self.use_lsgan) 101 | 102 | 103 | # summary 104 | 105 | tf.summary.scalar('loss/G', G_gan_loss) 106 | tf.summary.scalar('loss/D_Y', D_Y_loss) 107 | tf.summary.scalar('loss/F', F_gan_loss) 108 | tf.summary.scalar('loss/D_X', D_X_loss) 109 | tf.summary.scalar('loss/cycle', cycle_loss) 110 | tf.summary.scalar('loss/perceptual_loss', perceptual_loss) 111 | 112 | return G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x 113 | 114 | def optimize(self, G_loss, D_Y_loss, F_loss, D_X_loss): 115 | def make_optimizer(loss, variables, name='Adam'): 116 | """ Adam optimizer with learning rate 0.0002 for the first 100k steps (~100 epochs) 117 | and a linearly decaying rate that goes to zero over the next 100k steps 118 | """ 119 | global_step = tf.Variable(0, trainable=False) 120 | starter_learning_rate = self.learning_rate 121 | end_learning_rate = 0.0 122 | start_decay_step = 100000 123 | decay_steps = 100000 124 | beta1 = self.beta1 125 | learning_rate = ( 126 | tf.where( 127 | tf.greater_equal(global_step, start_decay_step), 128 | tf.train.polynomial_decay(starter_learning_rate, global_step-start_decay_step, 129 | decay_steps, end_learning_rate, 130 | power=1.0), 131 | starter_learning_rate 132 | ) 133 | 134 | ) 135 | tf.summary.scalar('learning_rate/{}'.format(name), learning_rate) 136 | 137 | learning_step = ( 138 | tf.train.AdamOptimizer(learning_rate, beta1=beta1, name=name) 139 | .minimize(loss, global_step=global_step, var_list=variables) 140 | ) 141 | return learning_step 142 | 143 | G_optimizer = make_optimizer(G_loss, self.G.variables, name='Adam_G') 144 | D_Y_optimizer = make_optimizer(D_Y_loss, self.D_Y.variables, name='Adam_D_Y') 145 | F_optimizer = make_optimizer(F_loss, self.F.variables, name='Adam_F') 146 | D_X_optimizer = make_optimizer(D_X_loss, self.D_X.variables, name='Adam_D_X') 147 | with tf.control_dependencies([G_optimizer, D_Y_optimizer, F_optimizer, D_X_optimizer]): 148 | return tf.no_op(name='optimizers') 149 | 150 | def discriminator_loss_Y(self, D, y, fake_y,fake_xy_smooth, use_lsgan=True): 151 | """ Note: default: D(y).shape == (batch_size,5,5,1), 152 | fake_buffer_size=50, batch_size=1 153 | Args: 154 | G: generator object 155 | D: discriminator object 156 | y: 4D tensor (batch_size, image_size, image_size, 3) 157 | Returns: 158 | loss: scalar 159 | """ 160 | if use_lsgan: 161 | # use mean squared error 162 | c = D(y) 163 | error_real = tf.reduce_mean(tf.squared_difference(c, REAL_LABEL)) 164 | error_fake_y = tf.reduce_mean(tf.square(D(fake_y))) 165 | error_fake_y_smooth = tf.reduce_mean(tf.square(D(fake_xy_smooth))) 166 | else: 167 | # use cross entropy 168 | error_real = -tf.reduce_mean(ops.safe_log(D(y))) 169 | error_fake_y = -tf.reduce_mean(ops.safe_log(1-D(fake_y))) 170 | error_fake_y_smooth = -tf.reduce_mean(ops.safe_log(1-D(fake_xy_smooth))) 171 | loss = (error_real + error_fake_y + error_fake_y_smooth) / 3 172 | return loss 173 | 174 | def discriminator_loss_X(self, D, y, fake_y, real_yx_smooth,use_lsgan=True): 175 | """ Note: default: D(y).shape == (batch_size,5,5,1), 176 | fake_buffer_size=50, batch_size=1 177 | Args: 178 | G: generator object 179 | D: discriminator object 180 | y: 4D tensor (batch_size, image_size, image_size, 3) 181 | Returns: 182 | loss: scalar 183 | """ 184 | if use_lsgan: 185 | # use mean squared error 186 | c = D(y) 187 | error_real = tf.reduce_mean(tf.squared_difference(c, REAL_LABEL)) 188 | error_fake = tf.reduce_mean(tf.square(D(fake_y))) 189 | error_real_y_smooth = tf.reduce_mean(tf.squared_difference(D(real_yx_smooth),REAL_LABEL)) 190 | else: 191 | # use cross entropy 192 | error_real = -tf.reduce_mean(ops.safe_log(D(y))) 193 | error_fake = -tf.reduce_mean(ops.safe_log(1-D(fake_y))) 194 | error_real_y_smooth = -tf.reduce_mean(ops.safe_log(D(real_yx_smooth))) 195 | loss = (error_real + error_fake + error_real_y_smooth) / 3 196 | return loss 197 | 198 | def generator_loss(self, D, fake_y, use_lsgan=True): 199 | """ fool discriminator into believing that G(x) is real 200 | """ 201 | if use_lsgan: 202 | # use mean squared error 203 | loss = tf.reduce_mean(tf.squared_difference(D(fake_y), REAL_LABEL)) 204 | else: 205 | # heuristic, non-saturating loss 206 | loss = -tf.reduce_mean(ops.safe_log(D(fake_y))) / 2 207 | return loss 208 | 209 | def cycle_consistency_loss(self, G, F, x, y): 210 | """ cycle consistency loss (L1 norm) 211 | """ 212 | forward_loss = tf.reduce_mean(tf.abs(F(G(x))-x)) 213 | backward_loss = tf.reduce_mean(tf.abs(G(F(y))-y)) 214 | loss = self.lambda1*forward_loss + self.lambda2*backward_loss 215 | return loss 216 | 217 | def perceptual_similarity_loss(self, G, F, x, y, vgg): 218 | x1 = tf.image.resize_images(x, [224,224]) # to feed vgg, need resize 219 | y1 = tf.image.resize_images(y, [224,224]) 220 | 221 | rx = F(G(x)) #create reconstructed images 222 | ry = G(F(y)) 223 | 224 | rx1 = tf.image.resize_images(rx, [224,224]) # to feed vgg, need resize 225 | ry1 = tf.image.resize_images(ry, [224,224]) 226 | 227 | fx1, fx2 = vgg.build(x1) # extract features from vgg 228 | fy1, fy2 = vgg.build(y1) 229 | 230 | frx1, frx2 = vgg.build(rx1) # extract features from vgg (2nd pool & 5th pool 231 | fry1, fry2 = vgg.build(ry1) 232 | 233 | m1 = tf.reduce_mean(tf.squared_difference(fx1, frx1)) # mse difference 234 | m2 = tf.reduce_mean(tf.squared_difference(fx2, frx2)) 235 | 236 | m3 = tf.reduce_mean(tf.squared_difference(fy1, fry1)) 237 | m4 = tf.reduce_mean(tf.squared_difference(fy2, fry2)) 238 | 239 | perceptual_loss = (m1 + m2 + m3 + m4)*0.00001*0.5 # calculate perceptual loss and give weight (0.00001*0.5) 240 | return perceptual_loss 241 | 242 | # def pixel_wise_loss(self, G, F, x, y): 243 | # rx = F(G(x)) 244 | # ry = G(F(y)) 245 | # pixel_wise_loss = tf.reduce_mean(tf.squared_difference(x, rx)) + tf.reduce_mean(tf.squared_difference(y, ry)) 246 | # return 10*pixel_wise_loss 247 | 248 | 249 | --------------------------------------------------------------------------------