├── Networks └── networks.py ├── README.md ├── Test.py ├── Train.py ├── images ├── ESIhighlyCitedPaper.png ├── MRI.bmp └── SPECT.bmp └── losses └── __init__.py /Networks/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | 8 | def upsample(src, tar): 9 | src = F.interpolate(src, size=tar.shape[2:], mode='bilinear', align_corners=True) 10 | return src 11 | class Conv1(nn.Module): 12 | def __init__(self, in_channel, out_channel): 13 | super(Conv1, self).__init__() 14 | 15 | self.conv=nn.Conv2d(in_channel,out_channel,kernel_size=1,padding=0,stride=1) 16 | self.bn=nn.BatchNorm2d(out_channel) 17 | self.relu=nn.LeakyReLU(inplace=True) 18 | 19 | def forward(self, input): 20 | 21 | out=self.conv(input) 22 | out=self.bn(out) 23 | out=self.relu(out) 24 | return out 25 | class Convlutioanl(nn.Module): 26 | def __init__(self, in_channel, out_channel): 27 | super(Convlutioanl, self).__init__() 28 | self.padding=(1,1,1,1) 29 | self.conv=nn.Conv2d(in_channel,out_channel,kernel_size=3,padding=0,stride=1) 30 | self.bn=nn.BatchNorm2d(out_channel) 31 | self.relu=nn.LeakyReLU(inplace=True) 32 | 33 | def forward(self, input): 34 | out=F.pad(input,self.padding,'replicate') 35 | out=self.conv(out) 36 | out=self.bn(out) 37 | out=self.relu(out) 38 | return out 39 | 40 | class Conv5_5(nn.Module): 41 | def __init__(self, in_channel, out_channel): 42 | super(Conv5_5, self).__init__() 43 | self.padding=(2,2,2,2) 44 | self.conv=nn.Conv2d(in_channel,out_channel,kernel_size=5,padding=0,stride=1) 45 | self.bn=nn.BatchNorm2d(out_channel) 46 | self.relu=nn.LeakyReLU(inplace=True) 47 | 48 | def forward(self, input): 49 | out=F.pad(input,self.padding,'replicate') 50 | out=self.conv(out) 51 | out=self.bn(out) 52 | out=self.relu(out) 53 | return out 54 | 55 | class Conv7_7(nn.Module): 56 | def __init__(self, in_channel, out_channel): 57 | super(Conv7_7, self).__init__() 58 | self.padding=(3,3,3,3) 59 | self.conv=nn.Conv2d(in_channel,out_channel,kernel_size=7,padding=0,stride=1) 60 | self.bn=nn.BatchNorm2d(out_channel) 61 | self.relu=nn.LeakyReLU(inplace=True) 62 | 63 | def forward(self, input): 64 | out=F.pad(input,self.padding,'replicate') 65 | out=self.conv(out) 66 | out=self.bn(out) 67 | out=self.relu(out) 68 | return out 69 | 70 | class Convlutioanl_out(nn.Module): 71 | def __init__(self, in_channel, out_channel): 72 | super(Convlutioanl_out, self).__init__() 73 | # self.padding=(2,2,2,2) 74 | self.conv=nn.Conv2d(in_channel,out_channel,kernel_size=1,padding=0,stride=1) 75 | # self.bn=nn.BatchNorm2d(out_channel) 76 | self.sigmoid=nn.Sigmoid() 77 | 78 | def forward(self, input): 79 | # out=F.pad(input,self.padding,'replicate') 80 | out=self.conv(input) 81 | # out=self.bn(out) 82 | out=self.sigmoid(out) 83 | return out 84 | 85 | class WindowAttention(nn.Module): 86 | 87 | 88 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 89 | 90 | super().__init__() 91 | self.dim = dim 92 | self.window_size = window_size 93 | self.num_heads = num_heads 94 | head_dim = dim // num_heads 95 | self.scale = qk_scale or head_dim ** -0.5 96 | 97 | self.relative_position_bias_table = nn.Parameter( 98 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) 99 | coords_h = torch.arange(self.window_size[0]) 100 | coords_w = torch.arange(self.window_size[1]) 101 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) 102 | coords_flatten = torch.flatten(coords, 1) 103 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] 104 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() 105 | relative_coords[:, :, 0] += self.window_size[0] - 1 106 | relative_coords[:, :, 1] += self.window_size[1] - 1 107 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 108 | relative_position_index = relative_coords.sum(-1) 109 | self.register_buffer("relative_position_index", relative_position_index) 110 | 111 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 112 | self.attn_drop = nn.Dropout(attn_drop) 113 | self.proj = nn.Linear(dim, dim) 114 | 115 | self.proj_drop = nn.Dropout(proj_drop) 116 | 117 | trunc_normal_(self.relative_position_bias_table, std=.02) 118 | self.softmax = nn.Softmax(dim=-1) 119 | 120 | def forward(self, x, mask=None): 121 | 122 | B_, N, C = x.shape 123 | 124 | A=self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads) 125 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 126 | q, k, v = qkv[0], qkv[1], qkv[2] 127 | 128 | q = q * self.scale 129 | attn = (q @ k.transpose(-2, -1)) 130 | 131 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 132 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) 133 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() 134 | attn = attn + relative_position_bias.unsqueeze(0) 135 | 136 | if mask is not None: 137 | nW = mask.shape[0] 138 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 139 | attn = attn.view(-1, self.num_heads, N, N) 140 | attn = self.softmax(attn) 141 | else: 142 | attn = self.softmax(attn) 143 | 144 | attn = self.attn_drop(attn) 145 | 146 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 147 | x = self.proj(x) 148 | x = self.proj_drop(x) 149 | return x 150 | 151 | def extra_repr(self) -> str: 152 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 153 | 154 | def flops(self, N): 155 | 156 | flops = 0 157 | 158 | flops += N * self.dim * 3 * self.dim 159 | 160 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 161 | 162 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 163 | 164 | flops += N * self.dim * self.dim 165 | return flops 166 | 167 | 168 | class Mlp(nn.Module): 169 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 170 | super().__init__() 171 | out_features = out_features or in_features 172 | hidden_features = hidden_features or in_features 173 | self.fc1 = nn.Linear(in_features, hidden_features) 174 | self.act = act_layer() 175 | self.fc2 = nn.Linear(hidden_features, out_features) 176 | self.drop = nn.Dropout(drop) 177 | 178 | def forward(self, x): 179 | x = self.fc1(x) 180 | x = self.act(x) 181 | x = self.drop(x) 182 | x = self.fc2(x) 183 | x = self.drop(x) 184 | return x 185 | 186 | 187 | 188 | def window_partition(x, window_size): 189 | 190 | B, H, W, C = x.shape 191 | 192 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 193 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 194 | return windows 195 | 196 | 197 | def window_reverse(windows, window_size, H, W): 198 | 199 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 200 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 201 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 202 | return x 203 | 204 | class SwinTransformerBlock(nn.Module): 205 | 206 | 207 | def __init__(self, dim, input_resolution, num_heads, window_size=1, shift_size=0, 208 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 209 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 210 | super().__init__() 211 | self.dim = dim 212 | self.input_resolution = input_resolution 213 | self.num_heads = num_heads 214 | self.window_size = window_size 215 | self.shift_size = shift_size 216 | self.mlp_ratio = mlp_ratio 217 | if min(self.input_resolution) <= self.window_size: 218 | 219 | self.shift_size = 0 220 | self.window_size = min(self.input_resolution) 221 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 222 | 223 | self.norm1 = norm_layer(dim) 224 | self.attn = WindowAttention( 225 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 226 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 227 | 228 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 229 | self.norm2 = norm_layer(dim) 230 | mlp_hidden_dim = int(dim * mlp_ratio) 231 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 232 | 233 | if self.shift_size > 0: 234 | attn_mask = self.calculate_mask(self.input_resolution) 235 | else: 236 | attn_mask = None 237 | 238 | self.register_buffer("attn_mask", attn_mask) 239 | 240 | def calculate_mask(self, x_size): 241 | 242 | H, W = x_size 243 | img_mask = torch.zeros((1, H, W, 1)) 244 | h_slices = (slice(0, -self.window_size), 245 | slice(-self.window_size, -self.shift_size), 246 | slice(-self.shift_size, None)) 247 | w_slices = (slice(0, -self.window_size), 248 | slice(-self.window_size, -self.shift_size), 249 | slice(-self.shift_size, None)) 250 | cnt = 0 251 | for h in h_slices: 252 | for w in w_slices: 253 | img_mask[:, h, w, :] = cnt 254 | cnt += 1 255 | 256 | mask_windows = window_partition(img_mask, self.window_size) 257 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 258 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 259 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 260 | 261 | return attn_mask 262 | 263 | def forward(self, x, x_size): 264 | H, W = x_size 265 | 266 | B,C,H,W= x.shape 267 | 268 | x=x.view(B,H,W,C) 269 | shortcut = x 270 | shape=x.view(H*W*B,C) 271 | x = self.norm1(shape) 272 | x = x.view(B, H, W, C) 273 | 274 | if self.shift_size > 0: 275 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 276 | else: 277 | shifted_x = x 278 | 279 | x_windows = window_partition(shifted_x, self.window_size) 280 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) 281 | if self.input_resolution == x_size: 282 | attn_windows = self.attn(x_windows, mask=self.attn_mask) 283 | else: 284 | attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) 285 | 286 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 287 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) 288 | if self.shift_size > 0: 289 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 290 | else: 291 | x = shifted_x 292 | 293 | 294 | x = shortcut + self.drop_path(x) 295 | x = x + self.drop_path(self.mlp(self.norm2(x))) 296 | B,H,W,C=x.shape 297 | x=x.view(B,C,H,W) 298 | 299 | 300 | return x 301 | 302 | def extra_repr(self) -> str: 303 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 304 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 305 | 306 | def flops(self): 307 | flops = 0 308 | H, W = self.input_resolution 309 | 310 | flops += self.dim * H * W 311 | 312 | nW = H * W / self.window_size / self.window_size 313 | flops += nW * self.attn.flops(self.window_size * self.window_size) 314 | 315 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 316 | 317 | flops += self.dim * H * W 318 | return flops 319 | 320 | 321 | 322 | class PatchEmbed(nn.Module): 323 | 324 | 325 | def __init__(self, img_size=120, patch_size=4, in_chans=6, embed_dim=96, norm_layer=None): 326 | super().__init__() 327 | img_size = to_2tuple(img_size) 328 | patch_size = to_2tuple(patch_size) 329 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 330 | self.img_size = img_size 331 | self.patch_size = patch_size 332 | self.patches_resolution = patches_resolution 333 | self.num_patches = patches_resolution[0] * patches_resolution[1] 334 | 335 | self.in_chans = in_chans 336 | self.embed_dim = embed_dim 337 | 338 | if norm_layer is not None: 339 | self.norm = norm_layer(embed_dim) 340 | else: 341 | self.norm = None 342 | 343 | def forward(self, x): 344 | x = x.flatten(2).transpose(1, 2) 345 | if self.norm is not None: 346 | x = self.norm(x) 347 | return x 348 | 349 | def flops(self): 350 | flops = 0 351 | H, W = self.img_size 352 | if self.norm is not None: 353 | flops += H * W * self.embed_dim 354 | return flops 355 | 356 | class BasicLayer(nn.Module): 357 | 358 | 359 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 360 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 361 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 362 | 363 | super().__init__() 364 | self.dim = dim 365 | self.input_resolution = input_resolution 366 | self.depth = depth 367 | self.use_checkpoint = use_checkpoint 368 | 369 | 370 | self.blocks = nn.ModuleList([ 371 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 372 | num_heads=num_heads, window_size=window_size, 373 | shift_size=0 if (i % 2 == 0) else window_size // 2, 374 | mlp_ratio=mlp_ratio, 375 | qkv_bias=qkv_bias, qk_scale=qk_scale, 376 | drop=drop, attn_drop=attn_drop, 377 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 378 | norm_layer=norm_layer) 379 | for i in range(depth)]) 380 | 381 | if downsample is not None: 382 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 383 | else: 384 | self.downsample = None 385 | 386 | def forward(self, x, x_size): 387 | for blk in self.blocks: 388 | 389 | x = blk(x, x_size) 390 | if self.downsample is not None: 391 | x = self.downsample(x) 392 | return x 393 | 394 | def extra_repr(self) -> str: 395 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 396 | 397 | def flops(self): 398 | flops = 0 399 | for blk in self.blocks: 400 | flops += blk.flops() 401 | if self.downsample is not None: 402 | flops += self.downsample.flops() 403 | return flops 404 | 405 | class Channel_attention(nn.Module): 406 | def __init__(self, channel, reduction): 407 | super(Channel_attention, self).__init__() 408 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 409 | self.fc=nn.Sequential( 410 | nn.Conv2d(channel,channel//reduction,1), 411 | nn.ReLU(inplace=True), 412 | nn.Conv2d(channel//reduction,channel,1)) 413 | self.sigmoid=nn.Sigmoid() 414 | def forward(self, input): 415 | out=self.avg_pool(input) 416 | out=self.fc(out) 417 | out=self.sigmoid(out) 418 | return out 419 | 420 | class MODEL(nn.Module): 421 | def __init__(self, in_channel=1,out_channel_16=16,out_channel_256=256, out_channel_32=32,out_channel=64,out_channel_512=512,output_channel=1,out_channel_128=128, 422 | img_size=120, patch_size=4, embed_dim=96, num_heads=8, window_size=1,out_channel_448=448,out_channel_896=896,out_channel_336=336, 423 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., patch_norm=True, depth=2, 424 | downsample=None, 425 | drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False 426 | ): 427 | super(MODEL, self).__init__() 428 | 429 | self.convInput= Convlutioanl(in_channel, out_channel_16) 430 | self.conv5 = Conv5_5(out_channel_16, out_channel) 431 | self.conv7 = Conv7_7( out_channel, out_channel_256) 432 | self.conv64 = Convlutioanl(out_channel, out_channel_256) 433 | self.conv = Convlutioanl(out_channel*2, out_channel) 434 | self.convolutional_out =Convlutioanl_out(out_channel_32, output_channel) 435 | self.conv16_16= Convlutioanl(out_channel_16, out_channel_16) 436 | self.conv64_16 = Convlutioanl(out_channel, out_channel_16) 437 | self.conv256_16 = Convlutioanl(out_channel_256, out_channel_16) 438 | 439 | self.cam1 = Channel_attention(out_channel_16,4) 440 | self.cam2 = Channel_attention(out_channel,8) 441 | self.cam3 = Channel_attention(out_channel_256,16) 442 | 443 | 444 | self.conv1=Conv1( out_channel_32, out_channel_16) 445 | 446 | self.conv2=Conv1( out_channel_128, out_channel) 447 | self.conv3 = Conv1(out_channel_512, out_channel_256) 448 | self.conv4 = Conv1(out_channel_336, out_channel_16) 449 | self.patch_norm = patch_norm 450 | 451 | self.patch_embed = PatchEmbed( 452 | img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, 453 | norm_layer=norm_layer if self.patch_norm else None) 454 | num_patches = self.patch_embed.num_patches 455 | patches_resolution = self.patch_embed.patches_resolution 456 | self.patches_resolution = patches_resolution 457 | self.basicLayer1 = BasicLayer(dim=out_channel_16, 458 | input_resolution=(patches_resolution[0], patches_resolution[1]), 459 | depth=depth, 460 | num_heads=num_heads, 461 | window_size=window_size, 462 | mlp_ratio=mlp_ratio, 463 | qkv_bias=qkv_bias, qk_scale=qk_scale, 464 | drop=drop, attn_drop=attn_drop, 465 | drop_path=drop_path, 466 | norm_layer=norm_layer, 467 | downsample=downsample, 468 | use_checkpoint=use_checkpoint) 469 | 470 | self.basicLayer2 = BasicLayer(dim=out_channel, 471 | input_resolution=(patches_resolution[0], patches_resolution[1]), 472 | depth=depth, 473 | num_heads=num_heads, 474 | window_size=window_size, 475 | mlp_ratio=mlp_ratio, 476 | qkv_bias=qkv_bias, qk_scale=qk_scale, 477 | drop=drop, attn_drop=attn_drop, 478 | drop_path=drop_path, 479 | norm_layer=norm_layer, 480 | downsample=downsample, 481 | use_checkpoint=use_checkpoint) 482 | 483 | self.basicLayer3 = BasicLayer(dim=out_channel_256, 484 | input_resolution=(patches_resolution[0], patches_resolution[1]), 485 | depth=depth, 486 | num_heads=num_heads, 487 | window_size=window_size, 488 | mlp_ratio=mlp_ratio, 489 | qkv_bias=qkv_bias, qk_scale=qk_scale, 490 | drop=drop, attn_drop=attn_drop, 491 | drop_path=drop_path, 492 | norm_layer=norm_layer, 493 | downsample=downsample, 494 | use_checkpoint=use_checkpoint) 495 | 496 | self.basicLayer4 = BasicLayer(dim=out_channel_448, 497 | input_resolution=(patches_resolution[0], patches_resolution[1]), 498 | depth=depth, 499 | num_heads=num_heads, 500 | window_size=window_size, 501 | mlp_ratio=mlp_ratio, 502 | qkv_bias=qkv_bias, qk_scale=qk_scale, 503 | drop=drop, attn_drop=attn_drop, 504 | drop_path=drop_path, 505 | norm_layer=norm_layer, 506 | downsample=downsample, 507 | use_checkpoint=use_checkpoint) 508 | 509 | def forward(self, ir,vi): 510 | 511 | convInput_A1 = self.convInput(ir) 512 | convInput_A2 = self.conv5(convInput_A1 ) 513 | convInput_A3 = self.conv7(convInput_A2) 514 | 515 | 516 | 517 | convInput_B1 = self.convInput(vi) 518 | convInput_B2 = self.conv5(convInput_B1) 519 | convInput_B3 = self.conv7(convInput_B2) 520 | 521 | camA1=self.cam1(convInput_A1) 522 | camA1_1 = torch.cat((camA1*convInput_A1, convInput_B1),1) 523 | convA1=self.conv1( camA1_1) 524 | camA1_2=self.cam1(convA1)* convA1 525 | 526 | encode_sizeA1 = ( camA1_2.shape[2], camA1_2.shape[3]) 527 | TransformerA1 = self.basicLayer1( camA1_2, encode_sizeA1) 528 | 529 | camB1 = self.cam1(convInput_B1) 530 | camB1_1 = torch.cat((camB1 * convInput_B1, convInput_A1), 1) 531 | convB1 = self.conv1(camB1_1) 532 | camB1_2 = self.cam1(convB1) * convB1 533 | 534 | encode_sizeB1 = (camB1_2.shape[2], camB1_2.shape[3]) 535 | TransformerB1 = self.basicLayer1(camB1_2, encode_sizeB1) 536 | 537 | camA2= self.cam2(convInput_A2) 538 | 539 | camA2_1 = torch.cat((camA2 * convInput_A2, convInput_B2), 1) 540 | convA2 = self.conv2(camA2_1) 541 | camA2_2 = self.cam2(convA2) * convA2 542 | 543 | encode_sizeA2 = (camA2_2.shape[2], camA2_2.shape[3]) 544 | TransformerA2 = self.basicLayer2(camA2_2, encode_sizeA2) 545 | 546 | camB2= self.cam2(convInput_B2) 547 | camB2_1 = torch.cat((camB2* convInput_B2, convInput_A2), 1) 548 | convB2= self.conv2(camB2_1) 549 | camB2_2 = self.cam2(convB2) * convB2 550 | 551 | encode_sizeB2= (camB2_2.shape[2], camB2_2.shape[3]) 552 | TransformerB2 = self.basicLayer2(camB2_2, encode_sizeB2) 553 | 554 | camA3=self.cam3(convInput_A3) 555 | camA3_1 = torch.cat(( camA3* convInput_A3, convInput_B3), 1) 556 | convA3 = self.conv3(camA3_1) 557 | camA3_2 = self.cam3(convA3) * convA3 558 | 559 | encode_sizeA3 = (camA3_2.shape[2], camA3_2.shape[3]) 560 | TransformerA3 = self.basicLayer3(camA3_2, encode_sizeA3) 561 | 562 | camB3=self.cam3(convInput_B3) 563 | camB3_1 = torch.cat((camB3 * convInput_B3, convInput_A3), 1) 564 | convB3 = self.conv3(camB3_1) 565 | camB3_2 = self.cam3(convB3) * convB3 566 | 567 | encode_sizeB3 = (camB3_2.shape[2], camB3_2.shape[3]) 568 | TransformerB3 = self.basicLayer3(camB3_2, encode_sizeB3) 569 | 570 | 571 | catA=torch.cat(( TransformerA1, TransformerA2, TransformerA3),1) 572 | convA4=self.conv4( catA) 573 | 574 | catB = torch.cat((TransformerB1, TransformerB2, TransformerB3), 1) 575 | 576 | convB4 = self.conv4(catB) 577 | 578 | camA4=self.cam1( convA4) 579 | camA4_1 = torch.cat((camA4 * convA4, convB4 ), 1) 580 | convA5 = self.conv1(camA4_1) 581 | camA4_2 = self.cam1(convA5) * convA5 582 | 583 | encode_sizeA4 = (camA4_2.shape[2], camA4_2.shape[3]) 584 | TransformerA4 = self.basicLayer1(camA4_2, encode_sizeA4) 585 | 586 | camB4 =self.cam1(convB4) 587 | camB4_1 = torch.cat(( camB4* convB4, convA4), 1) 588 | convB5 = self.conv1(camB4_1) 589 | camB4_2 = self.cam1(convB5) * convB5 590 | 591 | encode_sizeB4 = (camB4_2.shape[2], camB4_2.shape[3]) 592 | TransformerB4 = self.basicLayer1(camB4_2, encode_sizeB4) 593 | 594 | 595 | 596 | cat=torch.cat(( TransformerA4 , TransformerB4 ), 1) 597 | 598 | 599 | out = self.convolutional_out(cat ) 600 | return out 601 | 602 | 603 | 604 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FATFusion: A Functional–Anatomical Transformer for Medical Image Fusion (IPM 2024). 2 | 3 | This is the official implementation of the FATFusion model proposed in the paper ([FATFusion: A Functional–Anatomical Transformer for Medical Image Fusion](https://www.sciencedirect.com/science/article/pii/S0306457324000475)) with Pytorch. 4 | 5 |
@ARTICLE{Tang_2024_FATFusion,
37 | author={Tang, Wei and He, Fazhi},
38 | journal={Information Processing & Management},
39 | title={FATFusion: A Functional–Anatomical Transformer for Medical Image Fusion},
40 | year={2024},
41 | volume={61},
42 | number={4},
43 | pages={103687},
44 | doi={10.1016/j.ipm.2024.103687}}
45 |