├── .idea
├── densefuse-pytorch.iml
├── encodings.xml
├── inspectionProfiles
│ └── Project_Default.xml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── README.md
├── __pycache__
├── args_fusion.cpython-36.pyc
├── fusion_strategy.cpython-36.pyc
├── net.cpython-36.pyc
└── utils.cpython-36.pyc
├── args_fusion.py
├── fusion_strategy.py
├── images
├── IV_images.zip
└── test-RGB.zip
├── models
├── densefuse_gray.model
├── densefuse_rgb.model
└── loss
│ └── readme.txt
├── net.py
├── pytorch_msssim
├── __init__.py
└── __pycache__
│ └── __init__.cpython-36.pyc
├── test_image.py
├── train_densefuse.py
└── utils.py
/.idea/densefuse-pytorch.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 | 1577450538611
99 |
100 |
101 | 1577450538611
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DenseFuse - Pytorch version
2 |
3 | Published in: IEEE Transactions on Image Processing
4 |
5 | *H. Li, X. J. Wu, “DenseFuse: A Fusion Approach to Infrared and Visible Images,” IEEE Trans. Image Process., vol. 28, no. 5, pp. 2614–2623, May. 2019.*
6 |
7 | - [IEEEXplore](https://ieeexplore.ieee.org/document/8580578)
8 | - [arXiv](https://arxiv.org/abs/1804.08361)
9 |
10 | # Original version(TensorFlow) is available at [here](https://github.com/hli1221/imagefusion_densefuse)
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/__pycache__/args_fusion.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hli1221/densefuse-pytorch/4394b63e9295db1c6b7a5c3664551c90f0605f2b/__pycache__/args_fusion.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/fusion_strategy.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hli1221/densefuse-pytorch/4394b63e9295db1c6b7a5c3664551c90f0605f2b/__pycache__/fusion_strategy.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/net.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hli1221/densefuse-pytorch/4394b63e9295db1c6b7a5c3664551c90f0605f2b/__pycache__/net.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hli1221/densefuse-pytorch/4394b63e9295db1c6b7a5c3664551c90f0605f2b/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/args_fusion.py:
--------------------------------------------------------------------------------
1 |
2 | class args():
3 |
4 | # training args
5 | epochs = 4 #"number of training epochs, default is 2"
6 | batch_size = 4 #"batch size for training, default is 4"
7 | dataset = "MSCOCO 2014 path"
8 | HEIGHT = 256
9 | WIDTH = 256
10 |
11 | save_model_dir = "models" #"path to folder where trained model will be saved."
12 | save_loss_dir = "models/loss" # "path to folder where trained model will be saved."
13 |
14 | image_size = 256 #"size of training images, default is 256 X 256"
15 | cuda = 1 #"set it to 1 for running on GPU, 0 for CPU"
16 | seed = 42 #"random seed for training"
17 | ssim_weight = [1,10,100,1000,10000]
18 | ssim_path = ['1e0', '1e1', '1e2', '1e3', '1e4']
19 |
20 | lr = 1e-4 #"learning rate, default is 0.001"
21 | lr_light = 1e-4 # "learning rate, default is 0.001"
22 | log_interval = 5 #"number of images after which the training loss is logged, default is 500"
23 | resume = None
24 | resume_auto_en = None
25 | resume_auto_de = None
26 | resume_auto_fn = None
27 |
28 | # for test Final_cat_epoch_9_Wed_Jan__9_04_16_28_2019_1.0_1.0.model
29 | model_path_gray = "./models/densefuse_gray.model"
30 | model_path_rgb = "./models/densefuse_rgb.model"
31 |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/fusion_strategy.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | EPSILON = 1e-10
4 |
5 |
6 | # addition fusion strategy
7 | def addition_fusion(tensor1, tensor2):
8 | return (tensor1 + tensor2)/2
9 |
10 |
11 | # attention fusion strategy, average based on weight maps
12 | def attention_fusion_weight(tensor1, tensor2):
13 | # avg, max, nuclear
14 | f_spatial = spatial_fusion(tensor1, tensor2)
15 | tensor_f = f_spatial
16 | return tensor_f
17 |
18 |
19 | def spatial_fusion(tensor1, tensor2, spatial_type='sum'):
20 | shape = tensor1.size()
21 | # calculate spatial attention
22 | spatial1 = spatial_attention(tensor1, spatial_type)
23 | spatial2 = spatial_attention(tensor2, spatial_type)
24 |
25 | # get weight map, soft-max
26 | spatial_w1 = torch.exp(spatial1) / (torch.exp(spatial1) + torch.exp(spatial2) + EPSILON)
27 | spatial_w2 = torch.exp(spatial2) / (torch.exp(spatial1) + torch.exp(spatial2) + EPSILON)
28 |
29 | spatial_w1 = spatial_w1.repeat(1, shape[1], 1, 1)
30 | spatial_w2 = spatial_w2.repeat(1, shape[1], 1, 1)
31 |
32 | tensor_f = spatial_w1 * tensor1 + spatial_w2 * tensor2
33 |
34 | return tensor_f
35 |
36 |
37 | # spatial attention
38 | def spatial_attention(tensor, spatial_type='sum'):
39 | if spatial_type is 'mean':
40 | spatial = tensor.mean(dim=1, keepdim=True)
41 | elif spatial_type is 'sum':
42 | spatial = tensor.sum(dim=1, keepdim=True)
43 | return spatial
44 |
45 |
46 |
47 |
48 |
--------------------------------------------------------------------------------
/images/IV_images.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hli1221/densefuse-pytorch/4394b63e9295db1c6b7a5c3664551c90f0605f2b/images/IV_images.zip
--------------------------------------------------------------------------------
/images/test-RGB.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hli1221/densefuse-pytorch/4394b63e9295db1c6b7a5c3664551c90f0605f2b/images/test-RGB.zip
--------------------------------------------------------------------------------
/models/densefuse_gray.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hli1221/densefuse-pytorch/4394b63e9295db1c6b7a5c3664551c90f0605f2b/models/densefuse_gray.model
--------------------------------------------------------------------------------
/models/densefuse_rgb.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hli1221/densefuse-pytorch/4394b63e9295db1c6b7a5c3664551c90f0605f2b/models/densefuse_rgb.model
--------------------------------------------------------------------------------
/models/loss/readme.txt:
--------------------------------------------------------------------------------
1 |
2 | The path of loss files.
3 |
--------------------------------------------------------------------------------
/net.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | import fusion_strategy
7 |
8 |
9 | # Convolution operation
10 | class ConvLayer(torch.nn.Module):
11 | def __init__(self, in_channels, out_channels, kernel_size, stride, is_last=False):
12 | super(ConvLayer, self).__init__()
13 | reflection_padding = int(np.floor(kernel_size / 2))
14 | self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
15 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
16 | self.dropout = nn.Dropout2d(p=0.5)
17 | self.is_last = is_last
18 |
19 | def forward(self, x):
20 | out = self.reflection_pad(x)
21 | out = self.conv2d(out)
22 | if self.is_last is False:
23 | # out = F.normalize(out)
24 | out = F.relu(out, inplace=True)
25 | # out = self.dropout(out)
26 | return out
27 |
28 |
29 | # Dense convolution unit
30 | class DenseConv2d(torch.nn.Module):
31 | def __init__(self, in_channels, out_channels, kernel_size, stride):
32 | super(DenseConv2d, self).__init__()
33 | self.dense_conv = ConvLayer(in_channels, out_channels, kernel_size, stride)
34 |
35 | def forward(self, x):
36 | out = self.dense_conv(x)
37 | out = torch.cat([x, out], 1)
38 | return out
39 |
40 |
41 | # Dense Block unit
42 | class DenseBlock(torch.nn.Module):
43 | def __init__(self, in_channels, kernel_size, stride):
44 | super(DenseBlock, self).__init__()
45 | out_channels_def = 16
46 | denseblock = []
47 | denseblock += [DenseConv2d(in_channels, out_channels_def, kernel_size, stride),
48 | DenseConv2d(in_channels+out_channels_def, out_channels_def, kernel_size, stride),
49 | DenseConv2d(in_channels+out_channels_def*2, out_channels_def, kernel_size, stride)]
50 | self.denseblock = nn.Sequential(*denseblock)
51 |
52 | def forward(self, x):
53 | out = self.denseblock(x)
54 | return out
55 |
56 |
57 | # DenseFuse network
58 | class DenseFuse_net(nn.Module):
59 | def __init__(self, input_nc=1, output_nc=1):
60 | super(DenseFuse_net, self).__init__()
61 | denseblock = DenseBlock
62 | nb_filter = [16, 64, 32, 16]
63 | kernel_size = 3
64 | stride = 1
65 |
66 | # encoder
67 | self.conv1 = ConvLayer(input_nc, nb_filter[0], kernel_size, stride)
68 | self.DB1 = denseblock(nb_filter[0], kernel_size, stride)
69 |
70 | # decoder
71 | self.conv2 = ConvLayer(nb_filter[1], nb_filter[1], kernel_size, stride)
72 | self.conv3 = ConvLayer(nb_filter[1], nb_filter[2], kernel_size, stride)
73 | self.conv4 = ConvLayer(nb_filter[2], nb_filter[3], kernel_size, stride)
74 | self.conv5 = ConvLayer(nb_filter[3], output_nc, kernel_size, stride)
75 |
76 | def encoder(self, input):
77 | x1 = self.conv1(input)
78 | x_DB = self.DB1(x1)
79 | return [x_DB]
80 |
81 | # def fusion(self, en1, en2, strategy_type='addition'):
82 | # # addition
83 | # if strategy_type is 'attention_weight':
84 | # # attention weight
85 | # fusion_function = fusion_strategy.attention_fusion_weight
86 | # else:
87 | # fusion_function = fusion_strategy.addition_fusion
88 | #
89 | # f_0 = fusion_function(en1[0], en2[0])
90 | # return [f_0]
91 |
92 | def fusion(self, en1, en2, strategy_type='addition'):
93 | f_0 = (en1[0] + en2[0])/2
94 | return [f_0]
95 |
96 | def decoder(self, f_en):
97 | x2 = self.conv2(f_en[0])
98 | x3 = self.conv3(x2)
99 | x4 = self.conv4(x3)
100 | output = self.conv5(x4)
101 |
102 | return [output]
103 |
--------------------------------------------------------------------------------
/pytorch_msssim/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from math import exp
4 | import numpy as np
5 |
6 |
7 | def gaussian(window_size, sigma):
8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
9 | return gauss/gauss.sum()
10 |
11 |
12 | def create_window(window_size, channel=1):
13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
15 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
16 | return window
17 |
18 |
19 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
20 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
21 | if val_range is None:
22 | if torch.max(img1) > 128:
23 | max_val = 255
24 | else:
25 | max_val = 1
26 |
27 | if torch.min(img1) < -0.5:
28 | min_val = -1
29 | else:
30 | min_val = 0
31 | L = max_val - min_val
32 | else:
33 | L = val_range
34 |
35 | padd = 0
36 | (_, channel, height, width) = img1.size()
37 | if window is None:
38 | real_size = min(window_size, height, width)
39 | window = create_window(real_size, channel=channel).to(img1.device)
40 |
41 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
42 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
43 |
44 | mu1_sq = mu1.pow(2)
45 | mu2_sq = mu2.pow(2)
46 | mu1_mu2 = mu1 * mu2
47 |
48 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
49 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
50 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
51 |
52 | C1 = (0.01 * L) ** 2
53 | C2 = (0.03 * L) ** 2
54 |
55 | v1 = 2.0 * sigma12 + C2
56 | v2 = sigma1_sq + sigma2_sq + C2
57 | cs = torch.mean(v1 / v2) # contrast sensitivity
58 |
59 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
60 |
61 | if size_average:
62 | ret = ssim_map.mean()
63 | else:
64 | ret = ssim_map.mean(1).mean(1).mean(1)
65 |
66 | if full:
67 | return ret, cs
68 | return ret
69 |
70 |
71 | def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False):
72 | device = img1.device
73 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
74 | levels = weights.size()[0]
75 | mssim = []
76 | mcs = []
77 | for _ in range(levels):
78 | sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
79 | mssim.append(sim)
80 | mcs.append(cs)
81 |
82 | img1 = F.avg_pool2d(img1, (2, 2))
83 | img2 = F.avg_pool2d(img2, (2, 2))
84 |
85 | mssim = torch.stack(mssim)
86 | mcs = torch.stack(mcs)
87 |
88 | # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
89 | if normalize:
90 | mssim = (mssim + 1) / 2
91 | mcs = (mcs + 1) / 2
92 |
93 | pow1 = mcs ** weights
94 | pow2 = mssim ** weights
95 | # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
96 | output = torch.prod(pow1[:-1] * pow2[-1])
97 | return output
98 |
99 |
100 | # Classes to re-use window
101 | class SSIM(torch.nn.Module):
102 | def __init__(self, window_size=11, size_average=True, val_range=None):
103 | super(SSIM, self).__init__()
104 | self.window_size = window_size
105 | self.size_average = size_average
106 | self.val_range = val_range
107 |
108 | # Assume 1 channel for SSIM
109 | self.channel = 1
110 | self.window = create_window(window_size)
111 |
112 | def forward(self, img1, img2):
113 | (_, channel, _, _) = img1.size()
114 |
115 | if channel == self.channel and self.window.dtype == img1.dtype:
116 | window = self.window
117 | else:
118 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
119 | self.window = window
120 | self.channel = channel
121 |
122 | return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
123 |
124 | class MSSSIM(torch.nn.Module):
125 | def __init__(self, window_size=11, size_average=True, channel=3):
126 | super(MSSSIM, self).__init__()
127 | self.window_size = window_size
128 | self.size_average = size_average
129 | self.channel = channel
130 |
131 | def forward(self, img1, img2):
132 | # TODO: store window between calls if possible
133 | return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
134 |
--------------------------------------------------------------------------------
/pytorch_msssim/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hli1221/densefuse-pytorch/4394b63e9295db1c6b7a5c3664551c90f0605f2b/pytorch_msssim/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/test_image.py:
--------------------------------------------------------------------------------
1 | # test phase
2 | import torch
3 | from torch.autograd import Variable
4 | from net import DenseFuse_net
5 | import utils
6 | from args_fusion import args
7 | import numpy as np
8 | import time
9 | import cv2
10 |
11 |
12 | def load_model(path, input_nc, output_nc):
13 |
14 | nest_model = DenseFuse_net(input_nc, output_nc)
15 | nest_model.load_state_dict(torch.load(path))
16 |
17 | para = sum([np.prod(list(p.size())) for p in nest_model.parameters()])
18 | type_size = 4
19 | print('Model {} : params: {:4f}M'.format(nest_model._get_name(), para * type_size / 1000 / 1000))
20 |
21 | nest_model.eval()
22 | nest_model.cuda()
23 |
24 | return nest_model
25 |
26 |
27 | def _generate_fusion_image(model, strategy_type, img1, img2):
28 | # encoder
29 | # test = torch.unsqueeze(img_ir[:, i, :, :], 1)
30 | en_r = model.encoder(img1)
31 | # vision_features(en_r, 'ir')
32 | en_v = model.encoder(img2)
33 | # vision_features(en_v, 'vi')
34 | # fusion
35 | f = model.fusion(en_r, en_v, strategy_type=strategy_type)
36 | # f = en_v
37 | # decoder
38 | img_fusion = model.decoder(f)
39 | return img_fusion[0]
40 |
41 |
42 | def run_demo(model, infrared_path, visible_path, output_path_root, index, fusion_type, network_type, strategy_type, ssim_weight_str, mode):
43 | # if mode == 'L':
44 | ir_img = utils.get_test_images(infrared_path, height=None, width=None, mode=mode)
45 | vis_img = utils.get_test_images(visible_path, height=None, width=None, mode=mode)
46 | # else:
47 | # img_ir = utils.tensor_load_rgbimage(infrared_path)
48 | # img_ir = img_ir.unsqueeze(0).float()
49 | # img_vi = utils.tensor_load_rgbimage(visible_path)
50 | # img_vi = img_vi.unsqueeze(0).float()
51 |
52 | # dim = img_ir.shape
53 | if args.cuda:
54 | ir_img = ir_img.cuda()
55 | vis_img = vis_img.cuda()
56 | ir_img = Variable(ir_img, requires_grad=False)
57 | vis_img = Variable(vis_img, requires_grad=False)
58 | dimension = ir_img.size()
59 |
60 | img_fusion = _generate_fusion_image(model, strategy_type, ir_img, vis_img)
61 | ############################ multi outputs ##############################################
62 | file_name = 'fusion_' + fusion_type + '_' + str(index) + '_network_' + network_type + '_' + strategy_type + '_' + ssim_weight_str + '.png'
63 | output_path = output_path_root + file_name
64 | # # save images
65 | # utils.save_image_test(img_fusion, output_path)
66 | # utils.tensor_save_rgbimage(img_fusion, output_path)
67 | if args.cuda:
68 | img = img_fusion.cpu().clamp(0, 255).data[0].numpy()
69 | else:
70 | img = img_fusion.clamp(0, 255).data[0].numpy()
71 | img = img.transpose(1, 2, 0).astype('uint8')
72 | utils.save_images(output_path, img)
73 |
74 | print(output_path)
75 |
76 |
77 | def vision_features(feature_maps, img_type):
78 | count = 0
79 | for features in feature_maps:
80 | count += 1
81 | for index in range(features.size(1)):
82 | file_name = 'feature_maps_' + img_type + '_level_' + str(count) + '_channel_' + str(index) + '.png'
83 | output_path = 'outputs/feature_maps/' + file_name
84 | map = features[:, index, :, :].view(1,1,features.size(2),features.size(3))
85 | map = map*255
86 | # save images
87 | utils.save_image_test(map, output_path)
88 |
89 |
90 | def main():
91 | # run demo
92 | # test_path = "images/test-RGB/"
93 | test_path = "images/IV_images/"
94 | network_type = 'densefuse'
95 | fusion_type = 'auto' # auto, fusion_layer, fusion_all
96 | strategy_type_list = ['addition', 'attention_weight'] # addition, attention_weight, attention_enhance, adain_fusion, channel_fusion, saliency_mask
97 |
98 | output_path = './outputs/'
99 | strategy_type = strategy_type_list[0]
100 |
101 | if os.path.exists(output_path) is False:
102 | os.mkdir(output_path)
103 |
104 | # in_c = 3 for RGB images; in_c = 1 for gray images
105 | in_c = 1
106 | if in_c == 1:
107 | out_c = in_c
108 | mode = 'L'
109 | model_path = args.model_path_gray
110 | else:
111 | out_c = in_c
112 | mode = 'RGB'
113 | model_path = args.model_path_rgb
114 |
115 | with torch.no_grad():
116 | print('SSIM weight ----- ' + args.ssim_path[2])
117 | ssim_weight_str = args.ssim_path[2]
118 | model = load_model(model_path, in_c, out_c)
119 | for i in range(1):
120 | index = i + 1
121 | infrared_path = test_path + 'IR' + str(index) + '.jpg'
122 | visible_path = test_path + 'VIS' + str(index) + '.jpg'
123 | run_demo(model, infrared_path, visible_path, output_path, index, fusion_type, network_type, strategy_type, ssim_weight_str, mode)
124 | print('Done......')
125 |
126 | if __name__ == '__main__':
127 | main()
128 |
--------------------------------------------------------------------------------
/train_densefuse.py:
--------------------------------------------------------------------------------
1 | # Training DenseFuse network
2 | # auto-encoder
3 |
4 | import os
5 | import sys
6 | import time
7 | import numpy as np
8 | from tqdm import tqdm, trange
9 | import scipy.io as scio
10 | import random
11 | import torch
12 | from torch.optim import Adam
13 | from torch.autograd import Variable
14 | import utils
15 | from net import DenseFuse_net
16 | from args_fusion import args
17 | import pytorch_msssim
18 |
19 |
20 | def main():
21 | # os.environ["CUDA_VISIBLE_DEVICES"] = "3"
22 | original_imgs_path = utils.list_images(args.dataset)
23 | train_num = 40000
24 | original_imgs_path = original_imgs_path[:train_num]
25 | random.shuffle(original_imgs_path)
26 | # for i in range(5):
27 | i = 2
28 | train(i, original_imgs_path)
29 |
30 |
31 | def train(i, original_imgs_path):
32 |
33 | batch_size = args.batch_size
34 |
35 | # load network model, RGB
36 | in_c = 1 # 1 - gray; 3 - RGB
37 | if in_c == 1:
38 | img_model = 'L'
39 | else:
40 | img_model = 'RGB'
41 | input_nc = in_c
42 | output_nc = in_c
43 | densefuse_model = DenseFuse_net(input_nc, output_nc)
44 |
45 | if args.resume is not None:
46 | print('Resuming, initializing using weight from {}.'.format(args.resume))
47 | densefuse_model.load_state_dict(torch.load(args.resume))
48 | print(densefuse_model)
49 | optimizer = Adam(densefuse_model.parameters(), args.lr)
50 | mse_loss = torch.nn.MSELoss()
51 | ssim_loss = pytorch_msssim.msssim
52 |
53 | if args.cuda:
54 | densefuse_model.cuda()
55 |
56 | tbar = trange(args.epochs)
57 | print('Start training.....')
58 |
59 | # creating save path
60 | temp_path_model = os.path.join(args.save_model_dir, args.ssim_path[i])
61 | if os.path.exists(temp_path_model) is False:
62 | os.mkdir(temp_path_model)
63 |
64 | temp_path_loss = os.path.join(args.save_loss_dir, args.ssim_path[i])
65 | if os.path.exists(temp_path_loss) is False:
66 | os.mkdir(temp_path_loss)
67 |
68 | Loss_pixel = []
69 | Loss_ssim = []
70 | Loss_all = []
71 | all_ssim_loss = 0.
72 | all_pixel_loss = 0.
73 | for e in tbar:
74 | print('Epoch %d.....' % e)
75 | # load training database
76 | image_set_ir, batches = utils.load_dataset(original_imgs_path, batch_size)
77 | densefuse_model.train()
78 | count = 0
79 | for batch in range(batches):
80 | image_paths = image_set_ir[batch * batch_size:(batch * batch_size + batch_size)]
81 | img = utils.get_train_images_auto(image_paths, height=args.HEIGHT, width=args.WIDTH, mode=img_model)
82 |
83 | count += 1
84 | optimizer.zero_grad()
85 | img = Variable(img, requires_grad=False)
86 |
87 | if args.cuda:
88 | img = img.cuda()
89 | # get fusion image
90 | # encoder
91 | en = densefuse_model.encoder(img)
92 | # decoder
93 | outputs = densefuse_model.decoder(en)
94 | # resolution loss
95 | x = Variable(img.data.clone(), requires_grad=False)
96 |
97 | ssim_loss_value = 0.
98 | pixel_loss_value = 0.
99 | for output in outputs:
100 | pixel_loss_temp = mse_loss(output, x)
101 | ssim_loss_temp = ssim_loss(output, x, normalize=True)
102 | ssim_loss_value += (1-ssim_loss_temp)
103 | pixel_loss_value += pixel_loss_temp
104 | ssim_loss_value /= len(outputs)
105 | pixel_loss_value /= len(outputs)
106 |
107 | # total loss
108 | total_loss = pixel_loss_value + args.ssim_weight[i] * ssim_loss_value
109 | total_loss.backward()
110 | optimizer.step()
111 |
112 | all_ssim_loss += ssim_loss_value.item()
113 | all_pixel_loss += pixel_loss_value.item()
114 | if (batch + 1) % args.log_interval == 0:
115 | mesg = "{}\tEpoch {}:\t[{}/{}]\t pixel loss: {:.6f}\t ssim loss: {:.6f}\t total: {:.6f}".format(
116 | time.ctime(), e + 1, count, batches,
117 | all_pixel_loss / args.log_interval,
118 | all_ssim_loss / args.log_interval,
119 | (args.ssim_weight[i] * all_ssim_loss + all_pixel_loss) / args.log_interval
120 | )
121 | tbar.set_description(mesg)
122 | Loss_pixel.append(all_pixel_loss / args.log_interval)
123 | Loss_ssim.append(all_ssim_loss / args.log_interval)
124 | Loss_all.append((args.ssim_weight[i] * all_ssim_loss + all_pixel_loss) / args.log_interval)
125 |
126 | all_ssim_loss = 0.
127 | all_pixel_loss = 0.
128 |
129 | if (batch + 1) % (200 * args.log_interval) == 0:
130 | # save model
131 | densefuse_model.eval()
132 | densefuse_model.cpu()
133 | save_model_filename = args.ssim_path[i] + '/' + "Epoch_" + str(e) + "_iters_" + str(count) + "_" + \
134 | str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + args.ssim_path[
135 | i] + ".model"
136 | save_model_path = os.path.join(args.save_model_dir, save_model_filename)
137 | torch.save(densefuse_model.state_dict(), save_model_path)
138 | # save loss data
139 | # pixel loss
140 | loss_data_pixel = np.array(Loss_pixel)
141 | loss_filename_path = args.ssim_path[i] + '/' + "loss_pixel_epoch_" + str(
142 | args.epochs) + "_iters_" + str(count) + "_" + str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + \
143 | args.ssim_path[i] + ".mat"
144 | save_loss_path = os.path.join(args.save_loss_dir, loss_filename_path)
145 | scio.savemat(save_loss_path, {'loss_pixel': loss_data_pixel})
146 | # SSIM loss
147 | loss_data_ssim = np.array(Loss_ssim)
148 | loss_filename_path = args.ssim_path[i] + '/' + "loss_ssim_epoch_" + str(
149 | args.epochs) + "_iters_" + str(count) + "_" + str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + \
150 | args.ssim_path[i] + ".mat"
151 | save_loss_path = os.path.join(args.save_loss_dir, loss_filename_path)
152 | scio.savemat(save_loss_path, {'loss_ssim': loss_data_ssim})
153 | # all loss
154 | loss_data_total = np.array(Loss_all)
155 | loss_filename_path = args.ssim_path[i] + '/' + "loss_total_epoch_" + str(
156 | args.epochs) + "_iters_" + str(count) + "_" + str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + \
157 | args.ssim_path[i] + ".mat"
158 | save_loss_path = os.path.join(args.save_loss_dir, loss_filename_path)
159 | scio.savemat(save_loss_path, {'loss_total': loss_data_total})
160 |
161 | densefuse_model.train()
162 | densefuse_model.cuda()
163 | tbar.set_description("\nCheckpoint, trained model saved at", save_model_path)
164 |
165 | # pixel loss
166 | loss_data_pixel = np.array(Loss_pixel)
167 | loss_filename_path = args.ssim_path[i] + '/' + "Final_loss_pixel_epoch_" + str(
168 | args.epochs) + "_" + str(time.ctime()).replace(' ', '_').replace(':','_') + "_" + \
169 | args.ssim_path[i] + ".mat"
170 | save_loss_path = os.path.join(args.save_loss_dir, loss_filename_path)
171 | scio.savemat(save_loss_path, {'loss_pixel': loss_data_pixel})
172 | # SSIM loss
173 | loss_data_ssim = np.array(Loss_ssim)
174 | loss_filename_path = args.ssim_path[i] + '/' + "Final_loss_ssim_epoch_" + str(
175 | args.epochs) + "_" + str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + \
176 | args.ssim_path[i] + ".mat"
177 | save_loss_path = os.path.join(args.save_loss_dir, loss_filename_path)
178 | scio.savemat(save_loss_path, {'loss_ssim': loss_data_ssim})
179 | # all loss
180 | loss_data_total = np.array(Loss_all)
181 | loss_filename_path = args.ssim_path[i] + '/' + "Final_loss_total_epoch_" + str(
182 | args.epochs) + "_" + str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + \
183 | args.ssim_path[i] + ".mat"
184 | save_loss_path = os.path.join(args.save_loss_dir, loss_filename_path)
185 | scio.savemat(save_loss_path, {'loss_total': loss_data_total})
186 | # save model
187 | densefuse_model.eval()
188 | densefuse_model.cpu()
189 | save_model_filename = args.ssim_path[i] + '/' "Final_epoch_" + str(args.epochs) + "_" + \
190 | str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + args.ssim_path[i] + ".model"
191 | save_model_path = os.path.join(args.save_model_dir, save_model_filename)
192 | torch.save(densefuse_model.state_dict(), save_model_path)
193 |
194 | print("\nDone, trained model saved at", save_model_path)
195 |
196 |
197 | if __name__ == "__main__":
198 | main()
199 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import listdir, mkdir, sep
3 | from os.path import join, exists, splitext
4 | import random
5 | import numpy as np
6 | import torch
7 | from PIL import Image
8 | from torch.autograd import Variable
9 | from torch.utils.serialization import load_lua
10 | from args_fusion import args
11 | from scipy.misc import imread, imsave, imresize
12 | import matplotlib as mpl
13 | import cv2
14 | from torchvision import datasets, transforms
15 |
16 |
17 | def list_images(directory):
18 | images = []
19 | names = []
20 | dir = listdir(directory)
21 | dir.sort()
22 | for file in dir:
23 | name = file.lower()
24 | if name.endswith('.png'):
25 | images.append(join(directory, file))
26 | elif name.endswith('.jpg'):
27 | images.append(join(directory, file))
28 | elif name.endswith('.jpeg'):
29 | images.append(join(directory, file))
30 | name1 = name.split('.')
31 | names.append(name1[0])
32 | return images
33 |
34 |
35 | def tensor_load_rgbimage(filename, size=None, scale=None, keep_asp=False):
36 | img = Image.open(filename).convert('RGB')
37 | if size is not None:
38 | if keep_asp:
39 | size2 = int(size * 1.0 / img.size[0] * img.size[1])
40 | img = img.resize((size, size2), Image.ANTIALIAS)
41 | else:
42 | img = img.resize((size, size), Image.ANTIALIAS)
43 |
44 | elif scale is not None:
45 | img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)), Image.ANTIALIAS)
46 | img = np.array(img).transpose(2, 0, 1)
47 | img = torch.from_numpy(img).float()
48 | return img
49 |
50 |
51 | def tensor_save_rgbimage(tensor, filename, cuda=True):
52 | if cuda:
53 | # img = tensor.clone().cpu().clamp(0, 255).numpy()
54 | img = tensor.cpu().clamp(0, 255).data[0].numpy()
55 | else:
56 | # img = tensor.clone().clamp(0, 255).numpy()
57 | img = tensor.clamp(0, 255).numpy()
58 | img = img.transpose(1, 2, 0).astype('uint8')
59 | img = Image.fromarray(img)
60 | img.save(filename)
61 |
62 |
63 | def tensor_save_bgrimage(tensor, filename, cuda=False):
64 | (b, g, r) = torch.chunk(tensor, 3)
65 | tensor = torch.cat((r, g, b))
66 | tensor_save_rgbimage(tensor, filename, cuda)
67 |
68 |
69 | def gram_matrix(y):
70 | (b, ch, h, w) = y.size()
71 | features = y.view(b, ch, w * h)
72 | features_t = features.transpose(1, 2)
73 | gram = features.bmm(features_t) / (ch * h * w)
74 | return gram
75 |
76 |
77 | def matSqrt(x):
78 | U,D,V = torch.svd(x)
79 | return U * (D.pow(0.5).diag()) * V.t()
80 |
81 |
82 | # load training images
83 | def load_dataset(image_path, BATCH_SIZE, num_imgs=None):
84 | if num_imgs is None:
85 | num_imgs = len(image_path)
86 | original_imgs_path = image_path[:num_imgs]
87 | # random
88 | random.shuffle(original_imgs_path)
89 | mod = num_imgs % BATCH_SIZE
90 | print('BATCH SIZE %d.' % BATCH_SIZE)
91 | print('Train images number %d.' % num_imgs)
92 | print('Train images samples %s.' % str(num_imgs / BATCH_SIZE))
93 |
94 | if mod > 0:
95 | print('Train set has been trimmed %d samples...\n' % mod)
96 | original_imgs_path = original_imgs_path[:-mod]
97 | batches = int(len(original_imgs_path) // BATCH_SIZE)
98 | return original_imgs_path, batches
99 |
100 |
101 | def get_image(path, height=256, width=256, mode='L'):
102 | if mode == 'L':
103 | image = imread(path, mode=mode)
104 | elif mode == 'RGB':
105 | image = Image.open(path).convert('RGB')
106 |
107 | if height is not None and width is not None:
108 | image = imresize(image, [height, width], interp='nearest')
109 | return image
110 |
111 |
112 | def get_train_images_auto(paths, height=256, width=256, mode='RGB'):
113 | if isinstance(paths, str):
114 | paths = [paths]
115 | images = []
116 | for path in paths:
117 | image = get_image(path, height, width, mode=mode)
118 | if mode == 'L':
119 | image = np.reshape(image, [1, image.shape[0], image.shape[1]])
120 | else:
121 | image = np.reshape(image, [image.shape[2], image.shape[0], image.shape[1]])
122 | images.append(image)
123 |
124 | images = np.stack(images, axis=0)
125 | images = torch.from_numpy(images).float()
126 | return images
127 |
128 |
129 | def get_test_images(paths, height=None, width=None, mode='RGB'):
130 | ImageToTensor = transforms.Compose([transforms.ToTensor()])
131 | if isinstance(paths, str):
132 | paths = [paths]
133 | images = []
134 | for path in paths:
135 | image = get_image(path, height, width, mode=mode)
136 | if mode == 'L':
137 | image = np.reshape(image, [1, image.shape[0], image.shape[1]])
138 | else:
139 | # test = ImageToTensor(image).numpy()
140 | # shape = ImageToTensor(image).size()
141 | image = ImageToTensor(image).float().numpy()*255
142 | images.append(image)
143 | images = np.stack(images, axis=0)
144 | images = torch.from_numpy(images).float()
145 | return images
146 |
147 |
148 | # colormap
149 | def colormap():
150 | return mpl.colors.LinearSegmentedColormap.from_list('cmap', ['#FFFFFF', '#98F5FF', '#00FF00', '#FFFF00','#FF0000', '#8B0000'], 256)
151 |
152 |
153 | def save_images(path, data):
154 | # if isinstance(paths, str):
155 | # paths = [paths]
156 | #
157 | # t1 = len(paths)
158 | # t2 = len(datas)
159 | # assert (len(paths) == len(datas))
160 |
161 | # if prefix is None:
162 | # prefix = ''
163 | # if suffix is None:
164 | # suffix = ''
165 |
166 | if data.shape[2] == 1:
167 | data = data.reshape([data.shape[0], data.shape[1]])
168 | imsave(path, data)
169 |
170 | # for i, path in enumerate(paths):
171 | # data = datas[i]
172 | # # print('data ==>>\n', data)
173 | # if data.shape[2] == 1:
174 | # data = data.reshape([data.shape[0], data.shape[1]])
175 | # # print('data reshape==>>\n', data)
176 | #
177 | # name, ext = splitext(path)
178 | # name = name.split(sep)[-1]
179 | #
180 | # path = join(save_path, prefix + suffix + ext)
181 | # print('data path==>>', path)
182 | #
183 | # # new_im = Image.fromarray(data)
184 | # # new_im.show()
185 | #
186 | # imsave(path, data)
187 |
188 |
189 |
--------------------------------------------------------------------------------