├── Networks └── network.py ├── README.md ├── Test.py ├── Train.py ├── evaluation └── Q_Y.m ├── fusion results ├── 001.bmp └── 2.bmp ├── imgs ├── ablationTRM.jpg ├── alpha.jpg ├── fig3.jpg ├── fig4.jpg ├── fig6.jpg ├── gamma.jpg ├── lambda.jpg └── woSecondDARM.jpg ├── losses └── __init__.py ├── model_10.pth └── source images ├── ir ├── 1.bmp └── 2.bmp └── vi ├── 1.bmp └── 2.bmp /Networks/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 5 | 6 | class Convlutioanl(nn.Module): 7 | def __init__(self, in_channel, out_channel): 8 | super(Convlutioanl, self).__init__() 9 | self.padding=(2,2,2,2) 10 | self.conv=nn.Conv2d(in_channel,out_channel,kernel_size=5,padding=0,stride=1) 11 | self.bn=nn.BatchNorm2d(out_channel) 12 | self.relu=nn.ReLU(inplace=True) 13 | def forward(self, input): 14 | out=F.pad(input,self.padding,'replicate') 15 | out=self.conv(out) 16 | out=self.bn(out) 17 | out=self.relu(out) 18 | return out 19 | 20 | class Convlutioanl_out(nn.Module): 21 | def __init__(self, in_channel, out_channel): 22 | super(Convlutioanl_out, self).__init__() 23 | 24 | self.conv=nn.Conv2d(in_channel,out_channel,kernel_size=1,padding=0,stride=1) 25 | 26 | self.tanh=nn.Tanh() 27 | 28 | def forward(self, input): 29 | 30 | out=self.conv(input) 31 | 32 | out=self.tanh(out) 33 | return out 34 | class Fem(nn.Module): 35 | def __init__(self, in_channel, out_channel): 36 | super(Fem, self).__init__() 37 | self.padding = (1, 1, 1, 1) 38 | self.conv=nn.Conv2d(in_channel,out_channel,kernel_size=3,padding=0,stride=1) 39 | self.bn=nn.BatchNorm2d(out_channel) 40 | self.relu=nn.ReLU(inplace=True) 41 | 42 | def forward(self, input): 43 | out = F.pad(input, self.padding, 'replicate') 44 | out=self.conv(out) 45 | out=self.bn(out) 46 | out=self.relu(out) 47 | out = F.pad(out, self.padding, 'replicate') 48 | out=self.conv(out) 49 | out = self.bn(out) 50 | return out 51 | 52 | class Channel_attention(nn.Module): 53 | def __init__(self, channel, reduction=4): 54 | super(Channel_attention, self).__init__() 55 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 56 | self.fc=nn.Sequential( 57 | nn.Conv2d(channel,channel//reduction,1), 58 | nn.ReLU(inplace=True), 59 | nn.Conv2d(channel//reduction,channel,1)) 60 | self.sigmoid=nn.Sigmoid() 61 | def forward(self, input): 62 | out=self.avg_pool(input) 63 | out=self.fc(out) 64 | out=self.sigmoid(out) 65 | return out 66 | 67 | class Spatial_attention(nn.Module): 68 | def __init__(self, channel, reduction=4): 69 | super(Spatial_attention, self).__init__() 70 | self.body=nn.Sequential( 71 | nn.Conv2d(channel, channel//reduction,3,padding=1), 72 | nn.BatchNorm2d( channel//reduction), 73 | nn.ReLU(True), 74 | 75 | nn.Conv2d(channel // reduction, 1, 3, padding=1), 76 | nn.BatchNorm2d(1), 77 | nn.Sigmoid() 78 | ) 79 | def forward(self, input): 80 | return self.body(input) 81 | 82 | 83 | class WindowAttention(nn.Module): 84 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 85 | super().__init__() 86 | self.dim = dim 87 | self.window_size = window_size 88 | self.num_heads = num_heads 89 | head_dim = dim // num_heads 90 | self.scale = qk_scale or head_dim ** -0.5 91 | self.relative_position_bias_table = nn.Parameter( 92 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) 93 | coords_h = torch.arange(self.window_size[0]) 94 | coords_w = torch.arange(self.window_size[1]) 95 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) 96 | coords_flatten = torch.flatten(coords, 1) 97 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] 98 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() 99 | relative_coords[:, :, 0] += self.window_size[0] - 1 100 | relative_coords[:, :, 1] += self.window_size[1] - 1 101 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 102 | relative_position_index = relative_coords.sum(-1) 103 | self.register_buffer("relative_position_index", relative_position_index) 104 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 105 | self.attn_drop = nn.Dropout(attn_drop) 106 | self.proj = nn.Linear(dim, dim) 107 | self.proj_drop = nn.Dropout(proj_drop) 108 | trunc_normal_(self.relative_position_bias_table, std=.02) 109 | self.softmax = nn.Softmax(dim=-1) 110 | def forward(self, x, mask=None): 111 | B_, N, C = x.shape 112 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 113 | q, k, v = qkv[0], qkv[1], qkv[2] 114 | q = q * self.scale 115 | attn = (q @ k.transpose(-2, -1)) 116 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 117 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) 118 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() 119 | attn = attn + relative_position_bias.unsqueeze(0) 120 | if mask is not None: 121 | nW = mask.shape[0] 122 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 123 | attn = attn.view(-1, self.num_heads, N, N) 124 | attn = self.softmax(attn) 125 | else: 126 | attn = self.softmax(attn) 127 | attn = self.attn_drop(attn) 128 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 129 | x = self.proj(x) 130 | x = self.proj_drop(x) 131 | return x 132 | def extra_repr(self) -> str: 133 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 134 | def flops(self, N): 135 | flops = 0 136 | flops += N * self.dim * 3 * self.dim 137 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 138 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 139 | flops += N * self.dim * self.dim 140 | return flops 141 | 142 | 143 | class Mlp(nn.Module): 144 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 145 | super().__init__() 146 | out_features = out_features or in_features 147 | hidden_features = hidden_features or in_features 148 | self.fc1 = nn.Linear(in_features, hidden_features) 149 | self.act = act_layer() 150 | self.fc2 = nn.Linear(hidden_features, out_features) 151 | self.drop = nn.Dropout(drop) 152 | def forward(self, x): 153 | x = self.fc1(x) 154 | x = self.act(x) 155 | x = self.drop(x) 156 | x = self.fc2(x) 157 | x = self.drop(x) 158 | return x 159 | 160 | def window_partition(x, window_size): 161 | B, H, W, C = x.shape 162 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 163 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 164 | return windows 165 | def window_reverse(windows, window_size, H, W): 166 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 167 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 168 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 169 | return x 170 | 171 | class SwinTransformerBlock(nn.Module): 172 | def __init__(self, dim, input_resolution, num_heads, window_size=1, shift_size=0, 173 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 174 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 175 | super().__init__() 176 | self.dim = dim 177 | self.input_resolution = input_resolution 178 | self.num_heads = num_heads 179 | self.window_size = window_size 180 | self.shift_size = shift_size 181 | self.mlp_ratio = mlp_ratio 182 | if min(self.input_resolution) <= self.window_size: 183 | self.shift_size = 0 184 | self.window_size = min(self.input_resolution) 185 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 186 | self.norm1 = norm_layer(dim) 187 | self.attn = WindowAttention( 188 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 189 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 190 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 191 | self.norm2 = norm_layer(dim) 192 | mlp_hidden_dim = int(dim * mlp_ratio) 193 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 194 | if self.shift_size > 0: 195 | attn_mask = self.calculate_mask(self.input_resolution) 196 | else: 197 | attn_mask = None 198 | self.register_buffer("attn_mask", attn_mask) 199 | def calculate_mask(self, x_size): 200 | H, W = x_size 201 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 202 | h_slices = (slice(0, -self.window_size), 203 | slice(-self.window_size, -self.shift_size), 204 | slice(-self.shift_size, None)) 205 | w_slices = (slice(0, -self.window_size), 206 | slice(-self.window_size, -self.shift_size), 207 | slice(-self.shift_size, None)) 208 | cnt = 0 209 | for h in h_slices: 210 | for w in w_slices: 211 | img_mask[:, h, w, :] = cnt 212 | cnt += 1 213 | mask_windows = window_partition(img_mask, self.window_size) 214 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 215 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 216 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 217 | return attn_mask 218 | def forward(self, x, x_size): 219 | B,C,H,W= x.shape 220 | x=x.view(B,H,W,C) 221 | shortcut = x 222 | shape=x.view(H*W*B,C) 223 | x = self.norm1(shape) 224 | x = x.view(B, H, W, C) 225 | if self.shift_size > 0: 226 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 227 | else: 228 | shifted_x = x 229 | x_windows = window_partition(shifted_x, self.window_size) 230 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) 231 | if self.input_resolution == x_size: 232 | attn_windows = self.attn(x_windows, mask=self.attn_mask) 233 | else: 234 | attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) 235 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 236 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 237 | if self.shift_size > 0: 238 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 239 | else: 240 | x = shifted_x 241 | x = shortcut + self.drop_path(x) 242 | x = x + self.drop_path(self.mlp(self.norm2(x))) 243 | B,H,W,C=x.shape 244 | x=x.view(B,C,H,W) 245 | return x 246 | def extra_repr(self) -> str: 247 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 248 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 249 | def flops(self): 250 | flops = 0 251 | H, W = self.input_resolution 252 | flops += self.dim * H * W 253 | nW = H * W / self.window_size / self.window_size 254 | flops += nW * self.attn.flops(self.window_size * self.window_size) 255 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 256 | flops += self.dim * H * W 257 | return flops 258 | 259 | class PatchEmbed(nn.Module): 260 | def __init__(self, img_size=120, patch_size=4, in_chans=6, embed_dim=96, norm_layer=None): 261 | super().__init__() 262 | img_size = to_2tuple(img_size) 263 | patch_size = to_2tuple(patch_size) 264 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 265 | self.img_size = img_size 266 | self.patch_size = patch_size 267 | self.patches_resolution = patches_resolution 268 | self.num_patches = patches_resolution[0] * patches_resolution[1] 269 | self.in_chans = in_chans 270 | self.embed_dim = embed_dim 271 | if norm_layer is not None: 272 | self.norm = norm_layer(embed_dim) 273 | else: 274 | self.norm = None 275 | def forward(self, x): 276 | x = x.flatten(2).transpose(1, 2) # B Ph*Pw C 277 | if self.norm is not None: 278 | x = self.norm(x) 279 | return x 280 | def flops(self): 281 | flops = 0 282 | H, W = self.img_size 283 | if self.norm is not None: 284 | flops += H * W * self.embed_dim 285 | return flops 286 | class BasicLayer(nn.Module): 287 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 288 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 289 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 290 | super().__init__() 291 | self.dim = dim 292 | self.input_resolution = input_resolution 293 | self.depth = depth 294 | self.use_checkpoint = use_checkpoint 295 | self.blocks = nn.ModuleList([ 296 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 297 | num_heads=num_heads, window_size=window_size, 298 | shift_size=0 if (i % 2 == 0) else window_size // 2, 299 | mlp_ratio=mlp_ratio, 300 | qkv_bias=qkv_bias, qk_scale=qk_scale, 301 | drop=drop, attn_drop=attn_drop, 302 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 303 | norm_layer=norm_layer) 304 | for i in range(depth)]) 305 | if downsample is not None: 306 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 307 | else: 308 | self.downsample = None 309 | def forward(self, x, x_size): 310 | for blk in self.blocks: 311 | x = blk(x, x_size) 312 | if self.downsample is not None: 313 | x = self.downsample(x) 314 | return x 315 | def extra_repr(self) -> str: 316 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 317 | def flops(self): 318 | flops = 0 319 | for blk in self.blocks: 320 | flops += blk.flops() 321 | if self.downsample is not None: 322 | flops += self.downsample.flops() 323 | return flops 324 | class MODEL(nn.Module): 325 | def __init__(self, img_size=120,patch_size=4,embed_dim=96,num_heads=8, window_size=1,in_channel=2, out_channel=16,output_channel=1, 326 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., patch_norm=True,depth=2,downsample=None, 327 | drop_path=0., norm_layer=nn.LayerNorm,use_checkpoint=False ): 328 | super(MODEL, self).__init__() 329 | self.convolutional = Convlutioanl(in_channel, out_channel) 330 | self.convolutional_out = Convlutioanl_out( out_channel,output_channel) 331 | 332 | self.fem = Fem(out_channel, out_channel) 333 | self.cam = Channel_attention(out_channel) 334 | self.sam = Spatial_attention(out_channel) 335 | self.relu=nn.ReLU(True) 336 | self.patch_norm = patch_norm 337 | self.patch_embed = PatchEmbed( 338 | img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, 339 | norm_layer=norm_layer if self.patch_norm else None) 340 | patches_resolution = self.patch_embed.patches_resolution 341 | self.patches_resolution = patches_resolution 342 | self.basicLayer=BasicLayer(dim= out_channel, 343 | input_resolution=(patches_resolution[0],patches_resolution[1]), 344 | depth=depth, 345 | num_heads=num_heads, 346 | window_size=window_size, 347 | mlp_ratio=mlp_ratio, 348 | qkv_bias=qkv_bias, qk_scale=qk_scale, 349 | drop=drop, attn_drop=attn_drop, 350 | drop_path=drop_path, 351 | norm_layer=norm_layer, 352 | downsample=downsample, 353 | use_checkpoint=use_checkpoint) 354 | def forward(self, input): 355 | convolutioanl = self.convolutional(input) 356 | fem=self.fem(convolutioanl) 357 | cam=self.cam(fem) 358 | sam=self.sam(fem) 359 | fem_cam=fem*cam 360 | fem_sam=fem*sam 361 | add=fem_cam+fem_sam+convolutioanl 362 | encode=self.relu(add) 363 | encode_size = (encode.shape[2], encode.shape[3]) 364 | Transformer=self.basicLayer(encode, encode_size) 365 | de_fem=self.fem(Transformer) 366 | de_cam=self.cam(de_fem) 367 | de_sam=self.sam(de_fem) 368 | de_fem_cam=de_fem*de_cam 369 | de_fem_sam=de_fem*de_sam 370 | de_add=de_fem_cam+de_fem_sam+ de_fem 371 | out=self.convolutional_out (de_add) 372 | return out 373 | 374 | 375 | 376 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DATFuse: Infrared and Visible Image Fusion via Dual Attention Transformer (IEEE TCSVT 2023) 2 | 3 | This is the official implementation of the DATFuse model proposed in the paper ([DATFuse: Infrared and Visible Image Fusion via Dual Attention Transformer](https://ieeexplore.ieee.org/document/10006826)) with Pytorch. 4 | 5 | 6 | ## Comparison with SOTA methods 7 | 8 | ### Fusion results on TNO dataset 9 |  10 | 11 | ### Fusion results on RoadScene dataset 12 |  13 | 14 | 15 | 16 | ## Ablation study on network structure 17 |  18 | 19 | 20 | ## Ablation study on the number of TRMs 21 |  22 | 23 | ## Ablation study on the second DARM 24 |  25 | 26 | ## Impact of weight parameters in the loss function 27 | 28 | ### Impact of weight parameter α on fusion performance with λ and γ fixed as 100 and 10, respectively. 29 |  30 | 31 | ### Impact of weight parameter λ on fusion performance with α and γ fixed as 1 and 10, respectively. 32 |  33 | 34 | ### Impact of weight parameter γ on fusion performance with α and λ fixed as 1 and 100, respectively. 35 |  36 | 37 | ## Computational efficiency comparisons 38 | 39 | ### Average running time for generating a fused image (Unit: seconds) 40 | 41 | 42 | 43 | | Method | TNO Dataset | RoadScene Dataset | 44 | | :---: | :---: | :---: | 45 | | MDLatLRR | 26.0727 | 11.7310 | 46 | |AUIF| 0.1119 | 0.0726 | 47 | |DenseFuse| 0.5663 | 0.3190 | 48 | |FusionGAN| 2.6796 | 1.1442 | 49 | |GANMcC| 5.6752 | 2.3813 | 50 | |RFN_Nest| 2.3096| 0.9423 | 51 | |CSF| 10.3311 |5.5395 | 52 | |MFEIF| 0.0793 |0.0494 | 53 | |PPTFusion| 1.4150 |0.8656 | 54 | |SwinFuse| 3.2687 | 1.6478 | 55 | |DATFuse| 0.0257 | 0.0141| 56 | 57 | # Cite the paper 58 | If this work is helpful to you, please cite it as:
59 |@ARTICLE{Tang_2023_DATFuse,
68 | author={Tang, Wei and He, Fazhi and Liu, Yu and Duan, Yansong and Si, Tongzhen},
69 | journal={IEEE Transactions on Circuits and Systems for Video Technology},
70 | title={DATFuse: Infrared and Visible Image Fusion via Dual Attention Transformer},
71 | year={2023},
72 | volume={33},
73 | number={7},
74 | pages={3159-3172},
75 | doi={10.1109/TCSVT.2023.3234340}}
76 |