├── dataset.py
├── densefuse_net.py
├── main.py
├── readme.md
├── ssim.py
├── test.py
├── train.py
├── train_result
├── TrainData.mat
├── ValData.mat
├── curve.png
└── model_weight.pkl
└── utils.py
/dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Thu Jun 13 16:03:42 2019
4 |
5 | @author: win10
6 | """
7 |
8 | import torch.utils.data as Data
9 | import torchvision.transforms as transforms
10 |
11 | from glob import glob
12 | import os
13 | from PIL import Image
14 |
15 | class AEDataset(Data.Dataset):
16 | def __init__(self, root, resize= [256,256], transform = None, gray = True):
17 | self.files = glob(os.path.join(root, '*.*'))
18 | self.resize = resize
19 | self.gray = gray
20 | self._tensor = transforms.ToTensor()
21 | self.transform = transform
22 |
23 | def __len__(self):
24 | return len(self.files)
25 |
26 | def __getitem__(self, index):
27 | img = Image.open(self.files[index]).resize(self.resize)
28 |
29 | if self.gray:
30 | img = img.convert('L')
31 |
32 | img = self._tensor(img)
33 | if self.transform is not None:
34 | img = self.transform(img)
35 |
36 | return img
--------------------------------------------------------------------------------
/densefuse_net.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Wed Jun 12 21:03:03 2019
4 |
5 | @author: win10
6 | """
7 |
8 | import torch.nn as nn
9 | import torch
10 |
11 | class Decoder(nn.Module):
12 | def __init__(self):
13 | super(Decoder,self).__init__()
14 | self.layers = nn.Sequential()
15 | self.layers.add_module('Conv2', nn.Conv2d(64,64,3,1,1))
16 | self.layers.add_module('Act2' , nn.ReLU(inplace=True))
17 | self.layers.add_module('Conv3', nn.Conv2d(64,32,3,1,1))
18 | self.layers.add_module('Act3' , nn.ReLU(inplace=True))
19 | self.layers.add_module('Conv4', nn.Conv2d(32,16,3,1,1))
20 | self.layers.add_module('Act4' , nn.ReLU(inplace=True))
21 | self.layers.add_module('Conv5', nn.Conv2d(16,1,3,1,1))
22 |
23 | def forward(self, x):
24 | return self.layers(x)
25 |
26 | class Encoder(nn.Module):
27 | def __init__(self):
28 | super(Encoder,self).__init__()
29 |
30 | self.Conv1 = nn.Conv2d(1,16,3,1,1)
31 | self.Relu = nn.ReLU(inplace=True)
32 |
33 | self.layers = nn.ModuleDict({
34 | 'DenseConv1': nn.Conv2d(16,16,3,1,1),
35 | 'DenseConv2': nn.Conv2d(32,16,3,1,1),
36 | 'DenseConv3': nn.Conv2d(48,16,3,1,1)
37 | })
38 |
39 | def forward(self, x):
40 | x = self.Relu(self.Conv1(x))
41 | for i in range(len(self.layers)):
42 | out = self.layers['DenseConv'+str(i+1)] ( x )
43 | x = torch.cat([x,out],1)
44 | return x
45 |
46 | class DenseFuseNet(nn.Module):
47 |
48 | def __init__(self):
49 | super(DenseFuseNet,self).__init__()
50 |
51 | self.encoder = Encoder()
52 | self.decoder = Decoder()
53 |
54 | def forward(self,x):
55 | return self.decoder(self.encoder(x))
56 |
57 |
58 |
59 |
60 |
61 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Sat Jun 15 17:32:38 2019
4 |
5 | @author: win10
6 | """
7 | import torch
8 |
9 | from densefuse_net import DenseFuseNet
10 | from utils import test
11 |
12 | device = 'cuda'
13 |
14 | model = DenseFuseNet().to(device)
15 | model.load_state_dict(torch.load('./train_result/model_weight.pkl')['weight'])
16 |
17 | test_path = './images/IV_images/'
18 | test(test_path, model, mode='add')
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | DenseFuse (Pytorch)
7 |
8 |
9 | ---
10 |
11 | An pytorch implement of DenseFuse.
12 |
13 |
14 |
15 | ## 📝 Table of Contents
16 | - [About](#about)
17 | - [Getting Started](#getting_started)
18 | - [Deployment](#deployment)
19 | - [Usage](#usage)
20 | - [Built Using](#built_using)
21 | - [TODO](../TODO.md)
22 | - [Contributing](../CONTRIBUTING.md)
23 | - [Authors](#authors)
24 | - [Acknowledgments](#acknowledgement)
25 |
26 | ## 🧐 About
27 | This is a pytorch implement of DenseFuse proposed by this paper,
28 | [H. Li, X. J. Wu, “DenseFuse: A Fusion Approach to Infrared and Visible Images,” IEEE Trans. Image Process., vol. 28, no. 5, pp. 2614–2623, May. 2019.](https://arxiv.org/abs/1804.08361)
29 |
30 | The code is writted with torch 1.1.0 and pytorch-ssim.
31 |
32 |
33 | ## 🎈 Usage
34 |
35 | ### Quick start
36 | 1. Clone this repo and unpack it.
37 | 2. Download [test dataset](https://github.com/hli1221/imagefusion_densefuse/tree/master/images) and put test images in './images/IV_images'
38 | 3. run 'main.py'
39 |
40 | ### Training
41 | A pretrained model is available in './train_result/model_weight.pkl'. We train it on MS-COCO 2014. There are 82783 images in total, where 1000 images (COCO_train2014_000000574951 ~ COCO_train2014_000000581921.jpg) are used for validation and the rest are for training. In the training phase, all images are resize to 256x256 and are transformed to gray pictures. Model is optimized by Adam with learning rate being 1e-4. The batch size and epoch number are 2 and 4, respectively. Loss function is MSE+lambda x SSIM, where lambda=1. The experiments were implemented with 2080ti GPU and 32GB RAM. It took about 2 hours.
42 |
43 | If you want to re-train this net, you should download MS-COCO 2014 and run 'train.py'
44 |
45 |
46 |
--------------------------------------------------------------------------------
/ssim.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import torch.nn.functional as F
4 |
5 | from torch.autograd import Variable
6 |
7 | import numpy as np
8 |
9 | from math import exp
10 |
11 |
12 |
13 | def gaussian(window_size, sigma):
14 |
15 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
16 |
17 | return gauss/gauss.sum()
18 |
19 |
20 |
21 | def create_window(window_size, channel):
22 |
23 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
24 |
25 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
26 |
27 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
28 |
29 | return window
30 |
31 |
32 |
33 | def _ssim(img1, img2, window, window_size, channel, size_average = True):
34 |
35 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
36 |
37 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
38 |
39 |
40 |
41 | mu1_sq = mu1.pow(2)
42 |
43 | mu2_sq = mu2.pow(2)
44 |
45 | mu1_mu2 = mu1*mu2
46 |
47 |
48 |
49 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
50 |
51 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
52 |
53 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
54 |
55 |
56 |
57 | C1 = 0.01**2
58 |
59 | C2 = 0.03**2
60 |
61 |
62 |
63 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
64 |
65 |
66 |
67 | if size_average:
68 |
69 | return ssim_map.mean()
70 |
71 | else:
72 |
73 | return ssim_map.mean(1).mean(1).mean(1)
74 |
75 |
76 |
77 | class SSIM(torch.nn.Module):
78 |
79 | def __init__(self, window_size = 11, size_average = True):
80 |
81 | super(SSIM, self).__init__()
82 |
83 | self.window_size = window_size
84 |
85 | self.size_average = size_average
86 |
87 | self.channel = 1
88 |
89 | self.window = create_window(window_size, self.channel)
90 |
91 |
92 |
93 | def forward(self, img1, img2):
94 |
95 | (_, channel, _, _) = img1.size()
96 |
97 |
98 |
99 | if channel == self.channel and self.window.data.type() == img1.data.type():
100 |
101 | window = self.window
102 |
103 | else:
104 |
105 | window = create_window(self.window_size, channel)
106 |
107 |
108 |
109 | if img1.is_cuda:
110 |
111 | window = window.cuda(img1.get_device())
112 |
113 | window = window.type_as(img1)
114 |
115 |
116 |
117 | self.window = window
118 |
119 | self.channel = channel
120 |
121 |
122 |
123 |
124 |
125 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
126 |
127 |
128 |
129 | def ssim(img1, img2, window_size = 11, size_average = True):
130 |
131 | (_, channel, _, _) = img1.size()
132 |
133 | window = create_window(window_size, channel)
134 |
135 |
136 |
137 | if img1.is_cuda:
138 |
139 | window = window.cuda(img1.get_device())
140 |
141 | window = window.type_as(img1)
142 |
143 |
144 |
145 | return _ssim(img1, img2, window, window_size, channel, size_average)
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Sat Jun 15 15:05:50 2019
4 |
5 | @author: win10
6 | """
7 | from glob import glob
8 | import string
9 | from utils import test_gray, test_rgb
10 |
11 | test_path = './images/IV_images/'
12 |
13 | def test(test_path, model, img_type='gray', save_path='./test_result/',mode='l1',window_width=1):
14 | img_list = glob(test_path+'*')
15 | img_num = len(img_list)/2
16 | suffix = img_list[0].split('.')[-1]
17 | img_name_list = list(set([img_list[i].split('\\')[-1].split('.')[0].strip(string.digits) for i in range(len(img_list))]))
18 |
19 | if img_type == 'gray':
20 | fusion_phase = test_gray()
21 | elif img_type == 'rgb':
22 | fusion_phase = test_rgb()
23 |
24 | for i in range(int(img_num)):
25 | img1_path = test_path+img_name_list[0]+str(i+1)+'.'+suffix
26 | img2_path = test_path+img_name_list[1]+str(i+1)+'.'+suffix
27 | save_name = 'fusion'+str(i+1)+'_'+img_type+'_'+mode+'.'+suffix
28 | fusion_phase.get_fusion(img1_path,img2_path,model,
29 | save_path = save_path, save_name = save_name, mode=mode,window_width=window_width)
30 |
31 | test(test_path, model, mode='add')
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Thu Jun 13 15:59:40 2019
4 |
5 | @author: win10
6 | """
7 | import os
8 | os.chdir(r'D:\py_code\densefuse_pytorch')
9 |
10 | from torch.utils.data import DataLoader
11 | import torch.optim as optim
12 | import torch.nn as nn
13 | import torch
14 |
15 | from densefuse_net import DenseFuseNet
16 | from dataset import AEDataset
17 | from ssim import SSIM
18 | from utils import mkdir
19 |
20 | import os
21 | import scipy.io as scio
22 | import numpy as np
23 | from matplotlib import pyplot as plt
24 |
25 | # Parameters
26 | root = 'D:/coco/train2014_train/'
27 | root_val = 'D:/coco/train2014_val/'
28 | train_path = './train_result/'
29 | epochs = 4
30 | batch_size = 2
31 | device = 'cuda'
32 | lr = 1e-4
33 | lambd = 1
34 | loss_interval = 1000
35 | model_interval = 1000
36 | # Dataset
37 | data = AEDataset(root, resize= [256,256], transform = None, gray = True)
38 | loader = DataLoader(data, batch_size = batch_size, shuffle=True)
39 | data_val = AEDataset(root_val, resize= [256,256], transform = None, gray = True)
40 | loader_val = DataLoader(data_val, batch_size = 100, shuffle=True)
41 |
42 | # Model
43 | model = DenseFuseNet().to(device)
44 | print(model)
45 | optimizer = optim.Adam(model.parameters(), lr = lr)
46 | MSE_fun = nn.MSELoss()
47 | SSIM_fun = SSIM()
48 |
49 |
50 | # Training
51 | mse_train = []
52 | ssim_train = []
53 | loss_train = []
54 | mse_val = []
55 | ssim_val = []
56 | loss_val = []
57 | mkdir(train_path)
58 | print('============ Training Begins ===============')
59 | for iteration in range(epochs):
60 | for index, img in enumerate(loader):
61 | img = img.to(device)
62 |
63 | optimizer.zero_grad()
64 | img_recon = model(img)
65 | mse_loss = MSE_fun(img,img_recon)
66 | ssim_loss = 1-SSIM_fun(img,img_recon)
67 | loss = mse_loss+lambd*ssim_loss
68 | loss.backward()
69 | optimizer.step()
70 |
71 |
72 | if index%loss_interval ==0:
73 | print('[%d,%d] - Train - MSE: %.10f, SSIM: %.10f'%
74 | (iteration,index,mse_loss.item(),ssim_loss.item()))
75 | mse_train.append(mse_loss.item())
76 | ssim_train.append(ssim_loss.item())
77 | loss_train.append(loss.item())
78 |
79 | with torch.no_grad():
80 | tmp1, tmp2 = .0, .0
81 | for _, img in enumerate(loader_val):
82 | img = img.to(device)
83 | img_recon = model(img)
84 | tmp1 += (MSE_fun(img,img_recon)*img.shape[0]).item()
85 | tmp2 += (SSIM_fun(img,img_recon)*img.shape[0]).item()
86 | tmp3 = tmp1+lambd*tmp2
87 | mse_val.append(tmp1/data_val.__len__())
88 | ssim_val.append(tmp1/data_val.__len__())
89 | loss_val.append(tmp1/data_val.__len__())
90 | print('[%d,%d] - Validation - MSE: %.10f, SSIM: %.10f'%
91 | (iteration,index,mse_val[-1],ssim_val[-1]))
92 | scio.savemat(os.path.join(train_path, 'TrainData.mat'),
93 | {'mse_train': np.array(mse_train),
94 | 'ssim_train': np.array(ssim_train),
95 | 'loss_train': np.array(loss_train)})
96 | scio.savemat(os.path.join(train_path, 'ValData.mat'),
97 | {'mse_val': np.array(mse_val),
98 | 'ssim_val': np.array(ssim_val),
99 | 'loss_val': np.array(loss_val)})
100 |
101 | plt.figure(figsize=[12,8])
102 | plt.subplot(2,3,1), plt.semilogy(mse_train), plt.title('mse train')
103 | plt.subplot(2,3,2), plt.semilogy(ssim_train), plt.title('ssim train')
104 | plt.subplot(2,3,3), plt.semilogy(loss_train), plt.title('loss train')
105 | plt.subplot(2,3,4), plt.semilogy(mse_val), plt.title('mse val')
106 | plt.subplot(2,3,5), plt.semilogy(ssim_val), plt.title('ssim val')
107 | plt.subplot(2,3,6), plt.semilogy(loss_val), plt.title('loss val')
108 |
109 | plt.savefig(os.path.join(train_path,'curve.png'),dpi=90)
110 |
111 | if index%model_interval ==0:
112 | torch.save( {'weight': model.state_dict(), 'epoch':iteration, 'batch_index': index},
113 | os.path.join(train_path,'model_weight_new.pkl'))
114 | print('[%d,%d] - model is saved -'%(iteration,index))
115 |
116 |
--------------------------------------------------------------------------------
/train_result/TrainData.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shuangxu96/densefuse_pytorch/54e2ee245d8fc02d7afc3926a59ac7d1537e6c6a/train_result/TrainData.mat
--------------------------------------------------------------------------------
/train_result/ValData.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shuangxu96/densefuse_pytorch/54e2ee245d8fc02d7afc3926a59ac7d1537e6c6a/train_result/ValData.mat
--------------------------------------------------------------------------------
/train_result/curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shuangxu96/densefuse_pytorch/54e2ee245d8fc02d7afc3926a59ac7d1537e6c6a/train_result/curve.png
--------------------------------------------------------------------------------
/train_result/model_weight.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shuangxu96/densefuse_pytorch/54e2ee245d8fc02d7afc3926a59ac7d1537e6c6a/train_result/model_weight.pkl
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Wed Jun 12 22:32:22 2019
4 |
5 | @author: win10
6 | """
7 | from PIL import Image
8 | import os
9 | import string
10 | from glob import glob
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 | import torchvision.transforms as transforms
16 |
17 | _tensor = transforms.ToTensor()
18 | _pil_rgb = transforms.ToPILImage('RGB')
19 | _pil_gray = transforms.ToPILImage()
20 | device = 'cuda'
21 |
22 | def mkdir(path):
23 | if os.path.exists(path) is False:
24 | os.makedirs(path)
25 |
26 | def load_img(img_path, img_type='gray'):
27 | img = Image.open(img_path)
28 | if img_type=='gray':
29 | img = img.convert('L')
30 | return _tensor(img).unsqueeze(0)
31 |
32 | class Strategy(nn.Module):
33 | def __init__(self, mode='add', window_width=1):
34 | super().__init__()
35 | self.mode = mode
36 | if self.mode == 'l1':
37 | self.window_width = window_width
38 |
39 | def forward(self, y1, y2):
40 | if self.mode == 'add':
41 | return (y1+y2)/2
42 |
43 | if self.mode == 'l1':
44 | ActivityMap1 = y1.abs()
45 | ActivityMap2 = y2.abs()
46 |
47 | kernel = torch.ones(2*self.window_width+1,2*self.window_width+1)/(2*self.window_width+1)**2
48 | kernel = kernel.to(device).type(torch.float32)[None,None,:,:]
49 | kernel = kernel.expand(y1.shape[1],y1.shape[1],2*self.window_width+1,2*self.window_width+1)
50 | ActivityMap1 = F.conv2d(ActivityMap1, kernel, padding=self.window_width)
51 | ActivityMap2 = F.conv2d(ActivityMap2, kernel, padding=self.window_width)
52 | WeightMap1 = ActivityMap1/(ActivityMap1+ActivityMap2)
53 | WeightMap2 = ActivityMap2/(ActivityMap1+ActivityMap2)
54 | return WeightMap1*y1+WeightMap2*y2
55 |
56 | def fusion(x1,x2,model,mode='l1', window_width=1):
57 | with torch.no_grad():
58 | fusion_layer = Strategy(mode,window_width).to(device)
59 | feature1 = model.encoder(x1)
60 | feature2 = model.encoder(x2)
61 | feature_fusion = fusion_layer(feature1,feature2)
62 | return model.decoder(feature_fusion).squeeze(0).detach().cpu()
63 |
64 | class Test:
65 | def __init__(self):
66 | pass
67 |
68 | def load_imgs(self, img1_path,img2_path, device):
69 | img1 = load_img(img1_path,img_type=self.img_type).to(device)
70 | img2 = load_img(img2_path,img_type=self.img_type).to(device)
71 | return img1, img2
72 |
73 | def save_imgs(self, save_path,save_name, img_fusion):
74 | mkdir(save_path)
75 | save_path = os.path.join(save_path,save_name)
76 | img_fusion.save(save_path)
77 |
78 | class test_gray(Test):
79 | def __init__(self):
80 | self.img_type = 'rgray'
81 |
82 | def get_fusion(self,img1_path,img2_path,model,
83 | save_path = './test_result/', save_name = 'none', mode='l1',window_width=1):
84 | img1, img2 = self.load_imgs(img1_path,img2_path,device)
85 |
86 | img_fusion = fusion(x1=img1,x2=img2,model=model,mode=mode,window_width=window_width)
87 | img_fusion = _pil_gray(img_fusion)
88 |
89 | self.save_imgs(save_path,save_name, img_fusion)
90 | return img_fusion
91 |
92 | class test_rgb(Test):
93 | def __init__(self):
94 | self.img_type = 'rgb'
95 |
96 | def get_fusion(self,img1_path,img2_path,model,
97 | save_path = './test_result/', save_name = 'none', mode='l1',window_width=1):
98 | img1, img2 = self.load_imgs(img1_path,img2_path,device)
99 |
100 | img_fusion = _pil_rgb(torch.cat(
101 | [fusion(img1[:,i,:,:][:,None,:,:],
102 | img2[:,i,:,:][:,None,:,:], model,
103 | mode=mode,window_width=window_width)
104 | for i in range(3)],
105 | dim=0))
106 |
107 | self.save_imgs(save_path,save_name, img_fusion)
108 | return img_fusion
109 |
110 |
111 | def test(test_path, model, img_type='gray', save_path='./test_result/',mode='l1',window_width=1):
112 | img_list = glob(test_path+'*')
113 | img_num = len(img_list)/2
114 | suffix = img_list[0].split('.')[-1]
115 | img_name_list = list(set([img_list[i].split('\\')[-1].split('.')[0].strip(string.digits) for i in range(len(img_list))]))
116 |
117 | if img_type == 'gray':
118 | fusion_phase = test_gray()
119 | elif img_type == 'rgb':
120 | fusion_phase = test_rgb()
121 |
122 | for i in range(int(img_num)):
123 | img1_path = test_path+img_name_list[0]+str(i+1)+'.'+suffix
124 | img2_path = test_path+img_name_list[1]+str(i+1)+'.'+suffix
125 | save_name = 'fusion'+str(i+1)+'_'+img_type+'_'+mode+'.'+suffix
126 | fusion_phase.get_fusion(img1_path,img2_path,model,
127 | save_path = save_path, save_name = save_name, mode=mode,window_width=window_width)
128 |
--------------------------------------------------------------------------------