├── code ├── README.md ├── attention.py ├── decoder.py ├── densefuse_net.py ├── encoder.py ├── fusion_addition.py ├── fusion_l1norm.py ├── generate.py ├── icme2020_supplement.zip ├── icme2020template (1).pdf ├── main.py ├── new_attention.py ├── ssim_loss_function.py ├── train_recons.py └── utils.py └── icme2020template .pdf /code/README.md: -------------------------------------------------------------------------------- 1 | # Learning-attention-guided-deep-multi-scale-feature-ensemble-for-infrared-and-visible-image-fusion -------------------------------------------------------------------------------- /code/attention.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python import pywrap_tensorflow 3 | import numpy as np 4 | import cv2 5 | WEIGHT_INIT_STDDEV = 0.1 6 | def conv2d(x, kernel, bias, use_relu=True): 7 | # padding image with reflection mode 8 | x_padded = tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT') 9 | 10 | # conv and add bias 11 | # num_maps = x_padded.shape[3] 12 | # out = __batch_normalize(x_padded, num_maps) 13 | # out = tf.nn.relu(out) 14 | out = tf.nn.conv2d(x_padded, kernel, strides=[1, 1, 1, 1], padding='VALID') 15 | out = tf.nn.bias_add(out, bias) 16 | out = tf.nn.relu(out) 17 | return out 18 | 19 | class Attention(object): 20 | def __init__(self, model_pre_path_a): 21 | 22 | self.weights = [] 23 | self.model_pre_path = model_pre_path_a 24 | 25 | with tf.variable_scope('get_attention'): 26 | self.weights.append(self._create_variables(1, 2, 3, scope='attention_block_conv1')) 27 | self.weights.append(self._create_variables(2, 4, 3, scope='attention_block_conv2')) 28 | self.weights.append(self._create_variables(4, 8, 3, scope='attention_block_conv3')) 29 | self.weights.append(self._create_variables(8, 16, 3, scope='attention_block_conv4')) 30 | self.weights.append(self._create_variables(16, 32, 3, scope='attention_block_conv5')) 31 | self.weights.append(self._create_variables(32, 64, 3, scope='attention_block_conv6')) 32 | self.weights.append(self._create_variables(64, 64, 1, scope='attention_block_conv7')) 33 | 34 | def _create_variables(self, input_filters, output_filters, kernel_size, scope): 35 | # 3 * 3 * input * output 36 | shape = [kernel_size, kernel_size, input_filters, output_filters] 37 | if self.model_pre_path: 38 | reader = pywrap_tensorflow.NewCheckpointReader(self.model_pre_path) 39 | with tf.variable_scope(scope): 40 | kernel = tf.Variable(reader.get_tensor('encoder/' + scope + '/kernel'), name='kernel') 41 | bias = tf.Variable(reader.get_tensor('encoder/' + scope + '/bias'), name='bias') 42 | else: 43 | with tf.variable_scope(scope): 44 | # truncated_normal 从截断的正态分布中输出随机值 45 | #第一个参数是张量的维度,第二个是标准差 46 | kernel = tf.Variable(tf.truncated_normal(shape, stddev=WEIGHT_INIT_STDDEV), name='kernel') 47 | bias = tf.Variable(tf.zeros([output_filters]), name='bias') 48 | return (kernel, bias) 49 | def get_attention(self, image): 50 | out = image 51 | for i in range(len(self.weights)): 52 | kernel, bias = self.weights[i] 53 | if i == 6: 54 | out = tf.nn.conv2d(out, kernel, strides=[1, 1, 1, 1], padding='VALID') 55 | out = tf.nn.bias_add(out, bias) 56 | out = tf.nn.relu(out) 57 | elif i % 2 == 0: 58 | out = conv2d(out, kernel, bias, use_relu=True) 59 | else: 60 | out = tf.nn.atrous_conv2d(out, filters=kernel, rate=2, padding='SAME') 61 | 62 | return out 63 | 64 | def guideFilter(self, I, p, winSize, eps): 65 | 66 | mean_I = cv2.blur(I, winSize) # I的均值平滑 67 | mean_p = cv2.blur(p, winSize) # p的均值平滑 68 | 69 | mean_II = cv2.blur(I * I, winSize) # I*I的均值平滑 70 | mean_Ip = cv2.blur(I * p, winSize) # I*p的均值平滑 71 | 72 | var_I = mean_II - mean_I * mean_I # 方差 73 | cov_Ip = mean_Ip - mean_I * mean_p # 协方差 74 | 75 | a = cov_Ip / (var_I + eps) # 相关因子a 76 | b = mean_p - a * mean_I # 相关因子b 77 | 78 | mean_a = cv2.blur(a, winSize) # 对a进行均值平滑 79 | mean_b = cv2.blur(b, winSize) # 对b进行均值平滑 80 | 81 | q = mean_a * I + mean_b 82 | 83 | return q 84 | 85 | def RollingGuidance(self,I, sigma_s, sigma_r, iteration): 86 | sigma_s = (sigma_s, sigma_s) 87 | out = cv2.GaussianBlur(I, sigma_s, 0) 88 | sigma_r = sigma_r*sigma_r 89 | for i in range(iteration): 90 | out = self.guideFilter(out, I, sigma_s, sigma_r) 91 | 92 | return out 93 | 94 | def Grad(self,I1): 95 | G1 = [] 96 | L1 = [] 97 | G1.append(I1) 98 | sigma_s = 3 99 | sigma_r = [0.5, 0.5, 0.5, 0.5] 100 | iteration = [3, 3, 3, 3] 101 | indice = (1, 2, 3) 102 | for i in indice: 103 | G1.append(self.RollingGuidance(G1[i - 1], sigma_s, sigma_r[i - 1], iteration[i - 1])) 104 | L1.append(G1[i - 1] - G1[i]) 105 | sigma_s = 3 * sigma_s 106 | sigma_s = (3, 3) 107 | G1.append(cv2.GaussianBlur(G1[3], sigma_s, 0)) 108 | L1.append(G1[3] - G1[4]) 109 | L1.append(G1[4]) 110 | grad = L1[0] 111 | return grad 112 | -------------------------------------------------------------------------------- /code/decoder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python import pywrap_tensorflow 3 | 4 | WEIGHT_INIT_STDDEV = 0.1 5 | 6 | 7 | class Decoder(object): 8 | 9 | def __init__(self, model_pre_path): 10 | self.weight_vars = [] 11 | self.model_pre_path = model_pre_path 12 | 13 | with tf.variable_scope('decoder'): 14 | 15 | self.weight_vars.append(self._create_variables(64, 64, 3, scope='conv2_1')) 16 | self.weight_vars.append(self._create_variables(64, 32, 3, scope='conv2_2')) 17 | self.weight_vars.append(self._create_variables(64, 16, 3, scope='conv2_3')) 18 | self.weight_vars.append(self._create_variables(32, 1 , 3, scope='conv2_4')) 19 | 20 | def _create_variables(self, input_filters, output_filters, kernel_size, scope): 21 | 22 | if self.model_pre_path: 23 | reader = pywrap_tensorflow.NewCheckpointReader(self.model_pre_path) 24 | with tf.variable_scope(scope): 25 | kernel = tf.Variable(reader.get_tensor('decoder/' + scope + '/kernel'), name='kernel') 26 | bias = tf.Variable(reader.get_tensor('decoder/' + scope + '/bias'), name='bias') 27 | else: 28 | with tf.variable_scope(scope): 29 | shape = [kernel_size, kernel_size, input_filters, output_filters] 30 | kernel = tf.Variable(tf.truncated_normal(shape, stddev=WEIGHT_INIT_STDDEV), name='kernel') 31 | bias = tf.Variable(tf.zeros([output_filters]), name='bias') 32 | return (kernel, bias) 33 | 34 | def decode(self, image,block,block2): 35 | final_layer_idx = len(self.weight_vars) - 1 36 | 37 | 38 | out = image 39 | for i in range(len(self.weight_vars)): 40 | kernel, bias = self.weight_vars[i] 41 | 42 | if i == final_layer_idx: 43 | out = conv2d(out, kernel, bias, use_relu=False) 44 | else: 45 | if i==2: 46 | out = conv2d(out, kernel, bias) 47 | out=tf.concat([out,block],3) 48 | elif i==1: 49 | out = conv2d(out, kernel, bias) 50 | out = tf.concat([out, block2], 3) 51 | else: 52 | out = conv2d(out, kernel, bias) 53 | # print('decoder ', i) 54 | # print('decoder out:', out.shape) 55 | return out 56 | 57 | 58 | def conv2d(x, kernel, bias, use_relu=True): 59 | # padding image with reflection mode 60 | x_padded = tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT') 61 | 62 | # conv and add bias 63 | out = tf.nn.conv2d(x_padded, kernel, strides=[1, 1, 1, 1], padding='VALID') 64 | out = tf.nn.bias_add(out, bias) 65 | 66 | if use_relu: 67 | out = tf.nn.relu(out) 68 | 69 | return out 70 | 71 | -------------------------------------------------------------------------------- /code/densefuse_net.py: -------------------------------------------------------------------------------- 1 | # DenseFuse Network 2 | # Encoder -> Addition/L1-norm -> Decoder 3 | 4 | import tensorflow as tf 5 | 6 | from encoder import Encoder 7 | from decoder import Decoder 8 | from fusion_addition import Strategy 9 | 10 | class DenseFuseNet(object): 11 | 12 | def __init__(self, model_pre_path): 13 | print("------------------------------------") 14 | print(model_pre_path) 15 | self.encoder = Encoder(model_pre_path) 16 | self.decoder = Decoder(model_pre_path) 17 | 18 | def transform_addition(self, img1, img2): 19 | # encode image 20 | enc_1, enc_1_res_block,enc_1_block,enc_1_block2 = self.encoder.encode(img1) 21 | enc_2, enc_2_res_block ,enc_2_block,enc_2_block2= self.encoder.encode(img2) 22 | target_features = Strategy(enc_1, enc_2) 23 | # target_features = enc_c 24 | self.target_features = target_features 25 | print('target_features:', target_features.shape) 26 | # decode target features back to image 27 | generated_img = self.decoder.decode(target_features,enc_1_block,enc_1_block2) 28 | return generated_img 29 | #------------------------------------------------------------------------ 30 | #不涉及融合层的图像encoder decoder 31 | def transform_recons(self, img): 32 | # encode image 33 | 34 | enc, enc_res_block ,block,block2= self.encoder.encode(img) 35 | 36 | target_features = enc 37 | self.target_features = target_features 38 | generated_img = self.decoder.decode(target_features,block,block2) 39 | return generated_img 40 | 41 | #----------------------------------------------------------------------------- 42 | def transform_encoder(self, img): 43 | # encode image 44 | enc, enc_res_block,block,block2 = self.encoder.encode(img) 45 | return enc, enc_res_block,block,block2 46 | 47 | def transform_decoder(self, feature,block,block2): 48 | # decode image 49 | generated_img = self.decoder.decode(feature,block,block2) 50 | return generated_img 51 | -------------------------------------------------------------------------------- /code/encoder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python import pywrap_tensorflow 3 | import numpy as np 4 | import cv2 5 | import attention 6 | WEIGHT_INIT_STDDEV = 0.1 7 | DENSE_layers = 3 8 | DECAY = .9 9 | EPSILON = 1e-8 10 | class Encoder(object): 11 | def __init__(self, model_pre_path): 12 | self.weight_vars = [] 13 | self.model_pre_path = model_pre_path 14 | 15 | with tf.variable_scope('encoder'): 16 | self.weight_vars.append(self._create_variables(1, 16, 3, scope='conv1_1')) 17 | #--------------------------------------------------------------------------------------------------- 18 | self.weight_vars.append(self._create_variables(16, 16, 3, scope='dil_block_conv1')) 19 | self.weight_vars.append(self._create_variables(32, 16, 3, scope='dil_block_conv2')) 20 | self.weight_vars.append(self._create_variables(48, 16, 3, scope='dil_block_conv3')) 21 | 22 | self.weight_vars.append(self._create_variables(16, 16, 3, scope='dil_block_conv4')) 23 | self.weight_vars.append(self._create_variables(16, 32, 3, scope='dil_block_conv5')) 24 | self.weight_vars.append(self._create_variables(32, 64, 3, scope='dil_block_conv6')) 25 | 26 | self.weight_vars.append(self._create_variables(16, 16, 3, scope='dil_block_conv7')) 27 | self.weight_vars.append(self._create_variables(16, 32, 3, scope='dil_block_conv8')) 28 | self.weight_vars.append(self._create_variables(32, 64, 3, scope='dil_block_conv9')) 29 | #---------------------------------------------------------------------------------------------------- 30 | # self.weight_vars.append(self._create_variables(64, 32, 3, scope='conv1_2')) 31 | 32 | def _create_variables(self, input_filters, output_filters, kernel_size, scope): 33 | # 3 * 3 * input * output 34 | shape = [kernel_size, kernel_size, input_filters, output_filters] 35 | if self.model_pre_path: 36 | 37 | reader = pywrap_tensorflow.NewCheckpointReader(self.model_pre_path) 38 | with tf.variable_scope(scope): 39 | kernel = tf.Variable(reader.get_tensor('encoder/' + scope + '/kernel'), name='kernel') 40 | bias = tf.Variable(reader.get_tensor('encoder/' + scope + '/bias'), name='bias') 41 | else: 42 | with tf.variable_scope(scope): 43 | # truncated_normal 从截断的正态分布中输出随机值 44 | # 第一个参数是张量的维度,第二个是标准差 45 | 46 | kernel = tf.Variable(tf.truncated_normal(shape, stddev=WEIGHT_INIT_STDDEV), name='kernel') 47 | bias = tf.Variable(tf.zeros([output_filters]), name='bias') 48 | return (kernel, bias) 49 | 50 | # ================================================================================================================= 51 | # ================================================================================================================= 52 | def cbam_module(self, inputs, reduction_ratio=0.5, name=""): 53 | with tf.variable_scope("cbam_" + name, reuse=tf.AUTO_REUSE): ##tf.AUTO_REUSE 54 | batch_size, hidden_num = inputs.get_shape().as_list()[0], inputs.get_shape().as_list()[3] 55 | # batch_size = inputs.get_shape().as_list()[0] 56 | # hidden_num = out_dim 57 | # print('=====================================') 58 | # print(inputs.shape) 59 | # print(hidden_num) 60 | 61 | maxpool_channel = tf.reduce_max(tf.reduce_max(inputs, axis=1, keep_dims=True), axis=2, keep_dims=True) 62 | avgpool_channel = tf.reduce_mean(tf.reduce_mean(inputs, axis=1, keep_dims=True), axis=2, keep_dims=True) 63 | 64 | # print('----------------------------------') 65 | # print(maxpool_channel.shape) 66 | # print(avgpool_channel.shape) 67 | 68 | # 上面全局池化结果为batsize * 1 * 1 * channel,它这个拉平输入到全连接层 69 | # 这个拉平,它会保留batsize,所以结果是[batsize,channel] 70 | maxpool_channel = tf.layers.Flatten()(maxpool_channel) 71 | avgpool_channel = tf.layers.Flatten()(avgpool_channel) 72 | 73 | mlp_1_max = tf.layers.dense(inputs=maxpool_channel, units=int(hidden_num * reduction_ratio), name="mlp_1", 74 | reuse=None, activation=tf.nn.elu) ####relu 75 | mlp_2_max = tf.layers.dense(inputs=mlp_1_max, units=hidden_num, name="mlp_2", reuse=None) 76 | mlp_2_max = tf.reshape(mlp_2_max, [batch_size, 1, 1, hidden_num]) 77 | 78 | mlp_1_avg = tf.layers.dense(inputs=avgpool_channel, units=int(hidden_num * reduction_ratio), name="mlp_1", 79 | reuse=True, activation=tf.nn.elu) 80 | mlp_2_avg = tf.layers.dense(inputs=mlp_1_avg, units=hidden_num, name="mlp_2", reuse=True) 81 | mlp_2_avg = tf.reshape(mlp_2_avg, [batch_size, 1, 1, hidden_num]) 82 | 83 | channel_attention = tf.nn.sigmoid(mlp_2_max + mlp_2_avg) 84 | channel_refined_feature = inputs * channel_attention 85 | 86 | maxpool_spatial = tf.reduce_max(inputs, axis=3, keep_dims=True) 87 | avgpool_spatial = tf.reduce_mean(inputs, axis=3, keep_dims=True) 88 | max_avg_pool_spatial = tf.concat([maxpool_spatial, avgpool_spatial], axis=3) 89 | conv_layer = tf.layers.conv2d(inputs=max_avg_pool_spatial, filters=1, kernel_size=(7, 7), padding="same", 90 | activation=None) 91 | spatial_attention = tf.nn.sigmoid(conv_layer) 92 | 93 | refined_feature = channel_refined_feature * spatial_attention 94 | # print(refined_feature.shape) 95 | 96 | return refined_feature 97 | # ================================================================================================================= 98 | # ================================================================================================================= 99 | 100 | def encode(self, image): 101 | dia_indices_1 = (1, 2, 3) 102 | dia_indices_2 = (4, 5, 6) 103 | dia_indices_3 = (7, 8, 9) 104 | res_block=[] 105 | 106 | #out = image 107 | for i in range(len(self.weight_vars)): 108 | kernel, bias = self.weight_vars[i] 109 | #filter= tf.constant(value=1, shape=[3, 3, 16, 16], dtype=tf.float32) 110 | if i==0: 111 | former = conv2d(image, kernel, bias, use_relu=True) 112 | res_block.append(former) # 0 113 | if i in dia_indices_1: 114 | print("----------------------------------") 115 | print(i) 116 | if(i==1): 117 | x = tf.nn.atrous_conv2d(former, filters=kernel,rate=1,padding='SAME') 118 | x = tf.nn.bias_add(x, bias) 119 | x = tf.nn.relu(x) 120 | res_block.append(x) # 1 121 | x = tf.concat([x,former],3) 122 | elif(i==2): 123 | y=tf.nn.atrous_conv2d(x, filters=kernel,rate=1,padding='SAME') 124 | y = tf.nn.bias_add(y, bias) 125 | y = tf.nn.relu(y) 126 | res_block.append(y) # 2 127 | y = tf.concat([y,x],3) 128 | else: 129 | z = tf.nn.atrous_conv2d(y, filters=kernel, rate=1, padding='SAME') 130 | z = tf.nn.bias_add(z, bias) 131 | z = tf.nn.relu(z) 132 | res_block.append(z) # 3 133 | z = tf.concat([z, y], 3) 134 | 135 | if i in dia_indices_2: 136 | print("----------------------------------") 137 | print(i) 138 | if (i == 4): 139 | x = tf.nn.atrous_conv2d(former, filters=kernel, rate=2, padding='SAME') 140 | x = tf.nn.bias_add(x, bias) 141 | x = tf.nn.relu(x) 142 | res_block.append(x) #4 143 | elif (i == 5): 144 | y = tf.nn.atrous_conv2d(x, filters=kernel, rate=2, padding='SAME') 145 | y = tf.nn.bias_add(y, bias) 146 | y = tf.nn.relu(y) 147 | res_block.append(y) #5 148 | else: 149 | z = tf.nn.atrous_conv2d(y, filters=kernel, rate=2, padding='SAME') 150 | z = tf.nn.bias_add(z, bias) 151 | z = tf.nn.relu(z) 152 | res_block.append(z) # 6 153 | 154 | 155 | if i in dia_indices_3: 156 | print("----------------------------------") 157 | print(i) 158 | if (i == 7): 159 | x = tf.nn.atrous_conv2d(former, filters=kernel, rate=4, padding='SAME') 160 | x = tf.nn.bias_add(x, bias) 161 | x = tf.nn.relu(x) 162 | res_block.append(x) #7 163 | elif (i == 8): 164 | y = tf.nn.atrous_conv2d(x, filters=kernel, rate=4, padding='SAME') 165 | y = tf.nn.bias_add(y, bias) 166 | y = tf.nn.relu(y) 167 | res_block.append(y) #8 168 | else: 169 | z = tf.nn.atrous_conv2d(y, filters=kernel, rate=4, padding='SAME') 170 | z = tf.nn.bias_add(z, bias) 171 | z = tf.nn.relu(z) 172 | res_block.append(z) # 9 173 | 174 | 175 | feature = res_block[0] 176 | mix_indices = (1, 2, 3) 177 | for i in mix_indices: 178 | feature = tf.concat([feature, res_block[i]], 3) 179 | 180 | 181 | 182 | 183 | out = 1 * feature + 0.1 * res_block[6] + 0.1 * res_block[9] 184 | 185 | block = res_block[1] + 0.1 * res_block[4] + 0.1 * res_block[7] 186 | block2 = tf.concat([res_block[1], res_block[2]], 3) + 0.1 * res_block[5] + 0.1 * res_block[8] 187 | 188 | print(self.weight_vars[0]) 189 | 190 | return out,res_block,block,block2 191 | 192 | 193 | 194 | #--------------------------------------------------------------------------------- 195 | # x : 输入 196 | # kernel, bias : 卷积核, 偏移量 197 | # use_relu : 激活 198 | def conv2d(x, kernel, bias, use_relu=True): 199 | # padding image with reflection mode 200 | x_padded = tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT') 201 | 202 | # conv and add bias 203 | # num_maps = x_padded.shape[3] 204 | # out = __batch_normalize(x_padded, num_maps) 205 | # out = tf.nn.relu(out) 206 | out = tf.nn.conv2d(x_padded, kernel, strides=[1, 1, 1, 1], padding='VALID') 207 | out = tf.nn.bias_add(out, bias) 208 | out = tf.nn.relu(out) 209 | return out 210 | 211 | 212 | def transition_block(x, kernel, bias): 213 | 214 | num_maps = x.shape[3] 215 | out = __batch_normalize(x, num_maps) 216 | out = tf.nn.relu(out) 217 | out = conv2d(out, kernel, bias, use_relu=False) 218 | 219 | return out 220 | 221 | 222 | def __batch_normalize(inputs, num_maps, is_training=True): 223 | # Trainable variables for scaling and offsetting our inputs 224 | # scale = tf.Variable(tf.ones([num_maps], dtype=tf.float32)) 225 | # offset = tf.Variable(tf.zeros([num_maps], dtype=tf.float32)) 226 | 227 | # Mean and variances related to our current batch 228 | batch_mean, batch_var = tf.nn.moments(inputs, [0, 1, 2]) 229 | 230 | # # Create an optimizer to maintain a 'moving average' 231 | # ema = tf.train.ExponentialMovingAverage(decay=DECAY) 232 | # 233 | # def ema_retrieve(): 234 | # return ema.average(batch_mean), ema.average(batch_var) 235 | # 236 | # # If the net is being trained, update the average every training step 237 | # def ema_update(): 238 | # ema_apply = ema.apply([batch_mean, batch_var]) 239 | # 240 | # # Make sure to compute the new means and variances prior to returning their values 241 | # with tf.control_dependencies([ema_apply]): 242 | # return tf.identity(batch_mean), tf.identity(batch_var) 243 | # 244 | # # Retrieve the means and variances and apply the BN transformation 245 | # mean, var = tf.cond(tf.equal(is_training, True), ema_update, ema_retrieve) 246 | bn_inputs = tf.nn.batch_normalization(inputs, batch_mean, batch_var, None, None, EPSILON) 247 | 248 | return bn_inputs -------------------------------------------------------------------------------- /code/fusion_addition.py: -------------------------------------------------------------------------------- 1 | # Additioin 2 | 3 | def Strategy(content, style): 4 | # return tf.reduce_sum(content, style) 5 | return content+style 6 | 7 | -------------------------------------------------------------------------------- /code/fusion_l1norm.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def L1_norm(source_en_a, source_en_b): 5 | result = [] 6 | narry_a = source_en_a 7 | narry_b = source_en_b 8 | 9 | dimension = source_en_a.shape 10 | 11 | print(source_en_a.shape) 12 | print(source_en_b.shape) 13 | 14 | # caculate L1-norm 15 | temp_abs_a = tf.abs(narry_a) 16 | temp_abs_b = tf.abs(narry_b) 17 | _l1_a = tf.reduce_sum(temp_abs_a,3) 18 | _l1_b = tf.reduce_sum(temp_abs_b,3) 19 | 20 | _l1_a = tf.reduce_sum(_l1_a, 0) 21 | _l1_b = tf.reduce_sum(_l1_b, 0) 22 | l1_a = _l1_a.eval() 23 | l1_b = _l1_b.eval() 24 | 25 | # caculate the map for source images 26 | mask_value = l1_a + l1_b + 0.0000001 27 | 28 | mask_sign_a = l1_a/mask_value 29 | mask_sign_b = l1_b/mask_value 30 | 31 | array_MASK_a = mask_sign_a 32 | array_MASK_b = mask_sign_b 33 | 34 | for i in range(dimension[3]): 35 | temp_matrix = array_MASK_a*narry_a[0,:,:,i] + array_MASK_b*narry_b[0,:,:,i] 36 | result.append(temp_matrix) 37 | 38 | result = np.stack(result, axis=-1) 39 | 40 | resule_tf = np.reshape(result, (dimension[0], dimension[1], dimension[2], dimension[3])) 41 | 42 | return resule_tf 43 | 44 | def L1_norm_attention(source_en_a,feation_a, source_en_b,feation_b): 45 | result = [] 46 | narry_a = source_en_a 47 | narry_b = source_en_b 48 | 49 | dimension = source_en_a.shape 50 | 51 | print(source_en_a.shape) 52 | print(source_en_b.shape) 53 | 54 | # caculate L1-norm 55 | temp_abs_a = tf.abs(narry_a) 56 | temp_abs_b = tf.abs(narry_b) 57 | _l1_a = tf.reduce_sum(temp_abs_a,3) 58 | _l1_b = tf.reduce_sum(temp_abs_b,3) 59 | 60 | _l1_a = tf.reduce_sum(_l1_a, 0) 61 | _l1_b = tf.reduce_sum(_l1_b, 0) 62 | l1_a = _l1_a.eval() 63 | l1_b = _l1_b.eval() 64 | 65 | # caculate the map for source images 66 | mask_value = l1_a + l1_b + 0.0000001 67 | 68 | mask_sign_a = l1_a/mask_value 69 | mask_sign_b = l1_b/mask_value 70 | 71 | array_MASK_a = mask_sign_a 72 | array_MASK_b = mask_sign_b 73 | 74 | for i in range(dimension[3]): 75 | temp_matrix = array_MASK_a*narry_a[0,:,:,i] + array_MASK_b*narry_b[0,:,:,i] 76 | result.append(temp_matrix) 77 | 78 | result = np.stack(result, axis=-1) 79 | 80 | resule_tf = np.reshape(result, (dimension[0], dimension[1], dimension[2], dimension[3])) 81 | 82 | return resule_tf -------------------------------------------------------------------------------- /code/generate.py: -------------------------------------------------------------------------------- 1 | # Use a trained DenseFuse Net to generate fused images 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | import attention 6 | from datetime import datetime 7 | 8 | from fusion_l1norm import L1_norm, L1_norm_attention 9 | from densefuse_net import DenseFuseNet 10 | from utils import get_images, save_images, get_train_images, get_train_images_rgb 11 | from fusion_addition import Strategy 12 | import cv2 13 | import attention 14 | TRAINING_IMAGE_SHAPE = (256, 256, 1) # (height, width, color_channels) 15 | TRAINING_IMAGE_SHAPE_OR = (256, 256, 1) # (height, width, color_channels) 16 | 17 | def guideFilter(I, p, winSize, eps): 18 | 19 | mean_I = cv2.blur(I, winSize) # I的均值平滑 20 | mean_p = cv2.blur(p, winSize) # p的均值平滑 21 | 22 | mean_II = cv2.blur(I * I, winSize) # I*I的均值平滑 23 | mean_Ip = cv2.blur(I * p, winSize) # I*p的均值平滑 24 | 25 | var_I = mean_II - mean_I * mean_I # 方差 26 | cov_Ip = mean_Ip - mean_I * mean_p # 协方差 27 | 28 | a = cov_Ip / (var_I + eps + 0.0000001) # 相关因子a 29 | b = mean_p - a * mean_I # 相关因子b 30 | 31 | mean_a = cv2.blur(a, winSize) # 对a进行均值平滑 32 | mean_b = cv2.blur(b, winSize) # 对b进行均值平滑 33 | 34 | q = mean_a * I + mean_b 35 | 36 | return q 37 | 38 | def RollingGuidance(I, sigma_s, sigma_r, iteration): 39 | sigma_s = (sigma_s, sigma_s) 40 | out = cv2.GaussianBlur(I, sigma_s, 0) 41 | for i in range(iteration): 42 | out = guideFilter(out, I, sigma_s, sigma_r*sigma_r) 43 | 44 | return out 45 | 46 | def Grad(I1): 47 | G1=[] 48 | L1=[] 49 | G1.append(I1) 50 | sigma_s = 3 51 | sigma_r = [0.5, 0.5, 0.5, 0.5] 52 | iteration = [4, 4, 4, 4] 53 | indice=(1,2,3) 54 | for i in indice: 55 | G1.append(RollingGuidance(G1[i-1],sigma_s,sigma_r[i-1],iteration[i-1])) 56 | L1.append(G1[i-1]-G1[i]) 57 | sigma_s = 3 * sigma_s 58 | sigma_s = (3, 3) 59 | G1.append(cv2.GaussianBlur(G1[3], sigma_s, 0)) 60 | L1.append(G1[3]-G1[4]) 61 | L1.append(G1[4]) 62 | grad = L1[0] 63 | return grad 64 | 65 | def generate(infrared_path, visible_path, model_path, model_pre_path,model_path_a,model_pre_path_a ,ssim_weight, index, IS_VIDEO, IS_RGB, type='addition', output_path=None): 66 | 67 | if IS_VIDEO: 68 | print('video_addition') 69 | _handler_video(infrared_path, visible_path, model_path, model_pre_path, ssim_weight, output_path=output_path) 70 | else: 71 | if IS_RGB: 72 | print('RGB - addition') 73 | _handler_rgb(infrared_path, visible_path, model_path, model_pre_path, ssim_weight, index, 74 | output_path=output_path) 75 | 76 | print('RGB - l1') 77 | _handler_rgb_l1(infrared_path, visible_path, model_path, model_pre_path, ssim_weight, index, 78 | output_path=output_path) 79 | else: 80 | _handler_mix_a(infrared_path, visible_path, model_path, model_pre_path,model_path_a,model_pre_path_a, ssim_weight, index, 81 | output_path=output_path) 82 | #if type == 'addition': 83 | #print('addition') 84 | #_handler(infrared_path, visible_path, model_path, model_pre_path, ssim_weight, index, output_path=output_path) 85 | #elif type == 'l1': 86 | #print('l1') 87 | #_handler_l1(infrared_path, visible_path, model_path, model_pre_path, ssim_weight, index, output_path=output_path) 88 | 89 | 90 | def _handler(ir_path, vis_path, model_path, model_pre_path, ssim_weight, index, output_path=None): 91 | ir_img = get_train_images(ir_path, flag=False) 92 | vis_img = get_train_images(vis_path, flag=False) 93 | # ir_img = get_train_images_rgb(ir_path, flag=False) 94 | # vis_img = get_train_images_rgb(vis_path, flag=False) 95 | dimension = ir_img.shape 96 | 97 | ir_img = ir_img.reshape([1, dimension[0], dimension[1], dimension[2]]) 98 | vis_img = vis_img.reshape([1, dimension[0], dimension[1], dimension[2]]) 99 | 100 | ir_img = np.transpose(ir_img, (0, 2, 1, 3)) 101 | vis_img = np.transpose(vis_img, (0, 2, 1, 3)) 102 | 103 | print('img shape final:', ir_img.shape) 104 | 105 | with tf.Graph().as_default(), tf.Session() as sess: 106 | infrared_field = tf.placeholder( 107 | tf.float32, shape=ir_img.shape, name='content') 108 | visible_field = tf.placeholder( 109 | tf.float32, shape=ir_img.shape, name='style') 110 | 111 | dfn = DenseFuseNet(model_pre_path) 112 | 113 | output_image = dfn.transform_addition(infrared_field, visible_field) 114 | # restore the trained model and run the style transferring 115 | saver = tf.train.Saver() 116 | saver.restore(sess, model_path) 117 | 118 | output = sess.run(output_image, feed_dict={infrared_field: ir_img, visible_field: vis_img}) 119 | 120 | save_images(ir_path, output, output_path, 121 | prefix='fused' + str(index), suffix='_densefuse_addition_'+str(ssim_weight)) 122 | 123 | 124 | def _handler_l1(ir_path, vis_path, model_path, model_pre_path, ssim_weight, index, output_path=None): 125 | ir_img = get_train_images(ir_path, flag=False) 126 | vis_img = get_train_images(vis_path, flag=False) 127 | dimension = ir_img.shape 128 | 129 | ir_img = ir_img.reshape([1, dimension[0], dimension[1], dimension[2]]) 130 | vis_img = vis_img.reshape([1, dimension[0], dimension[1], dimension[2]]) 131 | 132 | ir_img = np.transpose(ir_img, (0, 2, 1, 3)) 133 | vis_img = np.transpose(vis_img, (0, 2, 1, 3)) 134 | 135 | print('img shape final:', ir_img.shape) 136 | 137 | with tf.Graph().as_default(), tf.Session() as sess: 138 | 139 | # build the dataflow graph 140 | infrared_field = tf.placeholder( 141 | tf.float32, shape=ir_img.shape, name='content') 142 | visible_field = tf.placeholder( 143 | tf.float32, shape=ir_img.shape, name='style') 144 | 145 | dfn = DenseFuseNet(model_pre_path) 146 | 147 | enc_ir,enc_ir_res_block,enc_ir_res_block1,enc_ir_res_block2 = dfn.transform_encoder(infrared_field) 148 | enc_vis,enc_vis_res_block,enc_vis_res_block1,enc_vis_res_block2 = dfn.transform_encoder(visible_field) 149 | 150 | target = tf.placeholder( 151 | tf.float32, shape=enc_ir.shape, name='target') 152 | block1 = tf.placeholder( 153 | tf.float32, shape=enc_ir_res_block1.shape, name='block1') 154 | block2= tf.placeholder( 155 | tf.float32, shape=enc_ir_res_block2.shape, name='block2') 156 | output_image = dfn.transform_decoder(target,block1,block2) 157 | 158 | # restore the trained model and run the style transferring 159 | saver = tf.train.Saver() 160 | saver.restore(sess, model_path) 161 | 162 | enc_ir_temp, enc_vis_temp,ir_block1,ir_block2,vis_block1,vis_block2 = sess.run([enc_ir, enc_vis,enc_ir_res_block1,enc_ir_res_block2 ,enc_vis_res_block1,enc_vis_res_block2], feed_dict={infrared_field: ir_img, visible_field: vis_img}) 163 | feature = L1_norm(enc_ir_temp, enc_vis_temp) 164 | t_block1=L1_norm(ir_block1,vis_block1) 165 | 166 | t_block2 = L1_norm(ir_block2, vis_block2) 167 | output = sess.run(output_image, feed_dict={target: feature ,block1:t_block1,block2:t_block2}) 168 | save_images(ir_path, output, output_path, 169 | prefix='fused' + str(index), suffix='_densefuse_l1norm_'+str(ssim_weight)) 170 | 171 | 172 | def _handler_video(ir_path, vis_path, model_path, model_pre_path, ssim_weight, output_path=None): 173 | infrared = ir_path[0] 174 | img = get_train_images(infrared, flag=False) 175 | img = img.reshape([1, img.shape[0], img.shape[1], img.shape[2]]) 176 | img = np.transpose(img, (0, 2, 1, 3)) 177 | print('img shape final:', img.shape) 178 | num_imgs = len(ir_path) 179 | 180 | with tf.Graph().as_default(), tf.Session() as sess: 181 | # build the dataflow graph 182 | infrared_field = tf.placeholder( 183 | tf.float32, shape=img.shape, name='content') 184 | visible_field = tf.placeholder( 185 | tf.float32, shape=img.shape, name='style') 186 | 187 | dfn = DenseFuseNet(model_pre_path) 188 | 189 | output_image = dfn.transform_addition(infrared_field, visible_field) 190 | 191 | # restore the trained model and run the style transferring 192 | saver = tf.train.Saver() 193 | saver.restore(sess, model_path) 194 | 195 | ##################GET IMAGES################################################################################### 196 | start_time = datetime.now() 197 | for i in range(num_imgs): 198 | print('image number:', i) 199 | infrared = ir_path[i] 200 | visible = vis_path[i] 201 | 202 | ir_img = get_train_images(infrared, flag=False) 203 | vis_img = get_train_images(visible, flag=False) 204 | dimension = ir_img.shape 205 | 206 | ir_img = ir_img.reshape([1, dimension[0], dimension[1], dimension[2]]) 207 | vis_img = vis_img.reshape([1, dimension[0], dimension[1], dimension[2]]) 208 | 209 | ir_img = np.transpose(ir_img, (0, 2, 1, 3)) 210 | vis_img = np.transpose(vis_img, (0, 2, 1, 3)) 211 | 212 | ################FEED######################################## 213 | output = sess.run(output_image, feed_dict={infrared_field: ir_img, visible_field: vis_img}) 214 | save_images(infrared, output, output_path, 215 | prefix='fused' + str(i), suffix='_addition_' + str(ssim_weight)) 216 | ###################################################################################################### 217 | elapsed_time = datetime.now() - start_time 218 | print('Dense block video==> elapsed time: %s' % (elapsed_time)) 219 | 220 | 221 | def _handler_rgb(ir_path, vis_path, model_path, model_pre_path, ssim_weight, index, output_path=None): 222 | # ir_img = get_train_images(ir_path, flag=False) 223 | # vis_img = get_train_images(vis_path, flag=False) 224 | ir_img = get_train_images_rgb(ir_path, flag=False) 225 | vis_img = get_train_images_rgb(vis_path, flag=False) 226 | dimension = ir_img.shape 227 | 228 | ir_img = ir_img.reshape([1, dimension[0], dimension[1], dimension[2]]) 229 | vis_img = vis_img.reshape([1, dimension[0], dimension[1], dimension[2]]) 230 | 231 | #ir_img = np.transpose(ir_img, (0, 2, 1, 3)) 232 | #vis_img = np.transpose(vis_img, (0, 2, 1, 3)) 233 | 234 | ir_img1 = ir_img[:, :, :, 0] 235 | ir_img1 = ir_img1.reshape([1, dimension[0], dimension[1], 1]) 236 | ir_img2 = ir_img[:, :, :, 1] 237 | ir_img2 = ir_img2.reshape([1, dimension[0], dimension[1], 1]) 238 | ir_img3 = ir_img[:, :, :, 2] 239 | ir_img3 = ir_img3.reshape([1, dimension[0], dimension[1], 1]) 240 | 241 | vis_img1 = vis_img[:, :, :, 0] 242 | vis_img1 = vis_img1.reshape([1, dimension[0], dimension[1], 1]) 243 | vis_img2 = vis_img[:, :, :, 1] 244 | vis_img2 = vis_img2.reshape([1, dimension[0], dimension[1], 1]) 245 | vis_img3 = vis_img[:, :, :, 2] 246 | vis_img3 = vis_img3.reshape([1, dimension[0], dimension[1], 1]) 247 | 248 | print('img shape final:', ir_img1.shape) 249 | 250 | with tf.Graph().as_default(), tf.Session() as sess: 251 | infrared_field = tf.placeholder( 252 | tf.float32, shape=ir_img1.shape, name='content') 253 | visible_field = tf.placeholder( 254 | tf.float32, shape=ir_img1.shape, name='style') 255 | 256 | dfn = DenseFuseNet(model_pre_path) 257 | 258 | output_image = dfn.transform_addition(infrared_field, visible_field) 259 | # restore the trained model and run the style transferring 260 | saver = tf.train.Saver() 261 | saver.restore(sess, model_path) 262 | 263 | output1 = sess.run(output_image, feed_dict={infrared_field: ir_img1, visible_field: vis_img1}) 264 | output2 = sess.run(output_image, feed_dict={infrared_field: ir_img2, visible_field: vis_img2}) 265 | output3 = sess.run(output_image, feed_dict={infrared_field: ir_img3, visible_field: vis_img3}) 266 | 267 | output1 = output1.reshape([1, dimension[0], dimension[1]]) 268 | output2 = output2.reshape([1, dimension[0], dimension[1]]) 269 | output3 = output3.reshape([1, dimension[0], dimension[1]]) 270 | 271 | output = np.stack((output1, output2, output3), axis=-1) 272 | #output = np.transpose(output, (0, 2, 1, 3)) 273 | save_images(ir_path, output, output_path, 274 | prefix='fused' + str(index), suffix='_densefuse_addition_'+str(ssim_weight)) 275 | 276 | 277 | def _handler_rgb_l1(ir_path, vis_path, model_path, model_pre_path, ssim_weight, index, output_path=None): 278 | # ir_img = get_train_images(ir_path, flag=False) 279 | # vis_img = get_train_images(vis_path, flag=False) 280 | ir_img = get_train_images_rgb(ir_path, flag=False) 281 | vis_img = get_train_images_rgb(vis_path, flag=False) 282 | dimension = ir_img.shape 283 | 284 | ir_img = ir_img.reshape([1, dimension[0], dimension[1], dimension[2]]) 285 | vis_img = vis_img.reshape([1, dimension[0], dimension[1], dimension[2]]) 286 | 287 | #ir_img = np.transpose(ir_img, (0, 2, 1, 3)) 288 | #vis_img = np.transpose(vis_img, (0, 2, 1, 3)) 289 | 290 | ir_img1 = ir_img[:, :, :, 0] 291 | ir_img1 = ir_img1.reshape([1, dimension[0], dimension[1], 1]) 292 | ir_img2 = ir_img[:, :, :, 1] 293 | ir_img2 = ir_img2.reshape([1, dimension[0], dimension[1], 1]) 294 | ir_img3 = ir_img[:, :, :, 2] 295 | ir_img3 = ir_img3.reshape([1, dimension[0], dimension[1], 1]) 296 | 297 | vis_img1 = vis_img[:, :, :, 0] 298 | vis_img1 = vis_img1.reshape([1, dimension[0], dimension[1], 1]) 299 | vis_img2 = vis_img[:, :, :, 1] 300 | vis_img2 = vis_img2.reshape([1, dimension[0], dimension[1], 1]) 301 | vis_img3 = vis_img[:, :, :, 2] 302 | vis_img3 = vis_img3.reshape([1, dimension[0], dimension[1], 1]) 303 | 304 | print('img shape final:', ir_img1.shape) 305 | 306 | with tf.Graph().as_default(), tf.Session() as sess: 307 | infrared_field = tf.placeholder( 308 | tf.float32, shape=ir_img1.shape, name='content') 309 | visible_field = tf.placeholder( 310 | tf.float32, shape=ir_img1.shape, name='style') 311 | 312 | dfn = DenseFuseNet(model_pre_path) 313 | 314 | enc_ir,enc_ir_res_block = dfn.transform_encoder(infrared_field) 315 | enc_vis,enc_vis_res_block = dfn.transform_encoder(visible_field) 316 | 317 | target = tf.placeholder( 318 | tf.float32, shape=enc_ir.shape, name='target') 319 | 320 | output_image = dfn.transform_decoder(target) 321 | 322 | # restore the trained model and run the style transferring 323 | saver = tf.train.Saver() 324 | saver.restore(sess, model_path) 325 | 326 | enc_ir_temp, enc_vis_temp = sess.run([enc_ir, enc_vis], feed_dict={infrared_field: ir_img1, visible_field: vis_img1}) 327 | feature = L1_norm(enc_ir_temp, enc_vis_temp) 328 | output1 = sess.run(output_image, feed_dict={target: feature}) 329 | 330 | enc_ir_temp, enc_vis_temp = sess.run([enc_ir, enc_vis], feed_dict={infrared_field: ir_img2, visible_field: vis_img2}) 331 | feature = L1_norm(enc_ir_temp, enc_vis_temp) 332 | output2 = sess.run(output_image, feed_dict={target: feature}) 333 | 334 | enc_ir_temp, enc_vis_temp = sess.run([enc_ir, enc_vis], feed_dict={infrared_field: ir_img3, visible_field: vis_img3}) 335 | feature = L1_norm(enc_ir_temp, enc_vis_temp) 336 | output3 = sess.run(output_image, feed_dict={target: feature}) 337 | 338 | output1 = output1.reshape([1, dimension[0], dimension[1]]) 339 | output2 = output2.reshape([1, dimension[0], dimension[1]]) 340 | output3 = output3.reshape([1, dimension[0], dimension[1]]) 341 | 342 | output = np.stack((output1, output2, output3), axis=-1) 343 | #output = np.transpose(output, (0, 2, 1, 3)) 344 | save_images(ir_path, output, output_path, 345 | prefix='fused' + str(index), suffix='_densefuse_l1norm_'+str(ssim_weight)) 346 | def _handler_mix(ir_path, vis_path, model_path, model_pre_path, ssim_weight, index, output_path=None): 347 | mix_block=[] 348 | ir_img = get_train_images(ir_path, flag=False) 349 | vis_img = get_train_images(vis_path, flag=False) 350 | dimension = ir_img.shape 351 | ir_img = ir_img.reshape([1, dimension[0], dimension[1], dimension[2]]) 352 | vis_img = vis_img.reshape([1, dimension[0], dimension[1], dimension[2]]) 353 | ir_img = np.transpose(ir_img, (0, 2, 1, 3)) 354 | vis_img = np.transpose(vis_img, (0, 2, 1, 3)) 355 | 356 | print('img shape final:', ir_img.shape) 357 | with tf.Graph().as_default(), tf.Session() as sess: 358 | infrared_field = tf.placeholder( 359 | tf.float32, shape=ir_img.shape, name='content') 360 | visible_field = tf.placeholder( 361 | tf.float32, shape=vis_img.shape, name='style') 362 | 363 | # ----------------------------------------------- 364 | 365 | dfn = DenseFuseNet(model_pre_path) 366 | 367 | #sess.run(tf.global_variables_initializer()) 368 | 369 | enc_ir,enc_ir_res_block ,enc_ir_block,enc_ir_block2= dfn.transform_encoder(infrared_field) 370 | enc_vis,enc_vis_res_block,enc_vis_block,enc_vis_block2 = dfn.transform_encoder(visible_field) 371 | 372 | result = tf.placeholder( 373 | tf.float32, shape=enc_ir.shape, name='target') 374 | 375 | 376 | 377 | saver = tf.train.Saver() 378 | saver.restore(sess, model_path) 379 | 380 | enc_ir_temp, enc_ir_res_block_temp, enc_ir_block_temp, enc_ir_block2_temp = sess.run( 381 | [enc_ir, enc_ir_res_block, enc_ir_block, enc_ir_block2], feed_dict={infrared_field: ir_img}) 382 | enc_vis_temp, enc_vis_res_block_temp, enc_vis_block_temp, enc_vis_block2_temp = sess.run( 383 | [enc_vis, enc_vis_res_block, enc_vis_block, enc_vis_block2], feed_dict={visible_field: vis_img}) 384 | 385 | block = L1_norm(enc_ir_block_temp, enc_vis_block_temp) 386 | block2=L1_norm(enc_ir_block2_temp,enc_vis_block2_temp) 387 | 388 | first_first = L1_norm(enc_ir_res_block_temp[0], enc_vis_res_block_temp[0]) 389 | first_second = Strategy(enc_ir_res_block_temp[1], enc_vis_res_block_temp[1]) 390 | #first_third = L1_norm_attention(enc_ir_res_block_temp[2],feation_ir, enc_vis_res_block_temp[2],feation_vis) 391 | #first_four = L1_norm_attention(enc_ir_res_block_temp[3],feation_ir, enc_vis_res_block_temp[3],feation_vis) 392 | first_third=L1_norm(enc_ir_res_block_temp[2],enc_vis_res_block_temp[2]) 393 | first_four=Strategy(enc_ir_res_block_temp[3],enc_vis_res_block_temp[3]) 394 | first_first = tf.concat([first_first, tf.to_int32(first_second, name='ToInt')],3) 395 | first_first = tf.concat([first_first, tf.to_int32(first_third, name='ToInt')],3) 396 | first_first = tf.concat([first_first, first_four],3) 397 | 398 | first = first_first 399 | 400 | second = L1_norm(enc_ir_res_block_temp[6], enc_vis_res_block_temp[6]) 401 | third = L1_norm(enc_ir_res_block_temp[9], enc_vis_res_block_temp[9]) 402 | 403 | feature = 1 * first + 0.1 * second + 0.1 * third 404 | 405 | #--------------------------------------------------------- 406 | # block=Strategy(enc_ir_block_temp,enc_vis_block_temp) 407 | # block2=L1_norm(enc_ir_block2_temp,enc_vis_block2_temp) 408 | #--------------------------------------------------------- 409 | 410 | feature = feature.eval() 411 | 412 | output_image = dfn.transform_decoder(result, block, block2) 413 | 414 | # output = dfn.transform_decoder(feature) 415 | # print(type(feature)) 416 | # output = sess.run(output_image, feed_dict={result: feature,enc_res_block:block,enc_res_block2:block2}) 417 | output = sess.run(output_image, feed_dict={result: feature}) 418 | 419 | save_images(ir_path, output, output_path, 420 | prefix='fused' + str(index), suffix='_mix_' + str(ssim_weight)) 421 | def _get_attention(ir_path,vis_path,model_path_a,model_pre_path_a): 422 | ir_img = get_train_images(ir_path, flag=False) 423 | vis_img = get_train_images(vis_path, flag=False) 424 | dimension = ir_img.shape 425 | ir_img = ir_img.reshape([1, dimension[0], dimension[1], dimension[2]]) 426 | vis_img = vis_img.reshape([1, dimension[0], dimension[1], dimension[2]]) 427 | ir_img = np.transpose(ir_img, (0, 2, 1, 3)) 428 | vis_img = np.transpose(vis_img, (0, 2, 1, 3)) 429 | g1 = tf.Graph() # 加载到Session 1的graph 430 | 431 | sess1 = tf.Session(graph=g1) # Session1 432 | 433 | with sess1.as_default(): 434 | with g1.as_default(), tf.Session() as sess: 435 | infrared_field = tf.placeholder( 436 | tf.float32, shape=ir_img.shape, name='content') 437 | visible_field = tf.placeholder( 438 | tf.float32, shape=vis_img.shape, name='style') 439 | edge_ir = tf.placeholder(tf.float32, shape=ir_img.shape, name='attention') 440 | edge_vis = tf.placeholder(tf.float32, shape=ir_img.shape, name='attention') 441 | 442 | # ----------------------------------------------- 443 | image_ir = sess.run(infrared_field, feed_dict={infrared_field: ir_img}) 444 | image_vis = sess.run(visible_field, feed_dict={visible_field: vis_img}) 445 | 446 | p_vis = image_vis[0] 447 | p_ir = image_ir[0] 448 | 449 | p_vis = np.squeeze(p_vis) # 降维 450 | p_ir = np.squeeze(p_ir) 451 | 452 | guideFilter_img_vis = Grad(p_vis) 453 | guideFilter_img_ir = Grad(p_ir) 454 | 455 | guideFilter_img_vis[guideFilter_img_vis < 0] = 0 456 | guideFilter_img_ir[guideFilter_img_ir < 0] = 0 457 | guideFilter_img_vis = np.expand_dims(guideFilter_img_vis, axis=-1) 458 | guideFilter_img_ir = np.expand_dims(guideFilter_img_ir, axis=-1) 459 | guideFilter_img_vis = np.expand_dims(guideFilter_img_vis, axis=0) 460 | guideFilter_img_ir = np.expand_dims(guideFilter_img_ir, axis=0) 461 | 462 | a = attention.Attention(model_pre_path_a) 463 | saver = tf.train.Saver() 464 | saver.restore(sess, model_path_a) 465 | 466 | feature_a=a.get_attention(edge_ir) 467 | feature_b=a.get_attention(edge_vis) 468 | 469 | 470 | edge_ir_temp = sess.run([feature_a], feed_dict={edge_ir: guideFilter_img_ir}) 471 | edge_vis_temp = sess.run([feature_b], feed_dict={edge_vis: guideFilter_img_vis}) 472 | '''feature_a = a.get_attention(edge_ir_temp) 473 | feature_b = a.get_attention(edge_vis_temp)''' 474 | 475 | return edge_ir_temp,edge_vis_temp 476 | 477 | 478 | def _handler_mix_a(ir_path, vis_path, model_path, model_pre_path,model_path_a,model_pre_path_a, ssim_weight, index, output_path=None): 479 | ir_img = get_train_images(ir_path, flag=False) 480 | vis_img = get_train_images(vis_path, flag=False) 481 | dimension = ir_img.shape 482 | ir_img = ir_img.reshape([1, dimension[0], dimension[1], dimension[2]]) 483 | vis_img = vis_img.reshape([1, dimension[0], dimension[1], dimension[2]]) 484 | ir_img = np.transpose(ir_img, (0, 2, 1, 3)) 485 | vis_img = np.transpose(vis_img, (0, 2, 1, 3)) 486 | 487 | 488 | g2 = tf.Graph() # 加载到Session 2的graph 489 | 490 | sess2 = tf.Session(graph=g2) # Session2 491 | 492 | with sess2.as_default(): # 1 493 | with g2.as_default(),tf.Session() as sess: 494 | infrared_field = tf.placeholder( 495 | tf.float32, shape=ir_img.shape, name='content') 496 | visible_field = tf.placeholder( 497 | tf.float32, shape=vis_img.shape, name='style') 498 | 499 | dfn = DenseFuseNet(model_pre_path) 500 | 501 | # sess.run(tf.global_variables_initializer()) 502 | 503 | enc_ir, enc_ir_res_block, enc_ir_block, enc_ir_block2 = dfn.transform_encoder(infrared_field) 504 | enc_vis, enc_vis_res_block, enc_vis_block, enc_vis_block2 = dfn.transform_encoder(visible_field) 505 | 506 | result = tf.placeholder( 507 | tf.float32, shape=enc_ir.shape, name='target') 508 | 509 | saver = tf.train.Saver() 510 | saver.restore(sess, model_path) 511 | print("______000________") 512 | feature_a,feature_b=_get_attention(ir_path,vis_path,model_path_a,model_pre_path_a) 513 | print("______111________") 514 | print(feature_a[0].shape) 515 | 516 | enc_ir_temp, enc_ir_res_block_temp, enc_ir_block_temp, enc_ir_block2_temp = sess.run( 517 | [enc_ir, enc_ir_res_block, enc_ir_block, enc_ir_block2], feed_dict={infrared_field: ir_img}) 518 | print("______222________") 519 | enc_vis_temp, enc_vis_res_block_temp, enc_vis_block_temp, enc_vis_block2_temp = sess.run( 520 | [enc_vis, enc_vis_res_block, enc_vis_block, enc_vis_block2], feed_dict={visible_field: vis_img}) 521 | print("______333________") 522 | # ---------------------------------------------------------------------------------------------------------- 523 | # ---------------------------------------------------------------------------------------------------------- 524 | 525 | 526 | #----------------------------------跳跃部分----------------------------------------------------------------- 527 | block = Strategy(enc_ir_block_temp, enc_vis_block_temp) * 0 528 | block2 = Strategy(enc_ir_block2_temp, enc_vis_block2_temp) * 0 529 | #block = L1_norm_attention(enc_ir_block_temp, feature_a, enc_vis_block_temp, feature_b) 530 | #block2 = L1_norm_attention(enc_ir_block2_temp, feature_a, enc_vis_block2_temp, feature_b) 531 | # ---------------------------------------------------------------------------------------------------------- 532 | 533 | first_first = Strategy(enc_ir_res_block_temp[0], enc_vis_res_block_temp[0]) 534 | #first_first = L1_norm_attention(enc_ir_res_block_temp[0],feature_a, enc_vis_res_block_temp[0],feature_b) 535 | first_second = Strategy(enc_ir_res_block_temp[1], enc_vis_res_block_temp[1]) 536 | #first_second = L1_norm_attention(enc_ir_res_block_temp[1],feature_a, enc_vis_res_block_temp[1],feature_b) 537 | first_third = Strategy(enc_ir_res_block_temp[2], enc_vis_res_block_temp[2]) 538 | #first_third = L1_norm_attention(enc_ir_res_block_temp[2], feature_a, enc_vis_res_block_temp[2], feature_b) 539 | #first_third = L1_norm(enc_ir_res_block_temp[2], enc_vis_res_block_temp[2]) * 0 540 | first_four = Strategy(enc_ir_res_block_temp[3], enc_vis_res_block_temp[3]) 541 | #first_four = L1_norm_attention(enc_ir_res_block_temp[3], feature_a, enc_vis_res_block_temp[3], feature_b) 542 | #first_four = L1_norm(enc_ir_res_block_temp[3], enc_vis_res_block_temp[3]) 543 | first_first = tf.concat([first_first, tf.to_int32(first_second, name='ToInt')], 3) 544 | first_first = tf.concat([first_first, tf.to_int32(first_third, name='ToInt')], 3) 545 | first_first = tf.concat([first_first, first_four], 3) 546 | print("______444________") 547 | 548 | first = first_first 549 | 550 | # -------------------------------------空洞卷积部分--------------------------------------------------------- 551 | #second = L1_norm_attention(enc_ir_res_block_temp[6],feature_a, enc_vis_res_block_temp[6],feature_b) 552 | print (enc_ir_res_block_temp[6].shape) 553 | second = L1_norm(enc_ir_res_block_temp[6], enc_vis_res_block_temp[6]) 554 | print("______4545________") 555 | third = Strategy(enc_ir_res_block_temp[9], enc_vis_res_block_temp[9]) 556 | print("______555________") 557 | # ---------------------------------------------------------------------------------------------------------- 558 | 559 | 560 | # ---------------------------------------------------------------------------------------------------------- 561 | # ---------------------------------------------------------------------------------------------------------- 562 | 563 | 564 | 565 | # -------------------------------------空洞卷积部分--------------------------------------------------------- 566 | feature = 1 * first + 0.1 * second + 0.1 * third 567 | # ---------------------------------------------------------------------------------------------------------- 568 | print ("51515151") 569 | 570 | # --------------------------------------------------------- 571 | # block=Strategy(enc_ir_block_temp,enc_vis_block_temp) 572 | # block2=L1_norm(enc_ir_block2_temp,enc_vis_block2_temp) 573 | # --------------------------------------------------------- 574 | 575 | feature = feature.eval() 576 | 577 | 578 | print ("52525252") 579 | 580 | # --------------将特征图压成单通道---------------------------------- 581 | #feature_map_vis_out = sess.run(tf.reduce_sum(feature_a[0], 3, keep_dims=True)) 582 | #feature_map_ir_out = sess.run(tf.reduce_sum(feature_b[0],3, keep_dims=True)) 583 | # ------------------------------------------------------------------ 584 | print (result.shape) 585 | print ("5555555") 586 | output_image = dfn.transform_decoder(result, block, block2) 587 | print("______666________") 588 | # output = dfn.transform_decoder(feature) 589 | # print(type(feature)) 590 | # output = sess.run(output_image, feed_dict={result: feature,enc_res_block:block,enc_res_block2:block2}) 591 | 592 | output = sess.run(output_image, feed_dict={result: feature}) 593 | print (output_image.shape) 594 | print("______777________") 595 | save_images(ir_path, output, output_path, 596 | prefix='' + str(index), suffix='-1') 597 | #prefix = '' + str(index), suffix = '-4' + str(ssim_weight)) 598 | #save_images(ir_path, feature_map_vis_out, output_path, 599 | # prefix='fused' + str(index), suffix='vis' + str(ssim_weight)) 600 | #save_images(ir_path, feature_map_ir_out, output_path, 601 | # prefix='fused' + str(index), suffix='ir' + str(ssim_weight)) 602 | 603 | 604 | 605 | -------------------------------------------------------------------------------- /code/icme2020_supplement.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ngujiang/Learning-attention-guided-deep-multi-scale-feature-ensemble-for-infrared-and-visible-image-fusion/77ce05bc6f87247b925b83c1b78a457a9e4b6f97/code/icme2020_supplement.zip -------------------------------------------------------------------------------- /code/icme2020template (1).pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ngujiang/Learning-attention-guided-deep-multi-scale-feature-ensemble-for-infrared-and-visible-image-fusion/77ce05bc6f87247b925b83c1b78a457a9e4b6f97/code/icme2020template (1).pdf -------------------------------------------------------------------------------- /code/main.py: -------------------------------------------------------------------------------- 1 | # Demo - train the DenseFuse network & use it to generate an image 2 | 3 | from __future__ import print_function 4 | 5 | import time 6 | 7 | from train_recons import train_recons,train_recons_a 8 | from generate import generate 9 | from utils import list_images 10 | import os 11 | 12 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 13 | 14 | # True for training phase 15 | IS_TRAINING = False 16 | IS_TRAINING_A = False 17 | # True for video sequences(frames) 18 | IS_VIDEO = False 19 | # True for RGB images 20 | is_RGB = False 21 | 22 | BATCH_SIZE = 32 23 | EPOCHES = 4 24 | 25 | SSIM_WEIGHTS = [1, 10, 100, 1000] 26 | SSIM_WEIGHTS_A=[1, 10, 100] 27 | 28 | #------------------------------------------------------------------------------------------------- 29 | MODEL_SAVE_PATHS = [ 30 | '/data/ljy/paper_again/19-11-20-final/model/densefuse_model_bs2_epoch4_all_weight_1e0.ckpt', 31 | '/data/ljy/paper_again/19-11-20-final/model/densefuse_model_bs2_epoch4_all_weight_1e1.ckpt', 32 | '/data/ljy/paper_again/19-11-20-final/model/densefuse_model_bs2_epoch4_all_weight_1e2.ckpt', 33 | '/data/ljy/paper_again/19-11-20-final/model/densefuse_model_bs2_epoch4_all_weight_1e3.ckpt', 34 | ] 35 | MODEL_SAVE_PATHS_A = [ 36 | '/data/ljy/paper_again/19-11-20-final/model_a/densefuse_model_bs2_epoch4_all_weight_1e0.ckpt', 37 | '/data/ljy/paper_again/19-11-20-final/model_a/densefuse_model_bs2_epoch4_all_weight_1e1.ckpt', 38 | '/data/ljy/paper_again/19-11-20-final/model_a/densefuse_model_bs2_epoch4_all_weight_1e2.ckpt', 39 | ] 40 | #ckpt文件用于保存tensorflow的模型 41 | #----------------------------------------------------------------------------------------------- 42 | # MODEL_SAVE_PATH = './models/deepfuse_dense_model_bs4_epoch2_relu_pLoss_noconv_test.ckpt' 43 | # model_pre_path = './models/deepfuse_dense_model_bs2_epoch2_relu_pLoss_noconv_NEW.ckpt' 44 | 45 | # In testing process, 'model_pre_path' is set to None 46 | # The "model_pre_path" in "main.py" is just a pre-train model and not necessary for training and testing. 47 | # It is set as None when you want to train your own model. 48 | # If you already train a model, you can set it as your model for initialize weights. 49 | model_pre_path = None 50 | model_pre_path_a = None 51 | def main(): 52 | 53 | if IS_TRAINING: 54 | #------------------------------------------------------------------------------------------------------- 55 | original_imgs_path = list_images('/data/ljy/train_mix/mix_256/') 56 | validatioin_imgs_path = list_images('/data/ljy/修改专用/imagefusion_densefuse-master/validation/validation/') 57 | #--------------------------------------------------------------------------------------------------------- 58 | for ssim_weight, model_save_path in zip(SSIM_WEIGHTS, MODEL_SAVE_PATHS): 59 | print('\nBegin to train the network ...\n') 60 | train_recons(original_imgs_path, validatioin_imgs_path, model_save_path, model_pre_path, ssim_weight, EPOCHES, BATCH_SIZE, debug=True) 61 | 62 | print('\nSuccessfully! Done training...\n') 63 | #==================================================================================================== 64 | elif IS_TRAINING_A: 65 | original_imgs_path = list_images('/data/ljy/train_mix/mix_256/') 66 | validatioin_imgs_path = list_images('/data/ljy/修改专用/imagefusion_densefuse-master/validation/validation/') 67 | for ssim_weight_a, model_save_path_a in zip(SSIM_WEIGHTS_A, MODEL_SAVE_PATHS_A): 68 | print('\nBegin to train the attention network ...\n') 69 | train_recons_a(original_imgs_path, validatioin_imgs_path, model_save_path_a, model_pre_path_a, ssim_weight_a,EPOCHES, BATCH_SIZE,MODEL_SAVE_PATHS[0], debug=True) 70 | print('\nSuccessfully! Done training...\n') 71 | 72 | 73 | 74 | 75 | #================================================================================================ 76 | else: 77 | if IS_VIDEO: 78 | ssim_weight = SSIM_WEIGHTS[0] 79 | model_path = MODEL_SAVE_PATHS[0] 80 | 81 | IR_path = list_images('video/1_IR/') 82 | VIS_path = list_images('video/1_VIS/') 83 | output_save_path = 'video/fused'+ str(ssim_weight) +'/' 84 | generate(IR_path, VIS_path, model_path, model_pre_path, 85 | ssim_weight, 0, IS_VIDEO, 'addition', output_path=output_save_path) 86 | else: 87 | ssim_weight = SSIM_WEIGHTS[1] 88 | model_path = MODEL_SAVE_PATHS[1] 89 | model_path_a=MODEL_SAVE_PATHS_A[1] 90 | print('\nBegin to generate pictures ...\n') 91 | # path = 'images/IV_images/' 92 | path = '/data/ljy/IV_images/' 93 | for i in range(20): 94 | #if i != 1 : 95 | # continue 96 | index = i + 1 97 | infrared = path + 'IR' + str(index) + '.png' 98 | visible = path + 'VIS' + str(index) + '.png' 99 | 100 | # RGB images 101 | #infrared = path + 'lytro-' + str(index) + '-A.jpg' 102 | #visible = path + 'lytro-' + str(index) + '-B.jpg' 103 | 104 | # choose fusion layer 105 | #fusion_type = 'addition' 106 | fusion_type = 'l1' 107 | # for ssim_weight, model_path in zip(SSIM_WEIGHTS, MODEL_SAVE_PATHS): 108 | # output_save_path = 'outputs' 109 | # 110 | # generate(infrared, visible, model_path, model_pre_path, 111 | # ssim_weight, index, IS_VIDEO, is_RGB, type = fusion_type, output_path = output_save_path) 112 | 113 | output_save_path = '/data/ljy/paper_again/19-11-20-final/attention/' 114 | generate(infrared, visible, model_path, model_pre_path,model_path_a,model_pre_path_a, 115 | ssim_weight, index, IS_VIDEO, is_RGB, type = fusion_type, output_path = output_save_path) 116 | 117 | 118 | if __name__ == '__main__': 119 | main() 120 | 121 | -------------------------------------------------------------------------------- /code/new_attention.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | from collections import namedtuple 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | import tensorflow.contrib.slim as slim 7 | 8 | from bilinear_sampler import * 9 | 10 | monodepth_parameters = namedtuple('parameters', 11 | 'encoder, ' 12 | 'height, width, ' 13 | 'batch_size, ' 14 | 'num_threads, ' 15 | 'num_epochs, ' 16 | 'do_stereo, ' 17 | 'wrap_mode, ' 18 | 'use_deconv, ' 19 | 'alpha_image_loss, ' 20 | 'disp_gradient_loss_weight, ' 21 | 'lr_loss_weight, ' 22 | 'full_summary') 23 | 24 | 25 | class MonodepthModel(object): 26 | """monodepth model""" 27 | 28 | def __init__(self, params, mode, left, right, reuse_variables=None, model_index=0): 29 | self.params = params 30 | self.mode = mode 31 | self.left = left 32 | self.right = right 33 | self.model_collection = ['model_' + str(model_index)] 34 | 35 | self.reuse_variables = reuse_variables 36 | 37 | self.build_model() 38 | self.build_outputs() 39 | 40 | if self.mode == 'test': 41 | return 42 | 43 | self.build_losses() 44 | self.build_summaries() 45 | 46 | def gradient_x(self, img): 47 | gx = img[:, :, :-1, :] - img[:, :, 1:, :] 48 | return gx 49 | 50 | def gradient_y(self, img): 51 | gy = img[:, :-1, :, :] - img[:, 1:, :, :] 52 | return gy 53 | 54 | def upsample_nn(self, x, ratio): 55 | s = tf.shape(x) 56 | h = s[1] 57 | w = s[2] 58 | return tf.image.resize_nearest_neighbor(x, [h * ratio, w * ratio]) 59 | 60 | def scale_pyramid(self, img, num_scales): 61 | scaled_imgs = [img] 62 | s = tf.shape(img) 63 | h = s[1] 64 | w = s[2] 65 | for i in range(num_scales - 1): 66 | ratio = 2 ** (i + 1) 67 | nh = h // ratio 68 | nw = w // ratio 69 | scaled_imgs.append(tf.image.resize_area(img, [nh, nw])) 70 | return scaled_imgs 71 | 72 | def generate_image_left(self, img, disp): 73 | return bilinear_sampler_1d_h(img, -disp) 74 | 75 | def generate_image_right(self, img, disp): 76 | return bilinear_sampler_1d_h(img, disp) 77 | 78 | def SSIM(self, x, y): 79 | C1 = 0.01 ** 2 80 | C2 = 0.03 ** 2 81 | 82 | mu_x = slim.avg_pool2d(x, 3, 1, 'VALID') 83 | mu_y = slim.avg_pool2d(y, 3, 1, 'VALID') 84 | 85 | sigma_x = slim.avg_pool2d(x ** 2, 3, 1, 'VALID') - mu_x ** 2 86 | sigma_y = slim.avg_pool2d(y ** 2, 3, 1, 'VALID') - mu_y ** 2 87 | sigma_xy = slim.avg_pool2d(x * y, 3, 1, 'VALID') - mu_x * mu_y 88 | 89 | SSIM_n = (2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2) 90 | SSIM_d = (mu_x ** 2 + mu_y ** 2 + C1) * (sigma_x + sigma_y + C2) 91 | 92 | SSIM = SSIM_n / SSIM_d 93 | 94 | return tf.clip_by_value((1 - SSIM) / 2, 0, 1) 95 | 96 | def get_disparity_smoothness(self, disp, pyramid): 97 | disp_gradients_x = [self.gradient_x(d) for d in disp] 98 | disp_gradients_y = [self.gradient_y(d) for d in disp] 99 | 100 | image_gradients_x = [self.gradient_x(img) for img in pyramid] 101 | image_gradients_y = [self.gradient_y(img) for img in pyramid] 102 | 103 | weights_x = [tf.exp(-tf.reduce_mean(tf.abs(g), 3, keep_dims=True)) for g in image_gradients_x] 104 | weights_y = [tf.exp(-tf.reduce_mean(tf.abs(g), 3, keep_dims=True)) for g in image_gradients_y] 105 | 106 | smoothness_x = [disp_gradients_x[i] * weights_x[i] for i in range(4)] 107 | smoothness_y = [disp_gradients_y[i] * weights_y[i] for i in range(4)] 108 | return smoothness_x + smoothness_y 109 | 110 | def get_disp(self, x): 111 | disp = 0.3 * self.conv(x, 2, 3, 1, tf.nn.sigmoid) 112 | return disp 113 | 114 | def conv(self, x, num_out_layers, kernel_size, stride, activation_fn=tf.nn.elu): 115 | p = np.floor((kernel_size - 1) / 2).astype(np.int32) 116 | p_x = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]]) 117 | return slim.conv2d(p_x, num_out_layers, kernel_size, stride, 'VALID', activation_fn=activation_fn) 118 | 119 | def conv_dia(self, x, num_out_layers, kernel_size, stride, rate_dia, activation_fn=tf.nn.elu): 120 | # p = np.floor((kernel_size - 1) / 2).astype(np.int32) 121 | # p_x = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]]) 122 | stride = 1 123 | return slim.conv2d(x, num_out_layers, kernel_size, stride, 'SAME', activation_fn=activation_fn, rate=rate_dia) 124 | 125 | def conv_block(self, x, num_out_layers, kernel_size): 126 | conv1 = self.conv(x, num_out_layers, kernel_size, 1) 127 | conv2 = self.conv(conv1, num_out_layers, kernel_size, 2) 128 | return conv2 129 | 130 | def maxpool(self, x, kernel_size): 131 | p = np.floor((kernel_size - 1) / 2).astype(np.int32) 132 | p_x = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]]) 133 | return slim.max_pool2d(p_x, kernel_size) 134 | 135 | ##########################################################################CBAM-Net 136 | def cbam_module(self, inputs, out_dim, reduction_ratio=0.5, name=""): 137 | with tf.variable_scope("cbam_" + name, reuse=tf.AUTO_REUSE): ##tf.AUTO_REUSE 138 | batch_size, hidden_num = inputs.get_shape().as_list()[0], inputs.get_shape().as_list()[3] 139 | # batch_size = inputs.get_shape().as_list()[0] 140 | # hidden_num = out_dim 141 | # print('=====================================') 142 | # print(inputs.shape) 143 | # print(hidden_num) 144 | 145 | maxpool_channel = tf.reduce_max(tf.reduce_max(inputs, axis=1, keep_dims=True), axis=2, keep_dims=True) 146 | avgpool_channel = tf.reduce_mean(tf.reduce_mean(inputs, axis=1, keep_dims=True), axis=2, keep_dims=True) 147 | 148 | # print('----------------------------------') 149 | # print(maxpool_channel.shape) 150 | # print(avgpool_channel.shape) 151 | 152 | # 上面全局池化结果为batsize * 1 * 1 * channel,它这个拉平输入到全连接层 153 | # 这个拉平,它会保留batsize,所以结果是[batsize,channel] 154 | maxpool_channel = tf.layers.Flatten()(maxpool_channel) 155 | avgpool_channel = tf.layers.Flatten()(avgpool_channel) 156 | 157 | mlp_1_max = tf.layers.dense(inputs=maxpool_channel, units=int(hidden_num * reduction_ratio), name="mlp_1", 158 | reuse=None, activation=tf.nn.elu) ####relu 159 | mlp_2_max = tf.layers.dense(inputs=mlp_1_max, units=hidden_num, name="mlp_2", reuse=None) 160 | mlp_2_max = tf.reshape(mlp_2_max, [batch_size, 1, 1, hidden_num]) 161 | 162 | mlp_1_avg = tf.layers.dense(inputs=avgpool_channel, units=int(hidden_num * reduction_ratio), name="mlp_1", 163 | reuse=True, activation=tf.nn.elu) 164 | mlp_2_avg = tf.layers.dense(inputs=mlp_1_avg, units=hidden_num, name="mlp_2", reuse=True) 165 | mlp_2_avg = tf.reshape(mlp_2_avg, [batch_size, 1, 1, hidden_num]) 166 | 167 | channel_attention = tf.nn.sigmoid(mlp_2_max + mlp_2_avg) 168 | channel_refined_feature = inputs * channel_attention 169 | 170 | maxpool_spatial = tf.reduce_max(inputs, axis=3, keep_dims=True) 171 | avgpool_spatial = tf.reduce_mean(inputs, axis=3, keep_dims=True) 172 | max_avg_pool_spatial = tf.concat([maxpool_spatial, avgpool_spatial], axis=3) 173 | conv_layer = tf.layers.conv2d(inputs=max_avg_pool_spatial, filters=1, kernel_size=(7, 7), padding="same", 174 | activation=None) 175 | spatial_attention = tf.nn.sigmoid(conv_layer) 176 | 177 | refined_feature = channel_refined_feature * spatial_attention 178 | # print(refined_feature.shape) 179 | 180 | return refined_feature 181 | 182 | ########################################################################## 183 | def resconv(self, x, num_layers, stride, attention=False): 184 | do_proj = tf.shape(x)[3] != num_layers or stride == 2 185 | shortcut = [] 186 | # print(num_layers) 187 | conv1 = self.conv(x, num_layers, 1, 1) 188 | # if num_layers == 128: 189 | # print(conv1) 190 | if stride == 2 and num_layers != 64: 191 | rate_dia = np.int(num_layers / 64) 192 | conv2 = self.conv_dia(conv1, num_layers, 3, stride, rate_dia) 193 | 194 | else: 195 | conv2 = self.conv(conv1, num_layers, 3, stride) 196 | conv3 = self.conv(conv2, 4 * num_layers, 1, 1, None) 197 | if attention == True: ##CBAM-Net layer 198 | # reduction_ratio = 16 ######## 199 | # print(num_layers) 200 | if num_layers == 64: 201 | conv3 = self.cbam_module(conv3, out_dim=num_layers, reduction_ratio=0.5, name='64') 202 | if num_layers == 128: 203 | conv3 = self.cbam_module(conv3, out_dim=num_layers, reduction_ratio=0.5, name='128') 204 | if num_layers == 256: 205 | conv3 = self.cbam_module(conv3, out_dim=num_layers, reduction_ratio=0.5, name='256') 206 | if num_layers == 512: 207 | conv3 = self.cbam_module(conv3, out_dim=num_layers, reduction_ratio=0.5, name='512') 208 | 209 | if do_proj: 210 | if stride == 2 and num_layers != 64: 211 | rate_dia = np.int(num_layers / 64) 212 | 213 | shortcut = self.conv_dia(x, 4 * num_layers, 1, stride, rate_dia) 214 | # print('====ok') 215 | # print(shortcut) 216 | # print(conv2) 217 | else: 218 | shortcut = self.conv(x, 4 * num_layers, 1, stride, None) ###### print('=====ok') 219 | else: 220 | shortcut = x 221 | return tf.nn.elu(conv3 + shortcut) 222 | 223 | def resblock(self, x, num_layers, num_blocks, attention=False): 224 | out = x 225 | for i in range(num_blocks - 1): 226 | out = self.resconv(out, num_layers, 1, attention) 227 | # print('++++++++++++++++++++++++++++++++++') 228 | out = self.resconv(out, num_layers, 2, attention) 229 | return out 230 | 231 | def upconv(self, x, num_out_layers, kernel_size, scale): 232 | upsample = self.upsample_nn(x, scale) 233 | conv = self.conv(upsample, num_out_layers, kernel_size, 1) 234 | return conv 235 | 236 | def deconv(self, x, num_out_layers, kernel_size, scale): 237 | p_x = tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]]) 238 | conv = slim.conv2d_transpose(p_x, num_out_layers, kernel_size, scale, 'SAME') 239 | return conv[:, 3:-1, 3:-1, :] 240 | 241 | def build_vgg(self): 242 | # set convenience functions 243 | conv = self.conv 244 | if self.params.use_deconv: 245 | upconv = self.deconv 246 | else: 247 | upconv = self.upconv 248 | 249 | with tf.variable_scope('encoder'): 250 | conv1 = self.conv_block(self.model_input, 32, 7) # H/2 251 | conv2 = self.conv_block(conv1, 64, 5) # H/4 252 | conv3 = self.conv_block(conv2, 128, 3) # H/8 253 | conv4 = self.conv_block(conv3, 256, 3) # H/16 254 | conv5 = self.conv_block(conv4, 512, 3) # H/32 255 | conv6 = self.conv_block(conv5, 512, 3) # H/64 256 | conv7 = self.conv_block(conv6, 512, 3) # H/128 257 | 258 | with tf.variable_scope('skips'): 259 | skip1 = conv1 260 | skip2 = conv2 261 | skip3 = conv3 262 | skip4 = conv4 263 | skip5 = conv5 264 | skip6 = conv6 265 | 266 | with tf.variable_scope('decoder'): 267 | upconv7 = upconv(conv7, 512, 3, 2) # H/64 268 | concat7 = tf.concat([upconv7, skip6], 3) 269 | iconv7 = conv(concat7, 512, 3, 1) 270 | 271 | upconv6 = upconv(iconv7, 512, 3, 2) # H/32 272 | concat6 = tf.concat([upconv6, skip5], 3) 273 | iconv6 = conv(concat6, 512, 3, 1) 274 | 275 | upconv5 = upconv(iconv6, 256, 3, 2) # H/16 276 | concat5 = tf.concat([upconv5, skip4], 3) 277 | iconv5 = conv(concat5, 256, 3, 1) 278 | 279 | upconv4 = upconv(iconv5, 128, 3, 2) # H/8 280 | concat4 = tf.concat([upconv4, skip3], 3) 281 | iconv4 = conv(concat4, 128, 3, 1) 282 | self.disp4 = self.get_disp(iconv4) 283 | udisp4 = self.upsample_nn(self.disp4, 2) 284 | 285 | upconv3 = upconv(iconv4, 64, 3, 2) # H/4 286 | concat3 = tf.concat([upconv3, skip2, udisp4], 3) 287 | iconv3 = conv(concat3, 64, 3, 1) 288 | self.disp3 = self.get_disp(iconv3) 289 | udisp3 = self.upsample_nn(self.disp3, 2) 290 | 291 | upconv2 = upconv(iconv3, 32, 3, 2) # H/2 292 | concat2 = tf.concat([upconv2, skip1, udisp3], 3) 293 | iconv2 = conv(concat2, 32, 3, 1) 294 | self.disp2 = self.get_disp(iconv2) 295 | udisp2 = self.upsample_nn(self.disp2, 2) 296 | 297 | upconv1 = upconv(iconv2, 16, 3, 2) # H 298 | concat1 = tf.concat([upconv1, udisp2], 3) 299 | iconv1 = conv(concat1, 16, 3, 1) 300 | self.disp1 = self.get_disp(iconv1) 301 | 302 | def build_resnet50(self): 303 | # set convenience functions 304 | conv = self.conv 305 | if self.params.use_deconv: 306 | upconv = self.deconv 307 | else: 308 | upconv = self.upconv 309 | 310 | with tf.variable_scope('encoder'): 311 | conv1 = conv(self.model_input, 64, 7, 2) # H/2 - 64D 312 | pool1 = self.maxpool(conv1, 3) # H/4 - 64D 313 | conv2 = self.resblock(pool1, 64, 3, True) # H/8 - 256D 314 | conv3 = self.resblock(conv2, 128, 4, True) # H/16 - 512D ##True 315 | conv4 = self.resblock(conv3, 256, 6, True) # H/32 - 1024D 316 | conv5 = self.resblock(conv4, 512, 3, True) # H/64 - 2048D 317 | 318 | with tf.variable_scope('skips'): 319 | skip1 = conv1 320 | skip2 = pool1 321 | skip3 = conv2 322 | skip4 = conv3 323 | skip5 = conv4 324 | 325 | # DECODING 326 | with tf.variable_scope('decoder'): 327 | 328 | # upconv5 = upconv(iconv6, 256, 3, 2) #H/1 329 | # print("ok") 330 | # print(conv5) 331 | # print(conv4) 332 | # print(conv3) 333 | # print(conv2) 334 | # print(skip2) 335 | # concat5 = tf.concat([conv5, skip5, skip4], 3) 336 | # iconv5 = conv(concat5, 256, 3, 1) 337 | # 338 | # upconv4 = upconv(iconv5, 128, 3, 2) #H/8 339 | concat4 = tf.concat([conv5, skip5, skip4, skip3], 3) 340 | iconv4 = conv(concat4, 128, 3, 1) 341 | self.disp4 = self.get_disp(iconv4) 342 | udisp4 = self.upsample_nn(self.disp4, 2) 343 | ############ 344 | self.disp4_up = self.upsample_nn(self.disp4, 8) 345 | # upconv4_4 = upconv(iconv4, 128, 3, 8) 346 | upconv4_4 = self.upsample_nn(iconv4, 8) 347 | 348 | upconv3 = upconv(iconv4, 64, 3, 2) # H/4 349 | concat3 = tf.concat([upconv3, skip2, udisp4], 3) 350 | iconv3 = conv(concat3, 64, 3, 1) 351 | self.disp3 = self.get_disp(iconv3) 352 | udisp3 = self.upsample_nn(self.disp3, 2) 353 | ############# 354 | self.disp3_up = self.upsample_nn(self.disp3, 4) 355 | # upconv3_3 = upconv(iconv3, 64, 3, 4) 356 | upconv3_3 = self.upsample_nn(iconv3, 4) 357 | 358 | upconv2 = upconv(iconv3, 32, 3, 2) # H/2 359 | concat2 = tf.concat([upconv2, skip1, udisp3], 3) 360 | iconv2 = conv(concat2, 32, 3, 1) 361 | self.disp2 = self.get_disp(iconv2) 362 | udisp2 = self.upsample_nn(self.disp2, 2) 363 | ########### 364 | self.disp2_up = self.upsample_nn(self.disp2, 2) 365 | # upconv2_2 = upconv(iconv2, 32, 3, 2) 366 | upconv2_2 = self.upsample_nn(iconv2, 2) 367 | 368 | # upconv1 = upconv(iconv2, 16, 3, 2) #H 369 | # concat1 = tf.concat([upconv1, udisp2], 3) 370 | # iconv1 = conv(concat1, 16, 3, 1) 371 | # self.disp1 = self.get_disp(iconv1) 372 | print('===================ok') 373 | # print(upconv2_2) 374 | # print(upconv3_3) 375 | # print(upconv4_4) 376 | upconv2_3_4 = conv(tf.concat([upconv4_4, upconv3_3, upconv2_2], 3), 16, 3, 1) 377 | # upconv2_3_4 = conv(tf.concat([ upconv3_3, upconv3_3], 3), 16, 3, 1) 378 | # upconv2_3_4 = upconv2_2 379 | 380 | upconv1 = upconv(iconv2, 16, 3, 2) # H 381 | concat1 = tf.concat([upconv1, udisp2, upconv2_3_4], 3) 382 | # concat1 = tf.concat([upconv1, udisp2], 3) 383 | # iconv1 = conv(concat1, 16, 3, 1) 384 | iconv1 = conv(concat1, 32, 3, 1) 385 | iconv1 = conv(iconv1, 32, 3, 1) 386 | iconv1 = conv(iconv1, 16, 3, 1) 387 | self.disp1 = self.get_disp(iconv1) 388 | 389 | def build_model(self): 390 | with slim.arg_scope([slim.conv2d, slim.conv2d_transpose], activation_fn=tf.nn.elu): 391 | with tf.variable_scope('model', reuse=self.reuse_variables): 392 | 393 | self.left_pyramid = self.scale_pyramid(self.left, 4) 394 | if self.mode == 'train': 395 | self.right_pyramid = self.scale_pyramid(self.right, 4) 396 | 397 | if self.params.do_stereo: 398 | self.model_input = tf.concat([self.left, self.right], 3) 399 | else: 400 | self.model_input = self.left 401 | 402 | # build model 403 | if self.params.encoder == 'vgg': 404 | self.build_vgg() 405 | elif self.params.encoder == 'resnet50': 406 | self.build_resnet50() 407 | else: 408 | return None 409 | 410 | def build_outputs(self): 411 | # STORE DISPARITIES 412 | with tf.variable_scope('disparities'): 413 | self.disp_est = [self.disp1, self.disp2, self.disp3, self.disp4] 414 | self.disp_left_est = [tf.expand_dims(d[:, :, :, 0], 3) for d in self.disp_est] 415 | self.disp_right_est = [tf.expand_dims(d[:, :, :, 1], 3) for d in self.disp_est] 416 | 417 | # self.left_est = [self.generate_image_left(self.right_pyramid[i], self.disp_left_est[i]) for i in range(4)] 418 | # self.l1_left = [tf.abs( self.left_est[i] - self.left_pyramid[i]) for i in range(4)] 419 | 420 | if self.mode == 'test': 421 | return 422 | 423 | # GENERATE IMAGES 424 | with tf.variable_scope('images'): 425 | self.left_est = [self.generate_image_left(self.right_pyramid[i], self.disp_left_est[i]) for i in range(4)] 426 | self.right_est = [self.generate_image_right(self.left_pyramid[i], self.disp_right_est[i]) for i in range(4)] 427 | 428 | # LR CONSISTENCY 429 | with tf.variable_scope('left-right'): 430 | self.right_to_left_disp = [self.generate_image_left(self.disp_right_est[i], self.disp_left_est[i]) for i in 431 | range(4)] 432 | self.left_to_right_disp = [self.generate_image_right(self.disp_left_est[i], self.disp_right_est[i]) for i in 433 | range(4)] 434 | 435 | # DISPARITY SMOOTHNESS 436 | with tf.variable_scope('smoothness'): 437 | self.disp_left_smoothness = self.get_disparity_smoothness(self.disp_left_est, self.left_pyramid) 438 | self.disp_right_smoothness = self.get_disparity_smoothness(self.disp_right_est, self.right_pyramid) 439 | 440 | def build_losses(self): 441 | with tf.variable_scope('losses', reuse=self.reuse_variables): 442 | # IMAGE RECONSTRUCTION 443 | # L1 444 | self.l1_left = [tf.abs(self.left_est[i] - self.left_pyramid[i]) for i in range(4)] 445 | self.l1_reconstruction_loss_left = [tf.reduce_mean(l) for l in self.l1_left] 446 | self.l1_right = [tf.abs(self.right_est[i] - self.right_pyramid[i]) for i in range(4)] 447 | self.l1_reconstruction_loss_right = [tf.reduce_mean(l) for l in self.l1_right] 448 | 449 | # SSIM 450 | self.ssim_left = [self.SSIM(self.left_est[i], self.left_pyramid[i]) for i in range(4)] 451 | self.ssim_loss_left = [tf.reduce_mean(s) for s in self.ssim_left] 452 | self.ssim_right = [self.SSIM(self.right_est[i], self.right_pyramid[i]) for i in range(4)] 453 | self.ssim_loss_right = [tf.reduce_mean(s) for s in self.ssim_right] 454 | 455 | # WEIGTHED SUM 456 | self.image_loss_right = [ 457 | self.params.alpha_image_loss * self.ssim_loss_right[i] + (1 - self.params.alpha_image_loss) * 458 | self.l1_reconstruction_loss_right[i] for i in range(4)] 459 | self.image_loss_left = [ 460 | self.params.alpha_image_loss * self.ssim_loss_left[i] + (1 - self.params.alpha_image_loss) * 461 | self.l1_reconstruction_loss_left[i] for i in range(4)] 462 | self.image_loss = tf.add_n(self.image_loss_left + self.image_loss_right) 463 | 464 | # DISPARITY SMOOTHNESS 465 | self.disp_left_loss = [tf.reduce_mean(tf.abs(self.disp_left_smoothness[i])) / 2 ** i for i in range(4)] 466 | self.disp_right_loss = [tf.reduce_mean(tf.abs(self.disp_right_smoothness[i])) / 2 ** i for i in range(4)] 467 | self.disp_gradient_loss = tf.add_n(self.disp_left_loss + self.disp_right_loss) 468 | 469 | # LR CONSISTENCY 470 | self.lr_left_loss = [tf.reduce_mean(tf.abs(self.right_to_left_disp[i] - self.disp_left_est[i])) for i in 471 | range(4)] 472 | self.lr_right_loss = [tf.reduce_mean(tf.abs(self.left_to_right_disp[i] - self.disp_right_est[i])) for i in 473 | range(4)] 474 | self.lr_loss = tf.add_n(self.lr_left_loss + self.lr_right_loss) 475 | 476 | # TOTAL LOSS 477 | self.total_loss = self.image_loss + self.params.disp_gradient_loss_weight * self.disp_gradient_loss + self.params.lr_loss_weight * self.lr_loss 478 | 479 | def build_summaries(self): 480 | # SUMMARIES 481 | with tf.device('/cpu:0'): 482 | for i in range(4): 483 | tf.summary.scalar('ssim_loss_' + str(i), self.ssim_loss_left[i] + self.ssim_loss_right[i], 484 | collections=self.model_collection) 485 | tf.summary.scalar('l1_loss_' + str(i), 486 | self.l1_reconstruction_loss_left[i] + self.l1_reconstruction_loss_right[i], 487 | collections=self.model_collection) 488 | tf.summary.scalar('image_loss_' + str(i), self.image_loss_left[i] + self.image_loss_right[i], 489 | collections=self.model_collection) 490 | tf.summary.scalar('disp_gradient_loss_' + str(i), self.disp_left_loss[i] + self.disp_right_loss[i], 491 | collections=self.model_collection) 492 | tf.summary.scalar('lr_loss_' + str(i), self.lr_left_loss[i] + self.lr_right_loss[i], 493 | collections=self.model_collection) 494 | tf.summary.image('disp_left_est_' + str(i), self.disp_left_est[i], max_outputs=4, 495 | collections=self.model_collection) 496 | tf.summary.image('disp_right_est_' + str(i), self.disp_right_est[i], max_outputs=4, 497 | collections=self.model_collection) 498 | 499 | if self.params.full_summary: 500 | tf.summary.image('left_est_' + str(i), self.left_est[i], max_outputs=4, 501 | collections=self.model_collection) 502 | tf.summary.image('right_est_' + str(i), self.right_est[i], max_outputs=4, 503 | collections=self.model_collection) 504 | tf.summary.image('ssim_left_' + str(i), self.ssim_left[i], max_outputs=4, 505 | collections=self.model_collection) 506 | tf.summary.image('ssim_right_' + str(i), self.ssim_right[i], max_outputs=4, 507 | collections=self.model_collection) 508 | tf.summary.image('l1_left_' + str(i), self.l1_left[i], max_outputs=4, 509 | collections=self.model_collection) 510 | tf.summary.image('l1_right_' + str(i), self.l1_right[i], max_outputs=4, 511 | collections=self.model_collection) 512 | 513 | if self.params.full_summary: 514 | tf.summary.image('left', self.left, max_outputs=4, collections=self.model_collection) 515 | tf.summary.image('right', self.right, max_outputs=4, collections=self.model_collection) 516 | -------------------------------------------------------------------------------- /code/ssim_loss_function.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def _tf_fspecial_gauss(size, sigma): 5 | """Function to mimic the 'fspecial' gaussian MATLAB function 6 | """ 7 | x_data, y_data = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1] 8 | 9 | x_data = np.expand_dims(x_data, axis=-1) 10 | x_data = np.expand_dims(x_data, axis=-1) 11 | 12 | y_data = np.expand_dims(y_data, axis=-1) 13 | y_data = np.expand_dims(y_data, axis=-1) 14 | 15 | x = tf.constant(x_data, dtype=tf.float32) 16 | y = tf.constant(y_data, dtype=tf.float32) 17 | 18 | g = tf.exp(-((x**2 + y**2)/(2.0*sigma**2))) 19 | return g / tf.reduce_sum(g) 20 | 21 | 22 | def SSIM_LOSS(img1, img2, size=11, sigma=1.5): 23 | window = _tf_fspecial_gauss(size, sigma) # window shape [size, size] 24 | K1 = 0.01 25 | K2 = 0.03 26 | L = 1 # depth of image (255 in case the image has a differnt scale) 27 | C1 = (K1*L)**2 28 | C2 = (K2*L)**2 29 | mu1 = tf.nn.conv2d(img1, window, strides=[1,1,1,1], padding='VALID') 30 | mu2 = tf.nn.conv2d(img2, window, strides=[1,1,1,1],padding='VALID') 31 | mu1_sq = mu1*mu1 32 | mu2_sq = mu2*mu2 33 | mu1_mu2 = mu1*mu2 34 | sigma1_sq = tf.nn.conv2d(img1*img1, window, strides=[1,1,1,1],padding='VALID') - mu1_sq 35 | sigma2_sq = tf.nn.conv2d(img2*img2, window, strides=[1,1,1,1],padding='VALID') - mu2_sq 36 | sigma12 = tf.nn.conv2d(img1*img2, window, strides=[1,1,1,1],padding='VALID') - mu1_mu2 37 | 38 | value = (2.0*sigma12 + C2)/(sigma1_sq + sigma2_sq + C2) 39 | value = tf.reduce_mean(value) 40 | return value -------------------------------------------------------------------------------- /code/train_recons.py: -------------------------------------------------------------------------------- 1 | # Train the DenseFuse Net 2 | 3 | from __future__ import print_function 4 | 5 | import scipy.io as scio 6 | import numpy as np 7 | import tensorflow as tf 8 | import cv2 9 | import matplotlib.pyplot as plt 10 | import attention 11 | import decoder 12 | from ssim_loss_function import SSIM_LOSS 13 | from densefuse_net import DenseFuseNet 14 | from utils import get_train_images 15 | 16 | STYLE_LAYERS = ('relu1_1', 'relu2_1', 'relu3_1', 'relu4_1') 17 | 18 | TRAINING_IMAGE_SHAPE = (256, 256, 1) # (height, width, color_channels) 19 | TRAINING_IMAGE_SHAPE_OR = (256, 256, 1) # (height, width, color_channels) 20 | 21 | LEARNING_RATE = 1e-4 22 | LEARNING_RATE_2 = 1e-4 23 | 24 | EPSILON = 1e-5 25 | 26 | 27 | 28 | def train_recons(original_imgs_path, validatioin_imgs_path, save_path, model_pre_path, ssim_weight, EPOCHES_set, BATCH_SIZE, debug=False, logging_period=1): 29 | if debug: 30 | from datetime import datetime 31 | start_time = datetime.now() 32 | EPOCHS = EPOCHES_set 33 | print("EPOCHES : ", EPOCHS) #EPOCHS = 4 遍历整个数据集的次数,训练网络一共要执行n*4次 34 | print("BATCH_SIZE: ", BATCH_SIZE) #BATCH_SIZE = 2 每个Batch有2个样本,共n/2个Batch,每处理两个样本模型权重就更新 35 | 36 | num_val = len(validatioin_imgs_path) #测试集样本个数 37 | num_imgs = len(original_imgs_path) #训练集样本个数 38 | # num_imgs = 100 39 | original_imgs_path = original_imgs_path[:num_imgs] #迷惑行为,自己赋给自己 40 | mod = num_imgs % BATCH_SIZE #Batch个数 41 | 42 | print('Train images number %d.\n' % num_imgs) 43 | print('Train images samples %s.\n' % str(num_imgs / BATCH_SIZE)) 44 | 45 | if mod > 0: 46 | print('Train set has been trimmed %d samples...\n' % mod) 47 | original_imgs_path = original_imgs_path[:-mod] #original_imags_path 数组移除最后两个 48 | 49 | # get the traing image shape 50 | #训练图像的长宽及通道数 255,255,1 51 | HEIGHT, WIDTH, CHANNELS = TRAINING_IMAGE_SHAPE 52 | INPUT_SHAPE = (BATCH_SIZE, HEIGHT, WIDTH, CHANNELS) #定义元组,意义不明 53 | 54 | HEIGHT_OR, WIDTH_OR, CHANNELS_OR = TRAINING_IMAGE_SHAPE_OR 55 | INPUT_SHAPE_OR = (BATCH_SIZE, HEIGHT_OR, WIDTH_OR, CHANNELS_OR) #OR是什么意思,意义不明 56 | 57 | # create the graph 58 | with tf.Graph().as_default(), tf.Session() as sess: 59 | original = tf.placeholder(tf.float32, shape=INPUT_SHAPE_OR, name='original') 60 | #神经网络构建graph的时候在模型中的占位,只分配必要的内存,运行模型时通过feed_dict()向占位符喂入数据 61 | #第一个参数,数据类型,常用tf.float32,tf.float64 62 | #第二个参数,数据形状,矩阵形状,图像的长宽及通道数 63 | #第三个参数,名称 64 | #返回Tensor类型 65 | source = original #迷惑行为,意义不明 66 | 67 | print('source :', source.shape) 68 | print('original:', original.shape) 69 | 70 | # create the deepfuse net (encoder and decoder) 71 | #创建深度学习网络 72 | dfn = DenseFuseNet(model_pre_path) #这里的model_pre_path是自己设置的模型参数,默认是None,若不为None则起始训练的参数为设置的文件 73 | generated_img = dfn.transform_recons(source) #输出图像 74 | print('generate:', generated_img.shape) 75 | 76 | ######################################################################################### 77 | # COST FUNCTION 部分 78 | ssim_loss_value = SSIM_LOSS(original, generated_img) #计算SSIM 79 | pixel_loss = tf.reduce_sum(tf.square(original - generated_img)) 80 | pixel_loss = pixel_loss/(BATCH_SIZE*HEIGHT*WIDTH) #计算pixel loss 81 | ssim_loss = 1 - ssim_loss_value #SSIM loss数值 82 | 83 | loss = ssim_weight*ssim_loss + pixel_loss #整体loss 84 | #train_op = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss) #自适应矩估计(梯度下降的一种方法) 85 | train_op = tf.train.AdamOptimizer(LEARNING_RATE_2).minimize(loss) # 自适应矩估计(梯度下降的一种方法) 86 | ########################################################################################## 87 | 88 | sess.run(tf.global_variables_initializer()) 89 | 90 | # saver = tf.train.Saver() 91 | saver = tf.train.Saver(keep_checkpoint_every_n_hours=1) 92 | 93 | # ** Start Training ** 94 | step = 0 95 | count_loss = 0 96 | n_batches = int(len(original_imgs_path) // BATCH_SIZE) 97 | val_batches = int(len(validatioin_imgs_path) // BATCH_SIZE) 98 | 99 | if debug: 100 | elapsed_time = datetime.now() - start_time 101 | print('\nElapsed time for preprocessing before actually train the model: %s' % elapsed_time) 102 | print('Now begin to train the model...\n') 103 | start_time = datetime.now() 104 | 105 | Loss_all = [i for i in range(EPOCHS * n_batches)] 106 | Loss_ssim = [i for i in range(EPOCHS * n_batches)] 107 | Loss_pixel = [i for i in range(EPOCHS * n_batches)] 108 | Val_ssim_data = [i for i in range(EPOCHS * n_batches)] 109 | Val_pixel_data = [i for i in range(EPOCHS * n_batches)] 110 | for epoch in range(EPOCHS): 111 | 112 | np.random.shuffle(original_imgs_path) 113 | 114 | for batch in range(n_batches): 115 | # retrive a batch of content and style images 116 | 117 | original_path = original_imgs_path[batch*BATCH_SIZE:(batch*BATCH_SIZE + BATCH_SIZE)] 118 | original_batch = get_train_images(original_path, crop_height=HEIGHT, crop_width=WIDTH, flag=False) 119 | original_batch = original_batch.reshape([BATCH_SIZE, 256, 256, 1]) 120 | 121 | # print('original_batch shape final:', original_batch.shape) 122 | 123 | # run the training step 124 | sess.run(train_op, feed_dict={original: original_batch}) 125 | step += 1 126 | if debug: 127 | is_last_step = (epoch == EPOCHS - 1) and (batch == n_batches - 1) 128 | 129 | if is_last_step or step % logging_period == 0: 130 | elapsed_time = datetime.now() - start_time 131 | _ssim_loss, _loss, _p_loss = sess.run([ssim_loss, loss, pixel_loss], feed_dict={original: original_batch}) 132 | Loss_all[count_loss] = _loss 133 | Loss_ssim[count_loss] = _ssim_loss 134 | Loss_pixel[count_loss] = _p_loss 135 | print('epoch: %d/%d, step: %d, total loss: %s, elapsed time: %s' % (epoch, EPOCHS, step, _loss, elapsed_time)) 136 | print('p_loss: %s, ssim_loss: %s ,w_ssim_loss: %s ' % (_p_loss, _ssim_loss, ssim_weight * _ssim_loss)) 137 | 138 | # calculate the accuracy rate for 1000 images, every 100 steps 139 | val_ssim_acc = 0 140 | val_pixel_acc = 0 141 | np.random.shuffle(validatioin_imgs_path) 142 | val_start_time = datetime.now() 143 | for v in range(val_batches): 144 | val_original_path = validatioin_imgs_path[v * BATCH_SIZE:(v * BATCH_SIZE + BATCH_SIZE)] 145 | val_original_batch = get_train_images(val_original_path, crop_height=HEIGHT, crop_width=WIDTH,flag=False) 146 | val_original_batch = val_original_batch.reshape([BATCH_SIZE, 256, 256, 1]) 147 | val_ssim, val_pixel = sess.run([ssim_loss, pixel_loss], feed_dict={original: val_original_batch}) 148 | val_ssim_acc = val_ssim_acc + (1 - val_ssim) 149 | val_pixel_acc = val_pixel_acc + val_pixel 150 | Val_ssim_data[count_loss] = val_ssim_acc/val_batches 151 | Val_pixel_data[count_loss] = val_pixel_acc / val_batches 152 | val_es_time = datetime.now() - val_start_time 153 | print('validation value, SSIM: %s, Pixel: %s, elapsed time: %s' % (val_ssim_acc/val_batches, val_pixel_acc / val_batches, val_es_time)) 154 | print('------------------------------------------------------------------------------') 155 | count_loss += 1 156 | 157 | 158 | # ** Done Training & Save the model ** 159 | saver.save(sess, save_path) 160 | #---------------------------------------------------------------------------------------------------------------- 161 | loss_data = Loss_all[:count_loss] 162 | scio.savemat('/data/ljy/paper_again/19-11-20-final/loss/DeepDenseLossData' + str(ssim_weight) + '.mat', 163 | {'loss': loss_data}) 164 | 165 | loss_ssim_data = Loss_ssim[:count_loss] 166 | scio.savemat('/data/ljy/paper_again/19-11-20-final/loss/DeepDenseLossSSIMData' + str( 167 | ssim_weight) + '.mat', {'loss_ssim': loss_ssim_data}) 168 | 169 | loss_pixel_data = Loss_pixel[:count_loss] 170 | scio.savemat('/data/ljy/paper_again/19-11-20-final/loss/DeepDenseLossPixelData.mat' + str( 171 | ssim_weight) + '', {'loss_pixel': loss_pixel_data}) 172 | 173 | validation_ssim_data = Val_ssim_data[:count_loss] 174 | scio.savemat('/data/ljy/paper_again/19-11-20-final/val/Validation_ssim_Data.mat' + str( 175 | ssim_weight) + '', {'val_ssim': validation_ssim_data}) 176 | 177 | validation_pixel_data = Val_pixel_data[:count_loss] 178 | scio.savemat('/data/ljy/paper_again/19-11-20-final/val/Validation_pixel_Data.mat' + str( 179 | ssim_weight) + '', {'val_pixel': validation_pixel_data}) 180 | #---------------------------------------------------------------------------------------------------- 181 | if debug: 182 | elapsed_time = datetime.now() - start_time 183 | print('Done training! Elapsed time: %s' % elapsed_time) 184 | print('Model is saved to: %s' % save_path) 185 | 186 | def train_recons_a(original_imgs_path, validatioin_imgs_path, save_path_a, model_pre_path_a, ssim_weight_a, EPOCHES_set, BATCH_SIZE,MODEL_SAVE_PATHS, debug=False, logging_period=1): 187 | if debug: 188 | from datetime import datetime 189 | start_time = datetime.now() 190 | EPOCHS = EPOCHES_set 191 | print("EPOCHES : ", EPOCHS) # EPOCHS = 4 遍历整个数据集的次数,训练网络一共要执行n*4次 192 | print("BATCH_SIZE: ", BATCH_SIZE) # BATCH_SIZE = 2 每个Batch有2个样本,共n/2个Batch,每处理两个样本模型权重就更新 193 | 194 | num_val = len(validatioin_imgs_path) # 测试集样本个数 195 | num_imgs = len(original_imgs_path) # 训练集样本个数 196 | # num_imgs = 100 197 | original_imgs_path = original_imgs_path[:num_imgs] # 迷惑行为,自己赋给自己 198 | mod = num_imgs % BATCH_SIZE # Batch个数 199 | 200 | print('Train images number %d.\n' % num_imgs) 201 | print('Train images samples %s.\n' % str(num_imgs / BATCH_SIZE)) 202 | 203 | if mod > 0: 204 | print('Train set has been trimmed %d samples...\n' % mod) 205 | original_imgs_path = original_imgs_path[:-mod] # original_imags_path 数组移除最后两个 206 | 207 | # get the traing image shape 208 | # 训练图像的长宽及通道数 255,255,1 209 | HEIGHT, WIDTH, CHANNELS = TRAINING_IMAGE_SHAPE 210 | INPUT_SHAPE = (BATCH_SIZE, HEIGHT, WIDTH, CHANNELS) # 定义元组,意义不明 211 | 212 | HEIGHT_OR, WIDTH_OR, CHANNELS_OR = TRAINING_IMAGE_SHAPE_OR 213 | INPUT_SHAPE_OR = (BATCH_SIZE, HEIGHT_OR, WIDTH_OR, CHANNELS_OR) # OR是什么意思,意义不明 214 | 215 | # create the graph 216 | with tf.Graph().as_default(), tf.Session() as sess: 217 | original = tf.placeholder(tf.float32, shape=INPUT_SHAPE_OR, name='original') 218 | attention_map = tf.placeholder(tf.float32, shape=INPUT_SHAPE_OR, name='attention') 219 | # 神经网络构建graph的时候在模型中的占位,只分配必要的内存,运行模型时通过feed_dict()向占位符喂入数据 220 | # 第一个参数,数据类型,常用tf.float32,tf.float64 221 | # 第二个参数,数据形状,矩阵形状,图像的长宽及通道数 222 | # 第三个参数,名称 223 | # 返回Tensor类型 224 | source = original # 迷惑行为,意义不明 225 | 226 | print('source :', source.shape) 227 | print('original:', original.shape) 228 | 229 | # create the deepfuse net (encoder and decoder) 230 | # 创建深度学习网络 231 | model_pre_path=MODEL_SAVE_PATHS 232 | dfn = DenseFuseNet(model_pre_path) # 这里的model_pre_path是自己设置的模型参数,默认是None,若不为None则起始训练的参数为设置的文件 233 | 234 | atn = attention.Attention(None) 235 | enc, enc_res_block, enc_block, enc_block2 = dfn.transform_encoder(source) 236 | weight_map=atn.get_attention(attention_map) 237 | enc_res_block_6_a= tf.multiply(enc_res_block[6],weight_map) 238 | enc_res_block_9_a=tf.multiply(enc_res_block[9],weight_map) 239 | feature = enc_res_block[0] 240 | mix_indices = (1, 2, 3) 241 | for i in mix_indices: 242 | feature = tf.concat([feature, enc_res_block[i]], 3) 243 | t_decode=feature+0.1*enc_res_block_6_a+0.1*enc_res_block_9_a 244 | generated_img = dfn.transform_decoder(t_decode,enc_block,enc_block2) 245 | print('generate:', generated_img.shape) 246 | ssim_loss_value = SSIM_LOSS(original, generated_img) # 计算SSIM 247 | pixel_loss = tf.reduce_sum(tf.square(original - generated_img)) 248 | pixel_loss = pixel_loss / (BATCH_SIZE * HEIGHT * WIDTH) # 计算pixel loss 249 | ssim_loss = 1 - ssim_loss_value # SSIM loss数值 250 | 251 | loss = ssim_weight_a * ssim_loss + pixel_loss # 整体loss 252 | # train_op = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss) #自适应矩估计(梯度下降的一种方法) 253 | train_op = tf.train.AdamOptimizer(LEARNING_RATE_2).minimize(loss,var_list=atn.weights) # 自适应矩估计(梯度下降的一种方法) 254 | sess.run(tf.global_variables_initializer()) 255 | 256 | # saver = tf.train.Saver() 257 | saver = tf.train.Saver(keep_checkpoint_every_n_hours=1) 258 | 259 | # ** Start Training ** 260 | step = 0 261 | count_loss = 0 262 | n_batches = int(len(original_imgs_path) // BATCH_SIZE) 263 | val_batches = int(len(validatioin_imgs_path) // BATCH_SIZE) 264 | 265 | if debug: 266 | elapsed_time = datetime.now() - start_time 267 | print('\nElapsed time for preprocessing before actually train the model: %s' % elapsed_time) 268 | print('Now begin to train the model...\n') 269 | start_time = datetime.now() 270 | 271 | Loss_all = [i for i in range(EPOCHS * n_batches)] 272 | Loss_ssim = [i for i in range(EPOCHS * n_batches)] 273 | Loss_pixel = [i for i in range(EPOCHS * n_batches)] 274 | Val_ssim_data = [i for i in range(EPOCHS * n_batches)] 275 | Val_pixel_data = [i for i in range(EPOCHS * n_batches)] 276 | for epoch in range(EPOCHS): 277 | 278 | np.random.shuffle(original_imgs_path) 279 | 280 | for batch in range(n_batches): 281 | # retrive a batch of content and style images 282 | 283 | original_path = original_imgs_path[batch * BATCH_SIZE:(batch * BATCH_SIZE + BATCH_SIZE)] 284 | original_batch = get_train_images(original_path, crop_height=HEIGHT, crop_width=WIDTH, flag=False) 285 | original_batch = original_batch.reshape([BATCH_SIZE, 256, 256, 1]) 286 | 287 | # print('original_batch shape final:', original_batch.shape) 288 | # ----------------------------------------------- 289 | imag = sess.run(original, feed_dict={original: original_batch}) 290 | guideFilter_imgs = np.zeros(INPUT_SHAPE_OR) 291 | for i in range(BATCH_SIZE): 292 | input = np.squeeze(imag[i]) 293 | out = atn.Grad(input) 294 | out = np.expand_dims(out, axis=-1) 295 | out[out < 0] = 0 296 | guideFilter_imgs[i] = out 297 | # ---------------------------------------------- 298 | # run the training step 299 | sess.run(train_op, feed_dict={original: original_batch, attention_map:guideFilter_imgs }) 300 | step += 1 301 | if debug: 302 | is_last_step = (epoch == EPOCHS - 1) and (batch == n_batches - 1) 303 | 304 | if is_last_step or step % logging_period == 0: 305 | elapsed_time = datetime.now() - start_time 306 | _ssim_loss, _loss, _p_loss = sess.run([ssim_loss, loss, pixel_loss], 307 | feed_dict={original: original_batch, attention_map: guideFilter_imgs}) 308 | Loss_all[count_loss] = _loss 309 | Loss_ssim[count_loss] = _ssim_loss 310 | Loss_pixel[count_loss] = _p_loss 311 | print('epoch: %d/%d, step: %d, total loss: %s, elapsed time: %s' % ( 312 | epoch, EPOCHS, step, _loss, elapsed_time)) 313 | print('p_loss: %s, ssim_loss: %s ,w_ssim_loss: %s ' % ( 314 | _p_loss, _ssim_loss, ssim_weight_a * _ssim_loss)) 315 | 316 | # calculate the accuracy rate for 1000 images, every 100 steps 317 | val_ssim_acc = 0 318 | val_pixel_acc = 0 319 | np.random.shuffle(validatioin_imgs_path) 320 | val_start_time = datetime.now() 321 | for v in range(val_batches): 322 | val_original_path = validatioin_imgs_path[v * BATCH_SIZE:(v * BATCH_SIZE + BATCH_SIZE)] 323 | val_original_batch = get_train_images(val_original_path, crop_height=HEIGHT, 324 | crop_width=WIDTH, flag=False) 325 | val_original_batch = val_original_batch.reshape([BATCH_SIZE, 256, 256, 1]) 326 | val_ssim, val_pixel = sess.run([ssim_loss, pixel_loss], 327 | feed_dict={original: val_original_batch, attention_map: guideFilter_imgs}) 328 | val_ssim_acc = val_ssim_acc + (1 - val_ssim) 329 | val_pixel_acc = val_pixel_acc + val_pixel 330 | Val_ssim_data[count_loss] = val_ssim_acc / val_batches 331 | Val_pixel_data[count_loss] = val_pixel_acc / val_batches 332 | val_es_time = datetime.now() - val_start_time 333 | print('validation value, SSIM: %s, Pixel: %s, elapsed time: %s' % ( 334 | val_ssim_acc / val_batches, val_pixel_acc / val_batches, val_es_time)) 335 | print('------------------------------------------------------------------------------') 336 | count_loss += 1 337 | 338 | # ** Done Training & Save the model ** 339 | saver.save(sess, save_path_a) 340 | # ---------------------------------------------------------------------------------------------------------------- 341 | loss_data = Loss_all[:count_loss] 342 | scio.savemat('/data/ljy/paper_again/19-11-20-final/model_a/loss/DeepDenseLossData' + str(ssim_weight_a) + '.mat', 343 | {'loss': loss_data}) 344 | 345 | loss_ssim_data = Loss_ssim[:count_loss] 346 | scio.savemat('/data/ljy/paper_again/19-11-20-final/model_a/loss/DeepDenseLossSSIMData' + str( 347 | ssim_weight_a) + '.mat', {'loss_ssim': loss_ssim_data}) 348 | 349 | loss_pixel_data = Loss_pixel[:count_loss] 350 | scio.savemat('/data/ljy/paper_again/19-11-20-final/model_a/loss/DeepDenseLossPixelData.mat' + str( 351 | ssim_weight_a) + '', {'loss_pixel': loss_pixel_data}) 352 | 353 | validation_ssim_data = Val_ssim_data[:count_loss] 354 | scio.savemat('/data/ljy/paper_again/19-11-20-final/model_a/val/Validation_ssim_Data.mat' + str( 355 | ssim_weight_a) + '', {'val_ssim': validation_ssim_data}) 356 | 357 | validation_pixel_data = Val_pixel_data[:count_loss] 358 | scio.savemat('/data/ljy/paper_again/19-11-20-final/model_a/val/Validation_pixel_Data.mat' + str( 359 | ssim_weight_a) + '', {'val_pixel': validation_pixel_data}) 360 | # ---------------------------------------------------------------------------------------------------- 361 | if debug: 362 | elapsed_time = datetime.now() - start_time 363 | print('Done training! Elapsed time: %s' % elapsed_time) 364 | print('Model is saved to: %s' % save_path_a) 365 | -------------------------------------------------------------------------------- /code/utils.py: -------------------------------------------------------------------------------- 1 | # Utility 2 | 3 | import numpy as np 4 | 5 | from os import listdir, mkdir, sep 6 | from os.path import join, exists, splitext 7 | from scipy.misc import imread, imsave, imresize 8 | import skimage 9 | import skimage.io 10 | import skimage.transform 11 | import tensorflow as tf 12 | from PIL import Image 13 | from functools import reduce 14 | 15 | def list_images(directory): 16 | images = [] 17 | for file in listdir(directory): 18 | name = file.lower() 19 | if name.endswith('.png'): 20 | images.append(join(directory, file)) 21 | elif name.endswith('.jpg'): 22 | images.append(join(directory, file)) 23 | elif name.endswith('.jpeg'): 24 | images.append(join(directory, file)) 25 | return images 26 | 27 | 28 | def get_train_images(paths, resize_len=512, crop_height=256, crop_width=256, flag = True): 29 | if isinstance(paths, str): 30 | paths = [paths] 31 | 32 | images = [] 33 | ny = 0 34 | nx = 0 35 | for path in paths: 36 | image = imread(path, mode='L') 37 | # image = imread(path, mode='RGB') 38 | 39 | if flag: 40 | image = np.stack(image, axis=0) 41 | image = np.stack((image, image, image), axis=-1) 42 | else: 43 | image = np.stack(image, axis=0) 44 | image = np.stack(image, axis=-1) 45 | 46 | images.append(image) 47 | images = np.stack(images, axis=-1) 48 | 49 | return images 50 | 51 | 52 | def get_train_images_rgb(path, resize_len=512, crop_height=256, crop_width=256, flag = True): 53 | 54 | # image = imread(path, mode='L') 55 | image = imread(path, mode='RGB') 56 | 57 | return image 58 | 59 | 60 | def get_images(paths, height=None, width=None): 61 | if isinstance(paths, str): 62 | paths = [paths] 63 | 64 | images = [] 65 | for path in paths: 66 | image = imread(path, mode='RGB') 67 | 68 | if height is not None and width is not None: 69 | image = imresize(image, [height, width], interp='nearest') 70 | 71 | images.append(image) 72 | 73 | images = np.stack(images, axis=0) 74 | print('images shape gen:', images.shape) 75 | return images 76 | 77 | 78 | def save_images(paths, datas, save_path, prefix=None, suffix=None): 79 | if isinstance(paths, str): 80 | paths = [paths] 81 | 82 | assert(len(paths) == len(datas)) 83 | 84 | if not exists(save_path): 85 | mkdir(save_path) 86 | 87 | if prefix is None: 88 | prefix = '' 89 | if suffix is None: 90 | suffix = '' 91 | 92 | for i, path in enumerate(paths): 93 | data = datas[i] 94 | # print('data ==>>\n', data) 95 | if data.shape[2] == 1: 96 | data = data.reshape([data.shape[0], data.shape[1]]) 97 | # print('data reshape==>>\n', data) 98 | 99 | name, ext = splitext(path) 100 | name = name.split(sep)[-1] 101 | 102 | path = join(save_path, prefix + suffix + ext) 103 | print('data path==>>', path) 104 | 105 | 106 | # new_im = Image.fromarray(data) 107 | # new_im.show() 108 | 109 | imsave(path, data) 110 | 111 | def get_l2_norm_loss(diffs): 112 | shape = diffs.get_shape().as_list() 113 | size = reduce(lambda x, y: x * y, shape) ** 2 114 | sum_of_squared_diffs = tf.reduce_sum(tf.square(diffs)) 115 | return sum_of_squared_diffs / size -------------------------------------------------------------------------------- /icme2020template .pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ngujiang/Learning-attention-guided-deep-multi-scale-feature-ensemble-for-infrared-and-visible-image-fusion/77ce05bc6f87247b925b83c1b78a457a9e4b6f97/icme2020template .pdf --------------------------------------------------------------------------------