├── dataset.py ├── densefuse_net.py ├── main.py ├── readme.md ├── ssim.py ├── test.py ├── train.py ├── train_result ├── TrainData.mat ├── ValData.mat ├── curve.png └── model_weight.pkl └── utils.py /dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Jun 13 16:03:42 2019 4 | 5 | @author: win10 6 | """ 7 | 8 | import torch.utils.data as Data 9 | import torchvision.transforms as transforms 10 | 11 | from glob import glob 12 | import os 13 | from PIL import Image 14 | 15 | class AEDataset(Data.Dataset): 16 | def __init__(self, root, resize= [256,256], transform = None, gray = True): 17 | self.files = glob(os.path.join(root, '*.*')) 18 | self.resize = resize 19 | self.gray = gray 20 | self._tensor = transforms.ToTensor() 21 | self.transform = transform 22 | 23 | def __len__(self): 24 | return len(self.files) 25 | 26 | def __getitem__(self, index): 27 | img = Image.open(self.files[index]).resize(self.resize) 28 | 29 | if self.gray: 30 | img = img.convert('L') 31 | 32 | img = self._tensor(img) 33 | if self.transform is not None: 34 | img = self.transform(img) 35 | 36 | return img -------------------------------------------------------------------------------- /densefuse_net.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Jun 12 21:03:03 2019 4 | 5 | @author: win10 6 | """ 7 | 8 | import torch.nn as nn 9 | import torch 10 | 11 | class Decoder(nn.Module): 12 | def __init__(self): 13 | super(Decoder,self).__init__() 14 | self.layers = nn.Sequential() 15 | self.layers.add_module('Conv2', nn.Conv2d(64,64,3,1,1)) 16 | self.layers.add_module('Act2' , nn.ReLU(inplace=True)) 17 | self.layers.add_module('Conv3', nn.Conv2d(64,32,3,1,1)) 18 | self.layers.add_module('Act3' , nn.ReLU(inplace=True)) 19 | self.layers.add_module('Conv4', nn.Conv2d(32,16,3,1,1)) 20 | self.layers.add_module('Act4' , nn.ReLU(inplace=True)) 21 | self.layers.add_module('Conv5', nn.Conv2d(16,1,3,1,1)) 22 | 23 | def forward(self, x): 24 | return self.layers(x) 25 | 26 | class Encoder(nn.Module): 27 | def __init__(self): 28 | super(Encoder,self).__init__() 29 | 30 | self.Conv1 = nn.Conv2d(1,16,3,1,1) 31 | self.Relu = nn.ReLU(inplace=True) 32 | 33 | self.layers = nn.ModuleDict({ 34 | 'DenseConv1': nn.Conv2d(16,16,3,1,1), 35 | 'DenseConv2': nn.Conv2d(32,16,3,1,1), 36 | 'DenseConv3': nn.Conv2d(48,16,3,1,1) 37 | }) 38 | 39 | def forward(self, x): 40 | x = self.Relu(self.Conv1(x)) 41 | for i in range(len(self.layers)): 42 | out = self.layers['DenseConv'+str(i+1)] ( x ) 43 | x = torch.cat([x,out],1) 44 | return x 45 | 46 | class DenseFuseNet(nn.Module): 47 | 48 | def __init__(self): 49 | super(DenseFuseNet,self).__init__() 50 | 51 | self.encoder = Encoder() 52 | self.decoder = Decoder() 53 | 54 | def forward(self,x): 55 | return self.decoder(self.encoder(x)) 56 | 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Jun 15 17:32:38 2019 4 | 5 | @author: win10 6 | """ 7 | import torch 8 | 9 | from densefuse_net import DenseFuseNet 10 | from utils import test 11 | 12 | device = 'cuda' 13 | 14 | model = DenseFuseNet().to(device) 15 | model.load_state_dict(torch.load('./train_result/model_weight.pkl')['weight']) 16 | 17 | test_path = './images/IV_images/' 18 | test(test_path, model, mode='add') -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | Project logo 4 |

5 | 6 |

DenseFuse (Pytorch)

7 | 8 | 9 | --- 10 | 11 |

An pytorch implement of DenseFuse. 12 |
13 |

14 | 15 | ## 📝 Table of Contents 16 | - [About](#about) 17 | - [Getting Started](#getting_started) 18 | - [Deployment](#deployment) 19 | - [Usage](#usage) 20 | - [Built Using](#built_using) 21 | - [TODO](../TODO.md) 22 | - [Contributing](../CONTRIBUTING.md) 23 | - [Authors](#authors) 24 | - [Acknowledgments](#acknowledgement) 25 | 26 | ## 🧐 About 27 | This is a pytorch implement of DenseFuse proposed by this paper, 28 | [H. Li, X. J. Wu, “DenseFuse: A Fusion Approach to Infrared and Visible Images,” IEEE Trans. Image Process., vol. 28, no. 5, pp. 2614–2623, May. 2019.](https://arxiv.org/abs/1804.08361) 29 | 30 | The code is writted with torch 1.1.0 and pytorch-ssim. 31 | 32 | 33 | ## 🎈 Usage 34 | 35 | ### Quick start 36 | 1. Clone this repo and unpack it. 37 | 2. Download [test dataset](https://github.com/hli1221/imagefusion_densefuse/tree/master/images) and put test images in './images/IV_images' 38 | 3. run 'main.py' 39 | 40 | ### Training 41 | A pretrained model is available in './train_result/model_weight.pkl'. We train it on MS-COCO 2014. There are 82783 images in total, where 1000 images (COCO_train2014_000000574951 ~ COCO_train2014_000000581921.jpg) are used for validation and the rest are for training. In the training phase, all images are resize to 256x256 and are transformed to gray pictures. Model is optimized by Adam with learning rate being 1e-4. The batch size and epoch number are 2 and 4, respectively. Loss function is MSE+lambda x SSIM, where lambda=1. The experiments were implemented with 2080ti GPU and 32GB RAM. It took about 2 hours. 42 | 43 | If you want to re-train this net, you should download MS-COCO 2014 and run 'train.py' 44 | 45 | 46 | -------------------------------------------------------------------------------- /ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.nn.functional as F 4 | 5 | from torch.autograd import Variable 6 | 7 | import numpy as np 8 | 9 | from math import exp 10 | 11 | 12 | 13 | def gaussian(window_size, sigma): 14 | 15 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 16 | 17 | return gauss/gauss.sum() 18 | 19 | 20 | 21 | def create_window(window_size, channel): 22 | 23 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 24 | 25 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 26 | 27 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 28 | 29 | return window 30 | 31 | 32 | 33 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 34 | 35 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 36 | 37 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 38 | 39 | 40 | 41 | mu1_sq = mu1.pow(2) 42 | 43 | mu2_sq = mu2.pow(2) 44 | 45 | mu1_mu2 = mu1*mu2 46 | 47 | 48 | 49 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 50 | 51 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 52 | 53 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 54 | 55 | 56 | 57 | C1 = 0.01**2 58 | 59 | C2 = 0.03**2 60 | 61 | 62 | 63 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 64 | 65 | 66 | 67 | if size_average: 68 | 69 | return ssim_map.mean() 70 | 71 | else: 72 | 73 | return ssim_map.mean(1).mean(1).mean(1) 74 | 75 | 76 | 77 | class SSIM(torch.nn.Module): 78 | 79 | def __init__(self, window_size = 11, size_average = True): 80 | 81 | super(SSIM, self).__init__() 82 | 83 | self.window_size = window_size 84 | 85 | self.size_average = size_average 86 | 87 | self.channel = 1 88 | 89 | self.window = create_window(window_size, self.channel) 90 | 91 | 92 | 93 | def forward(self, img1, img2): 94 | 95 | (_, channel, _, _) = img1.size() 96 | 97 | 98 | 99 | if channel == self.channel and self.window.data.type() == img1.data.type(): 100 | 101 | window = self.window 102 | 103 | else: 104 | 105 | window = create_window(self.window_size, channel) 106 | 107 | 108 | 109 | if img1.is_cuda: 110 | 111 | window = window.cuda(img1.get_device()) 112 | 113 | window = window.type_as(img1) 114 | 115 | 116 | 117 | self.window = window 118 | 119 | self.channel = channel 120 | 121 | 122 | 123 | 124 | 125 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 126 | 127 | 128 | 129 | def ssim(img1, img2, window_size = 11, size_average = True): 130 | 131 | (_, channel, _, _) = img1.size() 132 | 133 | window = create_window(window_size, channel) 134 | 135 | 136 | 137 | if img1.is_cuda: 138 | 139 | window = window.cuda(img1.get_device()) 140 | 141 | window = window.type_as(img1) 142 | 143 | 144 | 145 | return _ssim(img1, img2, window, window_size, channel, size_average) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Jun 15 15:05:50 2019 4 | 5 | @author: win10 6 | """ 7 | from glob import glob 8 | import string 9 | from utils import test_gray, test_rgb 10 | 11 | test_path = './images/IV_images/' 12 | 13 | def test(test_path, model, img_type='gray', save_path='./test_result/',mode='l1',window_width=1): 14 | img_list = glob(test_path+'*') 15 | img_num = len(img_list)/2 16 | suffix = img_list[0].split('.')[-1] 17 | img_name_list = list(set([img_list[i].split('\\')[-1].split('.')[0].strip(string.digits) for i in range(len(img_list))])) 18 | 19 | if img_type == 'gray': 20 | fusion_phase = test_gray() 21 | elif img_type == 'rgb': 22 | fusion_phase = test_rgb() 23 | 24 | for i in range(int(img_num)): 25 | img1_path = test_path+img_name_list[0]+str(i+1)+'.'+suffix 26 | img2_path = test_path+img_name_list[1]+str(i+1)+'.'+suffix 27 | save_name = 'fusion'+str(i+1)+'_'+img_type+'_'+mode+'.'+suffix 28 | fusion_phase.get_fusion(img1_path,img2_path,model, 29 | save_path = save_path, save_name = save_name, mode=mode,window_width=window_width) 30 | 31 | test(test_path, model, mode='add') -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Jun 13 15:59:40 2019 4 | 5 | @author: win10 6 | """ 7 | import os 8 | os.chdir(r'D:\py_code\densefuse_pytorch') 9 | 10 | from torch.utils.data import DataLoader 11 | import torch.optim as optim 12 | import torch.nn as nn 13 | import torch 14 | 15 | from densefuse_net import DenseFuseNet 16 | from dataset import AEDataset 17 | from ssim import SSIM 18 | from utils import mkdir 19 | 20 | import os 21 | import scipy.io as scio 22 | import numpy as np 23 | from matplotlib import pyplot as plt 24 | 25 | # Parameters 26 | root = 'D:/coco/train2014_train/' 27 | root_val = 'D:/coco/train2014_val/' 28 | train_path = './train_result/' 29 | epochs = 4 30 | batch_size = 2 31 | device = 'cuda' 32 | lr = 1e-4 33 | lambd = 1 34 | loss_interval = 1000 35 | model_interval = 1000 36 | # Dataset 37 | data = AEDataset(root, resize= [256,256], transform = None, gray = True) 38 | loader = DataLoader(data, batch_size = batch_size, shuffle=True) 39 | data_val = AEDataset(root_val, resize= [256,256], transform = None, gray = True) 40 | loader_val = DataLoader(data_val, batch_size = 100, shuffle=True) 41 | 42 | # Model 43 | model = DenseFuseNet().to(device) 44 | print(model) 45 | optimizer = optim.Adam(model.parameters(), lr = lr) 46 | MSE_fun = nn.MSELoss() 47 | SSIM_fun = SSIM() 48 | 49 | 50 | # Training 51 | mse_train = [] 52 | ssim_train = [] 53 | loss_train = [] 54 | mse_val = [] 55 | ssim_val = [] 56 | loss_val = [] 57 | mkdir(train_path) 58 | print('============ Training Begins ===============') 59 | for iteration in range(epochs): 60 | for index, img in enumerate(loader): 61 | img = img.to(device) 62 | 63 | optimizer.zero_grad() 64 | img_recon = model(img) 65 | mse_loss = MSE_fun(img,img_recon) 66 | ssim_loss = 1-SSIM_fun(img,img_recon) 67 | loss = mse_loss+lambd*ssim_loss 68 | loss.backward() 69 | optimizer.step() 70 | 71 | 72 | if index%loss_interval ==0: 73 | print('[%d,%d] - Train - MSE: %.10f, SSIM: %.10f'% 74 | (iteration,index,mse_loss.item(),ssim_loss.item())) 75 | mse_train.append(mse_loss.item()) 76 | ssim_train.append(ssim_loss.item()) 77 | loss_train.append(loss.item()) 78 | 79 | with torch.no_grad(): 80 | tmp1, tmp2 = .0, .0 81 | for _, img in enumerate(loader_val): 82 | img = img.to(device) 83 | img_recon = model(img) 84 | tmp1 += (MSE_fun(img,img_recon)*img.shape[0]).item() 85 | tmp2 += (SSIM_fun(img,img_recon)*img.shape[0]).item() 86 | tmp3 = tmp1+lambd*tmp2 87 | mse_val.append(tmp1/data_val.__len__()) 88 | ssim_val.append(tmp1/data_val.__len__()) 89 | loss_val.append(tmp1/data_val.__len__()) 90 | print('[%d,%d] - Validation - MSE: %.10f, SSIM: %.10f'% 91 | (iteration,index,mse_val[-1],ssim_val[-1])) 92 | scio.savemat(os.path.join(train_path, 'TrainData.mat'), 93 | {'mse_train': np.array(mse_train), 94 | 'ssim_train': np.array(ssim_train), 95 | 'loss_train': np.array(loss_train)}) 96 | scio.savemat(os.path.join(train_path, 'ValData.mat'), 97 | {'mse_val': np.array(mse_val), 98 | 'ssim_val': np.array(ssim_val), 99 | 'loss_val': np.array(loss_val)}) 100 | 101 | plt.figure(figsize=[12,8]) 102 | plt.subplot(2,3,1), plt.semilogy(mse_train), plt.title('mse train') 103 | plt.subplot(2,3,2), plt.semilogy(ssim_train), plt.title('ssim train') 104 | plt.subplot(2,3,3), plt.semilogy(loss_train), plt.title('loss train') 105 | plt.subplot(2,3,4), plt.semilogy(mse_val), plt.title('mse val') 106 | plt.subplot(2,3,5), plt.semilogy(ssim_val), plt.title('ssim val') 107 | plt.subplot(2,3,6), plt.semilogy(loss_val), plt.title('loss val') 108 | 109 | plt.savefig(os.path.join(train_path,'curve.png'),dpi=90) 110 | 111 | if index%model_interval ==0: 112 | torch.save( {'weight': model.state_dict(), 'epoch':iteration, 'batch_index': index}, 113 | os.path.join(train_path,'model_weight_new.pkl')) 114 | print('[%d,%d] - model is saved -'%(iteration,index)) 115 | 116 | -------------------------------------------------------------------------------- /train_result/TrainData.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuangxu96/densefuse_pytorch/54e2ee245d8fc02d7afc3926a59ac7d1537e6c6a/train_result/TrainData.mat -------------------------------------------------------------------------------- /train_result/ValData.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuangxu96/densefuse_pytorch/54e2ee245d8fc02d7afc3926a59ac7d1537e6c6a/train_result/ValData.mat -------------------------------------------------------------------------------- /train_result/curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuangxu96/densefuse_pytorch/54e2ee245d8fc02d7afc3926a59ac7d1537e6c6a/train_result/curve.png -------------------------------------------------------------------------------- /train_result/model_weight.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuangxu96/densefuse_pytorch/54e2ee245d8fc02d7afc3926a59ac7d1537e6c6a/train_result/model_weight.pkl -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Jun 12 22:32:22 2019 4 | 5 | @author: win10 6 | """ 7 | from PIL import Image 8 | import os 9 | import string 10 | from glob import glob 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torchvision.transforms as transforms 16 | 17 | _tensor = transforms.ToTensor() 18 | _pil_rgb = transforms.ToPILImage('RGB') 19 | _pil_gray = transforms.ToPILImage() 20 | device = 'cuda' 21 | 22 | def mkdir(path): 23 | if os.path.exists(path) is False: 24 | os.makedirs(path) 25 | 26 | def load_img(img_path, img_type='gray'): 27 | img = Image.open(img_path) 28 | if img_type=='gray': 29 | img = img.convert('L') 30 | return _tensor(img).unsqueeze(0) 31 | 32 | class Strategy(nn.Module): 33 | def __init__(self, mode='add', window_width=1): 34 | super().__init__() 35 | self.mode = mode 36 | if self.mode == 'l1': 37 | self.window_width = window_width 38 | 39 | def forward(self, y1, y2): 40 | if self.mode == 'add': 41 | return (y1+y2)/2 42 | 43 | if self.mode == 'l1': 44 | ActivityMap1 = y1.abs() 45 | ActivityMap2 = y2.abs() 46 | 47 | kernel = torch.ones(2*self.window_width+1,2*self.window_width+1)/(2*self.window_width+1)**2 48 | kernel = kernel.to(device).type(torch.float32)[None,None,:,:] 49 | kernel = kernel.expand(y1.shape[1],y1.shape[1],2*self.window_width+1,2*self.window_width+1) 50 | ActivityMap1 = F.conv2d(ActivityMap1, kernel, padding=self.window_width) 51 | ActivityMap2 = F.conv2d(ActivityMap2, kernel, padding=self.window_width) 52 | WeightMap1 = ActivityMap1/(ActivityMap1+ActivityMap2) 53 | WeightMap2 = ActivityMap2/(ActivityMap1+ActivityMap2) 54 | return WeightMap1*y1+WeightMap2*y2 55 | 56 | def fusion(x1,x2,model,mode='l1', window_width=1): 57 | with torch.no_grad(): 58 | fusion_layer = Strategy(mode,window_width).to(device) 59 | feature1 = model.encoder(x1) 60 | feature2 = model.encoder(x2) 61 | feature_fusion = fusion_layer(feature1,feature2) 62 | return model.decoder(feature_fusion).squeeze(0).detach().cpu() 63 | 64 | class Test: 65 | def __init__(self): 66 | pass 67 | 68 | def load_imgs(self, img1_path,img2_path, device): 69 | img1 = load_img(img1_path,img_type=self.img_type).to(device) 70 | img2 = load_img(img2_path,img_type=self.img_type).to(device) 71 | return img1, img2 72 | 73 | def save_imgs(self, save_path,save_name, img_fusion): 74 | mkdir(save_path) 75 | save_path = os.path.join(save_path,save_name) 76 | img_fusion.save(save_path) 77 | 78 | class test_gray(Test): 79 | def __init__(self): 80 | self.img_type = 'rgray' 81 | 82 | def get_fusion(self,img1_path,img2_path,model, 83 | save_path = './test_result/', save_name = 'none', mode='l1',window_width=1): 84 | img1, img2 = self.load_imgs(img1_path,img2_path,device) 85 | 86 | img_fusion = fusion(x1=img1,x2=img2,model=model,mode=mode,window_width=window_width) 87 | img_fusion = _pil_gray(img_fusion) 88 | 89 | self.save_imgs(save_path,save_name, img_fusion) 90 | return img_fusion 91 | 92 | class test_rgb(Test): 93 | def __init__(self): 94 | self.img_type = 'rgb' 95 | 96 | def get_fusion(self,img1_path,img2_path,model, 97 | save_path = './test_result/', save_name = 'none', mode='l1',window_width=1): 98 | img1, img2 = self.load_imgs(img1_path,img2_path,device) 99 | 100 | img_fusion = _pil_rgb(torch.cat( 101 | [fusion(img1[:,i,:,:][:,None,:,:], 102 | img2[:,i,:,:][:,None,:,:], model, 103 | mode=mode,window_width=window_width) 104 | for i in range(3)], 105 | dim=0)) 106 | 107 | self.save_imgs(save_path,save_name, img_fusion) 108 | return img_fusion 109 | 110 | 111 | def test(test_path, model, img_type='gray', save_path='./test_result/',mode='l1',window_width=1): 112 | img_list = glob(test_path+'*') 113 | img_num = len(img_list)/2 114 | suffix = img_list[0].split('.')[-1] 115 | img_name_list = list(set([img_list[i].split('\\')[-1].split('.')[0].strip(string.digits) for i in range(len(img_list))])) 116 | 117 | if img_type == 'gray': 118 | fusion_phase = test_gray() 119 | elif img_type == 'rgb': 120 | fusion_phase = test_rgb() 121 | 122 | for i in range(int(img_num)): 123 | img1_path = test_path+img_name_list[0]+str(i+1)+'.'+suffix 124 | img2_path = test_path+img_name_list[1]+str(i+1)+'.'+suffix 125 | save_name = 'fusion'+str(i+1)+'_'+img_type+'_'+mode+'.'+suffix 126 | fusion_phase.get_fusion(img1_path,img2_path,model, 127 | save_path = save_path, save_name = save_name, mode=mode,window_width=window_width) 128 | --------------------------------------------------------------------------------