├── Networks └── networks.py ├── README.md ├── Test.py ├── Train.py ├── images ├── ESIhighlyCitedPaper.png ├── MRI.bmp └── SPECT.bmp └── losses └── __init__.py /Networks/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | 8 | def upsample(src, tar): 9 | src = F.interpolate(src, size=tar.shape[2:], mode='bilinear', align_corners=True) 10 | return src 11 | class Conv1(nn.Module): 12 | def __init__(self, in_channel, out_channel): 13 | super(Conv1, self).__init__() 14 | 15 | self.conv=nn.Conv2d(in_channel,out_channel,kernel_size=1,padding=0,stride=1) 16 | self.bn=nn.BatchNorm2d(out_channel) 17 | self.relu=nn.LeakyReLU(inplace=True) 18 | 19 | def forward(self, input): 20 | 21 | out=self.conv(input) 22 | out=self.bn(out) 23 | out=self.relu(out) 24 | return out 25 | class Convlutioanl(nn.Module): 26 | def __init__(self, in_channel, out_channel): 27 | super(Convlutioanl, self).__init__() 28 | self.padding=(1,1,1,1) 29 | self.conv=nn.Conv2d(in_channel,out_channel,kernel_size=3,padding=0,stride=1) 30 | self.bn=nn.BatchNorm2d(out_channel) 31 | self.relu=nn.LeakyReLU(inplace=True) 32 | 33 | def forward(self, input): 34 | out=F.pad(input,self.padding,'replicate') 35 | out=self.conv(out) 36 | out=self.bn(out) 37 | out=self.relu(out) 38 | return out 39 | 40 | class Conv5_5(nn.Module): 41 | def __init__(self, in_channel, out_channel): 42 | super(Conv5_5, self).__init__() 43 | self.padding=(2,2,2,2) 44 | self.conv=nn.Conv2d(in_channel,out_channel,kernel_size=5,padding=0,stride=1) 45 | self.bn=nn.BatchNorm2d(out_channel) 46 | self.relu=nn.LeakyReLU(inplace=True) 47 | 48 | def forward(self, input): 49 | out=F.pad(input,self.padding,'replicate') 50 | out=self.conv(out) 51 | out=self.bn(out) 52 | out=self.relu(out) 53 | return out 54 | 55 | class Conv7_7(nn.Module): 56 | def __init__(self, in_channel, out_channel): 57 | super(Conv7_7, self).__init__() 58 | self.padding=(3,3,3,3) 59 | self.conv=nn.Conv2d(in_channel,out_channel,kernel_size=7,padding=0,stride=1) 60 | self.bn=nn.BatchNorm2d(out_channel) 61 | self.relu=nn.LeakyReLU(inplace=True) 62 | 63 | def forward(self, input): 64 | out=F.pad(input,self.padding,'replicate') 65 | out=self.conv(out) 66 | out=self.bn(out) 67 | out=self.relu(out) 68 | return out 69 | 70 | class Convlutioanl_out(nn.Module): 71 | def __init__(self, in_channel, out_channel): 72 | super(Convlutioanl_out, self).__init__() 73 | # self.padding=(2,2,2,2) 74 | self.conv=nn.Conv2d(in_channel,out_channel,kernel_size=1,padding=0,stride=1) 75 | # self.bn=nn.BatchNorm2d(out_channel) 76 | self.sigmoid=nn.Sigmoid() 77 | 78 | def forward(self, input): 79 | # out=F.pad(input,self.padding,'replicate') 80 | out=self.conv(input) 81 | # out=self.bn(out) 82 | out=self.sigmoid(out) 83 | return out 84 | 85 | class WindowAttention(nn.Module): 86 | 87 | 88 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 89 | 90 | super().__init__() 91 | self.dim = dim 92 | self.window_size = window_size 93 | self.num_heads = num_heads 94 | head_dim = dim // num_heads 95 | self.scale = qk_scale or head_dim ** -0.5 96 | 97 | self.relative_position_bias_table = nn.Parameter( 98 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) 99 | coords_h = torch.arange(self.window_size[0]) 100 | coords_w = torch.arange(self.window_size[1]) 101 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) 102 | coords_flatten = torch.flatten(coords, 1) 103 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] 104 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() 105 | relative_coords[:, :, 0] += self.window_size[0] - 1 106 | relative_coords[:, :, 1] += self.window_size[1] - 1 107 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 108 | relative_position_index = relative_coords.sum(-1) 109 | self.register_buffer("relative_position_index", relative_position_index) 110 | 111 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 112 | self.attn_drop = nn.Dropout(attn_drop) 113 | self.proj = nn.Linear(dim, dim) 114 | 115 | self.proj_drop = nn.Dropout(proj_drop) 116 | 117 | trunc_normal_(self.relative_position_bias_table, std=.02) 118 | self.softmax = nn.Softmax(dim=-1) 119 | 120 | def forward(self, x, mask=None): 121 | 122 | B_, N, C = x.shape 123 | 124 | A=self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads) 125 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 126 | q, k, v = qkv[0], qkv[1], qkv[2] 127 | 128 | q = q * self.scale 129 | attn = (q @ k.transpose(-2, -1)) 130 | 131 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 132 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) 133 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() 134 | attn = attn + relative_position_bias.unsqueeze(0) 135 | 136 | if mask is not None: 137 | nW = mask.shape[0] 138 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 139 | attn = attn.view(-1, self.num_heads, N, N) 140 | attn = self.softmax(attn) 141 | else: 142 | attn = self.softmax(attn) 143 | 144 | attn = self.attn_drop(attn) 145 | 146 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 147 | x = self.proj(x) 148 | x = self.proj_drop(x) 149 | return x 150 | 151 | def extra_repr(self) -> str: 152 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 153 | 154 | def flops(self, N): 155 | 156 | flops = 0 157 | 158 | flops += N * self.dim * 3 * self.dim 159 | 160 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 161 | 162 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 163 | 164 | flops += N * self.dim * self.dim 165 | return flops 166 | 167 | 168 | class Mlp(nn.Module): 169 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 170 | super().__init__() 171 | out_features = out_features or in_features 172 | hidden_features = hidden_features or in_features 173 | self.fc1 = nn.Linear(in_features, hidden_features) 174 | self.act = act_layer() 175 | self.fc2 = nn.Linear(hidden_features, out_features) 176 | self.drop = nn.Dropout(drop) 177 | 178 | def forward(self, x): 179 | x = self.fc1(x) 180 | x = self.act(x) 181 | x = self.drop(x) 182 | x = self.fc2(x) 183 | x = self.drop(x) 184 | return x 185 | 186 | 187 | 188 | def window_partition(x, window_size): 189 | 190 | B, H, W, C = x.shape 191 | 192 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 193 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 194 | return windows 195 | 196 | 197 | def window_reverse(windows, window_size, H, W): 198 | 199 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 200 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 201 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 202 | return x 203 | 204 | class SwinTransformerBlock(nn.Module): 205 | 206 | 207 | def __init__(self, dim, input_resolution, num_heads, window_size=1, shift_size=0, 208 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 209 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 210 | super().__init__() 211 | self.dim = dim 212 | self.input_resolution = input_resolution 213 | self.num_heads = num_heads 214 | self.window_size = window_size 215 | self.shift_size = shift_size 216 | self.mlp_ratio = mlp_ratio 217 | if min(self.input_resolution) <= self.window_size: 218 | 219 | self.shift_size = 0 220 | self.window_size = min(self.input_resolution) 221 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 222 | 223 | self.norm1 = norm_layer(dim) 224 | self.attn = WindowAttention( 225 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 226 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 227 | 228 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 229 | self.norm2 = norm_layer(dim) 230 | mlp_hidden_dim = int(dim * mlp_ratio) 231 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 232 | 233 | if self.shift_size > 0: 234 | attn_mask = self.calculate_mask(self.input_resolution) 235 | else: 236 | attn_mask = None 237 | 238 | self.register_buffer("attn_mask", attn_mask) 239 | 240 | def calculate_mask(self, x_size): 241 | 242 | H, W = x_size 243 | img_mask = torch.zeros((1, H, W, 1)) 244 | h_slices = (slice(0, -self.window_size), 245 | slice(-self.window_size, -self.shift_size), 246 | slice(-self.shift_size, None)) 247 | w_slices = (slice(0, -self.window_size), 248 | slice(-self.window_size, -self.shift_size), 249 | slice(-self.shift_size, None)) 250 | cnt = 0 251 | for h in h_slices: 252 | for w in w_slices: 253 | img_mask[:, h, w, :] = cnt 254 | cnt += 1 255 | 256 | mask_windows = window_partition(img_mask, self.window_size) 257 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 258 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 259 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 260 | 261 | return attn_mask 262 | 263 | def forward(self, x, x_size): 264 | H, W = x_size 265 | 266 | B,C,H,W= x.shape 267 | 268 | x=x.view(B,H,W,C) 269 | shortcut = x 270 | shape=x.view(H*W*B,C) 271 | x = self.norm1(shape) 272 | x = x.view(B, H, W, C) 273 | 274 | if self.shift_size > 0: 275 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 276 | else: 277 | shifted_x = x 278 | 279 | x_windows = window_partition(shifted_x, self.window_size) 280 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) 281 | if self.input_resolution == x_size: 282 | attn_windows = self.attn(x_windows, mask=self.attn_mask) 283 | else: 284 | attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) 285 | 286 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 287 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) 288 | if self.shift_size > 0: 289 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 290 | else: 291 | x = shifted_x 292 | 293 | 294 | x = shortcut + self.drop_path(x) 295 | x = x + self.drop_path(self.mlp(self.norm2(x))) 296 | B,H,W,C=x.shape 297 | x=x.view(B,C,H,W) 298 | 299 | 300 | return x 301 | 302 | def extra_repr(self) -> str: 303 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 304 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 305 | 306 | def flops(self): 307 | flops = 0 308 | H, W = self.input_resolution 309 | 310 | flops += self.dim * H * W 311 | 312 | nW = H * W / self.window_size / self.window_size 313 | flops += nW * self.attn.flops(self.window_size * self.window_size) 314 | 315 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 316 | 317 | flops += self.dim * H * W 318 | return flops 319 | 320 | 321 | 322 | class PatchEmbed(nn.Module): 323 | 324 | 325 | def __init__(self, img_size=120, patch_size=4, in_chans=6, embed_dim=96, norm_layer=None): 326 | super().__init__() 327 | img_size = to_2tuple(img_size) 328 | patch_size = to_2tuple(patch_size) 329 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 330 | self.img_size = img_size 331 | self.patch_size = patch_size 332 | self.patches_resolution = patches_resolution 333 | self.num_patches = patches_resolution[0] * patches_resolution[1] 334 | 335 | self.in_chans = in_chans 336 | self.embed_dim = embed_dim 337 | 338 | if norm_layer is not None: 339 | self.norm = norm_layer(embed_dim) 340 | else: 341 | self.norm = None 342 | 343 | def forward(self, x): 344 | x = x.flatten(2).transpose(1, 2) 345 | if self.norm is not None: 346 | x = self.norm(x) 347 | return x 348 | 349 | def flops(self): 350 | flops = 0 351 | H, W = self.img_size 352 | if self.norm is not None: 353 | flops += H * W * self.embed_dim 354 | return flops 355 | 356 | class BasicLayer(nn.Module): 357 | 358 | 359 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 360 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 361 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 362 | 363 | super().__init__() 364 | self.dim = dim 365 | self.input_resolution = input_resolution 366 | self.depth = depth 367 | self.use_checkpoint = use_checkpoint 368 | 369 | 370 | self.blocks = nn.ModuleList([ 371 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 372 | num_heads=num_heads, window_size=window_size, 373 | shift_size=0 if (i % 2 == 0) else window_size // 2, 374 | mlp_ratio=mlp_ratio, 375 | qkv_bias=qkv_bias, qk_scale=qk_scale, 376 | drop=drop, attn_drop=attn_drop, 377 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 378 | norm_layer=norm_layer) 379 | for i in range(depth)]) 380 | 381 | if downsample is not None: 382 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 383 | else: 384 | self.downsample = None 385 | 386 | def forward(self, x, x_size): 387 | for blk in self.blocks: 388 | 389 | x = blk(x, x_size) 390 | if self.downsample is not None: 391 | x = self.downsample(x) 392 | return x 393 | 394 | def extra_repr(self) -> str: 395 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 396 | 397 | def flops(self): 398 | flops = 0 399 | for blk in self.blocks: 400 | flops += blk.flops() 401 | if self.downsample is not None: 402 | flops += self.downsample.flops() 403 | return flops 404 | 405 | class Channel_attention(nn.Module): 406 | def __init__(self, channel, reduction): 407 | super(Channel_attention, self).__init__() 408 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 409 | self.fc=nn.Sequential( 410 | nn.Conv2d(channel,channel//reduction,1), 411 | nn.ReLU(inplace=True), 412 | nn.Conv2d(channel//reduction,channel,1)) 413 | self.sigmoid=nn.Sigmoid() 414 | def forward(self, input): 415 | out=self.avg_pool(input) 416 | out=self.fc(out) 417 | out=self.sigmoid(out) 418 | return out 419 | 420 | class MODEL(nn.Module): 421 | def __init__(self, in_channel=1,out_channel_16=16,out_channel_256=256, out_channel_32=32,out_channel=64,out_channel_512=512,output_channel=1,out_channel_128=128, 422 | img_size=120, patch_size=4, embed_dim=96, num_heads=8, window_size=1,out_channel_448=448,out_channel_896=896,out_channel_336=336, 423 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., patch_norm=True, depth=2, 424 | downsample=None, 425 | drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False 426 | ): 427 | super(MODEL, self).__init__() 428 | 429 | self.convInput= Convlutioanl(in_channel, out_channel_16) 430 | self.conv5 = Conv5_5(out_channel_16, out_channel) 431 | self.conv7 = Conv7_7( out_channel, out_channel_256) 432 | self.conv64 = Convlutioanl(out_channel, out_channel_256) 433 | self.conv = Convlutioanl(out_channel*2, out_channel) 434 | self.convolutional_out =Convlutioanl_out(out_channel_32, output_channel) 435 | self.conv16_16= Convlutioanl(out_channel_16, out_channel_16) 436 | self.conv64_16 = Convlutioanl(out_channel, out_channel_16) 437 | self.conv256_16 = Convlutioanl(out_channel_256, out_channel_16) 438 | 439 | self.cam1 = Channel_attention(out_channel_16,4) 440 | self.cam2 = Channel_attention(out_channel,8) 441 | self.cam3 = Channel_attention(out_channel_256,16) 442 | 443 | 444 | self.conv1=Conv1( out_channel_32, out_channel_16) 445 | 446 | self.conv2=Conv1( out_channel_128, out_channel) 447 | self.conv3 = Conv1(out_channel_512, out_channel_256) 448 | self.conv4 = Conv1(out_channel_336, out_channel_16) 449 | self.patch_norm = patch_norm 450 | 451 | self.patch_embed = PatchEmbed( 452 | img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, 453 | norm_layer=norm_layer if self.patch_norm else None) 454 | num_patches = self.patch_embed.num_patches 455 | patches_resolution = self.patch_embed.patches_resolution 456 | self.patches_resolution = patches_resolution 457 | self.basicLayer1 = BasicLayer(dim=out_channel_16, 458 | input_resolution=(patches_resolution[0], patches_resolution[1]), 459 | depth=depth, 460 | num_heads=num_heads, 461 | window_size=window_size, 462 | mlp_ratio=mlp_ratio, 463 | qkv_bias=qkv_bias, qk_scale=qk_scale, 464 | drop=drop, attn_drop=attn_drop, 465 | drop_path=drop_path, 466 | norm_layer=norm_layer, 467 | downsample=downsample, 468 | use_checkpoint=use_checkpoint) 469 | 470 | self.basicLayer2 = BasicLayer(dim=out_channel, 471 | input_resolution=(patches_resolution[0], patches_resolution[1]), 472 | depth=depth, 473 | num_heads=num_heads, 474 | window_size=window_size, 475 | mlp_ratio=mlp_ratio, 476 | qkv_bias=qkv_bias, qk_scale=qk_scale, 477 | drop=drop, attn_drop=attn_drop, 478 | drop_path=drop_path, 479 | norm_layer=norm_layer, 480 | downsample=downsample, 481 | use_checkpoint=use_checkpoint) 482 | 483 | self.basicLayer3 = BasicLayer(dim=out_channel_256, 484 | input_resolution=(patches_resolution[0], patches_resolution[1]), 485 | depth=depth, 486 | num_heads=num_heads, 487 | window_size=window_size, 488 | mlp_ratio=mlp_ratio, 489 | qkv_bias=qkv_bias, qk_scale=qk_scale, 490 | drop=drop, attn_drop=attn_drop, 491 | drop_path=drop_path, 492 | norm_layer=norm_layer, 493 | downsample=downsample, 494 | use_checkpoint=use_checkpoint) 495 | 496 | self.basicLayer4 = BasicLayer(dim=out_channel_448, 497 | input_resolution=(patches_resolution[0], patches_resolution[1]), 498 | depth=depth, 499 | num_heads=num_heads, 500 | window_size=window_size, 501 | mlp_ratio=mlp_ratio, 502 | qkv_bias=qkv_bias, qk_scale=qk_scale, 503 | drop=drop, attn_drop=attn_drop, 504 | drop_path=drop_path, 505 | norm_layer=norm_layer, 506 | downsample=downsample, 507 | use_checkpoint=use_checkpoint) 508 | 509 | def forward(self, ir,vi): 510 | 511 | convInput_A1 = self.convInput(ir) 512 | convInput_A2 = self.conv5(convInput_A1 ) 513 | convInput_A3 = self.conv7(convInput_A2) 514 | 515 | 516 | 517 | convInput_B1 = self.convInput(vi) 518 | convInput_B2 = self.conv5(convInput_B1) 519 | convInput_B3 = self.conv7(convInput_B2) 520 | 521 | camA1=self.cam1(convInput_A1) 522 | camA1_1 = torch.cat((camA1*convInput_A1, convInput_B1),1) 523 | convA1=self.conv1( camA1_1) 524 | camA1_2=self.cam1(convA1)* convA1 525 | 526 | encode_sizeA1 = ( camA1_2.shape[2], camA1_2.shape[3]) 527 | TransformerA1 = self.basicLayer1( camA1_2, encode_sizeA1) 528 | 529 | camB1 = self.cam1(convInput_B1) 530 | camB1_1 = torch.cat((camB1 * convInput_B1, convInput_A1), 1) 531 | convB1 = self.conv1(camB1_1) 532 | camB1_2 = self.cam1(convB1) * convB1 533 | 534 | encode_sizeB1 = (camB1_2.shape[2], camB1_2.shape[3]) 535 | TransformerB1 = self.basicLayer1(camB1_2, encode_sizeB1) 536 | 537 | camA2= self.cam2(convInput_A2) 538 | 539 | camA2_1 = torch.cat((camA2 * convInput_A2, convInput_B2), 1) 540 | convA2 = self.conv2(camA2_1) 541 | camA2_2 = self.cam2(convA2) * convA2 542 | 543 | encode_sizeA2 = (camA2_2.shape[2], camA2_2.shape[3]) 544 | TransformerA2 = self.basicLayer2(camA2_2, encode_sizeA2) 545 | 546 | camB2= self.cam2(convInput_B2) 547 | camB2_1 = torch.cat((camB2* convInput_B2, convInput_A2), 1) 548 | convB2= self.conv2(camB2_1) 549 | camB2_2 = self.cam2(convB2) * convB2 550 | 551 | encode_sizeB2= (camB2_2.shape[2], camB2_2.shape[3]) 552 | TransformerB2 = self.basicLayer2(camB2_2, encode_sizeB2) 553 | 554 | camA3=self.cam3(convInput_A3) 555 | camA3_1 = torch.cat(( camA3* convInput_A3, convInput_B3), 1) 556 | convA3 = self.conv3(camA3_1) 557 | camA3_2 = self.cam3(convA3) * convA3 558 | 559 | encode_sizeA3 = (camA3_2.shape[2], camA3_2.shape[3]) 560 | TransformerA3 = self.basicLayer3(camA3_2, encode_sizeA3) 561 | 562 | camB3=self.cam3(convInput_B3) 563 | camB3_1 = torch.cat((camB3 * convInput_B3, convInput_A3), 1) 564 | convB3 = self.conv3(camB3_1) 565 | camB3_2 = self.cam3(convB3) * convB3 566 | 567 | encode_sizeB3 = (camB3_2.shape[2], camB3_2.shape[3]) 568 | TransformerB3 = self.basicLayer3(camB3_2, encode_sizeB3) 569 | 570 | 571 | catA=torch.cat(( TransformerA1, TransformerA2, TransformerA3),1) 572 | convA4=self.conv4( catA) 573 | 574 | catB = torch.cat((TransformerB1, TransformerB2, TransformerB3), 1) 575 | 576 | convB4 = self.conv4(catB) 577 | 578 | camA4=self.cam1( convA4) 579 | camA4_1 = torch.cat((camA4 * convA4, convB4 ), 1) 580 | convA5 = self.conv1(camA4_1) 581 | camA4_2 = self.cam1(convA5) * convA5 582 | 583 | encode_sizeA4 = (camA4_2.shape[2], camA4_2.shape[3]) 584 | TransformerA4 = self.basicLayer1(camA4_2, encode_sizeA4) 585 | 586 | camB4 =self.cam1(convB4) 587 | camB4_1 = torch.cat(( camB4* convB4, convA4), 1) 588 | convB5 = self.conv1(camB4_1) 589 | camB4_2 = self.cam1(convB5) * convB5 590 | 591 | encode_sizeB4 = (camB4_2.shape[2], camB4_2.shape[3]) 592 | TransformerB4 = self.basicLayer1(camB4_2, encode_sizeB4) 593 | 594 | 595 | 596 | cat=torch.cat(( TransformerA4 , TransformerB4 ), 1) 597 | 598 | 599 | out = self.convolutional_out(cat ) 600 | return out 601 | 602 | 603 | 604 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FATFusion: A Functional–Anatomical Transformer for Medical Image Fusion (IPM 2024). 2 | 3 | This is the official implementation of the FATFusion model proposed in the paper ([FATFusion: A Functional–Anatomical Transformer for Medical Image Fusion](https://www.sciencedirect.com/science/article/pii/S0306457324000475)) with Pytorch. 4 | 5 |

Requirements

6 | 16 | 17 | # Tips: 18 | Dealing with RGB input: 19 | Refer to [DPCN-Fusion](https://github.com/tthinking/DPCN-Fusion/blob/master/test.py). 20 | 21 | Dataset is [here](http://www.med.harvard.edu/AANLIB/home.html). 22 | 23 | The Trained Model is [here](https://drive.google.com/drive/folders/137ntn1LPZt67gg-fP5yI5XN1QO37Qyhb). 24 | 25 | 26 | # Cite the paper 27 | If this work is helpful to you, please cite it as:

28 |
@ARTICLE{Tang_2024_FATFusion,
37 |   author={Tang, Wei and He, Fazhi},
38 |   journal={Information Processing & Management}, 
39 |   title={FATFusion: A Functional–Anatomical Transformer for Medical Image Fusion}, 
40 |   year={2024},
41 |   volume={61},
42 |   number={4},
43 |   pages={103687},
44 |   doi={10.1016/j.ipm.2024.103687}}
45 | 
46 | 47 | If you have any questions, feel free to contact me (weitang2021@whu.edu.cn). 48 | -------------------------------------------------------------------------------- /Test.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import os 4 | import torch 5 | import cv2 6 | import time 7 | import imageio 8 | import math 9 | 10 | import torchvision.transforms as transforms 11 | 12 | from Networks.networks import MODEL as net 13 | from thop import profile 14 | 15 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 16 | 17 | device = torch.device('cuda:0') 18 | 19 | 20 | model = net(in_channel=1) 21 | 22 | model_path = "./model.pth" 23 | use_gpu = torch.cuda.is_available() 24 | 25 | if use_gpu: 26 | print('GPU Mode Acitavted') 27 | model = model.cuda() 28 | model.cuda() 29 | 30 | model.load_state_dict(torch.load(model_path)) 31 | print(model) 32 | else: 33 | print('CPU Mode Acitavted') 34 | state_dict = torch.load(model_path, map_location='cpu') 35 | 36 | model.load_state_dict(state_dict) 37 | 38 | 39 | def fusion_gray(): 40 | 41 | 42 | 43 | path1 = './images/SPECT.bmp' 44 | 45 | path2 = './images/MRI.bmp' 46 | 47 | 48 | img1 = Image.open(path1).convert('L') 49 | img2 = Image.open(path2).convert('L') 50 | 51 | img1_read = np.array(img1) 52 | img2_read = np.array(img2) 53 | h = img1_read.shape[0] 54 | w = img1_read.shape[1] 55 | img1_org = img1 56 | img2_org = img2 57 | tran = transforms.ToTensor() 58 | img1_org = tran(img1_org) 59 | img2_org = tran(img2_org) 60 | 61 | if use_gpu: 62 | img1_org = img1_org.cuda() 63 | img2_org = img2_org.cuda() 64 | else: 65 | img1_org = img1_org 66 | img2_org = img2_org 67 | img1_org = img1_org.unsqueeze(0) 68 | img2_org = img2_org.unsqueeze(0) 69 | 70 | model.eval() 71 | out = model(img1_org, img2_org ) 72 | 73 | d_map_1_4 = np.squeeze(out.detach().cpu().numpy()) 74 | 75 | decision_1_4 = (d_map_1_4 * 255).astype(np.uint8) 76 | 77 | 78 | imageio.imwrite( 79 | './result/result.bmp',decision_1_4) 80 | 81 | 82 | 83 | 84 | 85 | if __name__ == '__main__': 86 | 87 | fusion_gray() 88 | -------------------------------------------------------------------------------- /Train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | from tqdm import tqdm 5 | import pandas as pd 6 | 7 | import glob 8 | 9 | from collections import OrderedDict 10 | import torch 11 | import joblib 12 | import torch.backends.cudnn as cudnn 13 | 14 | import torch.optim as optim 15 | 16 | import torchvision.transforms as transforms 17 | from PIL import Image 18 | from torch.utils.data import DataLoader, Dataset 19 | from Networks.networks import MODEL as net 20 | 21 | from losses import CharbonnierLoss_IR,CharbonnierLoss_VI, tv_vi,tv_ir 22 | 23 | device = torch.device('cuda:0') 24 | use_gpu = torch.cuda.is_available() 25 | 26 | if use_gpu: 27 | print('GPU Mode Acitavted') 28 | else: 29 | print('CPU Mode Acitavted') 30 | 31 | 32 | def parse_args(): 33 | parser = argparse.ArgumentParser() 34 | 35 | parser.add_argument('--name', default='...', help='model name: (default: arch+timestamp)') 36 | parser.add_argument('--epochs', default=10, type=int) 37 | parser.add_argument('--batch_size', default=8, type=int) 38 | parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float) 39 | parser.add_argument('--weight', default=[0.03,1000,10,100], type=float) 40 | 41 | parser.add_argument('--gamma', default=0.9, type=float) 42 | parser.add_argument('--betas', default=(0.9, 0.999), type=tuple) 43 | parser.add_argument('--eps', default=1e-8, type=float) 44 | parser.add_argument('--weight-decay', default=5e-4, type=float) 45 | parser.add_argument('--num_queries', default=100, type=int, 46 | help="Number of query slots") 47 | parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), 48 | help="Type of positional embedding to use on top of the image features") 49 | parser.add_argument('--alpha', default=300, type=int, 50 | help='number of new channel increases per depth (default: 300)') 51 | args = parser.parse_args() 52 | 53 | return args 54 | 55 | 56 | def rotate(image, s): 57 | if s == 0: 58 | image = image 59 | if s == 1: 60 | HF = transforms.RandomHorizontalFlip(p=1) 61 | image = HF(image) 62 | if s == 2: 63 | VF = transforms.RandomVerticalFlip(p=1) 64 | image = VF(image) 65 | return image 66 | 67 | 68 | def color2gray(image, s): 69 | if s == 0: 70 | image = image 71 | if s ==1: 72 | l = image.convert('L') 73 | n = np.array(l) 74 | image = np.expand_dims(n, axis=2) 75 | image = np.concatenate((image, image, image), axis=-1) 76 | image = Image.fromarray(image).convert('RGB') 77 | return image 78 | 79 | 80 | class GetDataset(Dataset): 81 | def __init__(self, imageFolderDataset, transform=None): 82 | self.imageFolderDataset = imageFolderDataset 83 | self.transform = transform 84 | 85 | def __getitem__(self, index): 86 | 87 | 88 | ir =... 89 | vi = ... 90 | 91 | if self.transform is not None: 92 | tran = transforms.ToTensor() 93 | ir = tran(ir) 94 | 95 | vi = tran(vi) 96 | 97 | 98 | return ir,vi 99 | 100 | def __len__(self): 101 | return len(self.imageFolderDataset) 102 | 103 | 104 | class AverageMeter(object): 105 | 106 | def __init__(self): 107 | self.reset() 108 | 109 | def reset(self): 110 | self.val = 0 111 | self.avg = 0 112 | self.sum = 0 113 | self.count = 0 114 | 115 | def update(self, val, n=1): 116 | self.val = val 117 | self.sum += val * n 118 | self.count += n 119 | self.avg = self.sum / self.count 120 | 121 | 122 | def train(args, train_loader_ir,train_loader_vi, model, criterion_CharbonnierLoss_IR, criterion_CharbonnierLoss_VI, criterion_tv_ir, criterion_tv_vi,optimizer, epoch, scheduler=None): 123 | losses = AverageMeter() 124 | losses_CharbonnierLoss_IR = AverageMeter() 125 | losses_CharbonnierLoss_VI = AverageMeter() 126 | losses_tv_ir= AverageMeter() 127 | losses_tv_vi = AverageMeter() 128 | weight = args.weight 129 | model.train() 130 | 131 | for i, (ir,vi) in tqdm(enumerate(train_loader_ir), total=len(train_loader_ir)): 132 | 133 | if use_gpu: 134 | 135 | ir = ir.cuda() 136 | vi = vi.cuda() 137 | 138 | else: 139 | 140 | ir = ir 141 | vi = vi 142 | 143 | out = model(ir,vi) 144 | 145 | CharbonnierLoss_IR = weight[0] * criterion_CharbonnierLoss_IR(out, ir) 146 | CharbonnierLoss_VI = weight[1] * criterion_CharbonnierLoss_VI(out, vi) 147 | loss_tv_ir = weight[2] * criterion_tv_ir(out, ir) 148 | loss_tv_vi = weight[3] * criterion_tv_vi(out, vi) 149 | loss = CharbonnierLoss_IR +CharbonnierLoss_VI + loss_tv_ir + loss_tv_vi 150 | 151 | losses.update(loss.item(), ir.size(0)) 152 | losses_CharbonnierLoss_IR.update(CharbonnierLoss_IR.item(), ir.size(0)) 153 | losses_CharbonnierLoss_VI.update(CharbonnierLoss_VI.item(), ir.size(0)) 154 | losses_tv_ir.update(loss_tv_ir.item(), ir.size(0)) 155 | losses_tv_vi.update(loss_tv_vi.item(), ir.size(0)) 156 | 157 | optimizer.zero_grad() 158 | loss.backward() 159 | optimizer.step() 160 | 161 | log = OrderedDict([ 162 | ('loss', losses.avg), 163 | ('CharbonnierLoss_IR', losses_CharbonnierLoss_IR.avg), 164 | ('CharbonnierLoss_VI', losses_CharbonnierLoss_VI.avg), 165 | ('loss_tv_ir', losses_tv_ir.avg), 166 | ('loss_tv_vi', losses_tv_vi.avg), 167 | ]) 168 | return log 169 | 170 | 171 | def main(): 172 | args = parse_args() 173 | 174 | if not os.path.exists('models/%s' %args.name): 175 | os.makedirs('models/%s' %args.name) 176 | 177 | print('Config -----') 178 | for arg in vars(args): 179 | print('%s: %s' %(arg, getattr(args, arg))) 180 | print('------------') 181 | 182 | with open('models/%s/args.txt' %args.name, 'w') as f: 183 | for arg in vars(args): 184 | print('%s: %s' %(arg, getattr(args, arg)), file=f) 185 | 186 | joblib.dump(args, 'models/%s/args.pkl' %args.name) 187 | cudnn.benchmark = True 188 | 189 | training_dir_ir = ... 190 | folder_dataset_train_ir = glob.glob(training_dir_ir + "*.bmp") 191 | training_dir_vi =... 192 | 193 | folder_dataset_train_vi = glob.glob(training_dir_vi + "*.bmp") 194 | 195 | transform_train = transforms.Compose([transforms.ToTensor(), 196 | transforms.Normalize((0.485, 0.456, 0.406), 197 | (0.229, 0.224, 0.225)) 198 | ]) 199 | 200 | dataset_train_ir = GetDataset(imageFolderDataset=folder_dataset_train_ir, 201 | transform=transform_train) 202 | dataset_train_vi = GetDataset(imageFolderDataset=folder_dataset_train_vi, 203 | transform=transform_train) 204 | 205 | train_loader_ir = DataLoader(dataset_train_ir, 206 | shuffle=True, 207 | batch_size=args.batch_size) 208 | train_loader_vi = DataLoader(dataset_train_vi, 209 | shuffle=True, 210 | batch_size=args.batch_size) 211 | model = net() 212 | if use_gpu: 213 | model = model.cuda() 214 | model.cuda() 215 | 216 | else: 217 | model = model 218 | criterion_CharbonnierLoss_IR = CharbonnierLoss_IR 219 | criterion_CharbonnierLoss_VI = CharbonnierLoss_VI 220 | criterion_tv_ir = tv_ir 221 | criterion_tv_vi = tv_vi 222 | optimizer = optim.Adam(model.parameters(), lr=args.lr, 223 | betas=args.betas, eps=args.eps) 224 | 225 | log = pd.DataFrame(index=[], 226 | columns=['epoch', 227 | 228 | 'loss', 229 | 'CharbonnierLoss_IR', 230 | 'CharbonnierLoss_VI', 231 | 'loss_tv_ir', 232 | 'loss_tv_vi', 233 | ]) 234 | 235 | for epoch in range(args.epochs): 236 | print('Epoch [%d/%d]' % (epoch + 1, args.epochs)) 237 | 238 | train_log = train(args, train_loader_ir, train_loader_vi, model, criterion_CharbonnierLoss_IR, criterion_CharbonnierLoss_VI, 239 | criterion_tv_ir, criterion_tv_vi, optimizer, epoch) # 训练集 240 | 241 | 242 | print('loss: %.4f - CharbonnierLoss_IR: %.4f -CharbonnierLoss_VI: %.4f - loss_tv_ir: %.4f - loss_tv_vi: %.4f ' 243 | % (train_log['loss'], 244 | train_log['CharbonnierLoss_IR'], 245 | train_log['CharbonnierLoss_VI'], 246 | train_log['loss_tv_ir'], 247 | train_log['loss_tv_vi'], 248 | 249 | )) 250 | 251 | tmp = pd.Series([ 252 | epoch + 1, 253 | 254 | train_log['loss'], 255 | train_log['CharbonnierLoss_IR'], 256 | train_log['CharbonnierLoss_VI'], 257 | train_log['loss_tv_ir'], 258 | train_log['loss_tv_vi'], 259 | 260 | ], index=['epoch', 'loss', 'CharbonnierLoss_IR', 'CharbonnierLoss_VI', 'loss_tv_ir', 'loss_tv_vi']) 261 | 262 | log = log.append(tmp, ignore_index=True) 263 | log.to_csv('models/%s/log.csv' %args.name, index=False) 264 | 265 | 266 | if (epoch+1) % 1 == 0: 267 | torch.save(model.state_dict(), 'models/%s/model_{}.pth'.format(epoch+1) %args.name) 268 | 269 | 270 | if __name__ == '__main__': 271 | main() 272 | 273 | 274 | -------------------------------------------------------------------------------- /images/ESIhighlyCitedPaper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tthinking/FATFusion/9873c347c83198f137c27084f87463f6f06b838a/images/ESIhighlyCitedPaper.png -------------------------------------------------------------------------------- /images/MRI.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tthinking/FATFusion/9873c347c83198f137c27084f87463f6f06b838a/images/MRI.bmp -------------------------------------------------------------------------------- /images/SPECT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tthinking/FATFusion/9873c347c83198f137c27084f87463f6f06b838a/images/SPECT.bmp -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | 5 | import torch.nn as nn 6 | 7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 8 | 9 | 10 | class TVLoss(nn.Module): 11 | def __init__(self,TVLoss_weight=1): 12 | super(TVLoss,self).__init__() 13 | self.TVLoss_weight = TVLoss_weight 14 | 15 | def forward(self,x): 16 | batch_size = x.size()[0] 17 | h_x = x.size()[2] 18 | w_x = x.size()[3] 19 | count_h = self._tensor_size(x[:,:,1:,:]) 20 | count_w = self._tensor_size(x[:,:,:,1:]) 21 | h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum() 22 | w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum() 23 | return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size 24 | 25 | def _tensor_size(self,t): 26 | return t.size()[1]*t.size()[2]*t.size()[3] 27 | 28 | def tv_loss(x): 29 | tv_loss =TVLoss().to(device) 30 | tv_loss = tv_loss(x) 31 | 32 | return tv_loss 33 | 34 | def tv_vi (fused_result,input_vi ): 35 | tv_vi=torch.norm((tv_loss(fused_result)-tv_loss(input_vi)),1) 36 | 37 | return tv_vi 38 | 39 | 40 | def tv_ir (fused_result,input_ir): 41 | tv_r=torch.norm((tv_loss(fused_result)-tv_loss(input_ir)),1) 42 | 43 | return tv_r 44 | 45 | def CharbonnierLoss_IR(f,ir): 46 | eps = 1e-3 47 | loss=torch.mean(torch.sqrt((f-ir)**2+eps**2)) 48 | return loss 49 | 50 | def CharbonnierLoss_VI(f,vi): 51 | eps = 1e-3 52 | loss=torch.mean(torch.sqrt((f-vi)**2+eps**2)) 53 | return loss 54 | 55 | --------------------------------------------------------------------------------