├── README.md ├── config ├── DenseFuse.yaml └── VIF_Net.yaml ├── core ├── dataset │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── fusion_datasets.cpython-36.pyc │ │ └── fusion_datasets.cpython-38.pyc │ ├── crop_datasets.py │ └── fusion_datasets.py ├── loss │ ├── Dist_Loss.py │ ├── SSIM_Loss.py │ ├── TV_Loss.py │ ├── VIF_SSIM_Loss.py │ ├── __init__.py │ └── __pycache__ │ │ ├── Dist_Loss.cpython-36.pyc │ │ ├── Dist_Loss.cpython-38.pyc │ │ ├── SSIM_Loss.cpython-36.pyc │ │ ├── SSIM_Loss.cpython-38.pyc │ │ ├── TV_Loss.cpython-36.pyc │ │ ├── TV_Loss.cpython-38.pyc │ │ ├── VIF_SSIM_Loss.cpython-36.pyc │ │ ├── VIF_SSIM_Loss.cpython-38.pyc │ │ ├── __init__.cpython-36.pyc │ │ └── __init__.cpython-38.pyc ├── model │ ├── DenseFuse.py │ ├── VIF_Net.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── DenseFuse.cpython-36.pyc │ │ ├── DenseFuse.cpython-38.pyc │ │ ├── VIF_Net.cpython-36.pyc │ │ ├── VIF_Net.cpython-38.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── model.cpython-36.pyc │ │ └── model.cpython-38.pyc │ └── model.py └── util │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ ├── utils.cpython-36.pyc │ └── utils.cpython-38.pyc │ └── utils.py ├── datasets ├── TNO │ ├── Inf │ │ └── test.jpg │ └── Vis │ │ └── test.jpg └── TNO_crop │ ├── Inf │ └── test.jpg │ └── Vis │ └── test.jpg ├── img ├── TensorBoard_0.png ├── TensorBoard_1.png └── results.jpg ├── run.py ├── test └── results.jpg ├── tools ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ ├── test.cpython-36.pyc │ ├── test.cpython-38.pyc │ ├── train.cpython-36.pyc │ └── train.cpython-38.pyc ├── test.py └── train.py └── work_dirs ├── TensorBoard_0.png └── TensorBoard_1.png /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch_Image_Fusion 2 |   基于Pytorch框架的多源图像像素级融合,包含针对多种网络的复现。 3 |   The pixel level fusion of multi-source images based on the pytorch framework includes the reproduction of multiple networks. 4 |    详细请访问 👉 https://blog.csdn.net/qq_36449741/article/details/104406931 5 | 6 | ![results](https://github.com/ChangeZH/Pytorch_Image_Fusion/blob/main/test/results.jpg) 7 | 8 | ## 环境要求 / Environmental Requirements 9 | 10 | ``` 11 | conda create -n PIF python=3.7 12 | conda activate PIF 13 | conda install pytorch=1.6.0 torchvision -c pytorch 14 | pip install tqdm pyyaml tensorboardX opencv-python 15 | ``` 16 | 17 | ## 数据集 / Dataset 18 | ⚡ TNO数据集下载地址 👉 链接:https://pan.baidu.com/s/1-6b-0onDCEPHAMUWyEkmtA 提取码:`PIF0` 19 | 20 | 注意要保证在不同数据类型文件夹下图片名称完全一样。 21 | 提供切片裁剪程序 ` ./core/dataset/crop_datasets.py ` ,修改: 22 | ```python 23 | # 此文件为./core/dataset/crop_datasets.py 45行 24 | 25 | if __name__ == '__main__': 26 | crop(path_dict={'Vis': '../../datasets/TNO/Vis/', 'Inf': '../../datasets/TNO/Inf/'}, # 数据类型与其路径的对应字典,例如'Vis'数据的路径为'../../datasets/TNO/Vis/','Inf'数据的路径为'../../datasets/TNO/Inf/',类型为字典 27 | crop_sizes=[64, 128, 256], # 切片大小,可以多种大小,类型为列表 28 | overlap_sizes=[32, 64, 128], # 切片重叠区域大小,与切片大小对应,不能大于对应切片大小,类型为列表 29 | save_path='') # 切片保存位置,类型为字符串 30 | ``` 31 | 修改后运行 ` python crop_datasets.py ` 进行数据切片。 32 | 33 | ## 参数设置 / Parameter Setting 34 | 35 | ```python 36 | # 此文件为./config/VIF_Net.yaml 37 | 38 | PROJECT: # 项目参数 39 | name: 'VIF_Net_Image_Fusion' # 项目名称 40 | save_path: './work_dirs/' # 项目保存路径,训练模型会保存至此路径下的项目名称文件夹中 41 | 42 | TRAIN_DATASET: # 训练数据集参数 43 | root_dir: './datasets/TNO_crop/' # 训练数据集根目录 44 | sensors: [ 'Vis', 'Inf' ] # 训练数据集包含的数据类型 45 | channels: 1 # 训练数据中图片的通道数 46 | input_size: 128 # 训练数据中图片的尺寸 47 | mean: [ 0.485, 0.456, 0.406 ] # 训练数据中图片的归一化均值(暂时用不到) 48 | std: [ 0.229, 0.224, 0.225 ] # 训练数据中图片的归一化标准差(暂时用不到) 49 | 50 | TRAIN: # 训练参数 51 | batch_size: 32 # 训练批次大小 52 | max_epoch: 200 # 训练最大代数 53 | lr: 0.01 # 训练学习率 54 | gamma: 0.01 # 训练学习率衰减系数 55 | milestones: [ 100, 150, 175 ] # 训练学习率衰减的里程碑 56 | opt: Adam # 训练优化器 57 | loss_func: ['VIF_SSIM_Loss', 'TV_Loss'] # 训练使用的损失函数 58 | val_interval: 1 # 训练每过多少代数后保存权重 59 | debug_interval: 100 # 训练每过多少批次后进行可视化,结果可视化在tensorboard中 60 | resume: None # 训练停止后继续训练加载权重路径 61 | loss_weights: [ 1000, 1 ] # 对VIF_Net的两个损失的权值 62 | 63 | TEST_DATASET: # 测试数据集参数 64 | root_dir: './datasets/TNO/' # 测试数据集根目录 65 | sensors: [ 'Vis', 'Inf' ] # 测试数据集包含的数据类型 66 | channels: 1 # 测试数据中图片的通道数 67 | input_size: 512 # 测试数据中图片的尺寸 68 | mean: [ 0.485, 0.456, 0.406 ] # 测试数据中图片的归一化均值(暂时用不到) 69 | std: [ 0.229, 0.224, 0.225 ] # 测试数据中图片的归一化标准差(暂时用不到) 70 | 71 | TEST: # 测试参数 72 | batch_size: 2 # 测试批次大小 73 | weight_path: './work_dirs/VIF_Net_Image_Fusion/model_50.pth' # 测试加载的权重路径 74 | save_path: './test/' # 测试结果保存路径 75 | 76 | MODEL: # 模型参数 77 | model_name: 'VIF_Net' # 模型名称 78 | input_channels: 1 # 模型输入通道数 79 | out_channels: 16 # 模型每一层输出的通道数 80 | input_sensors: [ 'Vis', 'Inf' ] # 模型输入数据类型 81 | coder_layers: 4 # 模型编码器层数 82 | decoder_layers: 4 # 模型解码器层数 83 | 84 | ``` 85 | 86 | ## 训练与测试 / Training And Testing 87 | 88 | ### 训练 / Training 89 | 运行 ` python run.py --train ` 进行训练。训练的模型权重会保存再指定的路径下。 90 | 91 | #### 训练DenseFuse 92 | 运行 ` python run.py --train --config ./config/DenseFuse.yaml` 进行训练。训练的模型权重会保存再指定的路径下。 93 | 94 | #### 训练VIF_Net 95 | 运行 ` python run.py --train --config ./config/VIF_Net.yaml` 进行训练。训练的模型权重会保存再指定的路径下。 96 | 97 | #### tensorboardX进行训练可视化 98 | 运行 ` tensorboard --logdir= XXX ` 进行训练可视化。将 ` XXX ` 替换为模型储存的路径。例如,config中有如下参数: 99 | ```python 100 | PROJECT: 101 | name: 'VIF_Net_Image_Fusion' 102 | save_path: './work_dirs/' 103 | weight_path: '' 104 | ``` 105 |   可运行 ` tensorboard --logdir= ./work_dirs/VIF_Net_Image_Fusion/ ` 进行训练可视化。再次训练后最好删除之前的 ` events ` 文件。 106 | ![SCALARS](https://github.com/ChangeZH/Pytorch_Image_Fusion/blob/main/work_dirs/TensorBoard_0.png) 107 | ![IMAGES](https://github.com/ChangeZH/Pytorch_Image_Fusion/blob/main/work_dirs/TensorBoard_1.png) 108 |   上图中每三行为一组,前两行为输入数据,第三行为融合结果。 109 | 110 | ### 测试 / Testing 111 |   运行 ` python run.py --test ` 进行测试。结果会批量保存至指定路径下。 112 | 113 | ## 预训练模型 / Pre-training Model 114 | - [x] ⚡ VIF_Net 👉 链接:https://pan.baidu.com/s/1avjiuNTovsoFmUWd5aPpzg 提取码:PIF2 115 | - [x] ⚡ DenseFuse 👉 链接:https://pan.baidu.com/s/1MzlbMhIvrFB7HxPAWdCdmQ 提取码:PIF3 116 | 117 | ## 计划中 / To Do 118 | - [x] VIF_Net 👉 https://blog.csdn.net/qq_36449741/article/details/104562999 119 | - [x] DenseFuse 👉 https://blog.csdn.net/qq_36449741/article/details/104776319 120 | -------------------------------------------------------------------------------- /config/DenseFuse.yaml: -------------------------------------------------------------------------------- 1 | PROJECT: 2 | name: 'DenseFuse_Image_Fusion' 3 | save_path: './work_dirs/' 4 | 5 | TRAIN_DATASET: 6 | root_dir: './datasets/TNO/' 7 | sensors: [ 'Vis', 'Inf' ] 8 | channels: 1 9 | input_size: 256 10 | mean: [ 0.485, 0.456, 0.406 ] 11 | std: [ 0.229, 0.224, 0.225 ] 12 | 13 | TRAIN: 14 | batch_size: 8 15 | gpu_id: 1 16 | max_epoch: 2000 17 | lr: 0.01 18 | gamma: 0.01 19 | milestones: [ 1000, 1500, 1750 ] 20 | opt: Adam 21 | loss_func: ['Dist_Loss', 'SSIM_Loss'] 22 | val_interval: 1 23 | debug_interval: 100 24 | resume: None 25 | loss_weights: [ 1, 1000 ] 26 | 27 | TEST_DATASET: 28 | root_dir: './datasets/TNO/' 29 | sensors: [ 'Vis', 'Inf' ] 30 | channels: 1 31 | input_size: 256 32 | mean: [ 0.485, 0.456, 0.406 ] 33 | std: [ 0.229, 0.224, 0.225 ] 34 | 35 | TEST: 36 | batch_size: 2 37 | weight_path: './work_dirs/DenseFuse_Image_Fusion/model_2000.pth' 38 | save_path: './test/' 39 | 40 | MODEL: 41 | model_name: 'DenseFuse' 42 | input_channels: 1 43 | out_channels: 16 44 | input_sensors: [ 'Vis', 'Inf' ] 45 | coder_layers: 4 46 | decoder_layers: 4 47 | -------------------------------------------------------------------------------- /config/VIF_Net.yaml: -------------------------------------------------------------------------------- 1 | PROJECT: 2 | name: 'VIF_Net_Image_Fusion' 3 | save_path: './work_dirs/' 4 | 5 | TRAIN_DATASET: 6 | root_dir: './datasets/TNO/' 7 | sensors: [ 'Vis', 'Inf' ] 8 | channels: 1 9 | input_size: 256 10 | mean: [ 0.485, 0.456, 0.406 ] 11 | std: [ 0.229, 0.224, 0.225 ] 12 | 13 | TRAIN: 14 | batch_size: 8 15 | max_epoch: 200 16 | lr: 0.01 17 | gamma: 0.01 18 | milestones: [ 100, 150, 175 ] 19 | opt: Adam 20 | loss_func: ['VIF_SSIM_Loss', 'TV_Loss'] 21 | val_interval: 1 22 | debug_interval: 100 23 | resume: None 24 | loss_weights: [ 1000, 1 ] 25 | 26 | TEST_DATASET: 27 | root_dir: './datasets/TNO/' 28 | sensors: [ 'Vis', 'Inf' ] 29 | channels: 1 30 | input_size: 256 31 | mean: [ 0.485, 0.456, 0.406 ] 32 | std: [ 0.229, 0.224, 0.225 ] 33 | 34 | TEST: 35 | batch_size: 2 36 | weight_path: './work_dirs/VIF_Net_Image_Fusion/model_100.pth' 37 | save_path: './test/' 38 | 39 | MODEL: 40 | model_name: 'VIF_Net' 41 | input_channels: 1 42 | out_channels: 16 43 | input_sensors: [ 'Vis', 'Inf' ] 44 | coder_layers: 4 45 | decoder_layers: 4 46 | -------------------------------------------------------------------------------- /core/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .fusion_datasets import Fusion_Datasets 2 | 3 | __all__ = [ 4 | 'Fusion_Datasets', 5 | ] 6 | -------------------------------------------------------------------------------- /core/dataset/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/core/dataset/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /core/dataset/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/core/dataset/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /core/dataset/__pycache__/fusion_datasets.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/core/dataset/__pycache__/fusion_datasets.cpython-36.pyc -------------------------------------------------------------------------------- /core/dataset/__pycache__/fusion_datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/core/dataset/__pycache__/fusion_datasets.cpython-38.pyc -------------------------------------------------------------------------------- /core/dataset/crop_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import itertools 4 | import numpy as np 5 | 6 | 7 | def crop(path_dict, crop_sizes, overlap_sizes, save_path): 8 | num = 0 9 | sensors = [i for i in path_dict] 10 | img_list = {i: os.listdir(path_dict[i]) for i in path_dict} 11 | for i in sensors: 12 | if not os.path.exists(os.path.join(save_path, i)): 13 | os.mkdir(os.path.join(save_path, i)) 14 | 15 | for name in img_list[sensors[0]]: 16 | img = {i: cv2.imread(os.path.join(path_dict[i], name)) for i in sensors} 17 | img_shape = img[sensors[0]].shape 18 | for index in range(len(crop_sizes)): 19 | crop_size = crop_sizes[index] 20 | overlap_size = overlap_sizes[index] 21 | y_min = np.arange(0, img_shape[0], crop_size - overlap_size) 22 | y_min = np.array(list(y_min) + [img_shape[0]]) 23 | x_min = np.arange(0, img_shape[1], crop_size - overlap_size) 24 | x_min = np.array(list(x_min) + [img_shape[1]]) 25 | y_min = np.unique(np.clip(y_min, a_min=0, a_max=img_shape[0] - crop_size)) 26 | x_min = np.unique(np.clip(x_min, a_min=0, a_max=img_shape[1] - crop_size)) 27 | crop_bboxes = [] 28 | for bbox in itertools.product(y_min, x_min): 29 | if bbox not in crop_bboxes: 30 | crop_bboxes.append(bbox) 31 | else: 32 | continue 33 | for bbox in crop_bboxes: 34 | num += 1 35 | crop_img = {i: img[i][bbox[0]:bbox[0] + crop_size, bbox[1]:bbox[1] + crop_size] for i in img} 36 | for i in sensors: 37 | cv2.imwrite(os.path.join(save_path, i, 38 | name.split('.')[0] + '_' + str(bbox[0]) + '_' + str(bbox[1]) + '.jpg'), 39 | crop_img[i]) 40 | 41 | print(num) 42 | return 0 43 | 44 | 45 | if __name__ == '__main__': 46 | crop(path_dict={'Vis': '../../datasets/TNO/Vis/', 'Inf': '../../datasets/TNO/Inf/'}, 47 | crop_sizes=[64, 128, 256], 48 | overlap_sizes=[32, 64, 128], 49 | save_path='') 50 | -------------------------------------------------------------------------------- /core/dataset/fusion_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from PIL import Image 5 | from collections import Counter 6 | from torchvision import transforms 7 | from torch.utils.data import DataLoader, Dataset 8 | 9 | 10 | class Fusion_Datasets(Dataset): 11 | """docstring for Fusion_Datasets""" 12 | 13 | def __init__(self, configs, transform=None): 14 | super(Fusion_Datasets, self).__init__() 15 | self.root_dir = configs['root_dir'] 16 | self.transform = transform 17 | self.channels = configs['channels'] 18 | self.sensors = configs['sensors'] 19 | self.img_list = {i: os.listdir(os.path.join(self.root_dir, i)) for i in self.sensors} 20 | self.img_path = {i: [os.path.join(self.root_dir, i, j) for j in os.listdir(os.path.join(self.root_dir, i))] 21 | for i in self.sensors} 22 | 23 | def __getitem__(self, index): 24 | img_data = {} 25 | for i in self.sensors: 26 | img = Image.open(self.img_path[i][index]) 27 | # print(self.img_path[i][index]) 28 | if self.channels == 1: 29 | img = img.convert('L') 30 | elif self.channels == 3: 31 | img = img.convert('RGB') 32 | if self.transform is not None: 33 | img = self.transform(img) 34 | img_data.update({i: img}) 35 | return img_data 36 | 37 | def __len__(self): 38 | img_num = [len(self.img_list[i]) for i in self.img_list] 39 | img_counter = Counter(img_num) 40 | assert len(img_counter) == 1, 'Sensors Has Different length' 41 | return img_num[0] 42 | 43 | 44 | if __name__ == '__main__': 45 | datasets = Fusion_Datasets(root_dir='../../datasets/TNO/', sensors=['Vis', 'Inf'], 46 | transform=transforms.Compose([transforms.Resize((512, 512)), transforms.ToTensor()])) 47 | train = DataLoader(datasets, 1, False) 48 | print(len(train)) 49 | for i, data in enumerate(train): 50 | print(data) 51 | -------------------------------------------------------------------------------- /core/loss/Dist_Loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Dist_Loss(nn.Module): 6 | """docstring for Dist_Loss""" 7 | 8 | def __init__(self, sensors, p=2): 9 | super(Dist_Loss, self).__init__() 10 | self.p = p 11 | 12 | def forward(self, input_images, output_images): 13 | return sum([torch.dist(input_images[sensor], output_images[sensor], p=self.p) for sensor in input_images]) 14 | -------------------------------------------------------------------------------- /core/loss/SSIM_Loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SSIM_Loss(nn.Module): 7 | """docstring for SSIM_Loss""" 8 | 9 | def __init__(self, sensors, num_channels=3, C=9e-4, device='cuda:0'): 10 | super(SSIM_Loss, self).__init__() 11 | self.sensor = sensors 12 | self.num_channels = num_channels 13 | self.device = device 14 | self.c = C 15 | 16 | def forward(self, input_images, output_images): 17 | batch_size, num_channels = input_images[self.sensor[0]].shape[0], input_images[self.sensor[0]].shape[1] 18 | ssim_loss = 0 19 | for sensor in input_images: 20 | for batch in range(batch_size): 21 | input_image, output_image = input_images[sensor][batch], output_images[sensor][batch] 22 | 23 | input_image_mean = torch.mean(input_image, dim=[1, 2]) 24 | output_image_mean = torch.mean(output_image, dim=[1, 2]) 25 | C = torch.ones_like(input_image_mean) * self.c 26 | 27 | input_image_var = torch.mean(input_image ** 2, dim=[1, 2]) - input_image_mean ** 2 28 | input_image_std = input_image_var ** .5 29 | 30 | output_image_var = torch.mean(output_image ** 2, dim=[1, 2]) - output_image_mean ** 2 31 | output_image_std = output_image_var ** .5 32 | 33 | input_output_var = torch.mean(input_image * output_image, 34 | dim=[1, 2]) - input_image_mean * output_image_mean 35 | 36 | l = (2 * input_image_mean * output_image_mean + C) / (input_image_mean ** 2 + output_image_mean ** 2 + C) 37 | c = (2 * input_image_std * output_image_std + C) / (input_image_std ** 2 + output_image_std ** 2 + C) 38 | s = (input_output_var + 2 * C) / (input_image_std * output_image_std + 2 * C) 39 | 40 | ssim_loss += 1 - l * c * s 41 | 42 | return ssim_loss.mean() 43 | 44 | 45 | if __name__ == '__main__': 46 | loss = SSIM_Loss(num_channels=3, C=9e-4, device='cpu') 47 | vis_images = torch.rand(2, 3, 256, 256) 48 | fusion_images = torch.rand(2, 3, 256, 256) 49 | print(loss(vis_images, fusion_images)) 50 | -------------------------------------------------------------------------------- /core/loss/TV_Loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class TV_Loss(nn.Module): 6 | """docstring for TV_Loss""" 7 | 8 | def __init__(self, sensors, num_inputs=2): 9 | super(TV_Loss, self).__init__() 10 | self.num_inputs = num_inputs 11 | 12 | def forward(self, input_images, output_images): 13 | input_images = [input_images[i] for i in input_images] 14 | fusion_images = output_images['Fusion'] 15 | tv_loss = 0 16 | for i in range(self.num_inputs): 17 | input_image = input_images[i] 18 | H, W = input_image.shape[2], input_image.shape[3] 19 | R = input_image - fusion_images 20 | L_tv = torch.pow(R[:, :, 1:H, :] - R[:, :, 0:H - 1, :], 2).sum() + \ 21 | torch.pow(R[:, :, :, 1:W] - R[:, :, :, 0:W - 1], 2).sum() 22 | tv_loss += L_tv 23 | return tv_loss 24 | 25 | 26 | if __name__ == '__main__': 27 | loss = TV_Loss(num_inputs=2) 28 | vis_images = torch.rand(2, 1, 256, 256) 29 | inf_images = torch.rand(2, 1, 256, 256) 30 | fusion_images = torch.rand(2, 1, 256, 256) 31 | print(loss({'0': vis_images, '1': inf_images}, fusion_images)) 32 | -------------------------------------------------------------------------------- /core/loss/VIF_SSIM_Loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class VIF_SSIM_Loss(nn.Module): 7 | """docstring for VIF_SSIM_Loss""" 8 | 9 | def __init__(self, sensors, kernal_size=11, num_channels=1, C=9e-4, device='cuda:0'): 10 | super(VIF_SSIM_Loss, self).__init__() 11 | self.sensors = sensors 12 | self.kernal_size = kernal_size 13 | self.num_channels = num_channels 14 | self.device = device 15 | self.c = C 16 | 17 | self.avg_kernal = torch.ones(num_channels, 1, self.kernal_size, self.kernal_size) / (self.kernal_size) ** 2 18 | self.avg_kernal = self.avg_kernal.to(device) 19 | 20 | def forward(self, input_images, output_images): 21 | vis_images, inf_images, fusion_images = input_images[self.sensors[0]], input_images[self.sensors[1]], \ 22 | output_images['Fusion'] 23 | batch_size, num_channels = vis_images.shape[0], vis_images.shape[1] 24 | 25 | vis_images_mean = F.conv2d(vis_images, self.avg_kernal, stride=self.kernal_size, groups=num_channels) 26 | vis_images_var = torch.abs(F.conv2d(vis_images ** 2, self.avg_kernal, stride=self.kernal_size, 27 | groups=num_channels) - vis_images_mean ** 2) 28 | 29 | inf_images_mean = F.conv2d(inf_images, self.avg_kernal, stride=self.kernal_size, groups=num_channels) 30 | inf_images_var = torch.abs(F.conv2d(inf_images ** 2, self.avg_kernal, stride=self.kernal_size, 31 | groups=num_channels) - inf_images_mean ** 2) 32 | 33 | fusion_images_mean = F.conv2d(fusion_images, self.avg_kernal, stride=self.kernal_size, groups=num_channels) 34 | fusion_images_var = torch.abs(F.conv2d(fusion_images ** 2, self.avg_kernal, stride=self.kernal_size, 35 | groups=num_channels) - fusion_images_mean ** 2) 36 | 37 | vis_fusion_images_var = F.conv2d(vis_images * fusion_images, self.avg_kernal, stride=self.kernal_size, 38 | groups=num_channels) - vis_images_mean * fusion_images_mean 39 | inf_fusion_images_var = F.conv2d(inf_images * fusion_images, self.avg_kernal, stride=self.kernal_size, 40 | groups=num_channels) - inf_images_mean * fusion_images_mean 41 | 42 | C = torch.ones_like(fusion_images_mean) * self.c 43 | 44 | ssim_l_vis_fusion = (2 * vis_images_mean * fusion_images_mean + C) / \ 45 | (vis_images_mean ** 2 + fusion_images_mean ** 2 + C) 46 | ssim_l_inf_fusion = (2 * inf_images_mean * fusion_images_mean + C) / \ 47 | (inf_images_mean ** 2 + fusion_images_mean ** 2 + C) 48 | 49 | ssim_s_vis_fusion = (vis_fusion_images_var + C) / (vis_images_var + fusion_images_var + C) 50 | ssim_s_inf_fusion = (inf_fusion_images_var + C) / (inf_images_var + fusion_images_var + C) 51 | 52 | score_vis_inf_fusion = (vis_images_mean > inf_images_mean) * ssim_l_vis_fusion * ssim_s_vis_fusion + \ 53 | (vis_images_mean <= inf_images_mean) * ssim_l_inf_fusion * ssim_s_inf_fusion 54 | 55 | ssim_loss = 1 - score_vis_inf_fusion.mean() 56 | 57 | return ssim_loss 58 | 59 | 60 | if __name__ == '__main__': 61 | loss = VIF_SSIM_Loss(kernal_size=8, num_channels=1, C=9e-4, device='cpu') 62 | vis_images = torch.rand(2, 1, 256, 256) 63 | inf_images = torch.rand(2, 1, 256, 256) 64 | fusion_images = torch.rand(2, 1, 256, 256) 65 | print(loss({'Vis': vis_images, 'Inf': inf_images}, fusion_images)) 66 | -------------------------------------------------------------------------------- /core/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .TV_Loss import TV_Loss 2 | from .SSIM_Loss import SSIM_Loss 3 | from .Dist_Loss import Dist_Loss 4 | from .VIF_SSIM_Loss import VIF_SSIM_Loss 5 | 6 | __all__ = ['VIF_SSIM_Loss', 'SSIM_Loss', 'Dist_Loss', 'TV_Loss'] 7 | -------------------------------------------------------------------------------- /core/loss/__pycache__/Dist_Loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/core/loss/__pycache__/Dist_Loss.cpython-36.pyc -------------------------------------------------------------------------------- /core/loss/__pycache__/Dist_Loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/core/loss/__pycache__/Dist_Loss.cpython-38.pyc -------------------------------------------------------------------------------- /core/loss/__pycache__/SSIM_Loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/core/loss/__pycache__/SSIM_Loss.cpython-36.pyc -------------------------------------------------------------------------------- /core/loss/__pycache__/SSIM_Loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/core/loss/__pycache__/SSIM_Loss.cpython-38.pyc -------------------------------------------------------------------------------- /core/loss/__pycache__/TV_Loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/core/loss/__pycache__/TV_Loss.cpython-36.pyc -------------------------------------------------------------------------------- /core/loss/__pycache__/TV_Loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/core/loss/__pycache__/TV_Loss.cpython-38.pyc -------------------------------------------------------------------------------- /core/loss/__pycache__/VIF_SSIM_Loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/core/loss/__pycache__/VIF_SSIM_Loss.cpython-36.pyc -------------------------------------------------------------------------------- /core/loss/__pycache__/VIF_SSIM_Loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/core/loss/__pycache__/VIF_SSIM_Loss.cpython-38.pyc -------------------------------------------------------------------------------- /core/loss/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/core/loss/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /core/loss/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/core/loss/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /core/model/DenseFuse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from core.model import * 4 | 5 | 6 | class DenseFuse(nn.Module): 7 | """docstring for DenseFuse""" 8 | 9 | def __init__(self, config): 10 | super(DenseFuse, self).__init__() 11 | self.config = config 12 | self.coder = nn.ModuleDict( 13 | {'Encoder': dense(config['input_channels'], config['out_channels'], config['coder_layers'])}) 14 | self.decoder = nn.ModuleList( 15 | [nn.Sequential(nn.Conv2d( 16 | in_channels=min(config['coder_layers'] * config['out_channels'], 17 | len(config['input_sensors']) * config['coder_layers'] * config[ 18 | 'out_channels'] // 2 ** i), 19 | out_channels=len(config['input_sensors']) * config['coder_layers'] * config['out_channels'] // 2 ** ( 20 | i + 1), kernel_size=3, padding=1), 21 | nn.BatchNorm2d(len(config['input_sensors']) * config['coder_layers'] * config['out_channels'] // 2 ** ( 22 | i + 1)), 23 | nn.ReLU()) if i != config['decoder_layers'] - 1 else nn.Sequential( 24 | nn.Conv2d( 25 | in_channels=len(config['input_sensors']) * config['coder_layers'] * config[ 26 | 'out_channels'] // 2 ** i, 27 | out_channels=config['input_channels'], kernel_size=3, padding=1), 28 | nn.BatchNorm2d(config['input_channels']), 29 | nn.ReLU()) for i in range(config['decoder_layers'])]) 30 | 31 | def forward(self, inputs, fusion_mode='L1'): 32 | feats = {} 33 | for sensor in self.config['input_sensors']: 34 | feats.update({sensor: self.coder['Encoder'](inputs[sensor])}) 35 | if fusion_mode == 'Add': 36 | feats = Add_Fusion_Layer(feats) 37 | elif fusion_mode == 'L1': 38 | feats = L1_Fusion_Layer(feats) 39 | for block in self.decoder: 40 | feats = {sensor: block(feats[sensor]) for sensor in feats} 41 | return feats 42 | -------------------------------------------------------------------------------- /core/model/VIF_Net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from core.model import dense 4 | 5 | 6 | class VIF_Net(nn.Module): 7 | """docstring for VIF_Net""" 8 | 9 | def __init__(self, config): 10 | super(VIF_Net, self).__init__() 11 | self.config = config 12 | self.coder = nn.ModuleDict( 13 | {sensor: dense(config['input_channels'], config['out_channels'], config['coder_layers']) for sensor in 14 | config['input_sensors']}) 15 | self.decoder = nn.ModuleList( 16 | [nn.Sequential(nn.Conv2d( 17 | in_channels=len(config['input_sensors']) * config['coder_layers'] * config['out_channels'] // 2 ** i, 18 | out_channels=len(config['input_sensors']) * config['coder_layers'] * config['out_channels'] // 2 ** ( 19 | i + 1), kernel_size=3, padding=1), 20 | nn.BatchNorm2d(len(config['input_sensors']) * config['coder_layers'] * config['out_channels'] // 2 ** ( 21 | i + 1)), 22 | nn.ReLU()) if i != config['decoder_layers'] - 1 else nn.Sequential( 23 | nn.Conv2d( 24 | in_channels=len(config['input_sensors']) * config['coder_layers'] * config[ 25 | 'out_channels'] // 2 ** i, 26 | out_channels=config['input_channels'], kernel_size=3, padding=1), 27 | nn.BatchNorm2d(config['input_channels']), 28 | nn.ReLU()) for i in range(config['decoder_layers'])]) 29 | 30 | def forward(self, inputs): 31 | feats = {} 32 | for sensor in self.config['input_sensors']: 33 | feats.update({sensor: self.coder[sensor](inputs[sensor])}) 34 | feats = torch.cat([feats[sensor] for sensor in self.config['input_sensors']], dim=1) 35 | for block in self.decoder: 36 | feats = block(feats) 37 | outputs = {'Fusion': feats} 38 | for sensor in inputs: 39 | outputs.update({sensor: inputs[sensor]}) 40 | return outputs 41 | -------------------------------------------------------------------------------- /core/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import * 2 | from .VIF_Net import VIF_Net 3 | from .DenseFuse import DenseFuse 4 | 5 | __all__ = [ 6 | 'dense', 'VIF_Net', 'DenseFuse', 'Add_Fusion_Layer', 'L1_Fusion_Layer' 7 | ] 8 | -------------------------------------------------------------------------------- /core/model/__pycache__/DenseFuse.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/core/model/__pycache__/DenseFuse.cpython-36.pyc -------------------------------------------------------------------------------- /core/model/__pycache__/DenseFuse.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/core/model/__pycache__/DenseFuse.cpython-38.pyc -------------------------------------------------------------------------------- /core/model/__pycache__/VIF_Net.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/core/model/__pycache__/VIF_Net.cpython-36.pyc -------------------------------------------------------------------------------- /core/model/__pycache__/VIF_Net.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/core/model/__pycache__/VIF_Net.cpython-38.pyc -------------------------------------------------------------------------------- /core/model/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/core/model/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /core/model/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/core/model/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /core/model/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/core/model/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /core/model/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/core/model/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /core/model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class dense(nn.Module): 6 | """docstring for dense""" 7 | 8 | def __init__(self, in_channels, out_channels, num_layers): 9 | super(dense, self).__init__() 10 | self.dense_block = nn.ModuleList([nn.Sequential( 11 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1), 12 | nn.BatchNorm2d(out_channels), 13 | nn.ReLU()) if i == 0 else nn.Sequential( 14 | nn.Conv2d(in_channels=out_channels * i, out_channels=out_channels, kernel_size=3, padding=1), 15 | nn.BatchNorm2d(out_channels), 16 | nn.ReLU()) for i in range(num_layers)]) 17 | 18 | def forward(self, inputs): 19 | feats = [inputs] 20 | for block in self.dense_block: 21 | feat = block(torch.cat(feats, dim=1)) if len(feats) == 1 else block(torch.cat(feats[1:], dim=1)) 22 | feats.append(feat) 23 | return torch.cat(feats[1:], dim=1) 24 | 25 | 26 | def Add_Fusion_Layer(inputs): 27 | inputs.update({'Fusion': torch.stack([inputs[sensor] for sensor in inputs], dim=0).sum(0)}) 28 | return inputs 29 | 30 | 31 | def L1_Fusion_Layer(inputs, kernal_size=3): 32 | avgpool = torch.nn.AvgPool2d(kernal_size, 1, (kernal_size - 1) // 2) 33 | weights = {sensor: avgpool(inputs[sensor]) for sensor in inputs} 34 | weights_sum = torch.stack([weights[sensor] for sensor in weights], dim=0).sum(0) 35 | weights = {sensor: (weights[sensor] + torch.ones_like(weights_sum) * 9e-4) / 36 | (weights_sum + torch.ones_like(weights_sum) * 9e-4) for sensor in inputs} 37 | inputs.update({'Fusion': torch.stack([weights[sensor] * inputs[sensor] for sensor in inputs], dim=0).sum(0)}) 38 | return inputs 39 | -------------------------------------------------------------------------------- /core/util/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | 3 | __all__ = [ 4 | 'load_config', 'debug', 'count_parameters' 5 | ] 6 | -------------------------------------------------------------------------------- /core/util/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/core/util/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /core/util/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/core/util/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /core/util/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/core/util/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /core/util/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/core/util/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /core/util/utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import torch 3 | 4 | 5 | def load_config(filename): 6 | with open(filename, 'r') as f: 7 | config = yaml.safe_load(f) 8 | return config 9 | 10 | 11 | def count_parameters(model): 12 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 13 | 14 | 15 | def debug(model_config, dataset_config, input_images, fusion_images): 16 | batch_szie, _, _, _ = input_images[model_config['input_sensors'][0]].shape 17 | input_imgs = {sensor: [] for sensor in model_config['input_sensors']} 18 | fusion_imgs = [] 19 | 20 | dev = input_images[model_config['input_sensors'][0]].device 21 | for batch in range(batch_szie): 22 | img = {sensor: input_images[sensor][batch, :, :, :] for sensor in model_config['input_sensors']} 23 | fusion = fusion_images['Fusion'][batch, :, :, :] 24 | channels = fusion.shape[0] 25 | # std = torch.Tensor(dataset_config['std']).to(dev).view(channels, 1, 1).expand_as(fusion) if channels == 3 \ 26 | # else torch.Tensor([sum(dataset_config['std']) / len(dataset_config['std'])]).to(dev).view(channels, 1, 27 | # 1).expand_as( 28 | # fusion) 29 | # mean = torch.Tensor(dataset_config['mean']).to(dev).view(channels, 1, 1).expand_as(fusion) if channels == 3 \ 30 | # else torch.Tensor([sum(dataset_config['mean']) / len(dataset_config['mean'])]).to(dev).view( 31 | # channels, 1, 1).expand_as(fusion) 32 | # img = {sensor: img[sensor] * std + mean for sensor in model_config['input_sensors']} 33 | # fusion = fusion * std + mean 34 | img = {sensor: img[sensor] for sensor in model_config['input_sensors']} 35 | 36 | for sensor in model_config['input_sensors']: 37 | input_imgs[sensor].append(img[sensor]) 38 | fusion_imgs.append(fusion) 39 | 40 | input_imgs = {sensor: torch.stack(input_imgs[sensor], dim=0).to(dev) for sensor in model_config['input_sensors']} 41 | fusion_imgs = torch.stack(fusion_imgs, dim=0).to(dev) 42 | return input_imgs, fusion_imgs 43 | -------------------------------------------------------------------------------- /datasets/TNO/Inf/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/datasets/TNO/Inf/test.jpg -------------------------------------------------------------------------------- /datasets/TNO/Vis/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/datasets/TNO/Vis/test.jpg -------------------------------------------------------------------------------- /datasets/TNO_crop/Inf/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/datasets/TNO_crop/Inf/test.jpg -------------------------------------------------------------------------------- /datasets/TNO_crop/Vis/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/datasets/TNO_crop/Vis/test.jpg -------------------------------------------------------------------------------- /img/TensorBoard_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/img/TensorBoard_0.png -------------------------------------------------------------------------------- /img/TensorBoard_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/img/TensorBoard_1.png -------------------------------------------------------------------------------- /img/results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/img/results.jpg -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from core.model import * 4 | from tools import train, test 5 | from core.dataset import Fusion_Datasets 6 | import torchvision.transforms as transforms 7 | from core.util import load_config, count_parameters 8 | 9 | 10 | def get_args(): 11 | parser = argparse.ArgumentParser(description='run') 12 | 13 | parser.add_argument('--config', type=str, default='./config/DenseFuse.yaml') 14 | parser.add_argument('--train', action='store_true', default=False) 15 | parser.add_argument('--test', action='store_true', default=False) 16 | 17 | args = parser.parse_args() 18 | return args 19 | 20 | 21 | def runner(args): 22 | configs = load_config(args.config) 23 | project_configs = configs['PROJECT'] 24 | model_configs = configs['MODEL'] 25 | train_configs = configs['TRAIN'] 26 | test_configs = configs['TEST'] 27 | train_dataset_configs = configs['TRAIN_DATASET'] 28 | test_dataset_configs = configs['TEST_DATASET'] 29 | input_size = train_dataset_configs['input_size'] if args.train else test_dataset_configs['input_size'] 30 | 31 | if train_dataset_configs['channels'] == 3: 32 | base_transforms = transforms.Compose( 33 | [transforms.Resize((input_size, input_size)), 34 | transforms.ToTensor()]) # , 35 | # transforms.Normalize(mean=train_dataset_configs['mean'], std=train_dataset_configs['std'])]) 36 | elif train_dataset_configs['channels'] == 1: 37 | base_transforms = transforms.Compose( 38 | [transforms.Resize((input_size, input_size)), 39 | transforms.ToTensor()]) # , 40 | # transforms.Normalize(mean=[sum(train_dataset_configs['mean']) / len(train_dataset_configs['mean'])], 41 | # std=[sum(train_dataset_configs['std']) / len(train_dataset_configs['std'])])]) 42 | 43 | train_datasets = Fusion_Datasets(train_dataset_configs, base_transforms) 44 | test_datasets = Fusion_Datasets(test_dataset_configs, base_transforms) 45 | 46 | model = eval(model_configs['model_name'])(model_configs) 47 | print('Model Para:', count_parameters(model)) 48 | 49 | if train_configs['resume'] != 'None': 50 | checkpoint = torch.load(train_configs['resume']) 51 | model.load_state_dict(checkpoint['model'].state_dict()) 52 | 53 | if args.train: 54 | train(model, train_datasets, test_datasets, configs) 55 | if args.test: 56 | test(model, test_datasets, configs, load_weight_path=True) 57 | 58 | 59 | if __name__ == '__main__': 60 | args = get_args() 61 | runner(args) 62 | -------------------------------------------------------------------------------- /test/results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/test/results.jpg -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .train import train 2 | from .test import test 3 | 4 | __all__ = ['train', 'test'] 5 | -------------------------------------------------------------------------------- /tools/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/tools/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /tools/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/tools/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /tools/__pycache__/test.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/tools/__pycache__/test.cpython-36.pyc -------------------------------------------------------------------------------- /tools/__pycache__/test.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/tools/__pycache__/test.cpython-38.pyc -------------------------------------------------------------------------------- /tools/__pycache__/train.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/tools/__pycache__/train.cpython-36.pyc -------------------------------------------------------------------------------- /tools/__pycache__/train.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/tools/__pycache__/train.cpython-38.pyc -------------------------------------------------------------------------------- /tools/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from tqdm import tqdm 4 | from core.util import debug 5 | from torch.utils.data import DataLoader 6 | import torchvision.transforms as transforms 7 | 8 | 9 | def test(model, fusion_datasets, configs, load_weight_path=False, save_path=None): 10 | model.eval() 11 | 12 | if load_weight_path: 13 | assert configs['TEST']['weight_path'] != 'None', 'Test Need To Resume Chekpoint' 14 | weight_path = configs['TEST']['weight_path'] 15 | checkpoint = torch.load(weight_path) 16 | model.load_state_dict(checkpoint['model'].state_dict()) 17 | is_use_gpu = torch.cuda.is_available() 18 | 19 | test_dataloader = DataLoader(fusion_datasets, batch_size=configs['TEST']['batch_size'], shuffle=False) 20 | test_num_iter = len(test_dataloader) 21 | dtransforms = transforms.Compose([transforms.ToPILImage()]) 22 | 23 | with tqdm(total=test_num_iter) as test_bar: 24 | for iter, data in enumerate(test_dataloader): 25 | 26 | if is_use_gpu: 27 | model = model.cuda() 28 | data = {sensor: data[sensor].cuda() for sensor in data} 29 | 30 | fusion_image = model(data) 31 | 32 | input_imgs, fusion_imgs = debug(configs['MODEL'], configs['TEST_DATASET'], data, fusion_image) 33 | input_imgs = [input_imgs[sensor] for sensor in configs['MODEL']['input_sensors']] 34 | imgs = input_imgs + [fusion_imgs] 35 | imgs = torch.cat(imgs, dim=3) 36 | for batch in range(imgs.shape[0]): 37 | if save_path is None: 38 | save_path = configs['TEST']['save_path'] 39 | name = os.path.join(save_path, str(len(os.listdir(save_path)))) 40 | img = imgs[batch].cpu() 41 | img = dtransforms(img) 42 | img.save(f'{name}.jpg') 43 | test_bar.update(1) 44 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from tqdm import tqdm 4 | from core.loss import * 5 | from core.util import debug 6 | from tensorboardX import SummaryWriter 7 | from torch.utils.data import DataLoader 8 | 9 | 10 | def train(model, train_datasets, test_datasets, configs): 11 | if not os.path.exists(os.path.join(configs['PROJECT']['save_path'], configs['PROJECT']['name'])): 12 | os.mkdir(os.path.join(configs['PROJECT']['save_path'], configs['PROJECT']['name'])) 13 | 14 | model.train() 15 | 16 | train_writer = SummaryWriter(log_dir=os.path.join(configs['PROJECT']['save_path'], configs['PROJECT']['name'])) 17 | print(f'Run Tensorboard:\n tensorboard --logdir=' + configs['PROJECT']['save_path'] + '/' + configs['PROJECT'][ 18 | 'name'] + '/') 19 | 20 | if configs['TRAIN']['resume'] == 'None': 21 | start_epoch = 1 22 | else: 23 | start_epoch = torch.load(configs['TRAIN']['resume'])['epoch'] + 1 24 | 25 | is_use_gpu = torch.cuda.is_available() 26 | 27 | optimizer = eval('torch.optim.' + configs['TRAIN']['opt'])(model.parameters(), configs['TRAIN']['lr']) 28 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=configs['TRAIN']['milestones'], 29 | gamma=configs['TRAIN']['gamma']) 30 | 31 | train_dataloader = DataLoader(train_datasets, batch_size=configs['TRAIN']['batch_size'], shuffle=True) 32 | train_num_iter = len(train_dataloader) 33 | 34 | loss_func = [eval(l)(sensors=configs['TRAIN_DATASET']['sensors']) for l in configs['TRAIN']['loss_func']] 35 | 36 | all_iter = 0 37 | for epoch in range(start_epoch, configs['TRAIN']['max_epoch'] + 1): 38 | 39 | loss_epoch = 0 40 | 41 | with tqdm(total=train_num_iter) as train_bar: 42 | for iter, data in enumerate(train_dataloader): 43 | 44 | if is_use_gpu: 45 | model = model.cuda(configs['TRAIN']['gpu_id']) 46 | data = {sensor: data[sensor].cuda(configs['TRAIN']['gpu_id']) for sensor in data} 47 | 48 | fusion_image = model(data) 49 | 50 | loss = [l(data, fusion_image) * configs['TRAIN']['loss_weights'][loss_func.index(l)] for l in loss_func] 51 | 52 | loss_batch = sum(loss) 53 | 54 | loss_epoch += loss_batch.item() 55 | optimizer.zero_grad() 56 | loss_batch.backward() 57 | optimizer.step() 58 | 59 | train_writer.add_scalar('loss', loss_batch, global_step=all_iter) 60 | train_bar.set_description( 61 | 'Epoch: {}/{}. TRAIN. Iter: {}/{}. All loss: {:.5f}'.format( 62 | epoch, configs['TRAIN']['max_epoch'], iter + 1, train_num_iter, 63 | loss_epoch / train_num_iter)) 64 | if configs['TRAIN']['debug_interval'] is not None and all_iter % configs['TRAIN'][ 65 | 'debug_interval'] == 0: 66 | input_imgs, fusion_imgs = debug(configs['MODEL'], configs['TRAIN_DATASET'], data, fusion_image) 67 | input_imgs = [input_imgs[sensor] for sensor in configs['MODEL']['input_sensors']] 68 | imgs = input_imgs + [fusion_imgs] 69 | train_writer.add_image('debug', torch.cat(imgs, dim=2), all_iter, dataformats='NCHW') 70 | 71 | all_iter += 1 72 | train_bar.update(1) 73 | 74 | scheduler.step() 75 | 76 | train_writer.add_scalar('lr', optimizer.state_dict()['param_groups'][0]['lr'], global_step=epoch) 77 | 78 | if configs['TRAIN']['val_interval'] is not None and all_iter % configs['TRAIN']['val_interval'] == 0: 79 | torch.save({'model': model, 'epoch': epoch}, 80 | os.path.join(configs['PROJECT']['save_path'], configs['PROJECT']['name'], 81 | f'model_{epoch}.pth')) 82 | -------------------------------------------------------------------------------- /work_dirs/TensorBoard_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/work_dirs/TensorBoard_0.png -------------------------------------------------------------------------------- /work_dirs/TensorBoard_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangeZH/Pytorch_Image_Fusion/aabe70d6eaef6549e850101b16e6813389a0b78e/work_dirs/TensorBoard_1.png --------------------------------------------------------------------------------