├── README.md ├── DDConv.py ├── DDConv_3D.py ├── TEC_Net_T.py └── TEC_Net_T_3D.py /README.md: -------------------------------------------------------------------------------- 1 | # TEC-Net: Vision Transformer Embrace Convolutional Neural Networks for Medical Image Segmentation 2 | 3 | 4 | * **Paper Link:** [Read the Paper](https://arxiv.org/abs/2306.04086) 5 | * **Authors:** Rui Sun, Tao Lei, Weichuan Zhang, Yong Wan, Yong Xia, Asoke K. Nandi 6 | -------------------------------------------------------------------------------- /DDConv.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dynamic Deformable Convolution 3 | """ 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from thop import * 8 | import functools 9 | import torch 10 | from torch import nn 11 | import torch.nn.functional as F 12 | from torch.nn.modules.conv import _ConvNd 13 | from torch.nn.modules.utils import _pair 14 | from torch.nn.parameter import Parameter 15 | 16 | device = torch.device("cpu" ) 17 | 18 | class DDConv(nn.Module): 19 | def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=False): 20 | super(DDConv, self).__init__() 21 | self.kernel_size = kernel_size 22 | self.padding = padding 23 | self.stride = stride 24 | self.zero_padding = nn.ZeroPad2d(padding) 25 | self.conv = SConv2D(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias) 26 | 27 | self.p_conv = SConv2D(inc, 2*kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride) 28 | nn.init.constant_(self.p_conv.weight, 0) 29 | self.p_conv.register_backward_hook(self._set_lr) 30 | 31 | self.modulation = modulation 32 | if modulation: 33 | self.m_conv = SConv2D(inc, kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride) 34 | nn.init.constant_(self.m_conv.weight, 0) 35 | self.m_conv.register_backward_hook(self._set_lr) 36 | 37 | @staticmethod 38 | def _set_lr(module, grad_input, grad_output): 39 | grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input))) 40 | grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output))) 41 | 42 | def forward(self, x): 43 | offset = self.p_conv(x) 44 | if self.modulation: 45 | m = torch.sigmoid(self.m_conv(x)) 46 | 47 | dtype = offset.data.type() 48 | ks = self.kernel_size 49 | N = offset.size(1) // 2 50 | 51 | if self.padding: 52 | x = self.zero_padding(x) 53 | 54 | p = self._get_p(offset, dtype).to(device) 55 | 56 | p = p.contiguous().permute(0, 2, 3, 1) 57 | q_lt = p.detach().floor() 58 | q_rb = q_lt + 1 59 | 60 | q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long() 61 | q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long() 62 | q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1) 63 | q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1) 64 | 65 | p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1) 66 | 67 | g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:])) 68 | g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:])) 69 | g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:])) 70 | g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:])) 71 | 72 | x_q_lt = self._get_x_q(x, q_lt, N) 73 | x_q_rb = self._get_x_q(x, q_rb, N) 74 | x_q_lb = self._get_x_q(x, q_lb, N) 75 | x_q_rt = self._get_x_q(x, q_rt, N) 76 | 77 | x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \ 78 | g_rb.unsqueeze(dim=1) * x_q_rb + \ 79 | g_lb.unsqueeze(dim=1) * x_q_lb + \ 80 | g_rt.unsqueeze(dim=1) * x_q_rt 81 | 82 | if self.modulation: 83 | m = m.contiguous().permute(0, 2, 3, 1) 84 | m = m.unsqueeze(dim=1) 85 | m = torch.cat([m for _ in range(x_offset.size(1))], dim=1) 86 | x_offset *= m 87 | 88 | x_offset = self._reshape_x_offset(x_offset, ks) 89 | out = self.conv(x_offset) 90 | 91 | return out 92 | 93 | def _get_p_n(self, N, dtype): 94 | p_n_x, p_n_y = torch.meshgrid( 95 | torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1), 96 | torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1)) 97 | p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0) 98 | p_n = p_n.view(1, 2*N, 1, 1).type(dtype) 99 | 100 | return p_n 101 | 102 | def _get_p_0(self, h, w, N, dtype): 103 | p_0_x, p_0_y = torch.meshgrid( 104 | torch.arange(1, h*self.stride+1, self.stride), 105 | torch.arange(1, w*self.stride+1, self.stride)) 106 | p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1) 107 | p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1) 108 | p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype) 109 | 110 | return p_0 111 | 112 | def _get_p(self, offset, dtype): 113 | N, h, w = offset.size(1)//2, offset.size(2), offset.size(3) 114 | 115 | p_n = self._get_p_n(N, dtype).to(device) 116 | p_0 = self._get_p_0(h, w, N, dtype).to(device) 117 | p = p_0 + p_n + offset 118 | return p 119 | 120 | def _get_x_q(self, x, q, N): 121 | b, h, w, _ = q.size() 122 | padded_w = x.size(3) 123 | c = x.size(1) 124 | x = x.contiguous().view(b, c, -1) 125 | index = q[..., :N]*padded_w + q[..., N:] 126 | index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1) 127 | x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N) 128 | return x_offset 129 | 130 | @staticmethod 131 | def _reshape_x_offset(x_offset, ks): 132 | b, c, h, w, N = x_offset.size() 133 | x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1) 134 | x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks) 135 | return x_offset 136 | 137 | class _routing(nn.Module): 138 | 139 | def __init__(self, in_channels, num_experts, dropout_rate): 140 | super(_routing, self).__init__() 141 | 142 | self.dropout = nn.Dropout(dropout_rate) 143 | self.fc = nn.Linear(in_channels, num_experts) 144 | 145 | def forward(self, x): 146 | x = torch.flatten(x) 147 | x = self.dropout(x) 148 | x = self.fc(x) 149 | return F.sigmoid(x) 150 | 151 | 152 | class SConv2D(_ConvNd): 153 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, 154 | bias=True, padding_mode='zeros', num_experts=8, dropout_rate=0.2): 155 | kernel_size = _pair(kernel_size) 156 | stride = _pair(stride) 157 | padding = _pair(padding) 158 | dilation = _pair(dilation) 159 | super(SConv2D, self).__init__( 160 | in_channels, out_channels, kernel_size, stride, padding, dilation, 161 | False, _pair(0), groups, bias, padding_mode) 162 | 163 | self._avg_pooling = functools.partial(F.adaptive_avg_pool2d, output_size=(1, 1)) 164 | self._routing_fn = _routing(in_channels, num_experts, dropout_rate) 165 | 166 | self.weight = Parameter(torch.Tensor( 167 | num_experts, out_channels, in_channels // groups, *kernel_size)) 168 | 169 | self.reset_parameters() 170 | 171 | def _conv_forward(self, input, weight): 172 | if self.padding_mode != 'zeros': 173 | return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode), 174 | weight, self.bias, self.stride, 175 | _pair(0), self.dilation, self.groups) 176 | return F.conv2d(input, weight, self.bias, self.stride, 177 | self.padding, self.dilation, self.groups) 178 | 179 | def forward(self, inputs): 180 | b, _, _, _ = inputs.size() 181 | res = [] 182 | for input in inputs: 183 | input = input.unsqueeze(0) 184 | pooled_inputs = self._avg_pooling(input) 185 | routing_weights = self._routing_fn(pooled_inputs) 186 | kernels = torch.sum(routing_weights[:, None, None, None, None] * self.weight, 0) 187 | out = self._conv_forward(input, kernels) 188 | res.append(out) 189 | return torch.cat(res, dim=0) 190 | -------------------------------------------------------------------------------- /DDConv_3D.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dynamic Deformable Convolution 3 | """ 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from thop import * 8 | import functools 9 | import torch 10 | from torch import nn 11 | import torch.nn.functional as F 12 | from torch.nn.modules.conv import _ConvNd 13 | from torch.nn.modules.utils import _pair,_triple 14 | from torch.nn.parameter import Parameter 15 | 16 | device = torch.device("cpu" ) 17 | 18 | class DDConv_3D(nn.Module): 19 | def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=False): 20 | super(DDConv_3D, self).__init__() 21 | self.kernel_size = kernel_size 22 | self.padding = padding 23 | self.stride = stride 24 | self.zero_padding = nn.ConstantPad3d(padding,value=0) 25 | self.conv = SConv3D(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias) 26 | 27 | self.p_conv = SConv3D(inc, 3*kernel_size*kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride) 28 | nn.init.constant_(self.p_conv.weight, 0) 29 | self.p_conv.register_backward_hook(self._set_lr) 30 | 31 | self.modulation = modulation 32 | if modulation: 33 | self.m_conv = SConv3D(inc, kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride) 34 | nn.init.constant_(self.m_conv.weight, 0) 35 | self.m_conv.register_backward_hook(self._set_lr) 36 | 37 | @staticmethod 38 | def _set_lr(module, grad_input, grad_output): 39 | grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input))) 40 | grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output))) 41 | 42 | def forward(self, x): 43 | offset = self.p_conv(x) 44 | if self.modulation: 45 | m = torch.sigmoid(self.m_conv(x)) 46 | 47 | dtype = offset.data.type() 48 | ks = self.kernel_size 49 | N = offset.size(1) // 3 50 | 51 | if self.padding: 52 | x = self.zero_padding(x) 53 | 54 | p = self._get_p(offset, dtype).to(device) 55 | 56 | p = p.contiguous().permute(0, 2, 3, 4, 1) 57 | q_lt = p.detach().floor() 58 | q_rb = q_lt + 1 59 | 60 | q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long() 61 | q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long() 62 | q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1) 63 | q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1) 64 | 65 | p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:2*N], 0, x.size(3)-1),torch.clamp(p[..., 2*N:], 0, x.size(4)-1)], dim=-1) 66 | 67 | g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:2*N].type_as(p) - p[..., N:2*N])) *(1 + (q_lt[..., 2*N:].type_as(p) - p[..., 2*N:])) 68 | g_rb = (1 + (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rb[..., N:2*N].type_as(p) - p[..., N:2*N])) *(1 + (q_rb[..., 2*N:].type_as(p) - p[..., 2*N:])) 69 | g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lb[..., N:2*N].type_as(p) - p[..., N:2*N])) *(1 + (q_lb[..., 2*N:].type_as(p) - p[..., 2*N:])) 70 | g_rt = (1 + (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:2*N].type_as(p) - p[..., N:2*N])) *(1 + (q_rt[..., 2*N:].type_as(p) - p[..., 2*N:])) 71 | 72 | x_q_lt = self._get_x_q(x, q_lt, N) 73 | x_q_rb = self._get_x_q(x, q_rb, N) 74 | x_q_lb = self._get_x_q(x, q_lb, N) 75 | x_q_rt = self._get_x_q(x, q_rt, N) 76 | 77 | x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \ 78 | g_rb.unsqueeze(dim=1) * x_q_rb + \ 79 | g_lb.unsqueeze(dim=1) * x_q_lb + \ 80 | g_rt.unsqueeze(dim=1) * x_q_rt 81 | 82 | if self.modulation: 83 | m = m.contiguous().permute(0, 2, 3, 1) 84 | m = m.unsqueeze(dim=1) 85 | m = torch.cat([m for _ in range(x_offset.size(1))], dim=1) 86 | x_offset *= m 87 | 88 | x_offset = self._reshape_x_offset(x_offset, ks) 89 | out = self.conv(x_offset) 90 | 91 | return out 92 | 93 | def _get_p_n(self, N, dtype): 94 | p_n_x, p_n_y, p_n_z= torch.meshgrid( 95 | torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1), 96 | torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1), 97 | torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1)) 98 | p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y),torch.flatten(p_n_z)], 0) 99 | p_n = p_n.view(1, 3*N, 1, 1,1).type(dtype) 100 | 101 | return p_n 102 | 103 | def _get_p_0(self, h, w,d, N, dtype): 104 | p_0_x, p_0_y,p_0_z = torch.meshgrid( 105 | torch.arange(1, h*self.stride+1, self.stride), 106 | torch.arange(1, w*self.stride+1, self.stride), 107 | torch.arange(1, d*self.stride+1, self.stride)) 108 | p_0_x = torch.flatten(p_0_x).view(1, 1, h, w,d).repeat(1, N, 1, 1,1) 109 | p_0_y = torch.flatten(p_0_y).view(1, 1, h, w,d).repeat(1, N, 1, 1,1) 110 | p_0_z = torch.flatten(p_0_z).view(1, 1, h, w,d).repeat(1, N, 1, 1,1) 111 | p_0 = torch.cat([p_0_x, p_0_y,p_0_z], 1).type(dtype) 112 | 113 | return p_0 114 | 115 | def _get_p(self, offset, dtype): 116 | N, h, w,d = offset.size(1)//3, offset.size(2), offset.size(3),offset.size(4) 117 | 118 | p_n = self._get_p_n(N, dtype).to(device) 119 | p_0 = self._get_p_0(h, w,d, N, dtype).to(device) 120 | p = p_0 + p_n + offset 121 | return p 122 | 123 | def _get_x_q(self, x, q, N): 124 | b, h, w,d, _ = q.size() 125 | padded_w = x.size(3) 126 | c = x.size(1) 127 | x = x.contiguous().view(b, c, -1) 128 | index = q[..., :N]*padded_w + q[..., N:2*N]+q[..., 2*N:] 129 | index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1,-1).contiguous().view(b, c, -1) 130 | x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w,d, N) 131 | return x_offset 132 | 133 | @staticmethod 134 | def _reshape_x_offset(x_offset, ks): 135 | b, c, h, w,d, N = x_offset.size() 136 | x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w,d*ks) for s in range(0, N, ks)], dim=-1) 137 | x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks,d*ks) 138 | return x_offset 139 | 140 | class _routing(nn.Module): 141 | 142 | def __init__(self, in_channels, num_experts, dropout_rate): 143 | super(_routing, self).__init__() 144 | 145 | self.dropout = nn.Dropout(dropout_rate) 146 | self.fc = nn.Linear(in_channels, num_experts) 147 | 148 | def forward(self, x): 149 | x = torch.flatten(x) 150 | x = self.dropout(x) 151 | x = self.fc(x) 152 | return F.sigmoid(x) 153 | 154 | 155 | class SConv3D(_ConvNd): 156 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, 157 | bias=True, padding_mode='zeros', num_experts=8, dropout_rate=0.2): 158 | kernel_size = _triple(kernel_size) 159 | stride = _triple(stride) 160 | padding = _triple(padding) 161 | dilation = _triple(dilation) 162 | super(SConv3D, self).__init__( 163 | in_channels, out_channels, kernel_size, stride, padding, dilation, 164 | False, _pair(0), groups, bias, padding_mode) 165 | 166 | self._avg_pooling = functools.partial(F.adaptive_avg_pool3d, output_size=(1, 1,1)) 167 | self._routing_fn = _routing(in_channels, num_experts, dropout_rate) 168 | 169 | 170 | self.weight = Parameter(torch.Tensor( 171 | num_experts, out_channels, in_channels // groups, *kernel_size)) 172 | 173 | self.reset_parameters() 174 | 175 | def _conv_forward(self, input, weight): 176 | if self.padding_mode != 'zeros': 177 | return F.conv3d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode), 178 | weight, self.bias, self.stride, 179 | _pair(0), self.dilation, self.groups) 180 | 181 | return F.conv3d(input, weight, self.bias, self.stride, 182 | self.padding, self.dilation, self.groups) 183 | 184 | def forward(self, inputs): 185 | b, _, _, _,_ = inputs.size() 186 | res = [] 187 | for input in inputs: 188 | input = input.unsqueeze(0) 189 | pooled_inputs = self._avg_pooling(input) 190 | 191 | routing_weights = self._routing_fn(pooled_inputs) 192 | 193 | 194 | kernels = torch.sum(routing_weights[:, None, None, None, None,None] * self.weight, 0) 195 | 196 | out = self._conv_forward(input, kernels) 197 | res.append(out) 198 | return torch.cat(res, dim=0) 199 | if __name__ == "__main__": 200 | with torch.no_grad(): 201 | input = torch.rand(2, 16, 112, 112, 80).to("cpu") 202 | model = DDConv_3D(16,16).to("cpu") 203 | 204 | out_result = model(input) 205 | print(out_result.shape) 206 | 207 | flops, params = profile(model, (input,)) 208 | 209 | print("-" * 50) 210 | print('FLOPs = ' + str(flops / 1000 ** 3) + ' G') 211 | print('Params = ' + str(params / 1000 ** 2) + ' M') 212 | -------------------------------------------------------------------------------- /TEC_Net_T.py: -------------------------------------------------------------------------------- 1 | """ 2 | CiT-Net-Tiny 3 | stage1 stage2 stage3 stage4 4 | size 56x56 28x28 14x14 7x7 5 | Unet dim 96 192 384 768 6 | Swin dim 96 192 384 768 7 | head 3 6 12 24 8 | num 2 2 6 2 9 | """ 10 | import torch 11 | import math 12 | import torch.nn as nn 13 | import torch.utils.checkpoint as checkpoint 14 | from einops import rearrange 15 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 16 | from thop import * 17 | from torch.nn import init 18 | import torch.nn.functional as F 19 | from torch.nn import Module, Sequential, Conv2d, ReLU,AdaptiveMaxPool2d, AdaptiveAvgPool2d, \ 20 | NLLLoss, BCELoss, CrossEntropyLoss, AvgPool2d, MaxPool2d, Parameter, Linear, Sigmoid, Softmax, Dropout, Embedding 21 | from DDConv import DDConv 22 | from einops.layers.torch import Rearrange 23 | 24 | 25 | class PatchEmbed(nn.Module): 26 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 27 | super().__init__() 28 | self.conv2patch = nn.Sequential( 29 | nn.Conv2d(3, embed_dim, kernel_size=4, stride=4), 30 | nn.GELU(), 31 | # nn.ReLU(), 32 | nn.BatchNorm2d(embed_dim) 33 | ) 34 | 35 | def forward(self, x): 36 | # # FIXME look at relaxing size constraints 37 | x = self.conv2patch(x) 38 | return x 39 | 40 | def flops(self): 41 | Ho, Wo = self.patches_resolution 42 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 43 | if self.norm is not None: 44 | flops += Ho * Wo * self.embed_dim 45 | return flops 46 | 47 | class Mlp(nn.Module): 48 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 49 | super().__init__() 50 | out_features = out_features or in_features 51 | hidden_features = hidden_features or in_features 52 | self.fc1 = nn.Linear(in_features, hidden_features) 53 | self.act = act_layer() 54 | self.fc2 = nn.Linear(hidden_features, out_features) 55 | self.drop = nn.Dropout(drop) 56 | 57 | def forward(self, x): 58 | x = self.fc1(x) 59 | x = self.act(x) 60 | x = self.drop(x) 61 | x = self.fc2(x) 62 | x = self.drop(x) 63 | return x 64 | 65 | class oneXone_conv(nn.Module): 66 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 67 | super(oneXone_conv, self).__init__() 68 | out_features = out_features or in_features 69 | hidden_features = hidden_features or in_features 70 | self.Conv1 = nn.Sequential( 71 | nn.Conv2d(in_features, hidden_features, kernel_size=1), 72 | nn.GELU(), 73 | nn.BatchNorm2d(hidden_features) 74 | ) 75 | self.Conv2 = nn.Sequential( 76 | nn.Conv2d(hidden_features, out_features, kernel_size=1), 77 | nn.GELU(), 78 | nn.BatchNorm2d(out_features) 79 | ) 80 | self.drop = nn.Dropout(drop) 81 | def forward(self, x): 82 | x = self.Conv1(x) 83 | x = self.drop(x) 84 | x = self.Conv2(x) 85 | x = self.drop(x) 86 | return x 87 | 88 | class GhostModule(nn.Module): 89 | def __init__(self, inp, oup=None, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True): 90 | super(GhostModule, self).__init__() 91 | oup = oup or inp 92 | init_channels = math.ceil(oup // ratio) 93 | new_channels = init_channels*(ratio-1) 94 | 95 | self.primary_conv = nn.Sequential( 96 | nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False), 97 | nn.BatchNorm2d(init_channels), 98 | nn.ReLU(inplace=True) if relu else nn.Sequential(), 99 | ) 100 | 101 | self.cheap_operation = nn.Sequential( 102 | nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False), 103 | nn.BatchNorm2d(new_channels), 104 | nn.ReLU(inplace=True) if relu else nn.Sequential(), 105 | ) 106 | 107 | def forward(self, x): 108 | x1 = self.primary_conv(x) 109 | x2 = self.cheap_operation(x1) 110 | out = torch.cat([x1, x2], dim=1) 111 | return out 112 | 113 | class GhostModule_Up(nn.Module): 114 | def __init__(self, inp, oup=None, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True): 115 | super(GhostModule_Up, self).__init__() 116 | oup = oup or inp 117 | init_channels = inp 118 | new_channels = init_channels 119 | 120 | self.primary_conv = nn.Sequential( 121 | nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False), 122 | nn.BatchNorm2d(init_channels), 123 | nn.ReLU(inplace=True) if relu else nn.Sequential(), 124 | ) 125 | 126 | self.cheap_operation = nn.Sequential( 127 | nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False), 128 | nn.BatchNorm2d(new_channels), 129 | nn.ReLU(inplace=True) if relu else nn.Sequential(), 130 | ) 131 | 132 | def forward(self, x): 133 | x1 = self.primary_conv(x) 134 | x2 = self.cheap_operation(x1) 135 | out = torch.cat([x1, x2], dim=1) 136 | return out 137 | 138 | def window_partition(x, window_size): 139 | B, H, W, C = x.shape 140 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 141 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 142 | return windows 143 | 144 | def window_reverse(windows, window_size, H, W): 145 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 146 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 147 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 148 | return x 149 | 150 | class CAM_Module(Module): 151 | def __init__(self, in_dim, dim, num_heads, qk_scale=None, C_lambda=1e-4, attn_drop=0., proj_drop=0.): 152 | super(CAM_Module, self).__init__() 153 | self.chanel_in = in_dim 154 | self.softmax = Softmax(dim=-1) 155 | self.c_lambda = C_lambda 156 | self.activaton = nn.Sigmoid() 157 | 158 | head_dim = dim // num_heads 159 | self.scale = qk_scale or head_dim ** -0.5 160 | self.attn_drop = nn.Dropout(attn_drop) 161 | self.proj = nn.Sequential(nn.Conv2d(dim//12, dim, kernel_size=3, padding=1, stride=1, bias=False, groups=dim//12), nn.GELU()) 162 | self.proj_drop = nn.Dropout(proj_drop) 163 | 164 | self.query_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//12, kernel_size=1) 165 | self.key_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//12, kernel_size=1) 166 | self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//12, kernel_size=1) 167 | 168 | self.gamma = Parameter(torch.zeros(1)) 169 | 170 | def forward(self, x, mask=None): 171 | m_batchsize, N, C = x.size() 172 | height = int(N ** .5) 173 | width = int(N ** .5) 174 | 175 | x = x.view(m_batchsize, C, height, width) 176 | proj_query = self.query_conv(x).view(m_batchsize, C//12, -1) 177 | proj_key = self.key_conv(x).view(m_batchsize, C//12, -1).permute(0, 2, 1) 178 | proj_value = self.value_conv(x).view(m_batchsize, C//12, -1) 179 | 180 | q = proj_query * self.scale 181 | attn = (q @ proj_key) 182 | 183 | if mask is not None: 184 | nW = mask.shape[0] # num_windows 185 | attn = attn.view(m_batchsize // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 186 | attn = attn.view(-1, self.num_heads, N, N) 187 | attn = self.softmax(attn) 188 | else: 189 | attn = self.softmax(attn) 190 | 191 | attn = self.attn_drop(attn) 192 | x = (attn @ proj_value).reshape(m_batchsize, C//12, height, width) 193 | x = self.proj(x) 194 | x = x.reshape(m_batchsize, C, N).transpose(1, 2) 195 | x = self.proj_drop(x) 196 | 197 | out = self.gamma * x + x 198 | return out 199 | 200 | class PAM_Module(Module): 201 | def __init__(self, in_dim, dim, num_heads, qk_scale=None, P_lambda=1e-4, attn_drop=0., proj_drop=0.): 202 | super(PAM_Module, self).__init__() 203 | self.chanel_in = in_dim 204 | 205 | head_dim = dim // num_heads 206 | self.scale = qk_scale or head_dim ** -0.5 207 | self.attn_drop = nn.Dropout(attn_drop) 208 | self.proj = nn.Sequential(nn.Conv2d(dim//12, dim, kernel_size=3, padding=1, stride=1, bias=False, groups=dim//12), nn.GELU()) 209 | self.proj_drop = nn.Dropout(proj_drop) 210 | 211 | self.query_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//12, kernel_size=1) 212 | self.key_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//12, kernel_size=1) 213 | self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//12, kernel_size=1) 214 | self.softmax = Softmax(dim=-1) 215 | 216 | 217 | self.p_lambda = P_lambda 218 | self.activaton = nn.Sigmoid() 219 | self.gamma = Parameter(torch.zeros(1)) 220 | 221 | def forward(self, x, mask=None): 222 | m_batchsize, N, C = x.size() 223 | height = int(N ** .5) 224 | width = int(N ** .5) 225 | 226 | x = x.view(m_batchsize, C, height, width) 227 | proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1) 228 | proj_key = self.key_conv(x).view(m_batchsize, -1, width*height) 229 | proj_value = self.value_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1) 230 | 231 | q = proj_query * self.scale 232 | attn = (q @ proj_key) 233 | 234 | if mask is not None: 235 | nW = mask.shape[0] # num_windows 236 | attn = attn.view(m_batchsize // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 237 | attn = attn.view(-1, self.num_heads, N, N) 238 | attn = self.softmax(attn) 239 | else: 240 | attn = self.softmax(attn) 241 | 242 | attn = self.attn_drop(attn) 243 | x = (attn @ proj_value).reshape(m_batchsize, C//12, height, width) 244 | x = self.proj(x) 245 | x = x.reshape(m_batchsize, C, N).transpose(1, 2) 246 | x = self.proj_drop(x) 247 | 248 | out = self.gamma * x + x 249 | return out 250 | 251 | class CHAM_Module(Module): 252 | def __init__(self, in_dim, dim, num_heads, qk_scale=None, P_lambda=1e-4, attn_drop=0., proj_drop=0.): 253 | super(CHAM_Module, self).__init__() 254 | self.chanel_in = in_dim 255 | 256 | head_dim = dim // num_heads 257 | self.scale = qk_scale or head_dim ** -0.5 258 | self.attn_drop = nn.Dropout(attn_drop) 259 | self.proj = nn.Sequential(nn.Conv2d(dim//12, dim, kernel_size=3, padding=1, stride=1, bias=False, groups=dim//12), nn.GELU()) 260 | self.proj_drop = nn.Dropout(proj_drop) 261 | 262 | self.query_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//12, kernel_size=1) 263 | self.key_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//12, kernel_size=1) 264 | self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//12, kernel_size=1) 265 | self.softmax = Softmax(dim=-1) 266 | 267 | self.p_lambda = P_lambda 268 | self.activaton = nn.Sigmoid() 269 | self.gamma = Parameter(torch.zeros(1)) 270 | 271 | def forward(self, x, mask=None): 272 | m_batchsize, N, C = x.size() 273 | height = int(N ** .5) 274 | width = int(N ** .5) 275 | 276 | x = x.view(m_batchsize, C, height, width) 277 | proj_query = self.query_conv(x).view(m_batchsize, C//12 * height, -1) 278 | proj_key = self.key_conv(x).view(m_batchsize, C//12 * height, -1).permute(0, 2, 1) 279 | proj_value = self.value_conv(x).view(m_batchsize, C//12 * height, -1) 280 | 281 | q = proj_query * self.scale 282 | attn = (q @ proj_key) 283 | 284 | if mask is not None: 285 | nW = mask.shape[0] # num_windows 286 | attn = attn.view(m_batchsize // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 287 | attn = attn.view(-1, self.num_heads, N, N) 288 | attn = self.softmax(attn) 289 | else: 290 | attn = self.softmax(attn) 291 | 292 | attn = self.attn_drop(attn) 293 | x = (attn @ proj_value).reshape(m_batchsize, C//12, height, width) 294 | x = self.proj(x) 295 | x = x.reshape(m_batchsize, C, N).transpose(1, 2) 296 | x = self.proj_drop(x) 297 | 298 | out = self.gamma * x + x 299 | return out 300 | 301 | 302 | class CWAM_Module(Module): 303 | def __init__(self, in_dim, dim, num_heads, qk_scale=None, P_lambda=1e-4, attn_drop=0., proj_drop=0.): 304 | super(CWAM_Module, self).__init__() 305 | self.chanel_in = in_dim 306 | 307 | head_dim = dim // num_heads 308 | self.scale = qk_scale or head_dim ** -0.5 309 | self.attn_drop = nn.Dropout(attn_drop) 310 | self.proj = nn.Sequential(nn.Conv2d(dim, dim, kernel_size=3, padding=1, stride=1, bias=False, groups=dim), nn.GELU()) 311 | self.proj_drop = nn.Dropout(proj_drop) 312 | 313 | self.query_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//12, kernel_size=1) 314 | self.key_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//12, kernel_size=1) 315 | self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 316 | self.softmax = Softmax(dim=-1) 317 | 318 | self.p_lambda = P_lambda 319 | self.activaton = nn.Sigmoid() 320 | self.gamma = Parameter(torch.zeros(1)) 321 | 322 | def forward(self, x, mask=None): 323 | m_batchsize, N, C = x.size() 324 | height = int(N ** .5) 325 | width = int(N ** .5) 326 | 327 | proj_query = x.view(m_batchsize, C * width, -1) 328 | proj_key = x.view(m_batchsize, C * width, -1).permute(0, 2, 1) 329 | proj_value = x.view(m_batchsize, C * width, -1) 330 | 331 | q = proj_query * self.scale 332 | attn = (q @ proj_key) 333 | 334 | if mask is not None: 335 | nW = mask.shape[0] # num_windows 336 | attn = attn.view(m_batchsize // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 337 | attn = attn.view(-1, self.num_heads, N, N) 338 | attn = self.softmax(attn) 339 | else: 340 | attn = self.softmax(attn) 341 | 342 | attn = self.attn_drop(attn) 343 | x = (attn @ proj_value).reshape(m_batchsize, C, height, width) 344 | x = self.proj(x) 345 | x = x.reshape(m_batchsize, C, N).transpose(1, 2) 346 | x = self.proj_drop(x) 347 | 348 | out = self.gamma * x + x 349 | return out 350 | 351 | class WindowAttention_ACAM(nn.Module): 352 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 353 | super().__init__() 354 | self.dim = dim 355 | 356 | self.C_C = CAM_Module(self.dim, dim=dim, num_heads=num_heads) 357 | self.H_W = PAM_Module(self.dim, dim=dim, num_heads=num_heads) 358 | self.C_H = CHAM_Module(self.dim, dim=dim, num_heads=num_heads) 359 | self.C_W = CWAM_Module(self.dim, dim=dim, num_heads=num_heads) 360 | 361 | self.gamma1 = Parameter(torch.zeros(1)) 362 | self.gamma2 = Parameter(torch.zeros(1)) 363 | self.gamma3 = Parameter(torch.ones(1) * 0.5) 364 | self.gamma4 = Parameter(torch.ones(1) * 0.5) 365 | 366 | def _build_projection(self, dim_in, kernel_size=3, stride=1, padding=1): 367 | proj = nn.Sequential( 368 | nn.Conv2d(dim_in, dim_in, kernel_size, padding=padding, stride=stride, bias=False, groups=dim_in), 369 | Rearrange('b c h w -> b (h w) c'), 370 | nn.LayerNorm(dim_in)) 371 | return proj 372 | 373 | def forward(self, x, mask=None): 374 | x_out1 = self.C_C(x) 375 | 376 | x_out2 = self.H_W(x) 377 | 378 | x_out3 = self.C_H(x) 379 | 380 | x_out4 = self.C_W(x) 381 | 382 | x_out = (self.gamma1 * x_out1) + (self.gamma2 * x_out2) + (self.gamma3 * x_out3) + (self.gamma4 * x_out4) 383 | 384 | return x_out 385 | """ =============================================================================================================== """ 386 | 387 | class SwinTransformerBlock(nn.Module): 388 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 389 | mlp_ratio=2., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 390 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 391 | super().__init__() 392 | 393 | self.dim = dim 394 | self.input_resolution = input_resolution 395 | self.num_heads = num_heads 396 | self.window_size = window_size 397 | self.shift_size = shift_size 398 | self.mlp_ratio = mlp_ratio 399 | 400 | 401 | if min(self.input_resolution) <= self.window_size: 402 | self.shift_size = 0 403 | self.window_size = min(self.input_resolution) 404 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 405 | 406 | self.norm1 = norm_layer(dim) 407 | self.norm2 = norm_layer([dim, input_resolution[0], input_resolution[1]]) 408 | 409 | self.attn = WindowAttention_ACAM( 410 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 411 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 412 | 413 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 414 | 415 | if self.shift_size > 0: 416 | H, W = self.input_resolution 417 | img_mask = torch.zeros((1, H, W, 1)) 418 | h_slices = (slice(0, -self.window_size), 419 | slice(-self.window_size, -self.shift_size), 420 | slice(-self.shift_size, None)) 421 | w_slices = (slice(0, -self.window_size), 422 | slice(-self.window_size, -self.shift_size), 423 | slice(-self.shift_size, None)) 424 | cnt = 0 425 | for h in h_slices: 426 | for w in w_slices: 427 | img_mask[:, h, w, :] = cnt 428 | cnt += 1 429 | 430 | mask_windows = window_partition(img_mask, self.window_size) 431 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 432 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 433 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 434 | else: 435 | attn_mask = None 436 | 437 | self.register_buffer("attn_mask", attn_mask) 438 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 439 | self.mlp = GhostModule(inp=dim) 440 | 441 | def forward(self, x): 442 | B, C, H, W = x.shape 443 | 444 | shortcut1 = x 445 | x = x.view(B, H, W, C) 446 | x = self.norm1(x) 447 | 448 | if self.shift_size > 0: 449 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 450 | else: 451 | shifted_x = x 452 | 453 | x_windows = window_partition(shifted_x, self.window_size) 454 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) 455 | 456 | attn_windows = self.attn(x_windows, mask=self.attn_mask) 457 | 458 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 459 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) 460 | 461 | if self.shift_size > 0: 462 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 463 | else: 464 | x = shifted_x 465 | x = x.view(B, C, H, W) 466 | x = shortcut1 + self.drop_path(x) 467 | shortcut2 = x 468 | 469 | x = self.norm2(x) 470 | 471 | x = shortcut2 + self.drop_path(self.mlp(x)) 472 | return x 473 | 474 | class PatchMerging(nn.Module): 475 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 476 | super().__init__() 477 | self.input_resolution = input_resolution 478 | h, w = input_resolution 479 | h = int(h/2) 480 | w = int(w/2) 481 | self.dim = dim 482 | self.norm = norm_layer([4*dim, h, w]) 483 | self.reduction = GhostModule(inp=4 * dim, oup=2 * dim, ratio=4) 484 | 485 | def forward(self, x): 486 | H, W = self.input_resolution 487 | B, C, H, W = x.shape 488 | 489 | x0 = x[:, :, 0::2, 0::2] # B C H/2 W/2 490 | x1 = x[:, :, 1::2, 0::2] # B C H/2 W/2 491 | x2 = x[:, :, 0::2, 1::2] # B C H/2 W/2 492 | x3 = x[:, :, 1::2, 1::2] # B C H/2 W/2 493 | x = torch.cat([x0, x1, x2, x3], 1) # B 4*C H/2 W/2 494 | 495 | x = self.norm(x) 496 | x = self.reduction(x) 497 | 498 | return x 499 | 500 | def extra_repr(self) -> str: 501 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 502 | 503 | def flops(self): 504 | H, W = self.input_resolution 505 | flops = H * W * self.dim 506 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 507 | return flops 508 | 509 | class BasicLayer(nn.Module): 510 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 511 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 512 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 513 | 514 | super().__init__() 515 | self.dim = dim 516 | self.input_resolution = input_resolution 517 | self.depth = depth 518 | self.use_checkpoint = use_checkpoint 519 | 520 | self.blocks = nn.ModuleList([ 521 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 522 | num_heads=num_heads, window_size=window_size, 523 | shift_size=0 if (i % 2 == 0) else window_size // 2, 524 | mlp_ratio=mlp_ratio, 525 | qkv_bias=qkv_bias, qk_scale=qk_scale, 526 | drop=drop, attn_drop=attn_drop, 527 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 528 | norm_layer=norm_layer) 529 | for i in range(depth)]) 530 | 531 | if downsample is not None: 532 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 533 | else: 534 | self.downsample = None 535 | 536 | def forward(self, x): 537 | for blk in self.blocks: 538 | if self.use_checkpoint: 539 | x = checkpoint.checkpoint(blk, x) 540 | else: 541 | x = blk(x) 542 | if self.downsample is not None: 543 | x = self.downsample(x) 544 | return x 545 | 546 | def extra_repr(self) -> str: 547 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 548 | 549 | def flops(self): 550 | flops = 0 551 | for blk in self.blocks: 552 | flops += blk.flops() 553 | if self.downsample is not None: 554 | flops += self.downsample.flops() 555 | return flops 556 | 557 | class PatchExpand(nn.Module): 558 | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): 559 | super().__init__() 560 | self.input_resolution = input_resolution 561 | self.dim = dim 562 | self.expand = GhostModule_Up(inp=dim) if dim_scale == 2 else nn.Identity() 563 | self.norm = norm_layer([dim // dim_scale, input_resolution[0]*2, input_resolution[1]*2]) 564 | 565 | def forward(self, x): 566 | B, C, H, W = x.shape 567 | x = self.expand(x) 568 | x = rearrange(x, 'b (p1 p2 c) h w -> b c (h p1) (w p2)', p1=2, p2=2, c=C // 2) 569 | x = self.norm(x) 570 | 571 | return x 572 | 573 | class BasicLayer_up(nn.Module): 574 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 575 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 576 | drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False): 577 | 578 | super().__init__() 579 | self.dim = dim 580 | self.input_resolution = input_resolution 581 | self.depth = depth 582 | self.use_checkpoint = use_checkpoint 583 | 584 | self.blocks = nn.ModuleList([ 585 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 586 | num_heads=num_heads, window_size=window_size, 587 | shift_size=0 if (i % 2 == 0) else window_size // 2, 588 | mlp_ratio=mlp_ratio, 589 | qkv_bias=qkv_bias, qk_scale=qk_scale, 590 | drop=drop, attn_drop=attn_drop, 591 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 592 | norm_layer=norm_layer) 593 | for i in range(depth)]) 594 | 595 | if upsample is not None: 596 | self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer) 597 | else: 598 | self.upsample = None 599 | 600 | def forward(self, x): 601 | for blk in self.blocks: 602 | if self.use_checkpoint: 603 | x = checkpoint.checkpoint(blk, x) 604 | else: 605 | x = blk(x) 606 | if self.upsample is not None: 607 | x = self.upsample(x) 608 | return x 609 | 610 | class FinalPatchExpand_X4(nn.Module): 611 | def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm): 612 | super().__init__() 613 | self.input_resolution = input_resolution 614 | self.dim = dim 615 | self.dim_scale = dim_scale 616 | self.expand = oneXone_conv(in_features = dim, out_features = 16 * dim) if dim_scale == 2 else nn.Identity() 617 | self.output_dim = dim 618 | self.norm = norm_layer([6, input_resolution[0]*4, input_resolution[1]*4]) 619 | 620 | def forward(self, x): 621 | B, C, H, W = x.shape 622 | x = self.expand(x) 623 | x = rearrange(x, 'b (p1 p2 c) h w -> b c (h p1) (w p2)', p1=self.dim_scale, p2=self.dim_scale, c=C // (self.dim_scale ** 2)) 624 | x = self.norm(x) 625 | return x 626 | 627 | def init_weights(net, init_type='normal', gain=0.02): 628 | def init_func(m): 629 | classname = m.__class__.__name__ 630 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 631 | if init_type == 'normal': 632 | init.normal_(m.weight.data, 0.0, gain) 633 | elif init_type == 'xavier': 634 | init.xavier_normal_(m.weight.data, gain=gain) 635 | elif init_type == 'kaiming': 636 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 637 | elif init_type == 'orthogonal': 638 | init.orthogonal_(m.weight.data, gain=gain) 639 | else: 640 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 641 | if hasattr(m, 'bias') and m.bias is not None: 642 | init.constant_(m.bias.data, 0.0) 643 | elif classname.find('BatchNorm2d') != -1: 644 | init.normal_(m.weight.data, 1.0, gain) 645 | init.constant_(m.bias.data, 0.0) 646 | 647 | print('initialize network with %s' % init_type) 648 | net.apply(init_func) 649 | 650 | class conv_block(nn.Module): 651 | def __init__(self, ch_in, ch_out): 652 | super(conv_block, self).__init__() 653 | self.conv = nn.Sequential( 654 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 655 | nn.BatchNorm2d(ch_out), 656 | nn.ReLU(inplace=True), 657 | nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 658 | nn.BatchNorm2d(ch_out), 659 | nn.ReLU(inplace=True) 660 | ) 661 | 662 | def forward(self, x): 663 | x = self.conv(x) 664 | return x 665 | 666 | class up_conv(nn.Module): 667 | def __init__(self, ch_in, ch_out): 668 | super(up_conv, self).__init__() 669 | self.up = nn.Sequential( 670 | nn.Upsample(scale_factor=2), 671 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 672 | nn.BatchNorm2d(ch_out), 673 | nn.ReLU(inplace=True) 674 | ) 675 | 676 | def forward(self, x): 677 | x = self.up(x) 678 | return x 679 | 680 | class conv_block_DDConv(nn.Module): 681 | def __init__(self, ch_in, ch_out): 682 | super(conv_block_DDConv, self).__init__() 683 | self.conv = nn.Sequential( 684 | DDConv(ch_in, ch_out, kernel_size=3, stride=1, padding=1), 685 | 686 | nn.BatchNorm2d(ch_out), 687 | nn.ReLU(inplace=True), 688 | 689 | DDConv(ch_out, ch_out, kernel_size=3, stride=1, padding=1), 690 | 691 | nn.BatchNorm2d(ch_out), 692 | nn.ReLU(inplace=True) 693 | ) 694 | 695 | def forward(self, x): 696 | x = self.conv(x) 697 | return x 698 | 699 | class up_conv_DDConv(nn.Module): 700 | def __init__(self, ch_in, ch_out): 701 | super(up_conv_DDConv, self).__init__() 702 | self.up = nn.Sequential( 703 | nn.Upsample(scale_factor=2), 704 | 705 | DDConv(ch_in, ch_out, kernel_size=1), 706 | 707 | nn.BatchNorm2d(ch_out), 708 | nn.ReLU(inplace=True) 709 | ) 710 | 711 | def forward(self, x): 712 | x = self.up(x) 713 | return x 714 | 715 | class ConvMixerLayer(nn.Module): 716 | def __init__(self, dim, kernel_size=9): 717 | super().__init__() 718 | self.Resnet = nn.Sequential( 719 | nn.Conv2d(dim, dim, kernel_size=kernel_size, groups=dim, padding=4), 720 | nn.GELU(), 721 | nn.BatchNorm2d(dim) 722 | ) 723 | self.Conv_1x1 = nn.Sequential( 724 | nn.Conv2d(dim, dim, kernel_size=1), 725 | nn.GELU(), 726 | nn.BatchNorm2d(dim) 727 | ) 728 | def forward(self, x): 729 | x = x + self.Resnet(x) 730 | x = self.Conv_1x1(x) 731 | return x 732 | 733 | 734 | class ConvMixer(nn.Module): 735 | def __init__(self, dim=512, depth=1, kernel_size=9, patch_size=4, n_classes=1000): 736 | super().__init__() 737 | self.conv2d1 = nn.Sequential( 738 | nn.Conv2d(1, dim, kernel_size=patch_size, stride=patch_size), 739 | nn.GELU(), 740 | nn.BatchNorm2d(dim) 741 | ) 742 | self.ConvMixer_blocks = nn.ModuleList([]) 743 | 744 | for _ in range(depth): 745 | self.ConvMixer_blocks.append(ConvMixerLayer(dim=dim, kernel_size=kernel_size)) 746 | 747 | self.head = nn.Sequential( 748 | nn.AdaptiveAvgPool2d((1, 1)), 749 | nn.Flatten(), 750 | nn.Linear(dim, n_classes) 751 | ) 752 | 753 | def forward(self, x): 754 | x = self.conv2d1(x) 755 | 756 | for ConvMixer_block in self.ConvMixer_blocks: 757 | x = ConvMixer_block(x) 758 | 759 | x = x 760 | return x 761 | 762 | class CIT(nn.Module): 763 | def __init__(self, img_size=224, patch_size=4, in_chans=3, out_chans=1, 764 | embed_dim=96, depths=[2, 2, 6, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24], 765 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, 766 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 767 | norm_layer=nn.LayerNorm, ape=True, patch_norm=True, 768 | use_checkpoint=False, final_upsample="expand_first", **kwargs): 769 | super().__init__() 770 | 771 | self.out_channel = out_chans 772 | self.num_layers = len(depths) 773 | self.ape = ape 774 | self.patch_norm = patch_norm 775 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 776 | self.size = int(img_size/(2 ** (self.num_layers + 1))) 777 | self.size_out = int(img_size/4) 778 | self.num_features_up = int(embed_dim * 2) 779 | self.mlp_ratio = mlp_ratio 780 | self.final_upsample = final_upsample 781 | 782 | self.window_size = window_size 783 | self.qkv_bias = qkv_bias 784 | self.qk_scale = qk_scale 785 | self.drop_rate = drop_rate 786 | self.attn_drop_rate = attn_drop_rate 787 | self.drop_path_rate = drop_path_rate 788 | self.norm_layer = nn.LayerNorm 789 | 790 | self.patch_embed = PatchEmbed( 791 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 792 | norm_layer=norm_layer if self.patch_norm else None) 793 | 794 | self.pos_drop = nn.Dropout(p=drop_rate) 795 | 796 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 797 | 798 | 799 | if self.final_upsample == "expand_first": 800 | self.up = FinalPatchExpand_X4(input_resolution=(img_size // patch_size, img_size // patch_size), 801 | dim_scale=4, dim=embed_dim) 802 | self.output = nn.Conv2d(in_channels=6, out_channels=1, kernel_size=1, bias=False) 803 | 804 | self.apply(self._init_weights) 805 | 806 | self.embed_dim = 96 807 | self.num_heads = 3 808 | self.depth35 = 6 809 | self.drop_path3 = dpr[4:10] 810 | self.drop_path4 = dpr[10:12] 811 | 812 | print("CiT-Net-T----embed_dim:{}; num_heads:{}; depths:{}".format(self.embed_dim, num_heads, depths)) 813 | 814 | self.layer1 = BasicLayer(dim=self.embed_dim * 1, 815 | input_resolution=(56, 56), 816 | depth=2, 817 | num_heads=self.num_heads * 1, 818 | window_size=self.window_size, # 7 819 | mlp_ratio=self.mlp_ratio, # 4. 820 | qkv_bias=self.qkv_bias, # True 821 | qk_scale=self.qk_scale, # None 822 | drop=self.drop_rate, # 0. 823 | attn_drop=self.attn_drop_rate, # 0. 824 | drop_path=dpr[0:2], 825 | norm_layer=self.norm_layer, 826 | downsample=PatchMerging, 827 | use_checkpoint=False) 828 | 829 | self.layer2 = BasicLayer(dim=self.embed_dim * 2, 830 | input_resolution=(28, 28), 831 | depth=2, 832 | num_heads=self.num_heads * 2, 833 | window_size=self.window_size, # 7 834 | mlp_ratio=self.mlp_ratio, # 4. 835 | qkv_bias=self.qkv_bias, # True 836 | qk_scale=self.qk_scale, # None 837 | drop=self.drop_rate, # 0. 838 | attn_drop=self.attn_drop_rate, # 0. 839 | drop_path=dpr[2:4], 840 | norm_layer=self.norm_layer, 841 | downsample=PatchMerging, 842 | use_checkpoint=False) 843 | 844 | self.layer3 = BasicLayer(dim=self.embed_dim * 4, 845 | input_resolution=(14, 14), 846 | depth=self.depth35, 847 | num_heads=self.num_heads * 4, 848 | window_size=self.window_size, # 7 849 | mlp_ratio=self.mlp_ratio, # 4. 850 | qkv_bias=self.qkv_bias, # True 851 | qk_scale=self.qk_scale, # None 852 | drop=self.drop_rate, # 0. 853 | attn_drop=self.attn_drop_rate, # 0. 854 | drop_path=self.drop_path3, 855 | norm_layer=self.norm_layer, 856 | downsample=PatchMerging, 857 | use_checkpoint=False) 858 | 859 | self.layer4 = BasicLayer(dim=self.embed_dim * 8, 860 | input_resolution=(7, 7), 861 | depth=2, 862 | num_heads=self.num_heads * 8, 863 | window_size=self.window_size, # 7 864 | mlp_ratio=self.mlp_ratio, # 4. 865 | qkv_bias=self.qkv_bias, # True 866 | qk_scale=self.qk_scale, # None 867 | drop=self.drop_rate, # 0. 868 | attn_drop=self.attn_drop_rate, # 0. 869 | drop_path=self.drop_path4, 870 | norm_layer=self.norm_layer, 871 | downsample=None, 872 | use_checkpoint=False) 873 | 874 | self.norm = norm_layer([self.num_features, self.size, self.size]) 875 | 876 | self.Patch_Expand1 = PatchExpand(input_resolution=(7, 7), 877 | dim=self.embed_dim * 8, 878 | dim_scale=2, 879 | norm_layer=norm_layer) 880 | 881 | 882 | self.concat_linear1 = GhostModule(inp=self.embed_dim * 8, oup=self.embed_dim * 4) 883 | 884 | self.layer5 = BasicLayer_up(dim=self.embed_dim * 4, 885 | input_resolution=(14, 14), 886 | depth=self.depth35, 887 | num_heads=self.num_heads * 4, 888 | 889 | window_size=self.window_size, # 7 890 | mlp_ratio=self.mlp_ratio, # 4. 891 | qkv_bias=self.qkv_bias, # True 892 | qk_scale=self.qk_scale, # None 893 | drop=self.drop_rate, # 0. 894 | attn_drop=self.attn_drop_rate, # 0. 895 | 896 | drop_path=self.drop_path3, 897 | norm_layer=norm_layer, 898 | upsample=PatchExpand, 899 | use_checkpoint=False) 900 | 901 | self.concat_linear2 = GhostModule(inp=self.embed_dim * 4, oup=self.embed_dim * 2) 902 | 903 | self.layer6 = BasicLayer_up(dim=self.embed_dim * 2, 904 | input_resolution=(28, 28), 905 | depth=2, 906 | num_heads=self.num_heads * 2, 907 | 908 | window_size=self.window_size, # 7 909 | mlp_ratio=self.mlp_ratio, # 4. 910 | qkv_bias=self.qkv_bias, # True 911 | qk_scale=self.qk_scale, # None 912 | drop=self.drop_rate, # 0. 913 | attn_drop=self.attn_drop_rate, # 0. 914 | 915 | drop_path=dpr[2:4], 916 | norm_layer=norm_layer, 917 | upsample=PatchExpand, 918 | use_checkpoint=False) 919 | 920 | self.concat_linear3 = GhostModule(inp=self.embed_dim * 2, oup=self.embed_dim * 1) 921 | 922 | self.layer7 = BasicLayer_up(dim=self.embed_dim * 1, 923 | input_resolution=(56, 56), 924 | depth=2, 925 | num_heads=self.num_heads * 1, 926 | 927 | window_size=self.window_size, # 7 928 | mlp_ratio=self.mlp_ratio, # 4. 929 | qkv_bias=self.qkv_bias, # True 930 | qk_scale=self.qk_scale, # None 931 | drop=self.drop_rate, # 0. 932 | attn_drop=self.attn_drop_rate, # 0. 933 | 934 | drop_path=dpr[0:2], 935 | norm_layer=norm_layer, 936 | upsample=None, 937 | use_checkpoint=False) 938 | 939 | self.norm_up = norm_layer([self.embed_dim, self.size_out, self.size_out]) 940 | self.patch = ConvMixer(dim=48, depth=5) # 修改ConvMixer层数 941 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 942 | self.Conv1e = conv_block(ch_in=48, ch_out=self.embed_dim * 1) 943 | self.Conv1s = conv_block(ch_in=48, ch_out=self.embed_dim * 1) 944 | self.Conv2e = conv_block_DDConv(ch_in=self.embed_dim * 1, ch_out=self.embed_dim * 2) 945 | self.Conv3e = conv_block_DDConv(ch_in=self.embed_dim * 2, ch_out=self.embed_dim * 4) 946 | self.Conv4e = conv_block_DDConv(ch_in=self.embed_dim * 4, ch_out=self.embed_dim * 8) 947 | self.Up4d = up_conv_DDConv(ch_in=self.embed_dim * 8, ch_out=self.embed_dim * 4) 948 | self.Up_conv4d = conv_block(ch_in=self.embed_dim * 8, ch_out=self.embed_dim * 4) 949 | self.Up3d = up_conv_DDConv(ch_in=self.embed_dim * 4, ch_out=self.embed_dim * 2) 950 | self.Up_conv3d = conv_block(ch_in=self.embed_dim * 4, ch_out=self.embed_dim * 2) 951 | self.Up2d = up_conv_DDConv(ch_in=self.embed_dim * 2, ch_out=self.embed_dim * 1) 952 | self.Up_conv2d = conv_block(ch_in=self.embed_dim * 2, ch_out=self.embed_dim * 1) 953 | self.Mid_Conv1 = nn.Conv2d(self.embed_dim * 2, self.embed_dim * 1, kernel_size=1, stride=1, padding=0) 954 | self.Mid_Conv2 = nn.Conv2d(self.embed_dim * 4, self.embed_dim * 2, kernel_size=1, stride=1, padding=0) 955 | self.Mid_Conv3 = nn.Conv2d(self.embed_dim * 8, self.embed_dim * 4, kernel_size=1, stride=1, padding=0) 956 | self.BN = nn.BatchNorm2d(1) 957 | self.CiT_Conv = nn.Conv2d(2, 1, kernel_size=1, stride=1, padding=0) 958 | 959 | def _init_weights(self, m): 960 | if isinstance(m, nn.Linear): 961 | trunc_normal_(m.weight, std=.02) 962 | if isinstance(m, nn.Linear) and m.bias is not None: 963 | nn.init.constant_(m.bias, 0) 964 | elif isinstance(m, nn.LayerNorm): 965 | nn.init.constant_(m.bias, 0) 966 | nn.init.constant_(m.weight, 1.0) 967 | 968 | @torch.jit.ignore 969 | def no_weight_decay(self): 970 | return {'absolute_pos_embed'} 971 | 972 | @torch.jit.ignore 973 | def no_weight_decay_keywords(self): 974 | return {'relative_position_bias_table'} 975 | 976 | 977 | def up_x4(self, x): 978 | 979 | B, C, H, W = x.shape 980 | if self.final_upsample == "expand_first": 981 | x = self.up(x) 982 | x = self.output(x) 983 | 984 | return x 985 | 986 | def forward(self, x): # 1,3,224,224 987 | x = self.patch(x) 988 | Cnn = x 989 | Swin = x 990 | 991 | Cnn = self.Conv1e(Cnn) # 1,96,56,56 992 | Swin = self.Conv1s(Swin) # 1,96,56,56 993 | Cnn1 = Cnn 994 | Swin1 = Swin 995 | Mid1 = torch.cat((Cnn1, Swin1), dim=1) 996 | Mid1 = self.Mid_Conv1(Mid1) 997 | 998 | Cnn = self.maxpool(Cnn) 999 | Cnn = self.Conv2e(Cnn) 1000 | Swin = self.layer1(Swin) # 28,28 1001 | Cnn2 = Cnn 1002 | Swin2 = Swin 1003 | Mid2 = torch.cat((Cnn2, Swin2), dim=1) 1004 | Mid2 = self.Mid_Conv2(Mid2) 1005 | 1006 | Cnn = self.maxpool(Cnn) 1007 | Cnn = self.Conv3e(Cnn) 1008 | Swin = self.layer2(Swin) # 14,14 1009 | Cnn3 = Cnn 1010 | Swin3 = Swin 1011 | Mid3 = torch.cat((Cnn3, Swin3), dim=1) 1012 | Mid3 = self.Mid_Conv3(Mid3) 1013 | 1014 | Cnn = self.maxpool(Cnn) 1015 | Cnn = self.Conv4e(Cnn) 1016 | Swin = self.layer3(Swin) # 7,7 1017 | Swin = self.layer4(Swin) # 7,7 1018 | Swin = self.norm(Swin) # B L C (1, 768, 7, 7) 1019 | Cnn4 = Cnn 1020 | Swin4 = Swin 1021 | 1022 | Cnn = self.Up4d(Cnn) 1023 | Cnn = torch.cat((Cnn, Mid3), dim=1) 1024 | Cnn = self.Up_conv4d(Cnn) 1025 | Swin = self.Patch_Expand1(Swin) 1026 | Swin = torch.cat([Swin, Mid3], 1) 1027 | Swin = self.concat_linear1(Swin) # 14,14 1028 | Cnn5 = Cnn 1029 | Swin5 = Swin 1030 | 1031 | Cnn = self.Up3d(Cnn) 1032 | Cnn = torch.cat((Cnn, Mid2), dim=1) 1033 | Cnn = self.Up_conv3d(Cnn) 1034 | Swin = self.layer5(Swin) 1035 | Swin = torch.cat([Swin, Mid2], 1) 1036 | Swin = self.concat_linear2(Swin) # 28,28 1037 | Cnn6 = Cnn 1038 | Swin6 = Swin 1039 | 1040 | Cnn7 = self.Up2d(Cnn) 1041 | Cnn7 = torch.cat((Cnn7, Mid1), dim=1) 1042 | Cnn7 = self.Up_conv2d(Cnn7) 1043 | 1044 | Swin7 = self.layer6(Swin) # 56,56 1045 | Swin7 = torch.cat([Swin7, Mid1], 1) # 56,56 1046 | Swin7 = self.concat_linear3(Swin7) # 56,56 1047 | Swin7 = self.layer7(Swin7) # 56,56 1048 | 1049 | CNN = self.up_x4(Cnn7) # 224,224 1050 | Swin = self.norm_up(Swin7) # B L C 1,96,56,56 1051 | SWIN = self.up_x4(Swin) # 224,224 1052 | 1053 | CNN_out = CNN 1054 | Trans_out = SWIN 1055 | 1056 | CNN = self.BN(CNN) 1057 | SWIN = self.BN(SWIN) 1058 | CiT = torch.cat((CNN, SWIN), dim=1) 1059 | CiT = self.CiT_Conv(CiT) 1060 | 1061 | CiT = torch.sigmoid(CiT) 1062 | CNN_out = torch.sigmoid(CNN_out) 1063 | Trans_out = torch.sigmoid(Trans_out) 1064 | 1065 | return CiT, CNN_out, Trans_out 1066 | 1067 | 1068 | 1069 | def flops(self): 1070 | flops = 0 1071 | flops += self.patch_embed.flops() 1072 | for i, layer in enumerate(self.layers): 1073 | flops += layer.flops() 1074 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) 1075 | flops += self.num_features * self.num_classes 1076 | return flops 1077 | 1078 | 1079 | if __name__ == "__main__": 1080 | with torch.no_grad(): 1081 | input = torch.rand(1, 1, 224, 224).to("cpu") 1082 | model =CIT().to("cpu") 1083 | 1084 | out_result, _, _ = model(input) 1085 | print(out_result.shape) 1086 | 1087 | flops, params = profile(model, (input,)) 1088 | 1089 | print("-" * 50) 1090 | print('FLOPs = ' + str(flops / 1000 ** 3) + ' G') 1091 | print('Params = ' + str(params / 1000 ** 2) + ' M') -------------------------------------------------------------------------------- /TEC_Net_T_3D.py: -------------------------------------------------------------------------------- 1 | """ 2 | CiT-Net-Tiny 3 | stage1 stage2 stage3 stage4 4 | size 56x56 28x28 14x14 7x7 5 | Unet dim 96 192 384 768 6 | Swin dim 96 192 384 768 7 | head 3 6 12 24 8 | num 2 2 6 2 9 | """ 10 | import torch 11 | import math 12 | import torch.nn as nn 13 | import torch.utils.checkpoint as checkpoint 14 | from einops import rearrange 15 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 16 | from thop import * 17 | from torch.nn import init 18 | import torch.nn.functional as F 19 | from torch.nn import Module, Sequential, Conv2d, ReLU,AdaptiveMaxPool2d, AdaptiveAvgPool2d, \ 20 | NLLLoss, BCELoss, CrossEntropyLoss, AvgPool2d, MaxPool2d, Parameter, Linear, Sigmoid, Softmax, Dropout, Embedding 21 | from DDConv import DDConv 22 | from einops.layers.torch import Rearrange 23 | 24 | 25 | class PatchEmbed(nn.Module): 26 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 27 | super().__init__() 28 | self.conv2patch = nn.Sequential( 29 | nn.Conv2d(3, embed_dim, kernel_size=4, stride=4), 30 | nn.GELU(), 31 | # nn.ReLU(), 32 | nn.BatchNorm2d(embed_dim) 33 | ) 34 | 35 | def forward(self, x): 36 | # # FIXME look at relaxing size constraints 37 | x = self.conv2patch(x) 38 | return x 39 | 40 | def flops(self): 41 | Ho, Wo = self.patches_resolution 42 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 43 | if self.norm is not None: 44 | flops += Ho * Wo * self.embed_dim 45 | return flops 46 | 47 | class Mlp(nn.Module): 48 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 49 | super().__init__() 50 | out_features = out_features or in_features 51 | hidden_features = hidden_features or in_features 52 | self.fc1 = nn.Linear(in_features, hidden_features) 53 | self.act = act_layer() 54 | self.fc2 = nn.Linear(hidden_features, out_features) 55 | self.drop = nn.Dropout(drop) 56 | 57 | def forward(self, x): 58 | x = self.fc1(x) 59 | x = self.act(x) 60 | x = self.drop(x) 61 | x = self.fc2(x) 62 | x = self.drop(x) 63 | return x 64 | 65 | class oneXone_conv(nn.Module): 66 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 67 | super(oneXone_conv, self).__init__() 68 | out_features = out_features or in_features 69 | hidden_features = hidden_features or in_features 70 | self.Conv1 = nn.Sequential( 71 | nn.Conv2d(in_features, hidden_features, kernel_size=1), 72 | nn.GELU(), 73 | nn.BatchNorm2d(hidden_features) 74 | ) 75 | self.Conv2 = nn.Sequential( 76 | nn.Conv2d(hidden_features, out_features, kernel_size=1), 77 | nn.GELU(), 78 | nn.BatchNorm2d(out_features) 79 | ) 80 | self.drop = nn.Dropout(drop) 81 | def forward(self, x): 82 | x = self.Conv1(x) 83 | x = self.drop(x) 84 | x = self.Conv2(x) 85 | x = self.drop(x) 86 | return x 87 | 88 | class GhostModule(nn.Module): 89 | def __init__(self, inp, oup=None, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True): 90 | super(GhostModule, self).__init__() 91 | oup = oup or inp 92 | init_channels = math.ceil(oup // ratio) 93 | new_channels = init_channels*(ratio-1) 94 | 95 | self.primary_conv = nn.Sequential( 96 | nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False), 97 | nn.BatchNorm2d(init_channels), 98 | nn.ReLU(inplace=True) if relu else nn.Sequential(), 99 | ) 100 | 101 | self.cheap_operation = nn.Sequential( 102 | nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False), 103 | nn.BatchNorm2d(new_channels), 104 | nn.ReLU(inplace=True) if relu else nn.Sequential(), 105 | ) 106 | 107 | def forward(self, x): 108 | x1 = self.primary_conv(x) 109 | x2 = self.cheap_operation(x1) 110 | out = torch.cat([x1, x2], dim=1) 111 | return out 112 | 113 | class GhostModule_Up(nn.Module): 114 | def __init__(self, inp, oup=None, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True): 115 | super(GhostModule_Up, self).__init__() 116 | oup = oup or inp 117 | init_channels = inp 118 | new_channels = init_channels 119 | 120 | self.primary_conv = nn.Sequential( 121 | nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False), 122 | nn.BatchNorm2d(init_channels), 123 | nn.ReLU(inplace=True) if relu else nn.Sequential(), 124 | ) 125 | 126 | self.cheap_operation = nn.Sequential( 127 | nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False), 128 | nn.BatchNorm2d(new_channels), 129 | nn.ReLU(inplace=True) if relu else nn.Sequential(), 130 | ) 131 | 132 | def forward(self, x): 133 | x1 = self.primary_conv(x) 134 | x2 = self.cheap_operation(x1) 135 | out = torch.cat([x1, x2], dim=1) 136 | return out 137 | 138 | def window_partition(x, window_size): 139 | B, H, W, C = x.shape 140 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 141 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 142 | return windows 143 | 144 | def window_reverse(windows, window_size, H, W): 145 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 146 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 147 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 148 | return x 149 | 150 | class CAM_Module(Module): 151 | def __init__(self, in_dim, dim, num_heads, qk_scale=None, C_lambda=1e-4, attn_drop=0., proj_drop=0.): 152 | super(CAM_Module, self).__init__() 153 | self.chanel_in = in_dim 154 | self.softmax = Softmax(dim=-1) 155 | self.c_lambda = C_lambda 156 | self.activaton = nn.Sigmoid() 157 | 158 | head_dim = dim // num_heads 159 | self.scale = qk_scale or head_dim ** -0.5 160 | self.attn_drop = nn.Dropout(attn_drop) 161 | self.proj = nn.Sequential(nn.Conv2d(dim//12, dim, kernel_size=3, padding=1, stride=1, bias=False, groups=dim//12), nn.GELU()) 162 | self.proj_drop = nn.Dropout(proj_drop) 163 | 164 | self.query_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//12, kernel_size=1) 165 | self.key_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//12, kernel_size=1) 166 | self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//12, kernel_size=1) 167 | 168 | self.gamma = Parameter(torch.zeros(1)) 169 | 170 | def forward(self, x, mask=None): 171 | m_batchsize, N, C = x.size() 172 | height = int(N ** .5) 173 | width = int(N ** .5) 174 | 175 | x = x.view(m_batchsize, C, height, width) 176 | proj_query = self.query_conv(x).view(m_batchsize, C//12, -1) 177 | proj_key = self.key_conv(x).view(m_batchsize, C//12, -1).permute(0, 2, 1) 178 | proj_value = self.value_conv(x).view(m_batchsize, C//12, -1) 179 | 180 | q = proj_query * self.scale 181 | attn = (q @ proj_key) 182 | 183 | if mask is not None: 184 | nW = mask.shape[0] # num_windows 185 | attn = attn.view(m_batchsize // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 186 | attn = attn.view(-1, self.num_heads, N, N) 187 | attn = self.softmax(attn) 188 | else: 189 | attn = self.softmax(attn) 190 | 191 | attn = self.attn_drop(attn) 192 | x = (attn @ proj_value).reshape(m_batchsize, C//12, height, width) 193 | x = self.proj(x) 194 | x = x.reshape(m_batchsize, C, N).transpose(1, 2) 195 | x = self.proj_drop(x) 196 | 197 | out = self.gamma * x + x 198 | return out 199 | 200 | class PAM_Module(Module): 201 | def __init__(self, in_dim, dim, num_heads, qk_scale=None, P_lambda=1e-4, attn_drop=0., proj_drop=0.): 202 | super(PAM_Module, self).__init__() 203 | self.chanel_in = in_dim 204 | 205 | head_dim = dim // num_heads 206 | self.scale = qk_scale or head_dim ** -0.5 207 | self.attn_drop = nn.Dropout(attn_drop) 208 | self.proj = nn.Sequential(nn.Conv2d(dim//12, dim, kernel_size=3, padding=1, stride=1, bias=False, groups=dim//12), nn.GELU()) 209 | self.proj_drop = nn.Dropout(proj_drop) 210 | 211 | self.query_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//12, kernel_size=1) 212 | self.key_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//12, kernel_size=1) 213 | self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//12, kernel_size=1) 214 | self.softmax = Softmax(dim=-1) 215 | 216 | 217 | self.p_lambda = P_lambda 218 | self.activaton = nn.Sigmoid() 219 | self.gamma = Parameter(torch.zeros(1)) 220 | 221 | def forward(self, x, mask=None): 222 | m_batchsize, N, C = x.size() 223 | height = int(N ** .5) 224 | width = int(N ** .5) 225 | 226 | x = x.view(m_batchsize, C, height, width) 227 | proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1) 228 | proj_key = self.key_conv(x).view(m_batchsize, -1, width*height) 229 | proj_value = self.value_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1) 230 | 231 | q = proj_query * self.scale 232 | attn = (q @ proj_key) 233 | 234 | if mask is not None: 235 | nW = mask.shape[0] # num_windows 236 | attn = attn.view(m_batchsize // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 237 | attn = attn.view(-1, self.num_heads, N, N) 238 | attn = self.softmax(attn) 239 | else: 240 | attn = self.softmax(attn) 241 | 242 | attn = self.attn_drop(attn) 243 | x = (attn @ proj_value).reshape(m_batchsize, C//12, height, width) 244 | x = self.proj(x) 245 | x = x.reshape(m_batchsize, C, N).transpose(1, 2) 246 | x = self.proj_drop(x) 247 | 248 | out = self.gamma * x + x 249 | return out 250 | 251 | class CHAM_Module(Module): 252 | def __init__(self, in_dim, dim, num_heads, qk_scale=None, P_lambda=1e-4, attn_drop=0., proj_drop=0.): 253 | super(CHAM_Module, self).__init__() 254 | self.chanel_in = in_dim 255 | 256 | head_dim = dim // num_heads 257 | self.scale = qk_scale or head_dim ** -0.5 258 | self.attn_drop = nn.Dropout(attn_drop) 259 | self.proj = nn.Sequential(nn.Conv2d(dim//12, dim, kernel_size=3, padding=1, stride=1, bias=False, groups=dim//12), nn.GELU()) 260 | self.proj_drop = nn.Dropout(proj_drop) 261 | 262 | self.query_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//12, kernel_size=1) 263 | self.key_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//12, kernel_size=1) 264 | self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//12, kernel_size=1) 265 | self.softmax = Softmax(dim=-1) 266 | 267 | self.p_lambda = P_lambda 268 | self.activaton = nn.Sigmoid() 269 | self.gamma = Parameter(torch.zeros(1)) 270 | 271 | def forward(self, x, mask=None): 272 | m_batchsize, N, C = x.size() 273 | height = int(N ** .5) 274 | width = int(N ** .5) 275 | 276 | x = x.view(m_batchsize, C, height, width) 277 | proj_query = self.query_conv(x).view(m_batchsize, C//12 * height, -1) 278 | proj_key = self.key_conv(x).view(m_batchsize, C//12 * height, -1).permute(0, 2, 1) 279 | proj_value = self.value_conv(x).view(m_batchsize, C//12 * height, -1) 280 | 281 | q = proj_query * self.scale 282 | attn = (q @ proj_key) 283 | 284 | if mask is not None: 285 | nW = mask.shape[0] # num_windows 286 | attn = attn.view(m_batchsize // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 287 | attn = attn.view(-1, self.num_heads, N, N) 288 | attn = self.softmax(attn) 289 | else: 290 | attn = self.softmax(attn) 291 | 292 | attn = self.attn_drop(attn) 293 | x = (attn @ proj_value).reshape(m_batchsize, C//12, height, width) 294 | x = self.proj(x) 295 | x = x.reshape(m_batchsize, C, N).transpose(1, 2) 296 | x = self.proj_drop(x) 297 | 298 | out = self.gamma * x + x 299 | return out 300 | 301 | 302 | class CWAM_Module(Module): 303 | def __init__(self, in_dim, dim, num_heads, qk_scale=None, P_lambda=1e-4, attn_drop=0., proj_drop=0.): 304 | super(CWAM_Module, self).__init__() 305 | self.chanel_in = in_dim 306 | 307 | head_dim = dim // num_heads 308 | self.scale = qk_scale or head_dim ** -0.5 309 | self.attn_drop = nn.Dropout(attn_drop) 310 | self.proj = nn.Sequential(nn.Conv2d(dim, dim, kernel_size=3, padding=1, stride=1, bias=False, groups=dim), nn.GELU()) 311 | self.proj_drop = nn.Dropout(proj_drop) 312 | 313 | self.query_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//12, kernel_size=1) 314 | self.key_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//12, kernel_size=1) 315 | self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 316 | self.softmax = Softmax(dim=-1) 317 | 318 | self.p_lambda = P_lambda 319 | self.activaton = nn.Sigmoid() 320 | self.gamma = Parameter(torch.zeros(1)) 321 | 322 | def forward(self, x, mask=None): 323 | m_batchsize, N, C = x.size() 324 | height = int(N ** .5) 325 | width = int(N ** .5) 326 | 327 | proj_query = x.view(m_batchsize, C * width, -1) 328 | proj_key = x.view(m_batchsize, C * width, -1).permute(0, 2, 1) 329 | proj_value = x.view(m_batchsize, C * width, -1) 330 | 331 | q = proj_query * self.scale 332 | attn = (q @ proj_key) 333 | 334 | if mask is not None: 335 | nW = mask.shape[0] # num_windows 336 | attn = attn.view(m_batchsize // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 337 | attn = attn.view(-1, self.num_heads, N, N) 338 | attn = self.softmax(attn) 339 | else: 340 | attn = self.softmax(attn) 341 | 342 | attn = self.attn_drop(attn) 343 | x = (attn @ proj_value).reshape(m_batchsize, C, height, width) 344 | x = self.proj(x) 345 | x = x.reshape(m_batchsize, C, N).transpose(1, 2) 346 | x = self.proj_drop(x) 347 | 348 | out = self.gamma * x + x 349 | return out 350 | 351 | class WindowAttention_ACAM(nn.Module): 352 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 353 | super().__init__() 354 | self.dim = dim 355 | 356 | self.C_C = CAM_Module(self.dim, dim=dim, num_heads=num_heads) 357 | self.H_W = PAM_Module(self.dim, dim=dim, num_heads=num_heads) 358 | self.C_H = CHAM_Module(self.dim, dim=dim, num_heads=num_heads) 359 | self.C_W = CWAM_Module(self.dim, dim=dim, num_heads=num_heads) 360 | 361 | self.gamma1 = Parameter(torch.zeros(1)) 362 | self.gamma2 = Parameter(torch.zeros(1)) 363 | self.gamma3 = Parameter(torch.ones(1) * 0.5) 364 | self.gamma4 = Parameter(torch.ones(1) * 0.5) 365 | 366 | def _build_projection(self, dim_in, kernel_size=3, stride=1, padding=1): 367 | proj = nn.Sequential( 368 | nn.Conv2d(dim_in, dim_in, kernel_size, padding=padding, stride=stride, bias=False, groups=dim_in), 369 | Rearrange('b c h w -> b (h w) c'), 370 | nn.LayerNorm(dim_in)) 371 | return proj 372 | 373 | def forward(self, x, mask=None): 374 | x_out1 = self.C_C(x) 375 | 376 | x_out2 = self.H_W(x) 377 | 378 | x_out3 = self.C_H(x) 379 | 380 | x_out4 = self.C_W(x) 381 | 382 | x_out = (self.gamma1 * x_out1) + (self.gamma2 * x_out2) + (self.gamma3 * x_out3) + (self.gamma4 * x_out4) 383 | 384 | return x_out 385 | """ =============================================================================================================== """ 386 | 387 | class SwinTransformerBlock(nn.Module): 388 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 389 | mlp_ratio=2., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 390 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 391 | super().__init__() 392 | 393 | self.dim = dim 394 | self.input_resolution = input_resolution 395 | self.num_heads = num_heads 396 | self.window_size = window_size 397 | self.shift_size = shift_size 398 | self.mlp_ratio = mlp_ratio 399 | 400 | 401 | if min(self.input_resolution) <= self.window_size: 402 | self.shift_size = 0 403 | self.window_size = min(self.input_resolution) 404 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 405 | 406 | self.norm1 = norm_layer(dim) 407 | self.norm2 = norm_layer([dim, input_resolution[0], input_resolution[1]]) 408 | 409 | self.attn = WindowAttention_ACAM( 410 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 411 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 412 | 413 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 414 | 415 | if self.shift_size > 0: 416 | H, W = self.input_resolution 417 | img_mask = torch.zeros((1, H, W, 1)) 418 | h_slices = (slice(0, -self.window_size), 419 | slice(-self.window_size, -self.shift_size), 420 | slice(-self.shift_size, None)) 421 | w_slices = (slice(0, -self.window_size), 422 | slice(-self.window_size, -self.shift_size), 423 | slice(-self.shift_size, None)) 424 | cnt = 0 425 | for h in h_slices: 426 | for w in w_slices: 427 | img_mask[:, h, w, :] = cnt 428 | cnt += 1 429 | 430 | mask_windows = window_partition(img_mask, self.window_size) 431 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 432 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 433 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 434 | else: 435 | attn_mask = None 436 | 437 | self.register_buffer("attn_mask", attn_mask) 438 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 439 | self.mlp = GhostModule(inp=dim) 440 | 441 | def forward(self, x): 442 | B, C, H, W = x.shape 443 | 444 | shortcut1 = x 445 | x = x.view(B, H, W, C) 446 | x = self.norm1(x) 447 | 448 | if self.shift_size > 0: 449 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 450 | else: 451 | shifted_x = x 452 | 453 | x_windows = window_partition(shifted_x, self.window_size) 454 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) 455 | 456 | attn_windows = self.attn(x_windows, mask=self.attn_mask) 457 | 458 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 459 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) 460 | 461 | if self.shift_size > 0: 462 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 463 | else: 464 | x = shifted_x 465 | x = x.view(B, C, H, W) 466 | x = shortcut1 + self.drop_path(x) 467 | shortcut2 = x 468 | 469 | x = self.norm2(x) 470 | 471 | x = shortcut2 + self.drop_path(self.mlp(x)) 472 | return x 473 | 474 | class PatchMerging(nn.Module): 475 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 476 | super().__init__() 477 | self.input_resolution = input_resolution 478 | h, w = input_resolution 479 | h = int(h/2) 480 | w = int(w/2) 481 | self.dim = dim 482 | self.norm = norm_layer([4*dim, h, w]) 483 | self.reduction = GhostModule(inp=4 * dim, oup=2 * dim, ratio=4) 484 | 485 | def forward(self, x): 486 | H, W = self.input_resolution 487 | B, C, H, W = x.shape 488 | 489 | x0 = x[:, :, 0::2, 0::2] # B C H/2 W/2 490 | x1 = x[:, :, 1::2, 0::2] # B C H/2 W/2 491 | x2 = x[:, :, 0::2, 1::2] # B C H/2 W/2 492 | x3 = x[:, :, 1::2, 1::2] # B C H/2 W/2 493 | x = torch.cat([x0, x1, x2, x3], 1) # B 4*C H/2 W/2 494 | 495 | x = self.norm(x) 496 | x = self.reduction(x) 497 | 498 | return x 499 | 500 | def extra_repr(self) -> str: 501 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 502 | 503 | def flops(self): 504 | H, W = self.input_resolution 505 | flops = H * W * self.dim 506 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 507 | return flops 508 | 509 | class BasicLayer(nn.Module): 510 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 511 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 512 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 513 | 514 | super().__init__() 515 | self.dim = dim 516 | self.input_resolution = input_resolution 517 | self.depth = depth 518 | self.use_checkpoint = use_checkpoint 519 | 520 | self.blocks = nn.ModuleList([ 521 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 522 | num_heads=num_heads, window_size=window_size, 523 | shift_size=0 if (i % 2 == 0) else window_size // 2, 524 | mlp_ratio=mlp_ratio, 525 | qkv_bias=qkv_bias, qk_scale=qk_scale, 526 | drop=drop, attn_drop=attn_drop, 527 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 528 | norm_layer=norm_layer) 529 | for i in range(depth)]) 530 | 531 | if downsample is not None: 532 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 533 | else: 534 | self.downsample = None 535 | 536 | def forward(self, x): 537 | for blk in self.blocks: 538 | if self.use_checkpoint: 539 | x = checkpoint.checkpoint(blk, x) 540 | else: 541 | x = blk(x) 542 | if self.downsample is not None: 543 | x = self.downsample(x) 544 | return x 545 | 546 | def extra_repr(self) -> str: 547 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 548 | 549 | def flops(self): 550 | flops = 0 551 | for blk in self.blocks: 552 | flops += blk.flops() 553 | if self.downsample is not None: 554 | flops += self.downsample.flops() 555 | return flops 556 | 557 | class PatchExpand(nn.Module): 558 | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): 559 | super().__init__() 560 | self.input_resolution = input_resolution 561 | self.dim = dim 562 | self.expand = GhostModule_Up(inp=dim) if dim_scale == 2 else nn.Identity() 563 | self.norm = norm_layer([dim // dim_scale, input_resolution[0]*2, input_resolution[1]*2]) 564 | 565 | def forward(self, x): 566 | B, C, H, W = x.shape 567 | x = self.expand(x) 568 | x = rearrange(x, 'b (p1 p2 c) h w -> b c (h p1) (w p2)', p1=2, p2=2, c=C // 2) 569 | x = self.norm(x) 570 | 571 | return x 572 | 573 | class BasicLayer_up(nn.Module): 574 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 575 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 576 | drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False): 577 | 578 | super().__init__() 579 | self.dim = dim 580 | self.input_resolution = input_resolution 581 | self.depth = depth 582 | self.use_checkpoint = use_checkpoint 583 | 584 | self.blocks = nn.ModuleList([ 585 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 586 | num_heads=num_heads, window_size=window_size, 587 | shift_size=0 if (i % 2 == 0) else window_size // 2, 588 | mlp_ratio=mlp_ratio, 589 | qkv_bias=qkv_bias, qk_scale=qk_scale, 590 | drop=drop, attn_drop=attn_drop, 591 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 592 | norm_layer=norm_layer) 593 | for i in range(depth)]) 594 | 595 | if upsample is not None: 596 | self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer) 597 | else: 598 | self.upsample = None 599 | 600 | def forward(self, x): 601 | for blk in self.blocks: 602 | if self.use_checkpoint: 603 | x = checkpoint.checkpoint(blk, x) 604 | else: 605 | x = blk(x) 606 | if self.upsample is not None: 607 | x = self.upsample(x) 608 | return x 609 | 610 | class FinalPatchExpand_X4(nn.Module): 611 | def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm): 612 | super().__init__() 613 | self.input_resolution = input_resolution 614 | self.dim = dim 615 | self.dim_scale = dim_scale 616 | self.expand = oneXone_conv(in_features = dim, out_features = 16 * dim) if dim_scale == 2 else nn.Identity() 617 | self.output_dim = dim 618 | self.norm = norm_layer([6, input_resolution[0]*4, input_resolution[1]*4]) 619 | 620 | def forward(self, x): 621 | B, C, H, W = x.shape 622 | x = self.expand(x) 623 | x = rearrange(x, 'b (p1 p2 c) h w -> b c (h p1) (w p2)', p1=self.dim_scale, p2=self.dim_scale, c=C // (self.dim_scale ** 2)) 624 | x = self.norm(x) 625 | return x 626 | 627 | def init_weights(net, init_type='normal', gain=0.02): 628 | def init_func(m): 629 | classname = m.__class__.__name__ 630 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 631 | if init_type == 'normal': 632 | init.normal_(m.weight.data, 0.0, gain) 633 | elif init_type == 'xavier': 634 | init.xavier_normal_(m.weight.data, gain=gain) 635 | elif init_type == 'kaiming': 636 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 637 | elif init_type == 'orthogonal': 638 | init.orthogonal_(m.weight.data, gain=gain) 639 | else: 640 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 641 | if hasattr(m, 'bias') and m.bias is not None: 642 | init.constant_(m.bias.data, 0.0) 643 | elif classname.find('BatchNorm2d') != -1: 644 | init.normal_(m.weight.data, 1.0, gain) 645 | init.constant_(m.bias.data, 0.0) 646 | 647 | print('initialize network with %s' % init_type) 648 | net.apply(init_func) 649 | 650 | class conv_block(nn.Module): 651 | def __init__(self, ch_in, ch_out): 652 | super(conv_block, self).__init__() 653 | self.conv = nn.Sequential( 654 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 655 | nn.BatchNorm2d(ch_out), 656 | nn.ReLU(inplace=True), 657 | nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 658 | nn.BatchNorm2d(ch_out), 659 | nn.ReLU(inplace=True) 660 | ) 661 | 662 | def forward(self, x): 663 | x = self.conv(x) 664 | return x 665 | 666 | class up_conv(nn.Module): 667 | def __init__(self, ch_in, ch_out): 668 | super(up_conv, self).__init__() 669 | self.up = nn.Sequential( 670 | nn.Upsample(scale_factor=2), 671 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 672 | nn.BatchNorm2d(ch_out), 673 | nn.ReLU(inplace=True) 674 | ) 675 | 676 | def forward(self, x): 677 | x = self.up(x) 678 | return x 679 | 680 | class conv_block_DDConv(nn.Module): 681 | def __init__(self, ch_in, ch_out): 682 | super(conv_block_DDConv, self).__init__() 683 | self.conv = nn.Sequential( 684 | DDConv(ch_in, ch_out, kernel_size=3, stride=1, padding=1), 685 | 686 | nn.BatchNorm2d(ch_out), 687 | nn.ReLU(inplace=True), 688 | 689 | DDConv(ch_out, ch_out, kernel_size=3, stride=1, padding=1), 690 | 691 | nn.BatchNorm2d(ch_out), 692 | nn.ReLU(inplace=True) 693 | ) 694 | 695 | def forward(self, x): 696 | x = self.conv(x) 697 | return x 698 | 699 | class up_conv_DDConv(nn.Module): 700 | def __init__(self, ch_in, ch_out): 701 | super(up_conv_DDConv, self).__init__() 702 | self.up = nn.Sequential( 703 | nn.Upsample(scale_factor=2), 704 | 705 | DDConv(ch_in, ch_out, kernel_size=1), 706 | 707 | nn.BatchNorm2d(ch_out), 708 | nn.ReLU(inplace=True) 709 | ) 710 | 711 | def forward(self, x): 712 | x = self.up(x) 713 | return x 714 | 715 | class ConvMixerLayer(nn.Module): 716 | def __init__(self, dim, kernel_size=9): 717 | super().__init__() 718 | self.Resnet = nn.Sequential( 719 | nn.Conv2d(dim, dim, kernel_size=kernel_size, groups=dim, padding=4), 720 | nn.GELU(), 721 | nn.BatchNorm2d(dim) 722 | ) 723 | self.Conv_1x1 = nn.Sequential( 724 | nn.Conv2d(dim, dim, kernel_size=1), 725 | nn.GELU(), 726 | nn.BatchNorm2d(dim) 727 | ) 728 | def forward(self, x): 729 | x = x + self.Resnet(x) 730 | x = self.Conv_1x1(x) 731 | return x 732 | 733 | 734 | class ConvMixer(nn.Module): 735 | def __init__(self, dim=512, depth=1, kernel_size=9, patch_size=4, n_classes=1000): 736 | super().__init__() 737 | self.conv2d1 = nn.Sequential( 738 | nn.Conv2d(1, dim, kernel_size=patch_size, stride=patch_size), 739 | nn.GELU(), 740 | nn.BatchNorm2d(dim) 741 | ) 742 | self.ConvMixer_blocks = nn.ModuleList([]) 743 | 744 | for _ in range(depth): 745 | self.ConvMixer_blocks.append(ConvMixerLayer(dim=dim, kernel_size=kernel_size)) 746 | 747 | self.head = nn.Sequential( 748 | nn.AdaptiveAvgPool2d((1, 1)), 749 | nn.Flatten(), 750 | nn.Linear(dim, n_classes) 751 | ) 752 | 753 | def forward(self, x): 754 | x = self.conv2d1(x) 755 | 756 | for ConvMixer_block in self.ConvMixer_blocks: 757 | x = ConvMixer_block(x) 758 | 759 | x = x 760 | return x 761 | 762 | class CIT(nn.Module): 763 | def __init__(self, img_size=224, patch_size=4, in_chans=3, out_chans=1, 764 | embed_dim=96, depths=[2, 2, 6, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24], 765 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, 766 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 767 | norm_layer=nn.LayerNorm, ape=True, patch_norm=True, 768 | use_checkpoint=False, final_upsample="expand_first", **kwargs): 769 | super().__init__() 770 | 771 | self.out_channel = out_chans 772 | self.num_layers = len(depths) 773 | self.ape = ape 774 | self.patch_norm = patch_norm 775 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 776 | self.size = int(img_size/(2 ** (self.num_layers + 1))) 777 | self.size_out = int(img_size/4) 778 | self.num_features_up = int(embed_dim * 2) 779 | self.mlp_ratio = mlp_ratio 780 | self.final_upsample = final_upsample 781 | 782 | self.window_size = window_size 783 | self.qkv_bias = qkv_bias 784 | self.qk_scale = qk_scale 785 | self.drop_rate = drop_rate 786 | self.attn_drop_rate = attn_drop_rate 787 | self.drop_path_rate = drop_path_rate 788 | self.norm_layer = nn.LayerNorm 789 | 790 | self.patch_embed = PatchEmbed( 791 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 792 | norm_layer=norm_layer if self.patch_norm else None) 793 | 794 | self.pos_drop = nn.Dropout(p=drop_rate) 795 | 796 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 797 | 798 | 799 | if self.final_upsample == "expand_first": 800 | self.up = FinalPatchExpand_X4(input_resolution=(img_size // patch_size, img_size // patch_size), 801 | dim_scale=4, dim=embed_dim) 802 | self.output = nn.Conv2d(in_channels=6, out_channels=1, kernel_size=1, bias=False) 803 | 804 | self.apply(self._init_weights) 805 | 806 | self.embed_dim = 96 807 | self.num_heads = 3 808 | self.depth35 = 6 809 | self.drop_path3 = dpr[4:10] 810 | self.drop_path4 = dpr[10:12] 811 | 812 | print("CiT-Net-T----embed_dim:{}; num_heads:{}; depths:{}".format(self.embed_dim, num_heads, depths)) 813 | 814 | self.layer1 = BasicLayer(dim=self.embed_dim * 1, 815 | input_resolution=(56, 56), 816 | depth=2, 817 | num_heads=self.num_heads * 1, 818 | window_size=self.window_size, # 7 819 | mlp_ratio=self.mlp_ratio, # 4. 820 | qkv_bias=self.qkv_bias, # True 821 | qk_scale=self.qk_scale, # None 822 | drop=self.drop_rate, # 0. 823 | attn_drop=self.attn_drop_rate, # 0. 824 | drop_path=dpr[0:2], 825 | norm_layer=self.norm_layer, 826 | downsample=PatchMerging, 827 | use_checkpoint=False) 828 | 829 | self.layer2 = BasicLayer(dim=self.embed_dim * 2, 830 | input_resolution=(28, 28), 831 | depth=2, 832 | num_heads=self.num_heads * 2, 833 | window_size=self.window_size, # 7 834 | mlp_ratio=self.mlp_ratio, # 4. 835 | qkv_bias=self.qkv_bias, # True 836 | qk_scale=self.qk_scale, # None 837 | drop=self.drop_rate, # 0. 838 | attn_drop=self.attn_drop_rate, # 0. 839 | drop_path=dpr[2:4], 840 | norm_layer=self.norm_layer, 841 | downsample=PatchMerging, 842 | use_checkpoint=False) 843 | 844 | self.layer3 = BasicLayer(dim=self.embed_dim * 4, 845 | input_resolution=(14, 14), 846 | depth=self.depth35, 847 | num_heads=self.num_heads * 4, 848 | window_size=self.window_size, # 7 849 | mlp_ratio=self.mlp_ratio, # 4. 850 | qkv_bias=self.qkv_bias, # True 851 | qk_scale=self.qk_scale, # None 852 | drop=self.drop_rate, # 0. 853 | attn_drop=self.attn_drop_rate, # 0. 854 | drop_path=self.drop_path3, 855 | norm_layer=self.norm_layer, 856 | downsample=PatchMerging, 857 | use_checkpoint=False) 858 | 859 | self.layer4 = BasicLayer(dim=self.embed_dim * 8, 860 | input_resolution=(7, 7), 861 | depth=2, 862 | num_heads=self.num_heads * 8, 863 | window_size=self.window_size, # 7 864 | mlp_ratio=self.mlp_ratio, # 4. 865 | qkv_bias=self.qkv_bias, # True 866 | qk_scale=self.qk_scale, # None 867 | drop=self.drop_rate, # 0. 868 | attn_drop=self.attn_drop_rate, # 0. 869 | drop_path=self.drop_path4, 870 | norm_layer=self.norm_layer, 871 | downsample=None, 872 | use_checkpoint=False) 873 | 874 | self.norm = norm_layer([self.num_features, self.size, self.size]) 875 | 876 | self.Patch_Expand1 = PatchExpand(input_resolution=(7, 7), 877 | dim=self.embed_dim * 8, 878 | dim_scale=2, 879 | norm_layer=norm_layer) 880 | 881 | 882 | self.concat_linear1 = GhostModule(inp=self.embed_dim * 8, oup=self.embed_dim * 4) 883 | 884 | self.layer5 = BasicLayer_up(dim=self.embed_dim * 4, 885 | input_resolution=(14, 14), 886 | depth=self.depth35, 887 | num_heads=self.num_heads * 4, 888 | 889 | window_size=self.window_size, # 7 890 | mlp_ratio=self.mlp_ratio, # 4. 891 | qkv_bias=self.qkv_bias, # True 892 | qk_scale=self.qk_scale, # None 893 | drop=self.drop_rate, # 0. 894 | attn_drop=self.attn_drop_rate, # 0. 895 | 896 | drop_path=self.drop_path3, 897 | norm_layer=norm_layer, 898 | upsample=PatchExpand, 899 | use_checkpoint=False) 900 | 901 | self.concat_linear2 = GhostModule(inp=self.embed_dim * 4, oup=self.embed_dim * 2) 902 | 903 | self.layer6 = BasicLayer_up(dim=self.embed_dim * 2, 904 | input_resolution=(28, 28), 905 | depth=2, 906 | num_heads=self.num_heads * 2, 907 | 908 | window_size=self.window_size, # 7 909 | mlp_ratio=self.mlp_ratio, # 4. 910 | qkv_bias=self.qkv_bias, # True 911 | qk_scale=self.qk_scale, # None 912 | drop=self.drop_rate, # 0. 913 | attn_drop=self.attn_drop_rate, # 0. 914 | 915 | drop_path=dpr[2:4], 916 | norm_layer=norm_layer, 917 | upsample=PatchExpand, 918 | use_checkpoint=False) 919 | 920 | self.concat_linear3 = GhostModule(inp=self.embed_dim * 2, oup=self.embed_dim * 1) 921 | 922 | self.layer7 = BasicLayer_up(dim=self.embed_dim * 1, 923 | input_resolution=(56, 56), 924 | depth=2, 925 | num_heads=self.num_heads * 1, 926 | 927 | window_size=self.window_size, # 7 928 | mlp_ratio=self.mlp_ratio, # 4. 929 | qkv_bias=self.qkv_bias, # True 930 | qk_scale=self.qk_scale, # None 931 | drop=self.drop_rate, # 0. 932 | attn_drop=self.attn_drop_rate, # 0. 933 | 934 | drop_path=dpr[0:2], 935 | norm_layer=norm_layer, 936 | upsample=None, 937 | use_checkpoint=False) 938 | 939 | self.norm_up = norm_layer([self.embed_dim, self.size_out, self.size_out]) 940 | self.patch = ConvMixer(dim=48, depth=5) # 修改ConvMixer层数 941 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 942 | self.Conv1e = conv_block(ch_in=48, ch_out=self.embed_dim * 1) 943 | self.Conv1s = conv_block(ch_in=48, ch_out=self.embed_dim * 1) 944 | self.Conv2e = conv_block_DDConv(ch_in=self.embed_dim * 1, ch_out=self.embed_dim * 2) 945 | self.Conv3e = conv_block_DDConv(ch_in=self.embed_dim * 2, ch_out=self.embed_dim * 4) 946 | self.Conv4e = conv_block_DDConv(ch_in=self.embed_dim * 4, ch_out=self.embed_dim * 8) 947 | self.Up4d = up_conv_DDConv(ch_in=self.embed_dim * 8, ch_out=self.embed_dim * 4) 948 | self.Up_conv4d = conv_block(ch_in=self.embed_dim * 8, ch_out=self.embed_dim * 4) 949 | self.Up3d = up_conv_DDConv(ch_in=self.embed_dim * 4, ch_out=self.embed_dim * 2) 950 | self.Up_conv3d = conv_block(ch_in=self.embed_dim * 4, ch_out=self.embed_dim * 2) 951 | self.Up2d = up_conv_DDConv(ch_in=self.embed_dim * 2, ch_out=self.embed_dim * 1) 952 | self.Up_conv2d = conv_block(ch_in=self.embed_dim * 2, ch_out=self.embed_dim * 1) 953 | self.Mid_Conv1 = nn.Conv2d(self.embed_dim * 2, self.embed_dim * 1, kernel_size=1, stride=1, padding=0) 954 | self.Mid_Conv2 = nn.Conv2d(self.embed_dim * 4, self.embed_dim * 2, kernel_size=1, stride=1, padding=0) 955 | self.Mid_Conv3 = nn.Conv2d(self.embed_dim * 8, self.embed_dim * 4, kernel_size=1, stride=1, padding=0) 956 | self.BN = nn.BatchNorm2d(1) 957 | self.CiT_Conv = nn.Conv2d(2, 1, kernel_size=1, stride=1, padding=0) 958 | 959 | def _init_weights(self, m): 960 | if isinstance(m, nn.Linear): 961 | trunc_normal_(m.weight, std=.02) 962 | if isinstance(m, nn.Linear) and m.bias is not None: 963 | nn.init.constant_(m.bias, 0) 964 | elif isinstance(m, nn.LayerNorm): 965 | nn.init.constant_(m.bias, 0) 966 | nn.init.constant_(m.weight, 1.0) 967 | 968 | @torch.jit.ignore 969 | def no_weight_decay(self): 970 | return {'absolute_pos_embed'} 971 | 972 | @torch.jit.ignore 973 | def no_weight_decay_keywords(self): 974 | return {'relative_position_bias_table'} 975 | 976 | 977 | def up_x4(self, x): 978 | 979 | B, C, H, W = x.shape 980 | if self.final_upsample == "expand_first": 981 | x = self.up(x) 982 | x = self.output(x) 983 | 984 | return x 985 | 986 | def forward(self, x): # 1,3,224,224 987 | x = self.patch(x) 988 | Cnn = x 989 | Swin = x 990 | 991 | Cnn = self.Conv1e(Cnn) # 1,96,56,56 992 | Swin = self.Conv1s(Swin) # 1,96,56,56 993 | Cnn1 = Cnn 994 | Swin1 = Swin 995 | Mid1 = torch.cat((Cnn1, Swin1), dim=1) 996 | Mid1 = self.Mid_Conv1(Mid1) 997 | 998 | Cnn = self.maxpool(Cnn) 999 | Cnn = self.Conv2e(Cnn) 1000 | Swin = self.layer1(Swin) # 28,28 1001 | Cnn2 = Cnn 1002 | Swin2 = Swin 1003 | Mid2 = torch.cat((Cnn2, Swin2), dim=1) 1004 | Mid2 = self.Mid_Conv2(Mid2) 1005 | 1006 | Cnn = self.maxpool(Cnn) 1007 | Cnn = self.Conv3e(Cnn) 1008 | Swin = self.layer2(Swin) # 14,14 1009 | Cnn3 = Cnn 1010 | Swin3 = Swin 1011 | Mid3 = torch.cat((Cnn3, Swin3), dim=1) 1012 | Mid3 = self.Mid_Conv3(Mid3) 1013 | 1014 | Cnn = self.maxpool(Cnn) 1015 | Cnn = self.Conv4e(Cnn) 1016 | Swin = self.layer3(Swin) # 7,7 1017 | Swin = self.layer4(Swin) # 7,7 1018 | Swin = self.norm(Swin) # B L C (1, 768, 7, 7) 1019 | Cnn4 = Cnn 1020 | Swin4 = Swin 1021 | 1022 | Cnn = self.Up4d(Cnn) 1023 | Cnn = torch.cat((Cnn, Mid3), dim=1) 1024 | Cnn = self.Up_conv4d(Cnn) 1025 | Swin = self.Patch_Expand1(Swin) 1026 | Swin = torch.cat([Swin, Mid3], 1) 1027 | Swin = self.concat_linear1(Swin) # 14,14 1028 | Cnn5 = Cnn 1029 | Swin5 = Swin 1030 | 1031 | Cnn = self.Up3d(Cnn) 1032 | Cnn = torch.cat((Cnn, Mid2), dim=1) 1033 | Cnn = self.Up_conv3d(Cnn) 1034 | Swin = self.layer5(Swin) 1035 | Swin = torch.cat([Swin, Mid2], 1) 1036 | Swin = self.concat_linear2(Swin) # 28,28 1037 | Cnn6 = Cnn 1038 | Swin6 = Swin 1039 | 1040 | Cnn7 = self.Up2d(Cnn) 1041 | Cnn7 = torch.cat((Cnn7, Mid1), dim=1) 1042 | Cnn7 = self.Up_conv2d(Cnn7) 1043 | 1044 | Swin7 = self.layer6(Swin) # 56,56 1045 | Swin7 = torch.cat([Swin7, Mid1], 1) # 56,56 1046 | Swin7 = self.concat_linear3(Swin7) # 56,56 1047 | Swin7 = self.layer7(Swin7) # 56,56 1048 | 1049 | CNN = self.up_x4(Cnn7) # 224,224 1050 | Swin = self.norm_up(Swin7) # B L C 1,96,56,56 1051 | SWIN = self.up_x4(Swin) # 224,224 1052 | 1053 | CNN_out = CNN 1054 | Trans_out = SWIN 1055 | 1056 | CNN = self.BN(CNN) 1057 | SWIN = self.BN(SWIN) 1058 | CiT = torch.cat((CNN, SWIN), dim=1) 1059 | CiT = self.CiT_Conv(CiT) 1060 | 1061 | CiT = torch.sigmoid(CiT) 1062 | CNN_out = torch.sigmoid(CNN_out) 1063 | Trans_out = torch.sigmoid(Trans_out) 1064 | 1065 | return CiT, CNN_out, Trans_out 1066 | 1067 | 1068 | 1069 | def flops(self): 1070 | flops = 0 1071 | flops += self.patch_embed.flops() 1072 | for i, layer in enumerate(self.layers): 1073 | flops += layer.flops() 1074 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) 1075 | flops += self.num_features * self.num_classes 1076 | return flops 1077 | 1078 | 1079 | if __name__ == "__main__": 1080 | with torch.no_grad(): 1081 | input = torch.rand(1, 1, 224, 224).to("cpu") 1082 | model =CIT().to("cpu") 1083 | 1084 | out_result, _, _ = model(input) 1085 | print(out_result.shape) 1086 | 1087 | flops, params = profile(model, (input,)) 1088 | 1089 | print("-" * 50) 1090 | print('FLOPs = ' + str(flops / 1000 ** 3) + ' G') 1091 | print('Params = ' + str(params / 1000 ** 2) + ' M') --------------------------------------------------------------------------------