├── lib ├── __init__.py ├── model.py ├── utils.py ├── dataset.py └── loss.py ├── img ├── structure.png ├── fuse_result1.png └── fuse_result2.png ├── inference.py ├── train.py ├── readme.md └── opts.py /lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /img/structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SunnerLi/DeepFuse.pytorch/HEAD/img/structure.png -------------------------------------------------------------------------------- /img/fuse_result1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SunnerLi/DeepFuse.pytorch/HEAD/img/fuse_result1.png -------------------------------------------------------------------------------- /img/fuse_result2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SunnerLi/DeepFuse.pytorch/HEAD/img/fuse_result2.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from lib.utils import INFO, fusePostProcess 2 | from lib.loss import MEF_SSIM_Loss 3 | from lib.model import DeepFuse 4 | from opts import TestOptions 5 | 6 | import torchvision_sunner.transforms as sunnertransforms 7 | import torchvision_sunner.data as sunnerData 8 | import torchvision.transforms as transforms 9 | 10 | from skimage import io as io 11 | import torch 12 | import cv2 13 | import os 14 | 15 | """ 16 | This script defines the inference procedure of DeepFuse 17 | 18 | Author: SunnerLi 19 | """ 20 | 21 | def inference(opts): 22 | # Load the image 23 | ops = transforms.Compose([ 24 | sunnertransforms.Resize((opts.H, opts.W)), 25 | sunnertransforms.ToTensor(), 26 | sunnertransforms.ToFloat(), 27 | sunnertransforms.Transpose(sunnertransforms.BHWC2BCHW), 28 | sunnertransforms.Normalize(), 29 | ]) 30 | img1 = cv2.imread(opts.image1) 31 | img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2YCrCb) 32 | img1 = torch.unsqueeze(ops(img1), 0) 33 | img2 = cv2.imread(opts.image2) 34 | img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2YCrCb) 35 | img2 = torch.unsqueeze(ops(img2), 0) 36 | 37 | # Load the pre-trained model 38 | model = DeepFuse() 39 | state = torch.load(opts.model) 40 | model.load_state_dict(state['model']) 41 | model.to(opts.device) 42 | model.eval() 43 | criterion = MEF_SSIM_Loss().to(opts.device) 44 | 45 | # Fuse! 46 | with torch.no_grad(): 47 | # Forward 48 | img1, img2 = img1.to(opts.device), img2.to(opts.device) 49 | img1_lum = img1[:, 0:1] 50 | img2_lum = img2[:, 0:1] 51 | model.setInput(img1_lum, img2_lum) 52 | y_f = model.forward() 53 | _, y_hat = criterion(y_1 = img1_lum, y_2 = img2_lum, y_f = y_f) 54 | 55 | # Save the image 56 | img = fusePostProcess(y_f, y_hat, img1, img2, single=False) 57 | cv2.imwrite(opts.res, img[0, :, :,:]) 58 | 59 | if __name__ == '__main__': 60 | opts = TestOptions().parse() 61 | inference(opts) -------------------------------------------------------------------------------- /lib/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | """ 5 | This script defines the DeepFuse model and related module 6 | 7 | Author: SunnerLi 8 | """ 9 | # ------------------------------------------------------------------------------------------------------- 10 | # Define layers 11 | # ------------------------------------------------------------------------------------------------------- 12 | class ConvLayer(nn.Module): 13 | def __init__(self, in_channels = 1, out_channels = 16, kernel_size = 5, last = nn.ReLU): 14 | super().__init__() 15 | if kernel_size == 5: 16 | padding = 2 17 | elif kernel_size == 7: 18 | padding = 3 19 | self.main = nn.Sequential( 20 | nn.Conv2d(in_channels, out_channels, kernel_size = kernel_size, stride = 1, padding = padding), 21 | nn.BatchNorm2d(out_channels), 22 | last() 23 | ) 24 | 25 | def forward(self, x): 26 | out = self.main(x) 27 | return out 28 | 29 | class FusionLayer(nn.Module): 30 | def forward(self, x, y): 31 | return x + y 32 | 33 | # ------------------------------------------------------------------------------------------------------- 34 | # Define model 35 | # ------------------------------------------------------------------------------------------------------- 36 | class DeepFuse(nn.Module): 37 | def __init__(self, device = 'cpu'): 38 | super().__init__() 39 | self.layer1 = ConvLayer(1, 16, 5, last = nn.LeakyReLU) 40 | self.layer2 = ConvLayer(16, 32, 7) 41 | self.layer3 = FusionLayer() 42 | self.layer4 = ConvLayer(32, 32, 7, last = nn.LeakyReLU) 43 | self.layer5 = ConvLayer(32, 16, 5, last = nn.LeakyReLU) 44 | self.layer6 = ConvLayer(16, 1, 5, last = nn.Tanh) 45 | self.device = device 46 | self.to(self.device) 47 | 48 | def setInput(self, y_1, y_2): 49 | self.y_1 = y_1 50 | self.y_2 = y_2 51 | 52 | def forward(self): 53 | c11 = self.layer1(self.y_1[:, 0:1]) 54 | c12 = self.layer1(self.y_2[:, 0:1]) 55 | c21 = self.layer2(c11) 56 | c22 = self.layer2(c12) 57 | f_m = self.layer3(c21, c22) 58 | c3 = self.layer4(f_m) 59 | c4 = self.layer5(c3) 60 | c5 = self.layer6(c4) 61 | return c5 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from lib.dataset import BracketedDataset 2 | from lib.utils import INFO, fusePostProcess 3 | from lib.loss import MEF_SSIM_Loss 4 | from lib.model import DeepFuse 5 | from opts import TrainOptions 6 | 7 | import torchvision_sunner.transforms as sunnertransforms 8 | import torchvision_sunner.data as sunnerData 9 | import torchvision.transforms as transforms 10 | 11 | from matplotlib import pyplot as plt 12 | from torch.optim import Adam 13 | from tqdm import tqdm 14 | 15 | import numpy as np 16 | import torch 17 | import cv2 18 | import os 19 | 20 | """ 21 | This script defines the training procedure of DeepFuse 22 | 23 | Author: SunnerLi 24 | """ 25 | 26 | def train(opts): 27 | # Create the loader 28 | loader = sunnerData.DataLoader( 29 | dataset = BracketedDataset( 30 | root = opts.folder, 31 | crop_size = opts.crop_size, 32 | transform = transforms.Compose([ 33 | sunnertransforms.ToTensor(), 34 | sunnertransforms.ToFloat(), 35 | sunnertransforms.Transpose(sunnertransforms.BHWC2BCHW), 36 | sunnertransforms.Normalize(), 37 | ]) 38 | ), batch_size = opts.batch_size, shuffle = True, num_workers = 0 39 | ) 40 | 41 | # Create the model 42 | model = DeepFuse(device = opts.device) 43 | criterion = MEF_SSIM_Loss().to(opts.device) 44 | optimizer = Adam(model.parameters(), lr = 0.0001) 45 | 46 | # Load pre-train model 47 | if os.path.exists(opts.resume): 48 | state = torch.load(opts.resume) 49 | Loss_list = state['loss'] 50 | model.load_state_dict(state['model']) 51 | else: 52 | Loss_list = [] 53 | 54 | # Train 55 | bar = tqdm(range(opts.epoch)) 56 | for ep in bar: 57 | loss_list = [] 58 | for (patch1, patch2) in loader: 59 | # Extract the luminance and move to computation device 60 | patch1, patch2 = patch1.to(opts.device), patch2.to(opts.device) 61 | patch1_lum = patch1[:, 0:1] 62 | patch2_lum = patch2[:, 0:1] 63 | 64 | # Forward and compute loss 65 | model.setInput(patch1_lum, patch2_lum) 66 | y_f = model.forward() 67 | loss, y_hat = criterion(y_1 = patch1_lum, y_2 = patch2_lum, y_f = y_f) 68 | loss_list.append(loss.item()) 69 | bar.set_description("Epoch: %d Loss: %.6f" % (ep, loss_list[-1])) 70 | 71 | # Update the parameters 72 | optimizer.zero_grad() 73 | loss.backward() 74 | optimizer.step() 75 | Loss_list.append(np.mean(loss_list)) 76 | 77 | # Save the training image 78 | if ep % opts.record_epoch == 0: 79 | img = fusePostProcess(y_f, y_hat, patch1, patch2, single = False) 80 | cv2.imwrite(os.path.join(opts.det, 'image', str(ep) + ".png"), img[0, :, :, :]) 81 | 82 | # Save the training model 83 | if ep % (opts.epoch // 5) == 0: 84 | model_name = str(ep) + ".pth" 85 | else: 86 | model_name = "latest.pth" 87 | state = { 88 | 'model': model.state_dict(), 89 | 'loss' : Loss_list 90 | } 91 | torch.save(state, os.path.join(opts.det, 'model', model_name)) 92 | 93 | # Plot the loss curve 94 | plt.clf() 95 | plt.plot(Loss_list, '-') 96 | plt.title("loss curve") 97 | plt.savefig(os.path.join(opts.det, 'image', "curve.png")) 98 | 99 | if __name__ == '__main__': 100 | opts = TrainOptions().parse() 101 | train(opts) -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import cv2 4 | 5 | """ 6 | This script defines the fundamental function which will be used in other script 7 | 8 | Author: SunnerLi 9 | """ 10 | 11 | L1_NORM = lambda b: torch.sum(torch.abs(b)) 12 | 13 | def INFO(string): 14 | print("[ DeepFuse ] %s" % (string)) 15 | 16 | def weightedFusion(cr1, cr2, cb1, cb2): 17 | """ 18 | Perform the weighted fusing for Cb and Cr channel (paper equation 6) 19 | 20 | Arg: cr1 (torch.Tensor) - The Cr slice of 1st image 21 | cr2 (torch.Tensor) - The Cr slice of 2nd image 22 | cb1 (torch.Tensor) - The Cb slice of 1st image 23 | cb2 (torch.Tensor) - The Cb slice of 2nd image 24 | Ret: The fused Cr slice and Cb slice 25 | """ 26 | # Fuse Cr channel 27 | cr_up = (cr1 * L1_NORM(cr1 - 127.5) + cr2 * L1_NORM(cr2 - 127.5)) 28 | cr_down = L1_NORM(cr1 - 127.5) + L1_NORM(cr2 - 127.5) 29 | cr_fuse = cr_up / cr_down 30 | 31 | # Fuse Cb channel 32 | cb_up = (cb1 * L1_NORM(cb1 - 127.5) + cb2 * L1_NORM(cb2 - 127.5)) 33 | cb_down = L1_NORM(cb1 - 127.5) + L1_NORM(cb2 - 127.5) 34 | cb_fuse = cb_up / cb_down 35 | 36 | return cr_fuse, cb_fuse 37 | 38 | def fusePostProcess(y_f, y_hat, img1, img2, single = True): 39 | """ 40 | Perform the post fusion process toward the both image with generated luminance slice 41 | 42 | Arg: y_f (torch.Tensor) - The generated luminance slice 43 | y_hat (torch.Tensor) - The ground truth luminance slice which is computed by MEF-SSIM formula 44 | img1 (torch.Tensor) - The 1st image tensor (in YCrCb format) 45 | img2 (torch.Tensor) - The 2nd image tensor (in YCrCb format) 46 | single (Bool) - If return the fusion result only or not 47 | Ret: The fusion output image 48 | """ 49 | with torch.no_grad(): 50 | # Recover value space [-1, 1] -> [0, 255] 51 | y_f = (y_f + 1) * 127.5 52 | y_hat = (y_hat + 1) * 127.5 53 | img1 = (img1 + 1) * 127.5 54 | img2 = (img2 + 1) * 127.5 55 | 56 | # weight fusion for Cb and Cr 57 | cr_fuse, cb_fuse = weightedFusion( 58 | cr1 = img1[:, 1:2], 59 | cr2 = img2[:, 1:2], 60 | cb1 = img1[:, 2:3], 61 | cb2 = img2[:, 2:3] 62 | ) 63 | 64 | # YCbCr -> BGR 65 | fuse_out = torch.zeros_like(img1) 66 | fuse_out[:, 0:1] = y_f 67 | fuse_out[:, 1:2] = cr_fuse 68 | fuse_out[:, 2:3] = cb_fuse 69 | fuse_out = fuse_out.transpose(1, 2).transpose(2, 3).cpu().numpy() 70 | fuse_out = fuse_out.astype(np.uint8) 71 | for i, m in enumerate(fuse_out): 72 | fuse_out[i] = cv2.cvtColor(m, cv2.COLOR_YCrCb2BGR) 73 | 74 | # Combine the output 75 | if not single: 76 | out1 = img1.transpose(1, 2).transpose(2, 3).cpu().numpy().astype(np.uint8) 77 | for i, m in enumerate(out1): 78 | out1[i] = cv2.cvtColor(m, cv2.COLOR_YCrCb2BGR) 79 | out2 = img2.transpose(1, 2).transpose(2, 3).cpu().numpy().astype(np.uint8) 80 | for i, m in enumerate(out2): 81 | out2[i] = cv2.cvtColor(m, cv2.COLOR_YCrCb2BGR) 82 | out3 = torch.zeros_like(img1) 83 | out3[:, 0:1] = y_hat 84 | out3[:, 1:2] = cr_fuse 85 | out3[:, 2:3] = cb_fuse 86 | out3 = out3.transpose(1, 2).transpose(2, 3).cpu().numpy() 87 | out3 = out3.astype(np.uint8) 88 | for i, m in enumerate(out3): 89 | out3[i] = cv2.cvtColor(m, cv2.COLOR_YCrCb2BGR) 90 | out = np.concatenate((out1, out2, fuse_out, out3), 2) 91 | else: 92 | out = fuse_out 93 | return out -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # DeepFuse.pytorch 2 | 3 | ### The re-implementation of ICCV 2017 DeepFuse paper idea 4 | 5 | [![Packagist](https://img.shields.io/badge/Pytorch-0.4.1-red.svg)]() 6 | [![Packagist](https://img.shields.io/badge/OpenCV-3.4.3-green.svg)]() 7 | [![Packagist](https://img.shields.io/badge/Torchvision_sunner-18.9.15-yellow.svg)](https://github.com/SunnerLi/Torchvision_sunner) 8 | 9 | ![](https://github.com/SunnerLi/DeepFuse.pytorch/blob/master/img/structure.png) 10 | 11 | Abstraction 12 | --- 13 | Multi-exposure fusion is a critical issue in computer vision. Additionally, this technique can be adopt in smart phone to demonstrate the image with high lighting quality. However, the original author didn't release the official implementation. In this repository, we try to re-produce the idea of DeepFuse [1], and fuse the under-exposure image and over-exposure image with appropriate manner. 14 | 15 | Result 16 | --- 17 | ![](https://github.com/SunnerLi/DeepFuse.pytorch/blob/master/img/fuse_result1.png) 18 | 19 | The above image shows the training result. The most left sub-figure is the under-exposure image. The second sub-figure is the over-exposure image. The third one is the rendered result, and the most right figure is the ground truth which is compute by MEF-SSIM loss concept. As you can see, the rough information of dark region and light region can be both remained. The following image is another example. 20 | 21 | ![](https://github.com/SunnerLi/DeepFuse.pytorch/blob/master/img/fuse_result2.png) 22 | 23 | Idea 24 | --- 25 | **You should notice that this is not the official implementation.** There are several different between this repository and the paper: 26 | 1. Since the dataset that author used cannot be obtained, we use [HDR-Eye dataset](https://mmspg.epfl.ch/hdr-eye?fbclid=IwAR1YLuQvcpu6yM2MsV60LcbURFopzIqqUBKlBUjvbNCQBXxB3iMzgm0Uy8o) [2] which can also deal with multiple exposure fusion problem. 27 | 2. Rather use _64*64_ patch size, we set the patch size as _256*256_. 28 | 3. We only train for 20 epochs. (30000 iteration for each epoch) 29 | 4. The calculation of y^hat is different. The detail can be found in [here](https://github.com/SunnerLi/DeepFuse.pytorch/blob/master/lib/loss.py#L102). 30 | 31 | Usage 32 | --- 33 | The detail of parameters can be found [here](https://github.com/SunnerLi/DeepFuse.pytorch/blob/master/opts.py). You can just simply use the command to train the DeepFuse: 34 | 35 | ``` 36 | python3 train.py --folder ./SunnerDataset/HDREyeDataset/images/Bracketed_images --batch_size 8 --epoch 15000 37 | ``` 38 | Or you can download the pre-trained model [here](https://drive.google.com/file/d/1NYlYeDCyu_KxAjsl9m9X9IXq588rCIq7/view?usp=sharing). Furthermore, inference with two image: 39 | 40 | ``` 41 | python3 inference.py --image1 --image2 --model train_result/model/latest.pth --res result.png 42 | ``` 43 | 44 | Notice 45 | --- 46 | After we check for several machine, we found that the program might get stuck at `cv2.cvtColor` function. We infer the reason is that the OpenCV cannot perfectly embed in the multiprocess mechanism which is provided by Pytorch. As the result, we assign `num_worker` as zero [here](https://github.com/SunnerLi/DeepFuse.pytorch/blob/master/train.py#L38) to avoid the issue. If your machine doesn't encounter this issue, you can add the number to accelerate the loading process. 47 | 48 | Reference 49 | --- 50 | [1] K. R. Prabhakar, V. S. Srikar, and R. V. Babu. Deepfuse: A deep unsupervised approach for exposure fusion with extreme exposure image pairs. In 2017 IEEE International Conference on Computer Vision (ICCV). IEEE, pages 4724–4732, 2017. 51 | [2] H. Nemoto, P. Korshunov, P. Hanhart, and T. Ebrahimi. Visual attention in ldr and hdr images. In 9th International Workshop on Video Processing and Quality Metrics for Consumer Electronics (VPQM), number EPFL-CONF-203873, 2015. -------------------------------------------------------------------------------- /lib/dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as Data 2 | 3 | from skimage import io 4 | from tqdm import tqdm 5 | from glob import glob 6 | import numpy as np 7 | import imageio 8 | import random 9 | import cv2 10 | import os 11 | 12 | """ 13 | This script defines the implementation of the data loader 14 | Notice: you should following the format below: 15 | root --+------ IMAGE_1_FOLDER --+-----under_exposure_image1 16 | | | 17 | | +-----under_exposure_image2 18 | | | 19 | | +-----over_exposure_image1 20 | | | 21 | | +-----over_exposure_image2 22 | | 23 | +------ IMAGE_2_FOLDER --+-----under_exposure_image1 24 | | | 25 | | +-----under_exposure_image2 26 | | | 27 | | +-----over_exposure_image1 28 | | | 29 | | +-----over_exposure_image2 30 | ... 31 | In the root folder, each image use a sub-folder to represent 32 | In each sub-folder, there are several under exposure images and over exposure images 33 | The program will random select one under and over image to crop and return 34 | 35 | Author: SunnerLi 36 | """ 37 | 38 | class BracketedDataset(Data.Dataset): 39 | def __init__(self, root, crop_size = 64, transform = None): 40 | self.files = glob(os.path.join(root, '*/')) 41 | self.crop_size = crop_size 42 | self.transform = transform 43 | self.under_exposure_imgs = [] 44 | self.over_exposure_imgs = [] 45 | self.statistic() 46 | 47 | def statistic(self): 48 | bar = tqdm(self.files) 49 | for folder_name in bar: 50 | bar.set_description(" Statistic the over-exposure and under-exposure image list...") 51 | # Get the mean 52 | mean_list = [] 53 | imgs_list = glob(os.path.join(folder_name, '*')) 54 | for img_name in imgs_list: 55 | img = cv2.imread(img_name) 56 | mean = np.mean(img) 57 | mean_list.append(mean) 58 | mean = np.mean(mean_list) 59 | 60 | # Split the image name 61 | under_list = [] 62 | over_list = [] 63 | for i, m in enumerate(mean_list): 64 | img = cv2.imread(imgs_list[i]) 65 | img = cv2.resize(img, (1200, 800)) 66 | if m > mean: 67 | over_list.append(img) 68 | else: 69 | under_list.append(img) 70 | assert len(under_list) > 0 and len(over_list) > 0 71 | 72 | # Store the result 73 | self.under_exposure_imgs.append(under_list) 74 | self.over_exposure_imgs.append(over_list) 75 | 76 | def __len__(self): 77 | return len(self.files) 78 | 79 | def __getitem__(self, index): 80 | # Random select 81 | under_img = self.under_exposure_imgs[index][random.randint(0, len(self.under_exposure_imgs[index]) - 1)] 82 | over_img = self.over_exposure_imgs[index][random.randint(0, len(self.over_exposure_imgs[index]) - 1)] 83 | under_img = cv2.cvtColor(under_img, cv2.COLOR_BGR2YCrCb) 84 | over_img = cv2.cvtColor(over_img, cv2.COLOR_BGR2YCrCb) 85 | 86 | # Transform 87 | if self.transform: 88 | under_img = self.transform(under_img) 89 | over_img = self.transform(over_img) 90 | 91 | # Crop the patch 92 | _, h, w = under_img.shape 93 | y = random.randint(0, h - self.crop_size) 94 | x = random.randint(0, w - self.crop_size) 95 | under_patch = under_img[:, y:y + self.crop_size, x:x + self.crop_size] 96 | over_patch = over_img [:, y:y + self.crop_size, x:x + self.crop_size] 97 | return under_patch, over_patch -------------------------------------------------------------------------------- /lib/loss.py: -------------------------------------------------------------------------------- 1 | from math import exp 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import torch 5 | 6 | """ 7 | This script defines the MEF-SSIM loss function which is mentioned in the DeepFuse paper 8 | The code is heavily borrowed from: https://github.com/Po-Hsun-Su/pytorch-ssim 9 | 10 | Author: SunnerLi 11 | """ 12 | 13 | L2_NORM = lambda b: torch.sqrt(torch.sum((b + 1e-8) ** 2)) 14 | 15 | class MEF_SSIM_Loss(nn.Module): 16 | def __init__(self, window_size = 11, size_average = True): 17 | """ 18 | Constructor 19 | """ 20 | super().__init__() 21 | self.window_size = window_size 22 | self.size_average = size_average 23 | self.channel = 1 24 | self.window = self.create_window(window_size, self.channel) 25 | 26 | def gaussian(self, window_size, sigma): 27 | """ 28 | Get the gaussian kernel which will be used in SSIM computation 29 | """ 30 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 31 | return gauss/gauss.sum() 32 | 33 | def create_window(self, window_size, channel): 34 | """ 35 | Create the gaussian window 36 | """ 37 | _1D_window = self.gaussian(window_size, 1.5).unsqueeze(1) 38 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 39 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 40 | return window 41 | 42 | def _ssim(self, img1, img2, window, window_size, channel, size_average = True): 43 | """ 44 | Compute the SSIM for the given two image 45 | The original source is here: https://stackoverflow.com/questions/39051451/ssim-ms-ssim-for-tensorflow 46 | """ 47 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 48 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 49 | 50 | mu1_sq = mu1.pow(2) 51 | mu2_sq = mu2.pow(2) 52 | mu1_mu2 = mu1*mu2 53 | 54 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 55 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 56 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 57 | 58 | C1 = 0.01**2 59 | C2 = 0.03**2 60 | 61 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 62 | 63 | if size_average: 64 | return ssim_map.mean() 65 | else: 66 | return ssim_map.mean(1).mean(1).mean(1) 67 | 68 | def w_fn(self, y): 69 | """ 70 | Return the weighting function that MEF-SSIM defines 71 | We use the power engery function as the paper describe: https://ece.uwaterloo.ca/~k29ma/papers/15_TIP_MEF.pdf 72 | 73 | Arg: y (torch.Tensor) - The structure tensor 74 | Ret: The weight of the given structure 75 | """ 76 | out = torch.sqrt(torch.sum(y ** 2)) 77 | return out 78 | 79 | def forward(self, y_1, y_2, y_f): 80 | """ 81 | Compute the MEF-SSIM for the given image pair and output image 82 | The y_1 and y_2 can exchange 83 | 84 | Arg: y_1 (torch.Tensor) - The LDR image 85 | y_2 (torch.Tensor) - Another LDR image in the same stack 86 | y_f (torch.Tensor) - The fused HDR image 87 | Ret: The loss value 88 | """ 89 | miu_y = (y_1 + y_2) / 2 90 | 91 | # Get the c_hat 92 | c_1 = L2_NORM(y_1 - miu_y) 93 | c_2 = L2_NORM(y_2 - miu_y) 94 | c_hat = torch.max(torch.stack([c_1, c_2])) 95 | 96 | # Get the s_hat 97 | s_1 = (y_1 - miu_y) / L2_NORM(y_1 - miu_y) 98 | s_2 = (y_2 - miu_y) / L2_NORM(y_2 - miu_y) 99 | s_bar = (self.w_fn(y_1) * s_1 + self.w_fn(y_2) * s_2) / (self.w_fn(y_1) + self.w_fn(y_2)) 100 | s_hat = s_bar / L2_NORM(s_bar) 101 | 102 | # ============================================================================================= 103 | # < Get the y_hat > 104 | # 105 | # Rather to output y_hat, we shift it with the mean of the over-exposure image and mean image 106 | # The result will much better than the original formula 107 | # ============================================================================================= 108 | y_hat = c_hat * s_hat 109 | y_hat += (y_2 + miu_y) / 2 110 | 111 | # Check if need to create the gaussian window 112 | (_, channel, _, _) = y_hat.size() 113 | if channel == self.channel and self.window.data.type() == y_hat.data.type(): 114 | window = self.window 115 | else: 116 | window = self.create_window(self.window_size, channel) 117 | window = window.to(y_f.get_device()) 118 | window = window.type_as(y_hat) 119 | self.window = window 120 | self.channel = channel 121 | 122 | # Compute SSIM between y_hat and y_f 123 | score = self._ssim(y_hat, y_f, window, self.window_size, channel, self.size_average) 124 | return 1 - score, y_hat -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | from lib.utils import INFO 2 | import argparse 3 | import torch 4 | import os 5 | 6 | """ 7 | This script defines the procedure to parse the parameters 8 | 9 | Author: SunnerLi 10 | """ 11 | 12 | def presentParameters(args_dict): 13 | """ 14 | Print the parameters setting line by line 15 | 16 | Arg: args_dict - The dict object which is transferred from argparse Namespace object 17 | """ 18 | INFO("========== Parameters ==========") 19 | for key in sorted(args_dict.keys()): 20 | INFO("{:>15} : {}".format(key, args_dict[key])) 21 | INFO("===============================") 22 | 23 | class TrainOptions(): 24 | """ 25 | Argument Explaination 26 | ====================================================================================================================== 27 | Symbol Type Default Explaination 28 | ---------------------------------------------------------------------------------------------------------------------- 29 | --folder Str /images/Bracketed_images The folder path of bracketed image 30 | --crop_size Int 256 - 31 | --batch_size Int 8 - 32 | --resume Str 1.pth The path of pre-trained model 33 | --det Str train_result The path of folder you want to store the result in 34 | --epoch Int 15000 - 35 | --record_epoch Int 100 The period you want to store the result 36 | ---------------------------------------------------------------------------------------------------------------------- 37 | """ 38 | def __init__(self): 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument('--folder' , type = str, default = "/home/sunner/Music/HDREyeDataset/images/Bracketed_images") 41 | parser.add_argument('--crop_size' , type = int, default = 256) 42 | parser.add_argument('--batch_size' , type = int, default = 8) 43 | parser.add_argument('--resume' , type = str, default = "1.pth") 44 | parser.add_argument('--det' , type = str, default = "train_result") 45 | parser.add_argument('--epoch' , type = int, default = 15000) 46 | parser.add_argument('--record_epoch' , type = int, default = 100) 47 | self.opts = parser.parse_args() 48 | self.opts.device = 'cuda' if torch.cuda.is_available() else 'cpu' 49 | 50 | def parse(self): 51 | # Print the parameter first 52 | presentParameters(vars(self.opts)) 53 | 54 | # Create the folder 55 | det_name = self.opts.det 56 | image_folder_name = os.path.join(det_name, "image") 57 | model_folder_name = os.path.join(det_name, "model") 58 | if not os.path.exists(self.opts.det): 59 | os.mkdir(self.opts.det) 60 | if not os.path.exists(image_folder_name): 61 | os.mkdir(image_folder_name) 62 | if not os.path.exists(model_folder_name): 63 | os.mkdir(model_folder_name) 64 | return self.opts 65 | 66 | ########################################################################################################################################### 67 | 68 | class TestOptions(): 69 | """ 70 | Argument Explaination 71 | ====================================================================================================================== 72 | Symbol Type Default Explaination 73 | ---------------------------------------------------------------------------------------------------------------------- 74 | --image1 Str X The path of under-exposure image 75 | --image2 Str X The path of over-exposure image 76 | --model Str model.pth The path of pre-trained model 77 | --res Str result.png The path to store the fusing image 78 | --H Int 400 The height of the result image 79 | --W Int 600 The width of the result image 80 | ---------------------------------------------------------------------------------------------------------------------- 81 | """ 82 | def __init__(self): 83 | parser = argparse.ArgumentParser() 84 | parser.add_argument('--image1' , type = str, required = True) 85 | parser.add_argument('--image2' , type = str, required = True) 86 | parser.add_argument('--model' , type = str, default = "model.pth") 87 | parser.add_argument('--res' , type = str, default = 'result.png') 88 | parser.add_argument('--H' , type = int, default = 400) 89 | parser.add_argument('--W' , type = int, default = 600) 90 | self.opts = parser.parse_args() 91 | self.opts.device = 'cuda' if torch.cuda.is_available() else 'cpu' 92 | 93 | def parse(self): 94 | presentParameters(vars(self.opts)) 95 | return self.opts --------------------------------------------------------------------------------