├── Figure ├── Illumination_aware.png └── PIAFusion.png ├── Fusion_results ├── MSRS │ ├── 00537D.png │ ├── 00556D.png │ ├── 00633D.png │ ├── 00881N.png │ └── 01023N.png ├── RoadScene │ ├── 037.png │ ├── 100.png │ └── 108.png └── TNO │ ├── 05.png │ ├── 17.png │ └── 18.png ├── LICENSE ├── README.md ├── __pycache__ ├── model.cpython-37.pyc ├── ops.cpython-37.pyc ├── train_network.cpython-37.pyc └── utils.cpython-37.pyc ├── checkpoint ├── Illumination │ ├── Illumination.model-99.data-00000-of-00001 │ ├── Illumination.model-99.index │ ├── Illumination.model-99.meta │ └── checkpoint └── PIAFusion │ ├── IAFusion.model-29.data-00000-of-00001 │ ├── IAFusion.model-29.index │ ├── IAFusion.model-29.meta │ └── checkpoint ├── main.py ├── model.py ├── ops.py ├── test_data ├── MSRS │ ├── ir │ │ ├── 00537D.png │ │ ├── 00556D.png │ │ ├── 00633D.png │ │ ├── 00881N.png │ │ └── 01023N.png │ └── vi │ │ ├── 00537D.png │ │ ├── 00556D.png │ │ ├── 00633D.png │ │ ├── 00881N.png │ │ └── 01023N.png ├── RoadScene │ ├── ir │ │ ├── 037.png │ │ ├── 100.png │ │ └── 108.png │ └── vi │ │ ├── 037.png │ │ ├── 100.png │ │ └── 108.png └── TNO │ ├── ir │ ├── 05.png │ ├── 17.png │ └── 18.png │ └── vi │ ├── 05.png │ ├── 17.png │ └── 18.png ├── train_network.py └── utils.py /Figure/Illumination_aware.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/Figure/Illumination_aware.png -------------------------------------------------------------------------------- /Figure/PIAFusion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/Figure/PIAFusion.png -------------------------------------------------------------------------------- /Fusion_results/MSRS/00537D.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/Fusion_results/MSRS/00537D.png -------------------------------------------------------------------------------- /Fusion_results/MSRS/00556D.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/Fusion_results/MSRS/00556D.png -------------------------------------------------------------------------------- /Fusion_results/MSRS/00633D.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/Fusion_results/MSRS/00633D.png -------------------------------------------------------------------------------- /Fusion_results/MSRS/00881N.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/Fusion_results/MSRS/00881N.png -------------------------------------------------------------------------------- /Fusion_results/MSRS/01023N.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/Fusion_results/MSRS/01023N.png -------------------------------------------------------------------------------- /Fusion_results/RoadScene/037.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/Fusion_results/RoadScene/037.png -------------------------------------------------------------------------------- /Fusion_results/RoadScene/100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/Fusion_results/RoadScene/100.png -------------------------------------------------------------------------------- /Fusion_results/RoadScene/108.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/Fusion_results/RoadScene/108.png -------------------------------------------------------------------------------- /Fusion_results/TNO/05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/Fusion_results/TNO/05.png -------------------------------------------------------------------------------- /Fusion_results/TNO/17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/Fusion_results/TNO/17.png -------------------------------------------------------------------------------- /Fusion_results/TNO/18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/Fusion_results/TNO/18.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Linfeng Tang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PIAFusion 2 | This is official tensorflow implementation of “[PIAFusion: A progressive infrared and visible image fusion network based on illumination aware](https://www.sciencedirect.com/science/article/abs/pii/S156625352200032X)”. 3 | 4 | The PyTorch implementation of our project, accomplished by @[linklist2](https://github.com/linklist2), can be fetched from [https://github.com/linklist2/PIAFusion_pytorch](https://github.com/linklist2/PIAFusion_pytorch). 5 | 6 | A new benchmark dataset for infrared and visible fusion are released in this paper, which is termed **[MSRS](https://github.com/Linfeng-Tang/MSRS)**. 7 | 8 | ## Architecture 9 | ![The overall framework of the progressive infrared and visible image fusion algorithm based on illumination-aware.](https://github.com/Linfeng-Tang/PIAFusion/blob/main/Figure/PIAFusion.png) 10 | 11 | ## Example 12 | 13 | ![An example of illumination imbalance.](https://github.com/Linfeng-Tang/PIAFusion/blob/main/Figure/Illumination_aware.png) 14 | An example of illumination imbalance. From left to right: infrared image, visible image, the fused results of DenseFuse, FusionGAN, and our proposed PIAFusion. 15 | The visible image contains abundant information, such as texture details in the daytime (top row). But salient targets and textures are all included in the infrared image at nighttime (bottom row). Existing methods ignore the illumination imbalance issues, causing detail loss and thermal target degradation. Our algorithm can adaptively integrate meaningful information according to illumination conditions. 16 | ## Recommended Environment 17 | 18 | - [ ] tensorflow-gpu 1.14.0 19 | - [ ] scipy 1.2.0 20 | - [ ] numpy 1.19.2 21 | - [ ] opencv 3.4.2 22 | 23 | ## To Training 24 | 25 | ### Training the Illumination-Aware Sub-Network 26 | Run: "python main.py --epoch=100 --is_train=True model_type=Illum --DataSet=MSRS" 27 | The dataset for training the illumination-aware sub-network can be download from [data_illum.h5](https://pan.baidu.com/s/1D7XVGFyPgn9lH6JxYXt65Q?pwd=PIAF). 28 | 29 | ### Training the Illmination-Aware Fusion Network 30 | Run: "python main.py --epoch=30 --is_train=True model_type=PIAFusion --DataSet=MSRS" 31 | The dataset for training the illumination-aware fusion network can be download from [data_MSRS.h5](https://pan.baidu.com/s/1D7XVGFyPgn9lH6JxYXt65Q?pwd=PIAF). 32 | 33 | ## To Testing 34 | ### The MSRS Dataset 35 | Run: "python main.py --is_train=False model_type=PIAFusion --DataSet=MSRS" 36 | 37 | ### The RoadScene Dataset 38 | Run: "python main.py --is_train=False model_type=PIAFusion --DataSet=RoadScene" 39 | 40 | ### The TNO Dataset 41 | Run: "python main.py --is_train=False model_type=PIAFusion --DataSet=TNO" 42 | 43 | ## Acknowledgement 44 | Our Multi-Spectral Road Scenarios (**[MSRS](https://github.com/Linfeng-Tang/MSRS)**) dataset is constructed on the basis of the **[MFNet](https://www.mi.t.u-tokyo.ac.jp/static/projects/mil_multispectral/)** dataset[1]. 45 | 46 | [1] Ha, Q., Watanabe, K., Karasawa, T., Ushiku, Y., Harada, T., 2017. Mfnet: Towards real-time semantic segmentation for autonomous vehicles with multi-spectral scenes, in: Proceedings of the IEEE International Conference on Intelligent Robots and Systems, pp.5108–5115. 47 | 48 | ## If this work is helpful to you, please cite it as: 49 | ``` 50 | @article{Tang2022PIAFusion, 51 | title={PIAFusion: A progressive infrared and visible image fusion network based on illumination aware}, 52 | author={Tang, Linfeng and Yuan, Jiteng and Zhang, Hao and Jiang, Xingyu and Ma, Jiayi}, 53 | journal={Information Fusion}, 54 | volume = {83-84}, 55 | pages = {79-92}, 56 | year = {2022}, 57 | issn = {1566-2535}, 58 | publisher={Elsevier} 59 | } 60 | ``` 61 | -------------------------------------------------------------------------------- /__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/ops.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/__pycache__/ops.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/train_network.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/__pycache__/train_network.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /checkpoint/Illumination/Illumination.model-99.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/checkpoint/Illumination/Illumination.model-99.data-00000-of-00001 -------------------------------------------------------------------------------- /checkpoint/Illumination/Illumination.model-99.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/checkpoint/Illumination/Illumination.model-99.index -------------------------------------------------------------------------------- /checkpoint/Illumination/Illumination.model-99.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/checkpoint/Illumination/Illumination.model-99.meta -------------------------------------------------------------------------------- /checkpoint/Illumination/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "Illumination.model-99" 2 | all_model_checkpoint_paths: "Illumination.model-50" 3 | all_model_checkpoint_paths: "Illumination.model-51" 4 | all_model_checkpoint_paths: "Illumination.model-52" 5 | all_model_checkpoint_paths: "Illumination.model-53" 6 | all_model_checkpoint_paths: "Illumination.model-54" 7 | all_model_checkpoint_paths: "Illumination.model-55" 8 | all_model_checkpoint_paths: "Illumination.model-56" 9 | all_model_checkpoint_paths: "Illumination.model-57" 10 | all_model_checkpoint_paths: "Illumination.model-58" 11 | all_model_checkpoint_paths: "Illumination.model-59" 12 | all_model_checkpoint_paths: "Illumination.model-60" 13 | all_model_checkpoint_paths: "Illumination.model-61" 14 | all_model_checkpoint_paths: "Illumination.model-62" 15 | all_model_checkpoint_paths: "Illumination.model-63" 16 | all_model_checkpoint_paths: "Illumination.model-64" 17 | all_model_checkpoint_paths: "Illumination.model-65" 18 | all_model_checkpoint_paths: "Illumination.model-66" 19 | all_model_checkpoint_paths: "Illumination.model-67" 20 | all_model_checkpoint_paths: "Illumination.model-68" 21 | all_model_checkpoint_paths: "Illumination.model-69" 22 | all_model_checkpoint_paths: "Illumination.model-70" 23 | all_model_checkpoint_paths: "Illumination.model-71" 24 | all_model_checkpoint_paths: "Illumination.model-72" 25 | all_model_checkpoint_paths: "Illumination.model-73" 26 | all_model_checkpoint_paths: "Illumination.model-74" 27 | all_model_checkpoint_paths: "Illumination.model-75" 28 | all_model_checkpoint_paths: "Illumination.model-76" 29 | all_model_checkpoint_paths: "Illumination.model-77" 30 | all_model_checkpoint_paths: "Illumination.model-78" 31 | all_model_checkpoint_paths: "Illumination.model-79" 32 | all_model_checkpoint_paths: "Illumination.model-80" 33 | all_model_checkpoint_paths: "Illumination.model-81" 34 | all_model_checkpoint_paths: "Illumination.model-82" 35 | all_model_checkpoint_paths: "Illumination.model-83" 36 | all_model_checkpoint_paths: "Illumination.model-84" 37 | all_model_checkpoint_paths: "Illumination.model-85" 38 | all_model_checkpoint_paths: "Illumination.model-86" 39 | all_model_checkpoint_paths: "Illumination.model-87" 40 | all_model_checkpoint_paths: "Illumination.model-88" 41 | all_model_checkpoint_paths: "Illumination.model-89" 42 | all_model_checkpoint_paths: "Illumination.model-90" 43 | all_model_checkpoint_paths: "Illumination.model-91" 44 | all_model_checkpoint_paths: "Illumination.model-92" 45 | all_model_checkpoint_paths: "Illumination.model-93" 46 | all_model_checkpoint_paths: "Illumination.model-94" 47 | all_model_checkpoint_paths: "Illumination.model-95" 48 | all_model_checkpoint_paths: "Illumination.model-96" 49 | all_model_checkpoint_paths: "Illumination.model-97" 50 | all_model_checkpoint_paths: "Illumination.model-98" 51 | all_model_checkpoint_paths: "Illumination.model-99" 52 | -------------------------------------------------------------------------------- /checkpoint/PIAFusion/IAFusion.model-29.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/checkpoint/PIAFusion/IAFusion.model-29.data-00000-of-00001 -------------------------------------------------------------------------------- /checkpoint/PIAFusion/IAFusion.model-29.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/checkpoint/PIAFusion/IAFusion.model-29.index -------------------------------------------------------------------------------- /checkpoint/PIAFusion/IAFusion.model-29.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/checkpoint/PIAFusion/IAFusion.model-29.meta -------------------------------------------------------------------------------- /checkpoint/PIAFusion/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "IAFusion.model-29" 2 | all_model_checkpoint_paths: "IAFusion.model-0" 3 | all_model_checkpoint_paths: "IAFusion.model-1" 4 | all_model_checkpoint_paths: "IAFusion.model-2" 5 | all_model_checkpoint_paths: "IAFusion.model-3" 6 | all_model_checkpoint_paths: "IAFusion.model-4" 7 | all_model_checkpoint_paths: "IAFusion.model-5" 8 | all_model_checkpoint_paths: "IAFusion.model-6" 9 | all_model_checkpoint_paths: "IAFusion.model-7" 10 | all_model_checkpoint_paths: "IAFusion.model-8" 11 | all_model_checkpoint_paths: "IAFusion.model-9" 12 | all_model_checkpoint_paths: "IAFusion.model-10" 13 | all_model_checkpoint_paths: "IAFusion.model-11" 14 | all_model_checkpoint_paths: "IAFusion.model-12" 15 | all_model_checkpoint_paths: "IAFusion.model-13" 16 | all_model_checkpoint_paths: "IAFusion.model-14" 17 | all_model_checkpoint_paths: "IAFusion.model-15" 18 | all_model_checkpoint_paths: "IAFusion.model-16" 19 | all_model_checkpoint_paths: "IAFusion.model-17" 20 | all_model_checkpoint_paths: "IAFusion.model-18" 21 | all_model_checkpoint_paths: "IAFusion.model-19" 22 | all_model_checkpoint_paths: "IAFusion.model-20" 23 | all_model_checkpoint_paths: "IAFusion.model-21" 24 | all_model_checkpoint_paths: "IAFusion.model-22" 25 | all_model_checkpoint_paths: "IAFusion.model-23" 26 | all_model_checkpoint_paths: "IAFusion.model-24" 27 | all_model_checkpoint_paths: "IAFusion.model-25" 28 | all_model_checkpoint_paths: "IAFusion.model-26" 29 | all_model_checkpoint_paths: "IAFusion.model-27" 30 | all_model_checkpoint_paths: "IAFusion.model-28" 31 | all_model_checkpoint_paths: "IAFusion.model-29" 32 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # from train import model 2 | from model import PIAFusion 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | import pprint 7 | import os 8 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 9 | 10 | flags = tf.app.flags 11 | flags.DEFINE_integer("epoch", 30, "Number of epoch [10]") 12 | flags.DEFINE_integer("batch_size", 64, "The size of batch images [128]") 13 | flags.DEFINE_integer("image_size", 64, "The size of image to use [33]") 14 | flags.DEFINE_integer("label_size", 2, "The size of label to produce [21]") 15 | flags.DEFINE_float("learning_rate", 1e-3, "The learning rate of gradient descent algorithm [1e-4]") 16 | flags.DEFINE_integer("stride", 24, "The size of stride to apply input image [14]") 17 | flags.DEFINE_string("checkpoint_dir", "checkpoint", "Name of checkpoint directory [checkpoint]") 18 | flags.DEFINE_string("summary_dir", "log", "Name of log directory [log]") 19 | flags.DEFINE_boolean("is_train", False, "True for training, False for testing [True]") 20 | flags.DEFINE_string("model_type", 'PIAFusion', "Illum for training the Illumination Aware network," 21 | " PIAFusion for training the Fusion Network [PIAFusion]") 22 | flags.DEFINE_string("DataSet", 'MSRS', "The Dataset for Testing, TNO, RoadScene, MSRS, [TNO]") 23 | FLAGS = flags.FLAGS 24 | 25 | pp = pprint.PrettyPrinter() 26 | def main(_): 27 | if not os.path.exists(FLAGS.checkpoint_dir): 28 | os.makedirs(FLAGS.checkpoint_dir) 29 | config = tf.ConfigProto() 30 | config.gpu_options.allow_growth = True 31 | with tf.Session(config=config) as sess: 32 | piafusion = PIAFusion(sess, 33 | image_size=FLAGS.image_size, 34 | label_size=FLAGS.label_size, 35 | batch_size=FLAGS.batch_size, 36 | checkpoint_dir=FLAGS.checkpoint_dir, 37 | model_type=FLAGS.model_type, 38 | phase=FLAGS.is_train, 39 | Data_set=FLAGS.DataSet) 40 | if FLAGS.is_train: 41 | piafusion.train(FLAGS) 42 | else: 43 | piafusion.test(FLAGS) 44 | 45 | 46 | if __name__ == '__main__': 47 | tf.app.run() 48 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import h5py 5 | import tensorflow as tf 6 | from ops import * 7 | from train_network import PIAFusion, Illumination_classifier 8 | from utils import * 9 | 10 | PIAfusion_net = PIAFusion() 11 | IC_net = Illumination_classifier() 12 | 13 | 14 | class PIAFusion(object): 15 | def __init__(self, 16 | sess, 17 | image_size=132, 18 | label_size=120, 19 | batch_size=32, 20 | checkpoint_dir=None, 21 | model_type=None, 22 | phase=None, 23 | Data_set=None): 24 | 25 | self.sess = sess 26 | self.image_size = image_size 27 | self.label_size = label_size 28 | self.batch_size = batch_size 29 | self.checkpoint_dir = checkpoint_dir 30 | self.model_type = model_type 31 | self.phase = phase 32 | self.DataSet = Data_set 33 | print('Parameters Setting: \n' 34 | 'Image Size: {}\n' 35 | 'Batch Size: {}\n' 36 | 'Model Type: {}\n' 37 | 'Is Training:{}\n'.format( 38 | self.image_size, self.batch_size, self.model_type, self.phase)) 39 | 40 | def build_classifier_model(self): 41 | with tf.name_scope('input'): 42 | # image patch 43 | self.images = tf.compat.v1.placeholder(tf.float32, [self.batch_size, self.image_size, self.image_size, 3], 44 | name='images') 45 | self.labels = tf.compat.v1.placeholder(tf.float32, [self.batch_size, self.label_size], name='label') 46 | 47 | with tf.compat.v1.variable_scope('classifier', reuse=False): 48 | self.predicted_label = IC_net.illumination_classifier(self.images, reuse=False) 49 | 50 | with tf.name_scope("learn_rate"): 51 | self.lr = tf.placeholder(tf.float32, name='lr') 52 | 53 | with tf.name_scope('c_loss'): 54 | self.classifier_loss = tf.reduce_mean( 55 | tf.nn.softmax_cross_entropy_with_logits(logits=self.predicted_label, labels=self.labels)) 56 | # Operation comparing prediction with true label 57 | correct_prediction = tf.equal(tf.argmax(self.predicted_label, 1), tf.argmax(self.labels, 1)) 58 | 59 | # Operation calculating the accuracy of our predictions 60 | self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 61 | # tf.compat.v1.summary.scalar which is used to display scalar information 62 | # used to display loss 63 | tf.compat.v1.summary.scalar('classifier loss', self.classifier_loss) 64 | self.c_loss_total = 10 * self.classifier_loss 65 | # display total_loss 66 | tf.compat.v1.summary.scalar('loss_c', self.c_loss_total) 67 | self.saver = tf.compat.v1.train.Saver(max_to_keep=50) 68 | with tf.name_scope('image'): 69 | tf.compat.v1.summary.image('image', self.images[0:1, :, :, 0:3]) 70 | 71 | def initial_classifier_model(self, Illum_images): 72 | with tf.compat.v1.variable_scope('classifier', reuse=False): 73 | self.predicted_label = IC_net.illumination_classifier(Illum_images, reuse=False) 74 | self.Illum_saver = tf.compat.v1.train.Saver(max_to_keep=50) 75 | 76 | def build_PIAFusion_model(self): 77 | with tf.name_scope('input'): 78 | # Visible image patch 79 | self.vi_images = tf.compat.v1.placeholder(tf.float32, [self.batch_size, self.image_size, self.image_size, 3], name='vi_images') 80 | self.ir_images = tf.compat.v1.placeholder(tf.float32, [self.batch_size, self.image_size, self.image_size, 1], name='ir_images') 81 | self.Y_images, self.Cb_images, self.Cr_images = RGB2YCbCr(self.vi_images) 82 | 83 | with tf.name_scope("learn_rate"): 84 | self.lr = tf.placeholder(tf.float32, name='lr') 85 | tf.global_variables_initializer().run() 86 | self.initial_classifier_model(self.vi_images) 87 | with tf.compat.v1.variable_scope('PIAFusion', reuse=False): 88 | self.fused_images, self.vi_features, self.ir_features = PIAfusion_net.PIAFusion(self.Y_images, self.ir_images, reuse=False) 89 | self.RGB_fused_images = YCbCr2RGB(self.fused_images, self.Cb_images, self.Cr_images, mode=1) 90 | print(self.checkpoint_dir) 91 | could_load = self.load(self.Illum_saver, self.checkpoint_dir, model_dir="%s" % ("Illumination")) 92 | if could_load: 93 | print(" [*] Load SUCCESS") 94 | else: 95 | print(" [!] Load failed...") 96 | with tf.compat.v1.variable_scope('classifier', reuse=True): 97 | self.predicted_label = IC_net.illumination_classifier(self.vi_images, reuse=True) 98 | day_probability = self.predicted_label[:, 0] 99 | night_probability = self.predicted_label[:, 1] 100 | self.vi_w, self.ir_w = illumination_mechanism(day_probability, night_probability) 101 | self.vi_w = tf.reshape(self.vi_w, shape=[self.batch_size, 1, 1, 1]) 102 | self.ir_w = tf.reshape(self.ir_w, shape=[self.batch_size, 1, 1, 1]) 103 | with tf.name_scope('grad_bin'): 104 | self.Image_vi_grad = gradient(self.Y_images) 105 | self.Image_ir_grad = gradient(self.ir_images) 106 | self.Image_fused_grad = gradient(self.fused_images) 107 | self.Image_max_grad = tf.round((self.Image_vi_grad + self.Image_ir_grad) // ( 108 | tf.abs(self.Image_vi_grad + self.Image_ir_grad) + 0.0000000001)) * tf.maximum( 109 | tf.abs(self.Image_vi_grad), tf.abs(self.Image_ir_grad)) 110 | self.concat_images = tf.concat([self.ir_images, self.Y_images], axis=-1, ) 111 | self.pseudo_images = 0.7 * tf.reduce_max(self.concat_images, axis=-1, keepdims=True) + 0.3 * (tf.multiply(self.vi_w, self.Y_images) + tf.multiply(self.ir_w, self.ir_images)) 112 | 113 | self.RGB_pseudo_images = YCbCr2RGB(self.pseudo_images, self.Cb_images, self.Cr_images, mode=1) 114 | with tf.name_scope('f_loss'): 115 | self.ir_l1_loss = tf.reduce_mean(tf.abs(self.fused_images - self.ir_images)) 116 | self.vi_l1_loss = tf.reduce_mean(tf.abs(self.fused_images - self.Y_images)) 117 | self.ir_grad_loss = tf.reduce_mean(tf.abs(self.Image_fused_grad - self.Image_ir_grad)) 118 | self.vi_grad_loss = tf.reduce_mean(tf.abs(self.Image_fused_grad - self.Image_vi_grad)) 119 | self.joint_grad_loss = L1_loss(self.Image_fused_grad, self.Image_max_grad) 120 | self.pixel_loss = L1_loss(self.pseudo_images, self.fused_images) 121 | self.f_total_loss = 50 * self.pixel_loss + 50 * self.joint_grad_loss 122 | 123 | tf.compat.v1.summary.scalar('IR L1 loss', self.ir_l1_loss) 124 | tf.compat.v1.summary.scalar('VI L1 loss', self.vi_l1_loss) 125 | tf.compat.v1.summary.scalar('IR Gradient loss', self.ir_grad_loss) 126 | tf.compat.v1.summary.scalar('Fusion model total loss', self.f_total_loss) 127 | tf.compat.v1.summary.scalar('VI Gradient loss', self.vi_grad_loss) 128 | 129 | self.saver = tf.compat.v1.train.Saver(max_to_keep=50) 130 | 131 | with tf.name_scope('image'): 132 | tf.compat.v1.summary.image('ir_image', self.ir_images[0:1, :, :, :]) 133 | tf.compat.v1.summary.image('vi_image', self.vi_images[0:1, :, :, :]) 134 | tf.compat.v1.summary.image('fused image', self.RGB_fused_images[0:1, :, :, :]) 135 | tf.compat.v1.summary.image('pseudo images', self.RGB_pseudo_images[0:1, :, :, :]) 136 | tf.compat.v1.summary.image('ir_feature', self.ir_features[0:1, :, :, 0:1]) 137 | tf.compat.v1.summary.image('vi_feature', self.vi_features[0:1, :, :, 0:1]) 138 | tf.compat.v1.summary.image('joint_gradient', self.Image_max_grad[0:1, :, :, 0:1]) 139 | tf.compat.v1.summary.image('fused_gradient', self.Image_fused_grad[0:1, :, :, 0:1]) 140 | 141 | 142 | def train(self, config): 143 | variables_dir = './variables' 144 | check_folder(variables_dir) 145 | variables_name = os.path.join(variables_dir, self.model_type + '.txt') 146 | if os.path.exists(variables_name): 147 | os.remove(variables_name) 148 | if config.model_type == 'Illum': 149 | print('train Illumination classifier!') 150 | print('Data Preparation!~') 151 | dataset_name = 'data_illum.h5' 152 | f = h5py.File(dataset_name, 'r') 153 | sources = f['data'][:] 154 | print(sources.shape) 155 | sources = np.transpose(sources, (0, 3, 2, 1)) 156 | images = sources[:, :, :, 0:3] 157 | labels = sources[:, 0, 0, 3:5] 158 | images = images ## input image [0, 1] 159 | self.build_classifier_model() 160 | num_imgs = sources.shape[0] 161 | # num_imgs = 800 162 | mod = num_imgs % self.batch_size 163 | n_batches = int(num_imgs // self.batch_size) 164 | print('Train images number %d, Batches: %d.\n' % (num_imgs, n_batches)) 165 | self.iteration = n_batches 166 | if mod > 0: 167 | print('Train set has been trimmed %d samples...\n' % mod) 168 | sources = sources[:-mod] 169 | print("source shape:", sources.shape) 170 | batch_idxs = n_batches 171 | tensorboard_path, log_path = form_results(dataset=config.DataSet, model_type=self.model_type) 172 | t_vars = tf.trainable_variables() 173 | C_vars = [var for var in t_vars if 'Classifier' in var.name] 174 | log_name = log_path + '/log.txt' 175 | if os.path.exists(log_name): 176 | os.remove(log_name) 177 | for var in C_vars: 178 | with open(variables_name, 'a') as log: 179 | log.write(var.name) 180 | log.write('\n') 181 | 182 | self.C_vars = C_vars 183 | with tf.name_scope('train_step'): 184 | self.train_classifier_op = tf.train.AdamOptimizer(config.learning_rate).minimize(self.c_loss_total, 185 | var_list=self.C_vars) 186 | 187 | self.summary_op = tf.summary.merge_all() 188 | writer = tf.summary.FileWriter(logdir=tensorboard_path, graph=self.sess.graph) 189 | tf.initialize_all_variables().run() 190 | counter = 0 191 | start_time = time.time() 192 | total_classifier_loss = 0 193 | total_loss = 0 194 | total_accuracy = 0 195 | show_num = 5 196 | show_count = 0 197 | if config.is_train: 198 | self.init_lr = config.learning_rate 199 | self.decay_epoch = int(config.epoch / 2) 200 | print("Training...") 201 | for ep in range(config.epoch): 202 | # Run by batch images 203 | lr = self.init_lr if ep < self.decay_epoch else self.init_lr * (config.epoch - ep) / ( 204 | config.epoch - self.decay_epoch) # linear decay 205 | batch_idxs = batch_idxs 206 | for idx in range(0, batch_idxs): 207 | batch_images = images[idx * config.batch_size: (idx + 1) * config.batch_size] 208 | batch_labels = labels[idx * config.batch_size: (idx + 1) * config.batch_size] 209 | counter += 1 210 | _, err_g, batch_classifer_loss, batch_accuracy, summary_str, predicted_label = self.sess.run( 211 | [self.train_classifier_op, self.c_loss_total, self.classifier_loss, self.accuracy, 212 | self.summary_op, self.predicted_label], 213 | feed_dict={self.images: batch_images, self.labels: batch_labels, self.lr: lr}) 214 | # Write the statistics to the log file 215 | total_classifier_loss += batch_classifer_loss 216 | total_loss += err_g 217 | total_accuracy += batch_accuracy 218 | show_count += 1 219 | writer.add_summary(summary_str, global_step=counter) 220 | if idx % show_num == show_num - 1: 221 | print("learn rate:[%0.6f]" % (lr)) 222 | print( 223 | "Epoch:[%d/%d], step:[%d/%d], time: [%4.4f], loss_g:[%.4f], classifier_loss:[%.4f], accuracy:[%.4f]" 224 | % ((ep + 1), config.epoch, idx + 1, batch_idxs, time.time() - start_time, 225 | total_loss / show_count, total_classifier_loss / show_count, 226 | total_accuracy / show_count)) 227 | # print(predicted_label) 228 | with open(log_path + '/log.txt', 'a') as log: 229 | log.write( 230 | "Epoch:[%d/%d], step:[%d/%d], time: [%4.4f], loss_g:[%.4f], classifier_loss:[%.4f], accuracy:[%.4f] \n" 231 | % ((ep + 1), config.epoch, idx + 1, batch_idxs, time.time() - start_time, 232 | total_loss / show_count, total_classifier_loss / show_count, 233 | total_accuracy / show_count)) 234 | total_classifier_loss = 0 235 | total_loss = 0 236 | total_accuracy = 0 237 | show_count = 0 238 | start_time = time.time() 239 | self.save(config.checkpoint_dir, ep) 240 | else: 241 | print(self.model_type == 'PIAFusion') 242 | print("Data preparation!") 243 | dataset_name = 'data_VIF.h5' 244 | # if config.DataSet == 'TNO': 245 | # dataset_name = 'data_VIF.h5' 246 | # elif config.DataSet == 'RoadScene': 247 | # dataset_name = 'data_road.h5' 248 | f = h5py.File(dataset_name, 'r') 249 | sources = f['data'][:] 250 | print(sources.shape) 251 | sources = np.transpose(sources, (0, 3, 2, 1)) 252 | images = sources 253 | images = images 254 | self.build_PIAFusion_model() 255 | if config.is_train: 256 | print('images shape: ', images.shape) 257 | num_imgs = sources.shape[0] 258 | mod = num_imgs % self.batch_size 259 | n_batches = int(num_imgs // self.batch_size) 260 | print('Train images number %d, Batches: %d.\n' % (num_imgs, n_batches)) 261 | self.iteration = n_batches 262 | if mod > 0: 263 | print('Train set has been trimmed %d samples...\n' % mod) 264 | sources = sources[:-mod] 265 | print("source shape:", sources.shape) 266 | batch_idxs = n_batches 267 | tensorboard_path, log_path = form_results(dataset=config.DataSet, model_type=self.model_type) 268 | t_vars = tf.trainable_variables() 269 | f_vars = [var for var in t_vars if 'classifier' not in var.name] 270 | log_name = log_path + '/log.txt' 271 | if os.path.exists(log_name): 272 | os.remove(log_name) 273 | for var in f_vars: 274 | with open(variables_name, 'a') as log: 275 | log.write(var.name) 276 | log.write('\n') 277 | 278 | self.f_vars = f_vars 279 | with tf.name_scope('train_step'): 280 | self.train_iafusion_op = tf.train.AdamOptimizer(config.learning_rate).minimize(self.f_total_loss, 281 | var_list=self.f_vars) 282 | I_vars = tf.global_variables() 283 | I_vars = [var for var in I_vars if 'classifier' not in var.name] 284 | init = tf.variables_initializer(I_vars) 285 | self.sess.run(init) 286 | self.summary_op = tf.summary.merge_all() 287 | writer = tf.summary.FileWriter(logdir=tensorboard_path, graph=self.sess.graph) 288 | counter = 0 289 | start_time = time.time() 290 | total_ir_l1_loss = 0 291 | total_vi_l1_loss = 0 292 | total_ir_grad_loss = 0 293 | total_vi_grad_loss = 0 294 | total_loss = 0 295 | show_num = 5 296 | show_count = 0 297 | self.init_lr = config.learning_rate 298 | self.decay_epoch = int(config.epoch / 2) 299 | print("Training...") 300 | print(self.sess.run(tf.get_default_graph().get_tensor_by_name("classifier/Classifier/conv1/bias:0"))) 301 | for ep in range(config.epoch): 302 | # Run by batch images 303 | lr = self.init_lr if ep < self.decay_epoch else self.init_lr * (config.epoch - ep) / ( 304 | config.epoch - self.decay_epoch) # linear decay 305 | batch_idxs = batch_idxs 306 | for idx in range(0, batch_idxs): 307 | batch_images = images[idx * config.batch_size: (idx + 1) * config.batch_size] 308 | vi_batch_images = batch_images[:, :, :, 0:3] 309 | ir_batch_images = batch_images[:, :, :, 3:4] 310 | counter += 1 311 | _, err_g, ir_batch_l1_loss, vi_batch_l1_loss, ir_batch_grad_loss, vi_batch_grad_loss, vi_batch_w, ir_batch_w, summary_str = self.sess.run( 312 | [self.train_iafusion_op, self.f_total_loss, self.ir_l1_loss, self.vi_l1_loss, self.ir_grad_loss, 313 | self.vi_grad_loss, self.vi_w, self.ir_w, self.summary_op], 314 | feed_dict={self.ir_images:ir_batch_images, self.vi_images:vi_batch_images, self.lr:lr}) 315 | 316 | # Write the statistics to the log file 317 | total_ir_l1_loss += ir_batch_l1_loss 318 | total_vi_l1_loss += vi_batch_l1_loss 319 | total_ir_grad_loss += ir_batch_grad_loss 320 | total_vi_grad_loss += vi_batch_grad_loss 321 | total_loss += err_g 322 | show_count += 1 323 | writer.add_summary(summary_str, global_step=counter) 324 | if idx % show_num == show_num - 1: 325 | print("learn rate:[%0.6f]" % (lr)) 326 | print( 327 | "Epoch:[%d/%d], step:[%d/%d], time: [%4.4f], loss_g:[%.4f], ir_L1_loss:[%.4f], vi_L1_loss:[%.4f], ir_gradient_loss:[%.4f], vi_gradient_loss:[%.4f], vi_weight:[%.4f], ir_weight:[%.4f]" 328 | % ((ep + 1), config.epoch, idx + 1, batch_idxs, time.time() - start_time, 329 | err_g, ir_batch_l1_loss, vi_batch_l1_loss, 330 | ir_batch_grad_loss, vi_batch_grad_loss, vi_batch_w[1], ir_batch_w[1])) 331 | print('vi_weight:', vi_batch_w[10], ', ir_weight:', ir_batch_w[10]) 332 | with open(log_path + '/log.txt', 'a') as log: 333 | log.write( 334 | "Epoch:[%d/%d], step:[%d/%d], time: [%4.4f], loss_g:[%.4f], ir_L1_loss:[%.4f], vi_L1_loss:[%.4f], ir_gradient_loss:[%.4f], vi_gradient_loss:[%.4f]\n" 335 | % ((ep + 1), config.epoch, idx + 1, batch_idxs, time.time() - start_time, 336 | err_g, ir_batch_l1_loss, vi_batch_l1_loss, 337 | ir_batch_grad_loss, vi_batch_grad_loss)) 338 | counter = 0 339 | total_ir_l1_loss = 0 340 | total_vi_l1_loss = 0 341 | total_ir_grad_loss = 0 342 | total_vi_grad_loss = 0 343 | total_illumination_loss = 0 344 | total_loss = 0 345 | show_num = 5 346 | show_count = 0 347 | start_time = time.time() 348 | self.save(config.checkpoint_dir, ep) 349 | 350 | def test(self, config): 351 | if self.model_type == 'Illum': 352 | test_day_dir = './test_data/Illum/day' 353 | test_night_dir = './/test_data/Illum/night' 354 | with tf.name_scope('input'): 355 | # infrared image patch 356 | self.images = tf.placeholder(tf.float32, [1, None, None, 3], name='images') 357 | self.initial_classifier_model(self.images) 358 | tf.global_variables_initializer().run() 359 | print(self.checkpoint_dir) 360 | could_load = self.load(self.Illum_saver, self.checkpoint_dir) 361 | if could_load: 362 | print(" [*] Load SUCCESS") 363 | else: 364 | print(" [!] Load failed...") 365 | 366 | filelist = os.listdir(test_day_dir) 367 | # filelist.sort(key=lambda x: int(x[0:-4])) 368 | with tf.compat.v1.variable_scope('classifier', reuse=True): 369 | self.predicted_label = IC_net.illumination_classifier(self.images, reuse=True) 370 | True_count = 0 371 | Total_count = 0 372 | for item in filelist: 373 | test_day_file = os.path.join(os.path.abspath(test_day_dir), item) 374 | test_day_image = load_test_data(test_day_file, mode=2) 375 | test_day_image = np.asarray(test_day_image) 376 | print('test_day_image:', test_day_image.shape) 377 | predicted_label = self.sess.run(self.predicted_label, feed_dict={self.images: test_day_image}) 378 | correct_prediction = np.argmax(predicted_label, 1) 379 | if correct_prediction[0] == 0: 380 | True_count += 1 381 | Total_count += 1 382 | print('input: {}, predicted_label: {}, correct_prediction: {}'.format('ir image', predicted_label, 383 | correct_prediction)) 384 | filelist = os.listdir(test_night_dir) 385 | # filelist.sort(key=lambda x: int(x[0:-4])) 386 | for item in filelist: 387 | test_night_file = os.path.join(os.path.abspath(test_night_dir), item) 388 | test_night_image = load_test_data(test_night_file, mode=2) 389 | test_night_image = np.asarray(test_night_image) 390 | print('test_night_image:', test_night_image.shape) 391 | predicted_label = self.sess.run(self.predicted_label, feed_dict={self.images: test_night_image}) 392 | correct_prediction = np.argmax(predicted_label, 1) 393 | if correct_prediction[0] == 1: 394 | True_count += 1 395 | Total_count += 1 396 | print('input: {}, predicted_label: {}, correct_prediction: {}'.format('ir image', predicted_label, 397 | correct_prediction)) 398 | print('Testing Ending, Testing number is {}, Testing accuracy is {:.2f}%'.format(Total_count, 399 | True_count / Total_count * 100)) 400 | else: 401 | method_name = 'PIAFusion' 402 | time_save_name = '../Time_Statistics_all.xlsx' 403 | sheet_name = config.DataSet 404 | if config.DataSet == 'MSRS': 405 | test_ir_dir = r'./test_data/MSRS/ir' 406 | test_vi_dir = r'./test_data/MSRS/vi' 407 | elif config.DataSet == 'TNO': 408 | test_ir_dir = r'./test_data/TNO/ir' 409 | test_vi_dir =r'./test_data/TNO/vi' 410 | elif config.DataSet == 'RoadScene': 411 | test_ir_dir = r'./test_data/RoadScene/ir' 412 | test_vi_dir = r'./test_data/RoadScene/vi' 413 | self.build_PIAFusion_model() 414 | tf.global_variables_initializer().run() 415 | print(self.checkpoint_dir) 416 | could_load = self.load(self.saver, self.checkpoint_dir) 417 | 418 | if could_load: 419 | print(" [*] Load SUCCESS") 420 | else: 421 | print(" [!] Load failed...") 422 | 423 | filelist = os.listdir(test_ir_dir) 424 | if config.DataSet != 'MSRS': 425 | filelist.sort(key=lambda x: int(x[0:-4])) 426 | else: 427 | filelist.sort(key=lambda x: int(x[0:-5])) 428 | self.fusion_dir = os.path.join('./Fusion_results', config.DataSet) 429 | check_folder(self.fusion_dir) 430 | 431 | with tf.name_scope('input'): 432 | # infrared image patch 433 | self.ir_images = tf.placeholder(tf.float32, [1, None, None, 1], name='ir_images') 434 | self.vi_images = tf.placeholder(tf.float32, [1, None, None, 3], name='vi_images') 435 | self.Y_images, self.Cb_images, self.Cr_images = RGB2YCbCr(self.vi_images) 436 | # print(self.Y_images.get_shape().as_list()) 437 | 438 | 439 | with tf.compat.v1.variable_scope('PIAFusion', reuse=False): 440 | self.fused_images = PIAfusion_net.PIAFusion(self.Y_images, self.ir_images, reuse=True, Feature_out=False) 441 | self.RGB_fused_images = YCbCr2RGB(self.fused_images, self.Cb_images, self.Cr_images, mode=2) 442 | time_list = [] 443 | for item in filelist: 444 | test_ir_file = os.path.join(os.path.abspath(test_ir_dir), item) 445 | test_vi_file = os.path.join(os.path.abspath(test_vi_dir), item) 446 | self.fusion_path = os.path.join(self.fusion_dir, item) 447 | print("Fusion save path: {}".format(self.fusion_path)) 448 | test_ir_image = load_test_data(test_ir_file) 449 | test_vi_image = load_test_data(test_vi_file, mode=2) 450 | test_ir_image = np.asarray(test_ir_image) 451 | test_vi_image = np.asarray(test_vi_image) 452 | start = time.time() 453 | fused_image = self.sess.run( 454 | self.RGB_fused_images, 455 | feed_dict={self.ir_images: test_ir_image, self.vi_images: test_vi_image}) 456 | fused_image = fused_image.squeeze() 457 | fused_image = fused_image * 255.0 458 | end = time.time() 459 | time_list.append(end-start) 460 | print('Fusion image {} ,times is {}'.format(item, end-start)) 461 | cv2.imwrite(self.fusion_path, fused_image) 462 | # i = 10 463 | # writexls(time_save_name, method_name, time_list, sheet_name, i) 464 | 465 | def save(self, checkpoint_dir, step): 466 | print(self.model_type) 467 | if self.model_type == 'Illum': 468 | model_name = 'Illumination.model' 469 | model_dir = "%s" % ("Illumination") 470 | elif self.model_type =='PIAFusion': 471 | model_name = "IAFusion.model" 472 | model_dir = "%s" % ("PIAFusion") 473 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir) 474 | check_folder(checkpoint_dir) 475 | self.saver.save(self.sess, 476 | os.path.join(checkpoint_dir, model_name), 477 | global_step=step) 478 | 479 | def load(self, saver, checkpoint_dir, model_dir=None): 480 | if model_dir == None: 481 | if self.model_type == 'Illum': 482 | model_dir = "%s" % ("Illumination") 483 | elif self.model_type == 'PIAFusion': 484 | model_dir = "%s" % ("PIAFusion") 485 | else: 486 | model_dir = model_dir 487 | print('model_dir: ', model_dir) 488 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir) 489 | print(checkpoint_dir) 490 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 491 | if ckpt and ckpt.model_checkpoint_path: 492 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 493 | print(ckpt_name) 494 | saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 495 | return True 496 | else: 497 | return False 498 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib as tf_contrib 3 | import numpy as np 4 | 5 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) 6 | 7 | # Xavier : tf_contrib.layers.xavier_initializer() 8 | # He : tf_contrib.layers.variance_scaling_initializer() 9 | 10 | weight_init = tf.truncated_normal_initializer(stddev=1e-3) 11 | weight_regularizer = None 12 | 13 | 14 | ################################################################################## 15 | # Layer 16 | ################################################################################## 17 | 18 | def conv(x, channels=1, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, scope='conv', sn=False, norm=False): 19 | with tf.variable_scope(scope): 20 | if pad > 0: 21 | if (kernel - stride) % 2 == 0: 22 | pad_top = pad 23 | pad_bottom = pad 24 | pad_left = pad 25 | pad_right = pad 26 | 27 | else: 28 | pad_top = pad 29 | pad_bottom = kernel - stride - pad_top 30 | pad_left = pad 31 | pad_right = kernel - stride - pad_left 32 | 33 | if pad_type == 'zero': 34 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]) 35 | if pad_type == 'reflect': 36 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT') 37 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init, 38 | regularizer=weight_regularizer) 39 | if sn: 40 | w = weights_spectral_norm(w) 41 | x = tf.nn.conv2d(input=x, filter=w, strides=[1, stride, stride, 1], padding='VALID') 42 | if use_bias: 43 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0)) 44 | x = tf.nn.bias_add(x, bias) 45 | if norm: 46 | x = batch_norm(x) 47 | 48 | return x 49 | 50 | 51 | def up_sample(x, up_x=None, stride=2, padding='SAME'): 52 | x_shape = x.get_shape().as_list() 53 | if padding == 'SAME': 54 | # print('x shape:', x_shape[1]) 55 | if x_shape[1] is None: 56 | output_shape = tf.shape(up_x[0, :, :, 0]) 57 | else: 58 | output_shape = [x_shape[1] * stride, x_shape[2] * stride] 59 | else: 60 | output_shape = [x_shape[1] * stride + max(kernel - stride, 0), 61 | x_shape[2] * stride + max(kernel - stride, 0)] 62 | up_sample = tf.image.resize_images(x, output_shape, method=1) 63 | return up_sample 64 | 65 | 66 | def depthwise_conv(x, kernel=3, stride=1, scope='depthwise_conv', sn=True, norm=False): 67 | with tf.variable_scope(scope): 68 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], 1], initializer=weight_init, 69 | regularizer=weight_regularizer) 70 | if sn: 71 | w = weights_spectral_norm(w) 72 | # bias = tf.get_variable("bias", [x.get_shape().as_list()[-1]], initializer=tf.constant_initializer(0.0)) 73 | x = tf.nn.depthwise_conv2d(x, w, strides=[1, stride, stride, 1], padding='SAME') 74 | # x = tf.nn.bias_add(x, bias) 75 | if norm: 76 | x = batch_norm(x) 77 | return x 78 | 79 | 80 | def deconv(x, up_features=None, channels=1, kernel=4, stride=2, pad=1, use_bias=True, scope='deconv', sn=False, norm=False): 81 | with tf.variable_scope(scope): 82 | output_shape = tf.shape(up_features[0, :, :, 0]) 83 | up_sample = tf.image.resize_images(x, output_shape, method=1) 84 | x = conv(up_sample, channels, kernel=kernel, stride=stride, pad=pad, pad_type='reflect', use_bias=use_bias, sn=sn) 85 | x = x + up_features 86 | return x 87 | 88 | 89 | def attribute_connet(x, channels, use_bias=True, sn=True, scope='attribute'): 90 | with tf.variable_scope(scope): 91 | x = tf.layers.dense(x, units=channels, kernel_initializer=weight_init, 92 | kernel_regularizer=weight_regularizer, use_bias=use_bias) 93 | return x 94 | 95 | 96 | def fully_conneted(x, channels, use_bias=True, sn=True, scope='fully'): 97 | with tf.variable_scope(scope): 98 | x = tf.layers.flatten(x) 99 | shape = x.get_shape().as_list() 100 | x_channel = shape[-1] 101 | if sn: 102 | w = tf.get_variable("kernel", [x_channel, channels], tf.float32, initializer=weight_init, 103 | regularizer=weight_regularizer) 104 | if use_bias: 105 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0)) 106 | 107 | x = tf.matmul(x, spectral_norm(w)) + bias 108 | else: 109 | x = tf.matmul(x, spectral_norm(w)) 110 | 111 | else: 112 | x = tf.layers.dense(x, units=channels, kernel_initializer=weight_init, 113 | kernel_regularizer=weight_regularizer, use_bias=use_bias) 114 | print('fully_connected shape: ', x.get_shape().as_list()) 115 | return x 116 | 117 | ################################################################################## 118 | # Activation function 119 | ################################################################################## 120 | 121 | def lrelu(x, alpha=0.1): 122 | # pytorch alpha is 0.01 123 | return tf.nn.leaky_relu(x, alpha) 124 | 125 | 126 | def relu(x): 127 | return tf.nn.relu(x) 128 | 129 | 130 | def tanh(x): 131 | return tf.tanh(x) 132 | 133 | 134 | ################################################################################## 135 | # Normalization function 136 | ################################################################################## 137 | 138 | def instance_norm(x, scope='instance_norm'): 139 | return tf_contrib.layers.instance_norm(x, 140 | epsilon=1e-05, 141 | center=True, scale=True, 142 | scope=scope) 143 | 144 | 145 | def batch_norm(x, scope='batch_norm'): 146 | return tf_contrib.layers.batch_norm(x, 147 | decay=0.999, 148 | center=True, 149 | scale=False, 150 | epsilon=0.001, 151 | scope=scope) 152 | 153 | 154 | def layer_norm(x, scope='layer_norm'): 155 | return tf_contrib.layers.layer_norm(x, 156 | center=True, scale=True, 157 | scope=scope) 158 | 159 | 160 | def spectral_norm(w, iteration=1): 161 | w_shape = w.shape.as_list() 162 | w = tf.reshape(w, [-1, w_shape[-1]]) 163 | 164 | u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False) 165 | 166 | u_hat = u 167 | v_hat = None 168 | for i in range(iteration): 169 | """ 170 | power iteration 171 | Usually iteration = 1 will be enough 172 | """ 173 | v_ = tf.matmul(u_hat, tf.transpose(w)) 174 | v_hat = tf.nn.l2_normalize(v_) 175 | 176 | u_ = tf.matmul(v_hat, w) 177 | u_hat = tf.nn.l2_normalize(u_) 178 | 179 | u_hat = tf.stop_gradient(u_hat) 180 | v_hat = tf.stop_gradient(v_hat) 181 | 182 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat)) 183 | 184 | with tf.control_dependencies([u.assign(u_hat)]): 185 | w_norm = w / sigma 186 | w_norm = tf.reshape(w_norm, w_shape) 187 | 188 | return w_norm 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | ################################################################################## 197 | # Loss function 198 | ################################################################################## 199 | def L1_loss(x, y): 200 | loss = tf.reduce_mean(tf.abs(y - x)) 201 | return loss 202 | 203 | 204 | def Fro_LOSS(batchimg): 205 | fro_norm = tf.square(tf.norm(batchimg, axis=[1, 2], ord='fro')) / (int(batchimg.shape[1]) * int(batchimg.shape[2])) 206 | # print('fro_norm shape:', fro_norm.get_shape().as_list()) 207 | E = tf.reduce_mean(fro_norm) 208 | return E 209 | 210 | 211 | def gradient(input): 212 | filter1 = tf.reshape(tf.constant([[-1., 0., 1.], [-2., 0., 2.], [-1., 0., 1.]]), [3, 3, 1, 1]) 213 | filter2 = tf.reshape(tf.constant([[-1., -2., -1.], [0., 0., 0.], [1., 2., 1.]]), [3, 3, 1, 1]) 214 | Gradient1 = tf.nn.conv2d(input, filter1, strides=[1, 1, 1, 1], padding='SAME') 215 | Gradient2 = tf.nn.conv2d(input, filter2, strides=[1, 1, 1, 1], padding='SAME') 216 | Gradient = tf.abs(Gradient1) + tf.abs(Gradient2) 217 | return Gradient 218 | 219 | 220 | def Gradient_loss(image_A, image_B): 221 | gradient_A = gradient(image_A) 222 | gradient_B = gradient(image_B) 223 | grad_loss = tf.reduce_mean(L1_loss(gradient_A, gradient_B)) 224 | return grad_loss 225 | 226 | 227 | def weights_spectral_norm(weights, u=None, iteration=1, update_collection=None, reuse=False, name='weights_SN'): 228 | with tf.variable_scope(name) as scope: 229 | if reuse: 230 | scope.reuse_variables() 231 | 232 | w_shape = weights.get_shape().as_list() 233 | w_mat = tf.reshape(weights, [-1, w_shape[-1]]) 234 | if u is None: 235 | u = tf.get_variable('u', shape=[1, w_shape[-1]], initializer=tf.truncated_normal_initializer(), 236 | trainable=False) 237 | 238 | def power_iteration(u, ite): 239 | v_ = tf.matmul(u, tf.transpose(w_mat)) 240 | v_hat = l2_norm(v_) 241 | u_ = tf.matmul(v_hat, w_mat) 242 | u_hat = l2_norm(u_) 243 | return u_hat, v_hat, ite + 1 244 | 245 | u_hat, v_hat, _ = power_iteration(u, iteration) 246 | 247 | sigma = tf.matmul(tf.matmul(v_hat, w_mat), tf.transpose(u_hat)) 248 | 249 | w_mat = w_mat / sigma 250 | 251 | if update_collection is None: 252 | with tf.control_dependencies([u.assign(u_hat)]): 253 | w_norm = tf.reshape(w_mat, w_shape) 254 | else: 255 | if not (update_collection == 'NO_OPS'): 256 | print(update_collection) 257 | tf.add_to_collection(update_collection, u.assign(u_hat)) 258 | 259 | w_norm = tf.reshape(w_mat, w_shape) 260 | return w_norm 261 | 262 | 263 | def l2_norm(input_x, epsilon=1e-12): 264 | input_x_norm = input_x / (tf.reduce_sum(input_x ** 2) ** 0.5 + epsilon) 265 | return input_x_norm 266 | 267 | -------------------------------------------------------------------------------- /test_data/MSRS/ir/00537D.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/test_data/MSRS/ir/00537D.png -------------------------------------------------------------------------------- /test_data/MSRS/ir/00556D.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/test_data/MSRS/ir/00556D.png -------------------------------------------------------------------------------- /test_data/MSRS/ir/00633D.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/test_data/MSRS/ir/00633D.png -------------------------------------------------------------------------------- /test_data/MSRS/ir/00881N.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/test_data/MSRS/ir/00881N.png -------------------------------------------------------------------------------- /test_data/MSRS/ir/01023N.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/test_data/MSRS/ir/01023N.png -------------------------------------------------------------------------------- /test_data/MSRS/vi/00537D.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/test_data/MSRS/vi/00537D.png -------------------------------------------------------------------------------- /test_data/MSRS/vi/00556D.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/test_data/MSRS/vi/00556D.png -------------------------------------------------------------------------------- /test_data/MSRS/vi/00633D.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/test_data/MSRS/vi/00633D.png -------------------------------------------------------------------------------- /test_data/MSRS/vi/00881N.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/test_data/MSRS/vi/00881N.png -------------------------------------------------------------------------------- /test_data/MSRS/vi/01023N.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/test_data/MSRS/vi/01023N.png -------------------------------------------------------------------------------- /test_data/RoadScene/ir/037.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/test_data/RoadScene/ir/037.png -------------------------------------------------------------------------------- /test_data/RoadScene/ir/100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/test_data/RoadScene/ir/100.png -------------------------------------------------------------------------------- /test_data/RoadScene/ir/108.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/test_data/RoadScene/ir/108.png -------------------------------------------------------------------------------- /test_data/RoadScene/vi/037.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/test_data/RoadScene/vi/037.png -------------------------------------------------------------------------------- /test_data/RoadScene/vi/100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/test_data/RoadScene/vi/100.png -------------------------------------------------------------------------------- /test_data/RoadScene/vi/108.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/test_data/RoadScene/vi/108.png -------------------------------------------------------------------------------- /test_data/TNO/ir/05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/test_data/TNO/ir/05.png -------------------------------------------------------------------------------- /test_data/TNO/ir/17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/test_data/TNO/ir/17.png -------------------------------------------------------------------------------- /test_data/TNO/ir/18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/test_data/TNO/ir/18.png -------------------------------------------------------------------------------- /test_data/TNO/vi/05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/test_data/TNO/vi/05.png -------------------------------------------------------------------------------- /test_data/TNO/vi/17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/test_data/TNO/vi/17.png -------------------------------------------------------------------------------- /test_data/TNO/vi/18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linfeng-Tang/PIAFusion/af3e87118d6bfe8e5aabafe7f2650e9df53bd61e/test_data/TNO/vi/18.png -------------------------------------------------------------------------------- /train_network.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from utils import weights_spectral_norm 3 | from ops import * 4 | from utils import * 5 | 6 | class Illumination_classifier(): 7 | def illumination_classifier(self, input_image, reuse=False): 8 | # features = [batch_size, 256, 256, 128] 9 | channel = 16 10 | with tf.compat.v1.variable_scope('Classifier', reuse=reuse): 11 | x = input_image 12 | for i in range(4): 13 | x = conv(x, channel, kernel=4, stride=2, pad=1, pad_type='reflect', scope='conv' + str(i + 1), sn=False, norm=False) 14 | x = lrelu(x) 15 | channel = channel * 2 16 | x = tf.reduce_mean(x, axis=[1, 2]) 17 | x = tf.layers.flatten(x) 18 | x = tf.layers.dense(inputs=x, units=128) 19 | out = tf.layers.dense(inputs=x, units=2) 20 | out = tf.abs(out) 21 | return out 22 | 23 | class PIAFusion(): 24 | def CMDAF(self, F_vi, F_ir): 25 | sub_vi_ir = tf.subtract(F_vi, F_ir) 26 | sub_w_vi_ir = tf.reduce_mean(sub_vi_ir, axis=[1, 2], keepdims=True) # Global Average Pooling 27 | w_vi_ir = tf.nn.sigmoid(sub_w_vi_ir) 28 | 29 | sub_ir_vi = tf.subtract(F_ir, F_vi) 30 | sub_w_ir_vi = tf.reduce_mean(sub_ir_vi, axis=[1, 2], keepdims=True) # Global Average Pooling 31 | w_ir_vi = tf.nn.sigmoid(sub_w_ir_vi) 32 | 33 | F_dvi = tf.multiply(w_vi_ir, sub_ir_vi) # 放大差分信号,此处是否应该调整为sub_ir_vi 34 | F_dir = tf.multiply(w_ir_vi, sub_vi_ir) 35 | 36 | F_fvi = tf.add(F_vi, F_dir) 37 | F_fir = tf.add(F_ir, F_dvi) 38 | return F_fvi, F_fir 39 | 40 | def Encoder(self, vi_image, ir_image, reuse=False): 41 | channel = 16 42 | with tf.compat.v1.variable_scope('encoder', reuse=reuse): 43 | x_ir = conv(ir_image, channel, kernel=1, stride=1, pad=0, pad_type='reflect', scope='conv5x5_ir') 44 | x_ir = lrelu(x_ir) 45 | x_vi = conv(vi_image, channel, kernel=1, stride=1, pad=0, pad_type='reflect', scope='conv5x5_vi') 46 | x_vi = lrelu(x_vi) 47 | block_num = 4 48 | for i in range(block_num): # the number of resblocks in feature extractor is 3 49 | input_ir = x_ir 50 | input_vi = x_vi 51 | with tf.compat.v1.variable_scope('Conv{}'.format(i + 1), reuse=False): 52 | # conv1 53 | x_ir = conv(input_ir, channel, kernel=3, stride=1, pad=1, pad_type='reflect', scope='conv3x3') 54 | x_ir = lrelu(x_ir) 55 | with tf.compat.v1.variable_scope('Conv{}'.format(i + 1), reuse=True): 56 | # conv1 57 | x_vi = conv(input_vi, channel, kernel=3, stride=1, pad=1, pad_type='reflect', scope='conv3x3') 58 | x_vi = lrelu(x_vi) 59 | # # want to use one convolutional layer to extract features with consistent distribution from various sourece images 60 | if i != block_num - 1: 61 | channel = channel * 2 62 | x_vi, x_ir = self.CMDAF(x_vi, x_ir) 63 | print('channel:', channel) 64 | return x_vi, x_ir 65 | 66 | 67 | def Decoder(self, x, reuse=False): 68 | channel = x.get_shape().as_list()[-1] 69 | print('channel:', channel) 70 | 71 | with tf.compat.v1.variable_scope('decoder', reuse=reuse): 72 | block_num = 4 73 | for i in range(block_num): # the number of resblocks in feature extractor is 3 74 | 75 | features = x 76 | x = conv(features, channel, kernel=3, stride=1, pad=1, pad_type='reflect', scope='conv{}'.format(i + 1)) 77 | x = lrelu(x) 78 | channel = channel / 2 79 | print('final channel:', channel) 80 | x = conv(x, 1, kernel=1, stride=1, pad=0, pad_type='reflect', scope='conv1x1') 81 | x = tf.nn.tanh(x) / 2 + 0.5 82 | return x 83 | 84 | def PIAFusion(self, vi_image, ir_image, reuse=False, Feature_out=True): 85 | vi_stream, ir_stream = self.Encoder(vi_image=vi_image, ir_image=ir_image, reuse=reuse) 86 | stream = tf.concat([vi_stream, ir_stream], axis=-1) 87 | fused_image = self.Decoder(stream, reuse=reuse) 88 | if Feature_out: 89 | return fused_image, vi_stream, ir_stream 90 | else: 91 | return fused_image 92 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Scipy version > 0.18 is needed, due to 'mode' option from scipy.misc.imread function 3 | """ 4 | 5 | import os 6 | import numpy as np 7 | import tensorflow as tf 8 | import cv2 9 | import shutil 10 | import pandas as pd 11 | from openpyxl import load_workbook 12 | 13 | FLAGS = tf.app.flags.FLAGS 14 | def gaussian_noise( input, std=0.05): 15 | noise = tf.random_normal(shape=tf.shape(input), mean=0.0, stddev=std, dtype=tf.float32) 16 | return input + noise 17 | 18 | def form_results(results_path='./Results', model_type=None, dataset=None): 19 | """ 20 | Forms folders for each run to store the tensorboard files, saved models and the log files. 21 | :return: three string pointing to tensorboard, saved models and log paths respectively. 22 | """ 23 | if not os.path.exists(results_path): 24 | os.mkdir(results_path) 25 | folder_name = "/{0}_{1}_model".format(model_type, dataset) 26 | tensorboard_path = results_path + folder_name + '/Tensorboard' 27 | log_path = results_path + folder_name + '/log' 28 | if os.path.exists(results_path + folder_name): 29 | shutil.rmtree(results_path + folder_name) 30 | if not os.path.exists(results_path + folder_name): 31 | os.mkdir(results_path + folder_name) 32 | os.mkdir(tensorboard_path) 33 | os.mkdir(log_path) 34 | return tensorboard_path, log_path 35 | 36 | def RGB2YCbCr(RGB_image): 37 | ## RGB_image [-1, 1] 38 | test_num1 = 16.0 / 255.0 39 | test_num2 = 128.0 / 255.0 40 | R = RGB_image[:, :, :, 0:1] 41 | G = RGB_image[:, :, :, 1:2] 42 | B = RGB_image[:, :, :, 2:3] 43 | Y = 0.257 * R + 0.564 * G + 0.098 * B + test_num1 44 | Cb = - 0.148 * R - 0.291 * G + 0.439 * B + test_num2 45 | Cr = 0.439 * R - 0.368 * G - 0.071 * B + test_num2 46 | return Y, Cb, Cr 47 | 48 | def RGB2Gray(RGB_image): 49 | RGB_image = RGB_image * 255.0 50 | 51 | 52 | def YCbCr2RGB(Y, Cb, Cr, mode=1): 53 | ## Y, Cb, Cr :[-1, 1] 54 | test_num1 = 16.0 / 255.0 55 | test_num2 = 128.0 / 255.0 56 | R = 1.164 * (Y - test_num1) + 1.596 * (Cr - test_num2) 57 | G = 1.164 * (Y - test_num1) - 0.392 * (Cb - test_num2) - 0.813 * (Cr - test_num2) 58 | B = 1.164 * (Y - test_num1) + 2.017 * (Cb - test_num2) 59 | RGB_image = tf.concat([R, G, B], axis=-1) 60 | BGR_image = tf.concat([B, G, R], axis=-1) 61 | if mode == 1: 62 | return RGB_image 63 | else: 64 | return BGR_image 65 | 66 | def illumination_mechanism(day_probability, night_probability, scheme_num=1): 67 | if scheme_num == 1: 68 | vi_w = day_probability / tf.add(day_probability, night_probability) 69 | ir_w = night_probability / tf.add(day_probability, night_probability) 70 | elif scheme_num == 2: 71 | vi_w = tf.exp(day_probability) / tf.add(tf.exp(day_probability), tf.exp(night_probability)) 72 | ir_w = tf.exp(night_probability) / tf.add(tf.exp(day_probability), tf.exp(night_probability)) 73 | elif scheme_num == 3: 74 | day_probability = tf.log(day_probability) 75 | night_probability = tf.log(night_probability) 76 | vi_w = day_probability / tf.add(day_probability, night_probability) 77 | ir_w = night_probability / tf.add(day_probability, night_probability) 78 | elif scheme_num == 4: 79 | min_labels = 0.1 * tf.ones_like(day_probability) 80 | max_labels = 0.9 * tf.ones_like(day_probability) 81 | 82 | min_labels_2 = 0.4 * tf.ones_like(day_probability) 83 | max_labels_2 = 0.6 * tf.ones_like(day_probability) 84 | vi_w = tf.where(day_probability - night_probability>0, max_labels, min_labels) 85 | ir_w = tf.where(day_probability - night_probability>0, min_labels, max_labels) 86 | return vi_w, ir_w 87 | def gradient(input): 88 | filter1 = tf.reshape(tf.constant([[-1., 0., 1.], [-2., 0., 2.], [-1., 0., 1.]]), [3, 3, 1, 1]) 89 | filter2 = tf.reshape(tf.constant([[-1., -2., -1.], [0., 0., 0.], [1., 2., 1.]]), [3, 3, 1, 1]) 90 | Gradient1 = tf.nn.conv2d(input, filter1, strides=[1, 1, 1, 1], padding='SAME') 91 | Gradient2 = tf.nn.conv2d(input, filter2, strides=[1, 1, 1, 1], padding='SAME') 92 | Gradient = tf.abs(Gradient1) + tf.abs(Gradient2) 93 | return Gradient 94 | 95 | 96 | def weights_spectral_norm(weights, u=None, iteration=1, update_collection=None, reuse=False, name='weights_SN'): 97 | with tf.variable_scope(name) as scope: 98 | if reuse: 99 | scope.reuse_variables() 100 | 101 | w_shape = weights.get_shape().as_list() 102 | w_mat = tf.reshape(weights, [-1, w_shape[-1]]) 103 | if u is None: 104 | u = tf.get_variable('u', shape=[1, w_shape[-1]], initializer=tf.truncated_normal_initializer(), 105 | trainable=False) 106 | 107 | def power_iteration(u, ite): 108 | v_ = tf.matmul(u, tf.transpose(w_mat)) 109 | v_hat = l2_norm(v_) 110 | u_ = tf.matmul(v_hat, w_mat) 111 | u_hat = l2_norm(u_) 112 | return u_hat, v_hat, ite + 1 113 | 114 | u_hat, v_hat, _ = power_iteration(u, iteration) 115 | 116 | sigma = tf.matmul(tf.matmul(v_hat, w_mat), tf.transpose(u_hat)) 117 | 118 | w_mat = w_mat / sigma 119 | 120 | if update_collection is None: 121 | with tf.control_dependencies([u.assign(u_hat)]): 122 | w_norm = tf.reshape(w_mat, w_shape) 123 | else: 124 | if not (update_collection == 'NO_OPS'): 125 | print(update_collection) 126 | tf.add_to_collection(update_collection, u.assign(u_hat)) 127 | 128 | w_norm = tf.reshape(w_mat, w_shape) 129 | return w_norm 130 | 131 | 132 | def lrelu(x, leak=0.2): 133 | return tf.maximum(x, leak * x) 134 | 135 | def l2_norm(input_x, epsilon=1e-12): 136 | input_x_norm = input_x / (tf.reduce_sum(input_x ** 2) ** 0.5 + epsilon) 137 | return input_x_norm 138 | 139 | def check_folder(dir): 140 | if not os.path.exists(dir): 141 | os.makedirs(dir) 142 | 143 | def load_test_data(image_path, mode=1): 144 | 145 | if mode == 1: 146 | print('image_path: ', image_path) 147 | img = cv2.imread(image_path, 0) 148 | print('image shape: ', img.shape) 149 | img = np.expand_dims(img, axis=0) 150 | img = np.expand_dims(img, axis=-1) 151 | img = preprocessing(img) 152 | else: 153 | print('image_path: ', image_path) 154 | img = cv2.imread(image_path, 1) 155 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32) 156 | print('image shape: ', img.shape) 157 | img = np.expand_dims(img, axis=0) 158 | img = preprocessing(img) 159 | return img 160 | 161 | def preprocessing(x): 162 | x = x / 255.0 # -1 ~ 1 163 | return x 164 | 165 | def append_df_to_excel(filename, df, sheet_name='Sheet1', startrow=None, 166 | truncate_sheet=False, 167 | **to_excel_kwargs): 168 | """ 169 | Append a DataFrame [df] to existing Excel file [filename] 170 | into [sheet_name] Sheet. 171 | If [filename] doesn't exist, then this function will create it. 172 | 173 | Parameters: 174 | filename : File path or existing ExcelWriter 175 | (Example: '/path/to/file.xlsx') 176 | df : dataframe to save to workbook 177 | sheet_name : Name of sheet which will contain DataFrame. 178 | (default: 'Sheet1') 179 | startrow : upper left cell row to dump data frame. 180 | Per default (startrow=None) calculate the last row 181 | in the existing DF and write to the next row... 182 | truncate_sheet : truncate (remove and recreate) [sheet_name] 183 | before writing DataFrame to Excel file 184 | to_excel_kwargs : arguments which will be passed to `DataFrame.to_excel()` 185 | [can be dictionary] 186 | 187 | Returns: None 188 | 189 | (c) [MaxU](https://stackoverflow.com/users/5741205/maxu?tab=profile) 190 | """ 191 | from openpyxl import load_workbook 192 | 193 | # ignore [engine] parameter if it was passed 194 | if 'engine' in to_excel_kwargs: 195 | to_excel_kwargs.pop('engine') 196 | 197 | writer = pd.ExcelWriter(filename, engine='openpyxl') 198 | 199 | # Python 2.x: define [FileNotFoundError] exception if it doesn't exist 200 | try: 201 | FileNotFoundError 202 | except NameError: 203 | FileNotFoundError = IOError 204 | 205 | try: 206 | # try to open an existing workbook 207 | writer.book = load_workbook(filename) 208 | 209 | # get the last row in the existing Excel sheet 210 | # if it was not specified explicitly 211 | if startrow is None and sheet_name in writer.book.sheetnames: 212 | startrow = writer.book[sheet_name].max_row 213 | 214 | # truncate sheet 215 | if truncate_sheet and sheet_name in writer.book.sheetnames: 216 | # index of [sheet_name] sheet 217 | idx = writer.book.sheetnames.index(sheet_name) 218 | # remove [sheet_name] 219 | writer.book.remove(writer.book.worksheets[idx]) 220 | # create an empty sheet [sheet_name] using old index 221 | writer.book.create_sheet(sheet_name, idx) 222 | 223 | # copy existing sheets 224 | writer.sheets = {ws.title: ws for ws in writer.book.worksheets} 225 | except FileNotFoundError: 226 | # file does not exist yet, we will create it 227 | pass 228 | 229 | if startrow is None: 230 | startrow = 0 231 | 232 | # write out the new sheet 233 | df.to_excel(writer, sheet_name, startrow=startrow, **to_excel_kwargs) 234 | 235 | # save the workbook 236 | writer.save() 237 | 238 | 239 | def writexls(save_name, method_name, time_list, sheet_name, i): 240 | df = pd.DataFrame({method_name: time_list}) 241 | append_df_to_excel(save_name, df, sheet_name=sheet_name, index=False, startrow=0, startcol=i) --------------------------------------------------------------------------------