├── Networks └── network.py ├── README.md ├── Test.py ├── Train.py ├── evaluation └── Q_Y.m ├── fusion results ├── 001.bmp └── 2.bmp ├── imgs ├── ablationTRM.jpg ├── alpha.jpg ├── fig3.jpg ├── fig4.jpg ├── fig6.jpg ├── gamma.jpg ├── lambda.jpg └── woSecondDARM.jpg ├── losses └── __init__.py ├── model_10.pth └── source images ├── ir ├── 1.bmp └── 2.bmp └── vi ├── 1.bmp └── 2.bmp /Networks/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 5 | 6 | class Convlutioanl(nn.Module): 7 | def __init__(self, in_channel, out_channel): 8 | super(Convlutioanl, self).__init__() 9 | self.padding=(2,2,2,2) 10 | self.conv=nn.Conv2d(in_channel,out_channel,kernel_size=5,padding=0,stride=1) 11 | self.bn=nn.BatchNorm2d(out_channel) 12 | self.relu=nn.ReLU(inplace=True) 13 | def forward(self, input): 14 | out=F.pad(input,self.padding,'replicate') 15 | out=self.conv(out) 16 | out=self.bn(out) 17 | out=self.relu(out) 18 | return out 19 | 20 | class Convlutioanl_out(nn.Module): 21 | def __init__(self, in_channel, out_channel): 22 | super(Convlutioanl_out, self).__init__() 23 | 24 | self.conv=nn.Conv2d(in_channel,out_channel,kernel_size=1,padding=0,stride=1) 25 | 26 | self.tanh=nn.Tanh() 27 | 28 | def forward(self, input): 29 | 30 | out=self.conv(input) 31 | 32 | out=self.tanh(out) 33 | return out 34 | class Fem(nn.Module): 35 | def __init__(self, in_channel, out_channel): 36 | super(Fem, self).__init__() 37 | self.padding = (1, 1, 1, 1) 38 | self.conv=nn.Conv2d(in_channel,out_channel,kernel_size=3,padding=0,stride=1) 39 | self.bn=nn.BatchNorm2d(out_channel) 40 | self.relu=nn.ReLU(inplace=True) 41 | 42 | def forward(self, input): 43 | out = F.pad(input, self.padding, 'replicate') 44 | out=self.conv(out) 45 | out=self.bn(out) 46 | out=self.relu(out) 47 | out = F.pad(out, self.padding, 'replicate') 48 | out=self.conv(out) 49 | out = self.bn(out) 50 | return out 51 | 52 | class Channel_attention(nn.Module): 53 | def __init__(self, channel, reduction=4): 54 | super(Channel_attention, self).__init__() 55 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 56 | self.fc=nn.Sequential( 57 | nn.Conv2d(channel,channel//reduction,1), 58 | nn.ReLU(inplace=True), 59 | nn.Conv2d(channel//reduction,channel,1)) 60 | self.sigmoid=nn.Sigmoid() 61 | def forward(self, input): 62 | out=self.avg_pool(input) 63 | out=self.fc(out) 64 | out=self.sigmoid(out) 65 | return out 66 | 67 | class Spatial_attention(nn.Module): 68 | def __init__(self, channel, reduction=4): 69 | super(Spatial_attention, self).__init__() 70 | self.body=nn.Sequential( 71 | nn.Conv2d(channel, channel//reduction,3,padding=1), 72 | nn.BatchNorm2d( channel//reduction), 73 | nn.ReLU(True), 74 | 75 | nn.Conv2d(channel // reduction, 1, 3, padding=1), 76 | nn.BatchNorm2d(1), 77 | nn.Sigmoid() 78 | ) 79 | def forward(self, input): 80 | return self.body(input) 81 | 82 | 83 | class WindowAttention(nn.Module): 84 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 85 | super().__init__() 86 | self.dim = dim 87 | self.window_size = window_size 88 | self.num_heads = num_heads 89 | head_dim = dim // num_heads 90 | self.scale = qk_scale or head_dim ** -0.5 91 | self.relative_position_bias_table = nn.Parameter( 92 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) 93 | coords_h = torch.arange(self.window_size[0]) 94 | coords_w = torch.arange(self.window_size[1]) 95 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) 96 | coords_flatten = torch.flatten(coords, 1) 97 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] 98 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() 99 | relative_coords[:, :, 0] += self.window_size[0] - 1 100 | relative_coords[:, :, 1] += self.window_size[1] - 1 101 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 102 | relative_position_index = relative_coords.sum(-1) 103 | self.register_buffer("relative_position_index", relative_position_index) 104 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 105 | self.attn_drop = nn.Dropout(attn_drop) 106 | self.proj = nn.Linear(dim, dim) 107 | self.proj_drop = nn.Dropout(proj_drop) 108 | trunc_normal_(self.relative_position_bias_table, std=.02) 109 | self.softmax = nn.Softmax(dim=-1) 110 | def forward(self, x, mask=None): 111 | B_, N, C = x.shape 112 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 113 | q, k, v = qkv[0], qkv[1], qkv[2] 114 | q = q * self.scale 115 | attn = (q @ k.transpose(-2, -1)) 116 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 117 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) 118 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() 119 | attn = attn + relative_position_bias.unsqueeze(0) 120 | if mask is not None: 121 | nW = mask.shape[0] 122 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 123 | attn = attn.view(-1, self.num_heads, N, N) 124 | attn = self.softmax(attn) 125 | else: 126 | attn = self.softmax(attn) 127 | attn = self.attn_drop(attn) 128 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 129 | x = self.proj(x) 130 | x = self.proj_drop(x) 131 | return x 132 | def extra_repr(self) -> str: 133 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 134 | def flops(self, N): 135 | flops = 0 136 | flops += N * self.dim * 3 * self.dim 137 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 138 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 139 | flops += N * self.dim * self.dim 140 | return flops 141 | 142 | 143 | class Mlp(nn.Module): 144 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 145 | super().__init__() 146 | out_features = out_features or in_features 147 | hidden_features = hidden_features or in_features 148 | self.fc1 = nn.Linear(in_features, hidden_features) 149 | self.act = act_layer() 150 | self.fc2 = nn.Linear(hidden_features, out_features) 151 | self.drop = nn.Dropout(drop) 152 | def forward(self, x): 153 | x = self.fc1(x) 154 | x = self.act(x) 155 | x = self.drop(x) 156 | x = self.fc2(x) 157 | x = self.drop(x) 158 | return x 159 | 160 | def window_partition(x, window_size): 161 | B, H, W, C = x.shape 162 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 163 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 164 | return windows 165 | def window_reverse(windows, window_size, H, W): 166 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 167 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 168 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 169 | return x 170 | 171 | class SwinTransformerBlock(nn.Module): 172 | def __init__(self, dim, input_resolution, num_heads, window_size=1, shift_size=0, 173 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 174 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 175 | super().__init__() 176 | self.dim = dim 177 | self.input_resolution = input_resolution 178 | self.num_heads = num_heads 179 | self.window_size = window_size 180 | self.shift_size = shift_size 181 | self.mlp_ratio = mlp_ratio 182 | if min(self.input_resolution) <= self.window_size: 183 | self.shift_size = 0 184 | self.window_size = min(self.input_resolution) 185 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 186 | self.norm1 = norm_layer(dim) 187 | self.attn = WindowAttention( 188 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 189 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 190 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 191 | self.norm2 = norm_layer(dim) 192 | mlp_hidden_dim = int(dim * mlp_ratio) 193 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 194 | if self.shift_size > 0: 195 | attn_mask = self.calculate_mask(self.input_resolution) 196 | else: 197 | attn_mask = None 198 | self.register_buffer("attn_mask", attn_mask) 199 | def calculate_mask(self, x_size): 200 | H, W = x_size 201 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 202 | h_slices = (slice(0, -self.window_size), 203 | slice(-self.window_size, -self.shift_size), 204 | slice(-self.shift_size, None)) 205 | w_slices = (slice(0, -self.window_size), 206 | slice(-self.window_size, -self.shift_size), 207 | slice(-self.shift_size, None)) 208 | cnt = 0 209 | for h in h_slices: 210 | for w in w_slices: 211 | img_mask[:, h, w, :] = cnt 212 | cnt += 1 213 | mask_windows = window_partition(img_mask, self.window_size) 214 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 215 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 216 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 217 | return attn_mask 218 | def forward(self, x, x_size): 219 | B,C,H,W= x.shape 220 | x=x.view(B,H,W,C) 221 | shortcut = x 222 | shape=x.view(H*W*B,C) 223 | x = self.norm1(shape) 224 | x = x.view(B, H, W, C) 225 | if self.shift_size > 0: 226 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 227 | else: 228 | shifted_x = x 229 | x_windows = window_partition(shifted_x, self.window_size) 230 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) 231 | if self.input_resolution == x_size: 232 | attn_windows = self.attn(x_windows, mask=self.attn_mask) 233 | else: 234 | attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) 235 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 236 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 237 | if self.shift_size > 0: 238 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 239 | else: 240 | x = shifted_x 241 | x = shortcut + self.drop_path(x) 242 | x = x + self.drop_path(self.mlp(self.norm2(x))) 243 | B,H,W,C=x.shape 244 | x=x.view(B,C,H,W) 245 | return x 246 | def extra_repr(self) -> str: 247 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 248 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 249 | def flops(self): 250 | flops = 0 251 | H, W = self.input_resolution 252 | flops += self.dim * H * W 253 | nW = H * W / self.window_size / self.window_size 254 | flops += nW * self.attn.flops(self.window_size * self.window_size) 255 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 256 | flops += self.dim * H * W 257 | return flops 258 | 259 | class PatchEmbed(nn.Module): 260 | def __init__(self, img_size=120, patch_size=4, in_chans=6, embed_dim=96, norm_layer=None): 261 | super().__init__() 262 | img_size = to_2tuple(img_size) 263 | patch_size = to_2tuple(patch_size) 264 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 265 | self.img_size = img_size 266 | self.patch_size = patch_size 267 | self.patches_resolution = patches_resolution 268 | self.num_patches = patches_resolution[0] * patches_resolution[1] 269 | self.in_chans = in_chans 270 | self.embed_dim = embed_dim 271 | if norm_layer is not None: 272 | self.norm = norm_layer(embed_dim) 273 | else: 274 | self.norm = None 275 | def forward(self, x): 276 | x = x.flatten(2).transpose(1, 2) # B Ph*Pw C 277 | if self.norm is not None: 278 | x = self.norm(x) 279 | return x 280 | def flops(self): 281 | flops = 0 282 | H, W = self.img_size 283 | if self.norm is not None: 284 | flops += H * W * self.embed_dim 285 | return flops 286 | class BasicLayer(nn.Module): 287 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 288 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 289 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 290 | super().__init__() 291 | self.dim = dim 292 | self.input_resolution = input_resolution 293 | self.depth = depth 294 | self.use_checkpoint = use_checkpoint 295 | self.blocks = nn.ModuleList([ 296 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 297 | num_heads=num_heads, window_size=window_size, 298 | shift_size=0 if (i % 2 == 0) else window_size // 2, 299 | mlp_ratio=mlp_ratio, 300 | qkv_bias=qkv_bias, qk_scale=qk_scale, 301 | drop=drop, attn_drop=attn_drop, 302 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 303 | norm_layer=norm_layer) 304 | for i in range(depth)]) 305 | if downsample is not None: 306 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 307 | else: 308 | self.downsample = None 309 | def forward(self, x, x_size): 310 | for blk in self.blocks: 311 | x = blk(x, x_size) 312 | if self.downsample is not None: 313 | x = self.downsample(x) 314 | return x 315 | def extra_repr(self) -> str: 316 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 317 | def flops(self): 318 | flops = 0 319 | for blk in self.blocks: 320 | flops += blk.flops() 321 | if self.downsample is not None: 322 | flops += self.downsample.flops() 323 | return flops 324 | class MODEL(nn.Module): 325 | def __init__(self, img_size=120,patch_size=4,embed_dim=96,num_heads=8, window_size=1,in_channel=2, out_channel=16,output_channel=1, 326 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., patch_norm=True,depth=2,downsample=None, 327 | drop_path=0., norm_layer=nn.LayerNorm,use_checkpoint=False ): 328 | super(MODEL, self).__init__() 329 | self.convolutional = Convlutioanl(in_channel, out_channel) 330 | self.convolutional_out = Convlutioanl_out( out_channel,output_channel) 331 | 332 | self.fem = Fem(out_channel, out_channel) 333 | self.cam = Channel_attention(out_channel) 334 | self.sam = Spatial_attention(out_channel) 335 | self.relu=nn.ReLU(True) 336 | self.patch_norm = patch_norm 337 | self.patch_embed = PatchEmbed( 338 | img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, 339 | norm_layer=norm_layer if self.patch_norm else None) 340 | patches_resolution = self.patch_embed.patches_resolution 341 | self.patches_resolution = patches_resolution 342 | self.basicLayer=BasicLayer(dim= out_channel, 343 | input_resolution=(patches_resolution[0],patches_resolution[1]), 344 | depth=depth, 345 | num_heads=num_heads, 346 | window_size=window_size, 347 | mlp_ratio=mlp_ratio, 348 | qkv_bias=qkv_bias, qk_scale=qk_scale, 349 | drop=drop, attn_drop=attn_drop, 350 | drop_path=drop_path, 351 | norm_layer=norm_layer, 352 | downsample=downsample, 353 | use_checkpoint=use_checkpoint) 354 | def forward(self, input): 355 | convolutioanl = self.convolutional(input) 356 | fem=self.fem(convolutioanl) 357 | cam=self.cam(fem) 358 | sam=self.sam(fem) 359 | fem_cam=fem*cam 360 | fem_sam=fem*sam 361 | add=fem_cam+fem_sam+convolutioanl 362 | encode=self.relu(add) 363 | encode_size = (encode.shape[2], encode.shape[3]) 364 | Transformer=self.basicLayer(encode, encode_size) 365 | de_fem=self.fem(Transformer) 366 | de_cam=self.cam(de_fem) 367 | de_sam=self.sam(de_fem) 368 | de_fem_cam=de_fem*de_cam 369 | de_fem_sam=de_fem*de_sam 370 | de_add=de_fem_cam+de_fem_sam+ de_fem 371 | out=self.convolutional_out (de_add) 372 | return out 373 | 374 | 375 | 376 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DATFuse: Infrared and Visible Image Fusion via Dual Attention Transformer (IEEE TCSVT 2023) 2 | 3 | This is the official implementation of the DATFuse model proposed in the paper ([DATFuse: Infrared and Visible Image Fusion via Dual Attention Transformer](https://ieeexplore.ieee.org/document/10006826)) with Pytorch. 4 | 5 | 6 | ## Comparison with SOTA methods 7 | 8 | ### Fusion results on TNO dataset 9 | ![Image text](https://github.com/tthinking/DATFuse/blob/main/imgs/fig3.jpg) 10 | 11 | ### Fusion results on RoadScene dataset 12 | ![Image text](https://github.com/tthinking/DATFuse/blob/main/imgs/fig4.jpg) 13 | 14 | 15 | 16 | ## Ablation study on network structure 17 | ![Image text](https://github.com/tthinking/DATFuse/blob/main/imgs/fig6.jpg) 18 | 19 | 20 | ## Ablation study on the number of TRMs 21 | ![Image text](https://github.com/tthinking/DATFuse/blob/main/imgs/ablationTRM.jpg) 22 | 23 | ## Ablation study on the second DARM 24 | ![Image text](https://github.com/tthinking/DATFuse/blob/main/imgs/woSecondDARM.jpg) 25 | 26 | ## Impact of weight parameters in the loss function 27 | 28 | ### Impact of weight parameter α on fusion performance with λ and γ fixed as 100 and 10, respectively. 29 | ![Image text](https://github.com/tthinking/DATFuse/blob/main/imgs/alpha.jpg) 30 | 31 | ### Impact of weight parameter λ on fusion performance with α and γ fixed as 1 and 10, respectively. 32 | ![Image text](https://github.com/tthinking/DATFuse/blob/main/imgs/lambda.jpg) 33 | 34 | ### Impact of weight parameter γ on fusion performance with α and λ fixed as 1 and 100, respectively. 35 | ![Image text](https://github.com/tthinking/DATFuse/blob/main/imgs/gamma.jpg) 36 | 37 | ## Computational efficiency comparisons 38 | 39 | ### Average running time for generating a fused image (Unit: seconds) 40 | 41 | 42 | 43 | | Method | TNO Dataset | RoadScene Dataset | 44 | | :---: | :---: | :---: | 45 | | MDLatLRR | 26.0727 | 11.7310 | 46 | |AUIF| 0.1119 | 0.0726 | 47 | |DenseFuse| 0.5663 | 0.3190 | 48 | |FusionGAN| 2.6796 | 1.1442 | 49 | |GANMcC| 5.6752 | 2.3813 | 50 | |RFN_Nest| 2.3096| 0.9423 | 51 | |CSF| 10.3311 |5.5395 | 52 | |MFEIF| 0.0793 |0.0494 | 53 | |PPTFusion| 1.4150 |0.8656 | 54 | |SwinFuse| 3.2687 | 1.6478 | 55 | |DATFuse| 0.0257 | 0.0141| 56 | 57 | # Cite the paper 58 | If this work is helpful to you, please cite it as:

59 |
@ARTICLE{Tang_2023_DATFuse,
68 |   author={Tang, Wei and He, Fazhi and Liu, Yu and Duan, Yansong and Si, Tongzhen},
69 |   journal={IEEE Transactions on Circuits and Systems for Video Technology}, 
70 |   title={DATFuse: Infrared and Visible Image Fusion via Dual Attention Transformer}, 
71 |   year={2023},
72 |   volume={33},
73 |   number={7},
74 |   pages={3159-3172},
75 |   doi={10.1109/TCSVT.2023.3234340}}
76 | 
77 | 78 | If you have any questions, feel free to contact me (weitang2021@whu.edu.cn). 79 | -------------------------------------------------------------------------------- /Test.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import os 4 | import torch 5 | import time 6 | import imageio 7 | import torchvision.transforms as transforms 8 | from Networks.network import MODEL as net 9 | import statistics 10 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 11 | 12 | device = torch.device('cuda:0') 13 | 14 | 15 | model = net(in_channel=2) 16 | 17 | model_path = "./model_10.pth" 18 | 19 | model = model.cuda() 20 | model.cuda() 21 | 22 | model.load_state_dict(torch.load(model_path)) 23 | 24 | 25 | def fusion(): 26 | fuse_time = [] 27 | for num in range(1,3): 28 | 29 | path1 = './source images/ir/{}.bmp'.format(num) 30 | path2 = './source images/vi/{}.bmp'.format(num) 31 | img1 = Image.open(path1).convert('L') 32 | img2 = Image.open(path2).convert('L') 33 | 34 | img1_org = img1 35 | img2_org = img2 36 | 37 | tran = transforms.ToTensor() 38 | 39 | img1_org = tran(img1_org) 40 | img2_org = tran(img2_org) 41 | input_img = torch.cat((img1_org, img2_org), 0).unsqueeze(0) 42 | 43 | input_img = input_img.cuda() 44 | 45 | model.eval() 46 | start = time.time() 47 | out = model(input_img) 48 | end = time.time() 49 | fuse_time.append(end - start) 50 | result = np.squeeze(out.detach().cpu().numpy()) 51 | result = (result * 255).astype(np.uint8) 52 | 53 | imageio.imwrite('./fusion result/{}.bmp'.format(num),result ) 54 | 55 | mean = statistics.mean(fuse_time[1:]) 56 | 57 | print(f'fuse avg time: {mean:.4f}') 58 | 59 | 60 | if __name__ == '__main__': 61 | 62 | fusion() 63 | -------------------------------------------------------------------------------- /Train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from tqdm import tqdm 4 | import pandas as pd 5 | import joblib 6 | import glob 7 | from collections import OrderedDict 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | import torch.optim as optim 11 | import torchvision.transforms as transforms 12 | from PIL import Image 13 | from torch.utils.data import DataLoader, Dataset 14 | from Networks.network import MODEL as net 15 | from losses import ir_loss, vi_loss,ssim_loss,gra_loss 16 | device = torch.device('cuda:0') 17 | use_gpu = torch.cuda.is_available() 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--name', default='...', help='model name: (default: arch+timestamp)') 22 | parser.add_argument('--epochs', default=10, type=int) 23 | parser.add_argument('--batch_size', default=128, type=int) 24 | parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float) 25 | parser.add_argument('--betas', default=(0.9, 0.999), type=tuple) 26 | parser.add_argument('--eps', default=1e-8, type=float) 27 | parser.add_argument('--weight', default=[1, 1,10,100], type=float) 28 | args = parser.parse_args() 29 | return args 30 | 31 | class GetDataset(Dataset): 32 | def __init__(self, imageFolderDataset, transform=None): 33 | self.imageFolderDataset = imageFolderDataset 34 | self.transform = transform 35 | 36 | def __getitem__(self, index): 37 | 38 | ir='...' 39 | vi='...' 40 | 41 | ir = Image.open(ir).convert('L') 42 | vi = Image.open(vi).convert('L') 43 | 44 | if self.transform is not None: 45 | tran = transforms.ToTensor() 46 | ir=tran(ir) 47 | vi= tran(vi) 48 | input = torch.cat((ir, vi), -3) 49 | return input, ir,vi 50 | 51 | def __len__(self): 52 | return len(self.imageFolderDataset) 53 | 54 | 55 | class AverageMeter(object): 56 | 57 | def __init__(self): 58 | self.reset() 59 | 60 | def reset(self): 61 | self.val = 0 62 | self.avg = 0 63 | self.sum = 0 64 | self.count = 0 65 | 66 | def update(self, val, n=1): 67 | self.val = val 68 | self.sum += val * n 69 | self.count += n 70 | self.avg = self.sum / self.count 71 | 72 | 73 | def train(args, train_loader_ir,train_loader_vi, model, criterion_ir, criterion_vi,criterion_ssim,criterion_gra,optimizer, epoch, scheduler=None): 74 | losses = AverageMeter() 75 | losses_ir = AverageMeter() 76 | losses_vi = AverageMeter() 77 | losses_ssim = AverageMeter() 78 | losses_gra= AverageMeter() 79 | weight = args.weight 80 | model.train() 81 | 82 | for i, (input,ir,vi) in tqdm(enumerate(train_loader_ir), total=len(train_loader_ir)): 83 | 84 | 85 | input = input.cuda() 86 | 87 | ir=ir.cuda() 88 | vi=vi.cuda() 89 | 90 | out = model(input) 91 | 92 | loss_ir = weight[0] * criterion_ir(out, ir) 93 | loss_vi = weight[1] * criterion_vi(out, vi) 94 | loss_ssim= weight[2] * criterion_ssim(out,ir, vi) 95 | loss_gra = weight[3] * criterion_gra(out, ir,vi) 96 | loss = loss_ir + loss_vi+loss_ssim+ loss_gra 97 | 98 | losses.update(loss.item(), input.size(0)) 99 | losses_ir.update(loss_ir.item(), input.size(0)) 100 | losses_vi.update(loss_vi.item(), input.size(0)) 101 | losses_ssim.update(loss_ssim.item(), input.size(0)) 102 | losses_gra.update(loss_gra.item(), input.size(0)) 103 | 104 | optimizer.zero_grad() 105 | loss.backward() 106 | optimizer.step() 107 | 108 | log = OrderedDict([ 109 | ('loss', losses.avg), 110 | ('loss_ir', losses_ir.avg), 111 | ('loss_vi', losses_vi.avg), 112 | ('loss_ssim', losses_ssim.avg), 113 | ('loss_gra', losses_gra.avg), 114 | ]) 115 | 116 | return log 117 | 118 | 119 | 120 | def main(): 121 | args = parse_args() 122 | 123 | if not os.path.exists('models/%s' %args.name): 124 | os.makedirs('models/%s' %args.name) 125 | 126 | 127 | with open('models/%s/args.txt' %args.name, 'w') as f: 128 | for arg in vars(args): 129 | print('%s: %s' %(arg, getattr(args, arg)), file=f) 130 | 131 | joblib.dump(args, 'models/%s/args.pkl' %args.name) 132 | cudnn.benchmark = True 133 | 134 | training_dir_ir = "..." 135 | folder_dataset_train_ir = glob.glob(training_dir_ir ) 136 | training_dir_vi = "..." 137 | 138 | folder_dataset_train_vi= glob.glob(training_dir_vi ) 139 | 140 | transform_train = transforms.Compose([transforms.ToTensor(), 141 | transforms.Normalize((0.485, 0.456, 0.406), 142 | (0.229, 0.224, 0.225)) 143 | ]) 144 | 145 | dataset_train_ir = GetDataset(imageFolderDataset=folder_dataset_train_ir, 146 | transform=transform_train) 147 | dataset_train_vi = GetDataset(imageFolderDataset=folder_dataset_train_vi, 148 | transform=transform_train) 149 | 150 | train_loader_ir = DataLoader(dataset_train_ir, 151 | shuffle=True, 152 | batch_size=args.batch_size) 153 | train_loader_vi = DataLoader(dataset_train_vi, 154 | shuffle=True, 155 | batch_size=args.batch_size) 156 | model = net(in_channel=2) 157 | if use_gpu: 158 | model = model.cuda() 159 | model.cuda() 160 | 161 | else: 162 | model = model 163 | criterion_ir = ir_loss 164 | criterion_vi = vi_loss 165 | criterion_ssim = ssim_loss 166 | criterion_gra = gra_loss 167 | optimizer = optim.Adam(model.parameters(), lr=args.lr, 168 | betas=args.betas, eps=args.eps) 169 | log = pd.DataFrame(index=[], 170 | columns=['epoch', 171 | 'loss', 172 | 'loss_ir', 173 | 'loss_vi', 174 | 'loss_ssim', 175 | 'loss_gra', 176 | ]) 177 | 178 | for epoch in range(args.epochs): 179 | 180 | train_log = train(args, train_loader_ir,train_loader_vi, model, criterion_ir, criterion_vi,criterion_ssim,criterion_gra, optimizer, epoch) 181 | tmp = pd.Series([ 182 | epoch + 1, 183 | train_log['loss'], 184 | train_log['loss_ir'], 185 | train_log['loss_vi'], 186 | train_log['loss_ssim'], 187 | train_log['loss_gra'], 188 | ], index=['epoch', 'loss', 'loss_ir', 'loss_vi', 'loss_ssim', 'loss_gra']) 189 | 190 | log = log.append(tmp, ignore_index=True) 191 | log.to_csv('models/%s/log.csv' %args.name, index=False) 192 | 193 | if (epoch+1) % 1 == 0: 194 | torch.save(model.state_dict(), 'models/%s/model_{}.pth'.format(epoch+1) %args.name) 195 | 196 | if __name__ == '__main__': 197 | main() 198 | 199 | 200 | -------------------------------------------------------------------------------- /evaluation/Q_Y.m: -------------------------------------------------------------------------------- 1 | function res=metricYang(im1,im2,fim) 2 | 3 | 4 | im1=double(im1); 5 | im2=double(im2); 6 | fim=double(fim); 7 | 8 | [mssim1, ssim_map1, sigma1_sq1,sigma2_sq1] = ssim_yang(im1, im2); 9 | [mssim2, ssim_map2, sigma1_sq2,sigma2_sq2] = ssim_yang(im1, fim); 10 | [mssim3, ssim_map3, sigma1_sq3,sigma2_sq3] = ssim_yang(im2, fim); 11 | 12 | bin_map=ssim_map1>=0.75; 13 | 14 | buffer=sigma1_sq1+sigma2_sq1; 15 | test=(buffer==0); test=test*0.5; 16 | sigma1_sq1=sigma1_sq1+test; sigma2_sq1=sigma2_sq1+test; 17 | buffer=sigma1_sq1+sigma2_sq1; 18 | ramda=sigma1_sq1./buffer; 19 | 20 | Q1=(ramda.*ssim_map2+(1-ramda).*ssim_map3).*bin_map; 21 | 22 | Q2=(max(ssim_map2,ssim_map3)).*(~bin_map); 23 | 24 | Q=mean2(Q1+Q2); 25 | 26 | res=Q; 27 | 28 | 29 | function [mssim, ssim_map, sigma1_sq,sigma2_sq] = ssim_yang(img1, img2) 30 | 31 | [M N] = size(img1); 32 | if ((M < 11) | (N < 11)) 33 | ssim_index = -Inf; 34 | ssim_map = -Inf; 35 | return 36 | end 37 | window = fspecial('gaussian', 7, 1.5); % 38 | 39 | L = 255; % 40 | 41 | C1 = 2e-16; 42 | C2 = 2e-16; 43 | 44 | window = window/sum(sum(window)); 45 | img1 = double(img1); 46 | img2 = double(img2); 47 | mu1 = filter2(window, img1, 'valid'); 48 | mu2 = filter2(window, img2, 'valid'); 49 | mu1_sq = mu1.*mu1; 50 | mu2_sq = mu2.*mu2; 51 | mu1_mu2 = mu1.*mu2; 52 | sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq; 53 | sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq; 54 | sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2; 55 | if (C1 > 0 & C2 > 0) 56 | ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2)); 57 | else 58 | numerator1 = 2*mu1_mu2 + C1; 59 | numerator2 = 2*sigma12 + C2; 60 | denominator1 = mu1_sq + mu2_sq + C1; 61 | denominator2 = sigma1_sq + sigma2_sq + C2; 62 | ssim_map = ones(size(mu1)); 63 | 64 | index = (denominator1.*denominator2 > 0); 65 | ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index)); 66 | index = (denominator1 ~= 0) & (denominator2 == 0); 67 | ssim_map(index) = numerator1(index)./denominator1(index); 68 | end 69 | 70 | mssim = mean2(ssim_map); 71 | 72 | 73 | return 74 | -------------------------------------------------------------------------------- /fusion results/001.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tthinking/DATFuse/0cd93f15fb5552c7aef79b7be56726c863d6c5ba/fusion results/001.bmp -------------------------------------------------------------------------------- /fusion results/2.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tthinking/DATFuse/0cd93f15fb5552c7aef79b7be56726c863d6c5ba/fusion results/2.bmp -------------------------------------------------------------------------------- /imgs/ablationTRM.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tthinking/DATFuse/0cd93f15fb5552c7aef79b7be56726c863d6c5ba/imgs/ablationTRM.jpg -------------------------------------------------------------------------------- /imgs/alpha.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tthinking/DATFuse/0cd93f15fb5552c7aef79b7be56726c863d6c5ba/imgs/alpha.jpg -------------------------------------------------------------------------------- /imgs/fig3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tthinking/DATFuse/0cd93f15fb5552c7aef79b7be56726c863d6c5ba/imgs/fig3.jpg -------------------------------------------------------------------------------- /imgs/fig4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tthinking/DATFuse/0cd93f15fb5552c7aef79b7be56726c863d6c5ba/imgs/fig4.jpg -------------------------------------------------------------------------------- /imgs/fig6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tthinking/DATFuse/0cd93f15fb5552c7aef79b7be56726c863d6c5ba/imgs/fig6.jpg -------------------------------------------------------------------------------- /imgs/gamma.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tthinking/DATFuse/0cd93f15fb5552c7aef79b7be56726c863d6c5ba/imgs/gamma.jpg -------------------------------------------------------------------------------- /imgs/lambda.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tthinking/DATFuse/0cd93f15fb5552c7aef79b7be56726c863d6c5ba/imgs/lambda.jpg -------------------------------------------------------------------------------- /imgs/woSecondDARM.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tthinking/DATFuse/0cd93f15fb5552c7aef79b7be56726c863d6c5ba/imgs/woSecondDARM.jpg -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch.nn.functional as F 4 | import torch 5 | from math import exp 6 | import torch.nn as nn 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | def ir_loss (fused_result,input_ir ): 11 | a=fused_result-input_ir 12 | b=torch.square(fused_result-input_ir) 13 | c=torch.mean(torch.square(fused_result-input_ir)) 14 | ir_loss=c 15 | return ir_loss 16 | 17 | def vi_loss (fused_result , input_vi): 18 | vi_loss=torch.mean(torch.square(fused_result-input_vi)) 19 | return vi_loss 20 | 21 | 22 | def ssim_loss (fused_result,input_ir,input_vi ): 23 | ssim_loss=ssim(fused_result,torch.maximum(input_ir,input_vi)) 24 | 25 | return ssim_loss 26 | 27 | 28 | def gra_loss( fused_result,input_ir, input_vi): 29 | gra_loss =torch.norm( Gradient(fused_result)- torch.maximum(Gradient(input_ir), Gradient(input_vi))) 30 | return gra_loss 31 | def gaussian(window_size, sigma): 32 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 33 | return gauss/gauss.sum() 34 | 35 | 36 | def create_window(window_size, channel=1): 37 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 38 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 39 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 40 | return window 41 | 42 | 43 | def ssim(img1, img2, window_size=11, window=None, val_range=None): 44 | if val_range is None: 45 | if torch.max(img1) > 128: 46 | max_val = 255 47 | else: 48 | max_val = 1 49 | 50 | if torch.min(img1) < -0.5: 51 | min_val = -1 52 | else: 53 | min_val = 0 54 | L = max_val - min_val 55 | else: 56 | L = val_range 57 | 58 | padd = 0 59 | (_, channel, height, width) = img1.size() 60 | if window is None: 61 | real_size = min(window_size, height, width) 62 | window = create_window(real_size, channel=channel).to(img1.device) 63 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel) 64 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel) 65 | mu1_sq = mu1.pow(2) 66 | mu2_sq = mu2.pow(2) 67 | mu1_mu2 = mu1 * mu2 68 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq 69 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq 70 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 71 | C1 = (0.01 * L) ** 2 72 | C2 = (0.03 * L) ** 2 73 | v1 = 2.0 * sigma12 + C2 74 | v2 = sigma1_sq + sigma2_sq + C2 75 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 76 | ret = ssim_map.mean() 77 | return 1-ret 78 | 79 | class gradient(nn.Module): 80 | def __init__(self): 81 | super(gradient, self).__init__() 82 | x_kernel = [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]] 83 | y_kernel = [[1, 2, 1], [0, 0, 0], [-1, -2, -1]] 84 | x_kernel = torch.FloatTensor(x_kernel).unsqueeze(0).unsqueeze(0) 85 | y_kernel = torch.FloatTensor(y_kernel).unsqueeze(0).unsqueeze(0) 86 | self.x_weight = nn.Parameter(data=x_kernel, requires_grad=False) 87 | self.y_weight = nn.Parameter(data=y_kernel, requires_grad=False) 88 | 89 | def forward(self, input): 90 | x_grad = torch.nn.functional.conv2d(input, self.x_weight, padding=1) 91 | y_grad = torch.nn.functional.conv2d(input, self.y_weight, padding=1) 92 | gradRes = torch.mean((x_grad + y_grad).float()) 93 | return gradRes 94 | 95 | def Gradient(x): 96 | gradient_model =gradient().to(device) 97 | g = gradient_model(x) 98 | return g -------------------------------------------------------------------------------- /model_10.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tthinking/DATFuse/0cd93f15fb5552c7aef79b7be56726c863d6c5ba/model_10.pth -------------------------------------------------------------------------------- /source images/ir/1.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tthinking/DATFuse/0cd93f15fb5552c7aef79b7be56726c863d6c5ba/source images/ir/1.bmp -------------------------------------------------------------------------------- /source images/ir/2.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tthinking/DATFuse/0cd93f15fb5552c7aef79b7be56726c863d6c5ba/source images/ir/2.bmp -------------------------------------------------------------------------------- /source images/vi/1.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tthinking/DATFuse/0cd93f15fb5552c7aef79b7be56726c863d6c5ba/source images/vi/1.bmp -------------------------------------------------------------------------------- /source images/vi/2.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tthinking/DATFuse/0cd93f15fb5552c7aef79b7be56726c863d6c5ba/source images/vi/2.bmp --------------------------------------------------------------------------------