├── .idea ├── .gitignore ├── DeepLearningPlugAndPlayModule.iml ├── deployment.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── README.md ├── 上采样 └── DySample.py ├── 即插即用卷积 ├── ARConv.py ├── CAMixer.py ├── CFBConv.py ├── CKGConv.py ├── DCNv4.py ├── DEConv.py ├── DynamicConv.py ├── LDConv.py ├── MorphologyConv.py ├── PConv.py ├── SCConv.py ├── StarConv.py └── WTConv.py ├── 注意力模块 ├── AGF.py ├── ASSA.py ├── AgentAttention.py ├── CA.py ├── CBAM.py ├── CGA.py ├── CGLU.py ├── DANet.py ├── DAT.py ├── ECA.py ├── FcaNet.py ├── LGAG.py ├── MAB.py ├── MCA.py ├── MCPA.py ├── MEGA.py ├── MLKA.py ├── MSPA.py ├── NonLocal.py ├── SA-Net.py ├── SENet.py ├── SLAB.py ├── SRA.py └── SimAM.py └── 特征提取or融合or对齐模块 ├── AFF.py ├── CAFM.py ├── CCFF.py ├── CCMF.py ├── CGAFusion.py ├── CSAM.py ├── FARM.py ├── FCA.py ├── GLSA.py ├── LSK.py ├── MFII.py ├── PSFM.py ├── SMFA.py └── SSFF.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # 默认忽略的文件 2 | /shelf/ 3 | /workspace.xml 4 | # 基于编辑器的 HTTP 客户端请求 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/DeepLearningPlugAndPlayModule.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /上采样/DySample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def normal_init(module, mean=0, std=1, bias=0): 7 | if hasattr(module, 'weight') and module.weight is not None: 8 | nn.init.normal_(module.weight, mean, std) 9 | if hasattr(module, 'bias') and module.bias is not None: 10 | nn.init.constant_(module.bias, bias) 11 | 12 | 13 | def constant_init(module, val, bias=0): 14 | if hasattr(module, 'weight') and module.weight is not None: 15 | nn.init.constant_(module.weight, val) 16 | if hasattr(module, 'bias') and module.bias is not None: 17 | nn.init.constant_(module.bias, bias) 18 | 19 | 20 | class DySample(nn.Module): 21 | def __init__(self, in_channels, scale=2, style='lp', groups=4, dyscope=False): 22 | super().__init__() 23 | self.scale = scale 24 | self.style = style 25 | self.groups = groups 26 | assert style in ['lp', 'pl'] 27 | if style == 'pl': 28 | assert in_channels >= scale ** 2 and in_channels % scale ** 2 == 0 29 | assert in_channels >= groups and in_channels % groups == 0 30 | 31 | if style == 'pl': 32 | in_channels = in_channels // scale ** 2 33 | out_channels = 2 * groups 34 | else: 35 | out_channels = 2 * groups * scale ** 2 36 | 37 | self.offset = nn.Conv2d(in_channels, out_channels, 1) 38 | normal_init(self.offset, std=0.001) 39 | if dyscope: 40 | self.scope = nn.Conv2d(in_channels, out_channels, 1, bias=False) 41 | constant_init(self.scope, val=0.) 42 | 43 | self.register_buffer('init_pos', self._init_pos()) 44 | 45 | def _init_pos(self): 46 | h = torch.arange((-self.scale + 1) / 2, (self.scale - 1) / 2 + 1) / self.scale 47 | return torch.stack(torch.meshgrid([h, h])).transpose(1, 2).repeat(1, self.groups, 1).reshape(1, -1, 1, 1) 48 | 49 | def sample(self, x, offset): 50 | B, _, H, W = offset.shape 51 | offset = offset.view(B, 2, -1, H, W) 52 | coords_h = torch.arange(H) + 0.5 53 | coords_w = torch.arange(W) + 0.5 54 | coords = torch.stack(torch.meshgrid([coords_w, coords_h]) 55 | ).transpose(1, 2).unsqueeze(1).unsqueeze(0).type(x.dtype).to(x.device) 56 | normalizer = torch.tensor([W, H], dtype=x.dtype, device=x.device).view(1, 2, 1, 1, 1) 57 | coords = 2 * (coords + offset) / normalizer - 1 58 | coords = F.pixel_shuffle(coords.view(B, -1, H, W), self.scale).view( 59 | B, 2, -1, self.scale * H, self.scale * W).permute(0, 2, 3, 4, 1).contiguous().flatten(0, 1) 60 | return F.grid_sample(x.reshape(B * self.groups, -1, H, W), coords, mode='bilinear', 61 | align_corners=False, padding_mode="border").view(B, -1, self.scale * H, self.scale * W) 62 | 63 | def forward_lp(self, x): 64 | if hasattr(self, 'scope'): 65 | offset = self.offset(x) * self.scope(x).sigmoid() * 0.5 + self.init_pos 66 | else: 67 | offset = self.offset(x) * 0.25 + self.init_pos 68 | return self.sample(x, offset) 69 | 70 | def forward_pl(self, x): 71 | x_ = F.pixel_shuffle(x, self.scale) 72 | if hasattr(self, 'scope'): 73 | offset = F.pixel_unshuffle(self.offset(x_) * self.scope(x_).sigmoid(), self.scale) * 0.5 + self.init_pos 74 | else: 75 | offset = F.pixel_unshuffle(self.offset(x_), self.scale) * 0.25 + self.init_pos 76 | return self.sample(x, offset) 77 | 78 | def forward(self, x): 79 | if self.style == 'pl': 80 | return self.forward_pl(x) 81 | return self.forward_lp(x) -------------------------------------------------------------------------------- /即插即用卷积/ARConv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ARConv(nn.Module): 6 | def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, l_max=9, w_max=9, flag=False, modulation=True): 7 | super(ARConv, self).__init__() 8 | self.lmax = l_max 9 | self.wmax = w_max 10 | self.inc = inc 11 | self.outc = outc 12 | self.kernel_size = kernel_size 13 | self.padding = padding 14 | self.stride = stride 15 | self.zero_padding = nn.ZeroPad2d(padding) 16 | self.flag = flag 17 | self.modulation = modulation 18 | self.i_list = [33, 35, 53, 37, 73, 55, 57, 75, 77] 19 | self.convs = nn.ModuleList( 20 | [ 21 | nn.Conv2d(inc, outc, kernel_size=(i // 10, i % 10), stride=(i // 10, i % 10), padding=0) 22 | for i in self.i_list 23 | ] 24 | ) 25 | self.m_conv = nn.Sequential( 26 | nn.Conv2d(inc, outc, kernel_size=3, padding=1, stride=stride), 27 | nn.LeakyReLU(), 28 | nn.Dropout2d(0.3), 29 | nn.Conv2d(outc, outc, kernel_size=3, padding=1, stride=stride), 30 | nn.LeakyReLU(), 31 | nn.Dropout2d(0.3), 32 | nn.Conv2d(outc, outc, kernel_size=3, padding=1, stride=stride), 33 | nn.Tanh() 34 | ) 35 | self.b_conv = nn.Sequential( 36 | nn.Conv2d(inc, outc, kernel_size=3, padding=1, stride=stride), 37 | nn.LeakyReLU(), 38 | nn.Dropout2d(0.3), 39 | nn.Conv2d(outc, outc, kernel_size=3, padding=1, stride=stride), 40 | nn.LeakyReLU(), 41 | nn.Dropout2d(0.3), 42 | nn.Conv2d(outc, outc, kernel_size=3, padding=1, stride=stride) 43 | ) 44 | self.p_conv = nn.Sequential( 45 | nn.Conv2d(inc, inc, kernel_size=3, padding=1, stride=stride), 46 | nn.BatchNorm2d(inc), 47 | nn.LeakyReLU(), 48 | nn.Dropout2d(0), 49 | nn.Conv2d(inc, inc, kernel_size=3, padding=1, stride=stride), 50 | nn.BatchNorm2d(inc), 51 | nn.LeakyReLU(), 52 | ) 53 | self.l_conv = nn.Sequential( 54 | nn.Conv2d(inc, 1, kernel_size=3, padding=1, stride=stride), 55 | nn.BatchNorm2d(1), 56 | nn.LeakyReLU(), 57 | nn.Dropout2d(0), 58 | nn.Conv2d(1, 1, 1), 59 | nn.BatchNorm2d(1), 60 | nn.Sigmoid() 61 | ) 62 | self.w_conv = nn.Sequential( 63 | nn.Conv2d(inc, 1, kernel_size=3, padding=1, stride=stride), 64 | nn.BatchNorm2d(1), 65 | nn.LeakyReLU(), 66 | nn.Dropout2d(0), 67 | nn.Conv2d(1, 1, 1), 68 | nn.BatchNorm2d(1), 69 | nn.Sigmoid() 70 | ) 71 | self.dropout1 = nn.Dropout(0.3) 72 | self.dropout2 = nn.Dropout2d(0.3) 73 | self.hook_handles = [] 74 | self.hook_handles.append(self.m_conv[0].register_full_backward_hook(self._set_lr)) 75 | self.hook_handles.append(self.m_conv[1].register_full_backward_hook(self._set_lr)) 76 | self.hook_handles.append(self.b_conv[0].register_full_backward_hook(self._set_lr)) 77 | self.hook_handles.append(self.b_conv[1].register_full_backward_hook(self._set_lr)) 78 | self.hook_handles.append(self.p_conv[0].register_full_backward_hook(self._set_lr)) 79 | self.hook_handles.append(self.p_conv[1].register_full_backward_hook(self._set_lr)) 80 | self.hook_handles.append(self.l_conv[0].register_full_backward_hook(self._set_lr)) 81 | self.hook_handles.append(self.l_conv[1].register_full_backward_hook(self._set_lr)) 82 | self.hook_handles.append(self.w_conv[0].register_full_backward_hook(self._set_lr)) 83 | self.hook_handles.append(self.w_conv[1].register_full_backward_hook(self._set_lr)) 84 | 85 | self.reserved_NXY = nn.Parameter(torch.tensor([3, 3], dtype=torch.int32), requires_grad=False) 86 | 87 | @staticmethod 88 | def _set_lr(module, grad_input, grad_output): 89 | grad_input = tuple(g * 0.1 if g is not None else None for g in grad_input) 90 | grad_output = tuple(g * 0.1 if g is not None else None for g in grad_output) 91 | return grad_input 92 | 93 | def remove_hooks(self): 94 | for handle in self.hook_handles: 95 | handle.remove() # 移除钩子函数 96 | self.hook_handles.clear() # 清空句柄列表 97 | 98 | def forward(self, x, epoch, hw_range): 99 | assert isinstance(hw_range, list) and len( 100 | hw_range) == 2, "hw_range should be a list with 2 elements, represent the range of h w" 101 | scale = hw_range[1] // 9 102 | if hw_range[0] == 1 and hw_range[1] == 3: 103 | scale = 1 104 | m = self.m_conv(x) 105 | bias = self.b_conv(x) 106 | offset = self.p_conv(x * 100) 107 | l = self.l_conv(offset) * (hw_range[1] - 1) + 1 # b, 1, h, w 108 | w = self.w_conv(offset) * (hw_range[1] - 1) + 1 # b, 1, h, w 109 | if epoch <= 100: 110 | mean_l = l.mean(dim=0).mean(dim=1).mean(dim=1) 111 | mean_w = w.mean(dim=0).mean(dim=1).mean(dim=1) 112 | N_X = int(mean_l // scale) 113 | N_Y = int(mean_w // scale) 114 | 115 | def phi(x): 116 | if x % 2 == 0: 117 | x -= 1 118 | return x 119 | 120 | N_X, N_Y = phi(N_X), phi(N_Y) 121 | N_X, N_Y = max(N_X, 3), max(N_Y, 3) 122 | N_X, N_Y = min(N_X, 7), min(N_Y, 7) 123 | if epoch == 100: 124 | self.reserved_NXY = self.reserved_NXY = nn.Parameter( 125 | torch.tensor([N_X, N_Y], dtype=torch.int32, device=x.device), 126 | requires_grad=False 127 | ) 128 | else: 129 | N_X = self.reserved_NXY[0] 130 | N_Y = self.reserved_NXY[1] 131 | 132 | N = N_X * N_Y 133 | # print(N_X, N_Y) 134 | l = l.repeat([1, N, 1, 1]) 135 | w = w.repeat([1, N, 1, 1]) 136 | offset = torch.cat((l, w), dim=1) 137 | dtype = offset.data.type() 138 | if self.padding: 139 | x = self.zero_padding(x) 140 | p = self._get_p(offset, dtype, N_X, N_Y) # (b, 2*N, h, w) 141 | p = p.contiguous().permute(0, 2, 3, 1) # (b, h, w, 2*N) 142 | q_lt = p.detach().floor() 143 | q_rb = q_lt + 1 144 | q_lt = torch.cat( 145 | [ 146 | torch.clamp(q_lt[..., :N], 0, x.size(2) - 1), 147 | torch.clamp(q_lt[..., N:], 0, x.size(3) - 1), 148 | ], 149 | dim=-1, 150 | ).long() 151 | q_rb = torch.cat( 152 | [ 153 | torch.clamp(q_rb[..., :N], 0, x.size(2) - 1), 154 | torch.clamp(q_rb[..., N:], 0, x.size(3) - 1), 155 | ], 156 | dim=-1, 157 | ).long() 158 | q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1) 159 | q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1) 160 | # clip p 161 | p = torch.cat( 162 | [ 163 | torch.clamp(p[..., :N], 0, x.size(2) - 1), 164 | torch.clamp(p[..., N:], 0, x.size(3) - 1), 165 | ], 166 | dim=-1, 167 | ) 168 | # bilinear kernel (b, h, w, N) 169 | g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * ( 170 | 1 + (q_lt[..., N:].type_as(p) - p[..., N:]) 171 | ) 172 | g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * ( 173 | 1 - (q_rb[..., N:].type_as(p) - p[..., N:]) 174 | ) 175 | g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * ( 176 | 1 - (q_lb[..., N:].type_as(p) - p[..., N:]) 177 | ) 178 | g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * ( 179 | 1 + (q_rt[..., N:].type_as(p) - p[..., N:]) 180 | ) 181 | # (b, c, h, w, N) 182 | x_q_lt = self._get_x_q(x, q_lt, N) 183 | x_q_rb = self._get_x_q(x, q_rb, N) 184 | x_q_lb = self._get_x_q(x, q_lb, N) 185 | x_q_rt = self._get_x_q(x, q_rt, N) 186 | # (b, c, h, w, N) 187 | x_offset = ( 188 | g_lt.unsqueeze(dim=1) * x_q_lt 189 | + g_rb.unsqueeze(dim=1) * x_q_rb 190 | + g_lb.unsqueeze(dim=1) * x_q_lb 191 | + g_rt.unsqueeze(dim=1) * x_q_rt 192 | ) 193 | x_offset = self._reshape_x_offset(x_offset, N_X, N_Y) 194 | x_offset = self.dropout2(x_offset) 195 | x_offset = self.convs[self.i_list.index(N_X * 10 + N_Y)](x_offset) 196 | out = x_offset * m + bias 197 | return out 198 | 199 | def _get_p_n(self, N, dtype, n_x, n_y): 200 | p_n_x, p_n_y = torch.meshgrid( 201 | torch.arange(-(n_x - 1) // 2, (n_x - 1) // 2 + 1), 202 | torch.arange(-(n_y - 1) // 2, (n_y - 1) // 2 + 1), 203 | ) 204 | p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0) 205 | p_n = p_n.view(1, 2 * N, 1, 1).type(dtype) 206 | return p_n 207 | 208 | def _get_p_0(self, h, w, N, dtype): 209 | p_0_x, p_0_y = torch.meshgrid( 210 | torch.arange(1, h * self.stride + 1, self.stride), 211 | torch.arange(1, w * self.stride + 1, self.stride), 212 | ) 213 | p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1) 214 | p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1) 215 | p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype) 216 | return p_0 217 | 218 | def _get_p(self, offset, dtype, n_x, n_y): 219 | N, h, w = offset.size(1) // 2, offset.size(2), offset.size(3) 220 | L, W = offset.split([N, N], dim=1) 221 | L = L / n_x 222 | W = W / n_y 223 | offsett = torch.cat([L, W], dim=1) 224 | p_n = self._get_p_n(N, dtype, n_x, n_y) 225 | p_n = p_n.repeat([1, 1, h, w]) 226 | p_0 = self._get_p_0(h, w, N, dtype) 227 | p = p_0 + offsett * p_n 228 | return p 229 | 230 | def _get_x_q(self, x, q, N): 231 | b, h, w, _ = q.size() 232 | padded_w = x.size(3) 233 | c = x.size(1) 234 | x = x.contiguous().view(b, c, -1) 235 | index = q[..., :N] * padded_w + q[..., N:] 236 | index = ( 237 | index.contiguous() 238 | .unsqueeze(dim=1) 239 | .expand(-1, c, -1, -1, -1) 240 | .contiguous() 241 | .view(b, c, -1) 242 | ) 243 | x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N) 244 | return x_offset 245 | 246 | @staticmethod 247 | def _reshape_x_offset(x_offset, n_x, n_y): 248 | b, c, h, w, N = x_offset.size() 249 | x_offset = torch.cat([x_offset[..., s:s + n_y].contiguous().view(b, c, h, w * n_y) for s in range(0, N, n_y)], 250 | dim=-1) 251 | x_offset = x_offset.contiguous().view(b, c, h * n_x, w * n_y) 252 | return x_offset -------------------------------------------------------------------------------- /即插即用卷积/CFBConv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from mmcv.cnn import (Conv2d,ConvModule) 5 | from mmcv.runner import BaseModule 6 | 7 | from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,trunc_normal_init,normal_init) 8 | from timm.models.layers import DropPath 9 | from mmseg.models.decode_heads.decode_head import BaseDecodeHead 10 | from mmseg.ops import resize 11 | 12 | #BN->Conv->GELU->drop->Conv2->drop 13 | class MLP(BaseModule): 14 | def __init__(self, 15 | in_channels, 16 | hidden_channels=None, 17 | out_channels=None, 18 | drop_rate=0.): 19 | super(MLP,self).__init__() 20 | hidden_channels = hidden_channels or in_channels 21 | out_channels = out_channels or in_channels 22 | self.norm = nn.SyncBatchNorm(in_channels, eps=1e-06) 23 | self.conv1 = nn.Conv2d(in_channels, hidden_channels, 3, 1, 1) 24 | self.act = nn.GELU() 25 | self.conv2 = nn.Conv2d(hidden_channels, out_channels, 3, 1, 1) 26 | self.drop = nn.Dropout(drop_rate) 27 | 28 | self.apply(self._init_weights) 29 | 30 | def _init_weights(self, m): 31 | if isinstance(m, nn.Linear): 32 | trunc_normal_init(m.weight, std=.02) 33 | if m.bias is not None: 34 | constant_init(m.bias, val=0) 35 | elif isinstance(m, (nn.SyncBatchNorm, nn.BatchNorm2d)): 36 | constant_init(m.weight, val=1.0) 37 | constant_init(m.bias, val=0) 38 | elif isinstance(m, nn.Conv2d): 39 | kaiming_init(m.weight) 40 | if m.bias is not None: 41 | constant_init(m.bias, val=0) 42 | 43 | def forward(self, x): 44 | x = self.norm(x) 45 | x = self.conv1(x) 46 | x = self.act(x) 47 | x = self.drop(x) 48 | x = self.conv2(x) 49 | x = self.drop(x) 50 | return x 51 | 52 | class ConvolutionalAttention(BaseModule): 53 | """ 54 | The ConvolutionalAttention implementation 55 | Args: 56 | in_channels (int, optional): The input channels. 57 | inter_channels (int, optional): The channels of intermediate feature. 58 | out_channels (int, optional): The output channels. 59 | num_heads (int, optional): The num of heads in attention. Default: 8 60 | """ 61 | 62 | def __init__(self, 63 | in_channels, 64 | out_channels, 65 | inter_channels, 66 | num_heads=8): 67 | super(ConvolutionalAttention,self).__init__() 68 | assert out_channels % num_heads == 0, \ 69 | "out_channels ({}) should be be a multiple of num_heads ({})".format(out_channels, num_heads) 70 | self.in_channels = in_channels 71 | self.out_channels = out_channels 72 | self.inter_channels = inter_channels 73 | self.num_heads = num_heads 74 | self.norm = nn.SyncBatchNorm(in_channels) 75 | 76 | self.kv =nn.Parameter(torch.zeros(inter_channels, in_channels, 7, 1)) 77 | self.kv3 =nn.Parameter(torch.zeros(inter_channels, in_channels, 1, 7)) 78 | trunc_normal_init(self.kv, std=0.001) 79 | trunc_normal_init(self.kv3, std=0.001) 80 | 81 | self.apply(self._init_weights) 82 | 83 | def _init_weights(self, m): 84 | if isinstance(m, nn.Linear): 85 | trunc_normal_init(m.weight, std=.001) 86 | if m.bias is not None: 87 | constant_init(m.bias, val=0.) 88 | elif isinstance(m, (nn.SyncBatchNorm, nn.BatchNorm2d)): 89 | constant_init(m.weight, val=1.) 90 | constant_init(m.bias, val=.0) 91 | elif isinstance(m, nn.Conv2d): 92 | trunc_normal_init(m.weight, std=.001) 93 | if m.bias is not None: 94 | constant_init(m.bias, val=0.) 95 | 96 | 97 | def _act_dn(self, x): 98 | x_shape = x.shape # n,c_inter,h,w 99 | h, w = x_shape[2], x_shape[3] 100 | x = x.reshape( 101 | [x_shape[0], self.num_heads, self.inter_channels // self.num_heads, -1]) #n,c_inter,h,w -> n,heads,c_inner//heads,hw 102 | x = F.softmax(x, dim=3) 103 | x = x / (torch.sum(x, dim =2, keepdim=True) + 1e-06) 104 | x = x.reshape([x_shape[0], self.inter_channels, h, w]) 105 | return x 106 | 107 | def forward(self, x): 108 | """ 109 | Args: 110 | x (Tensor): The input tensor. (n,c,h,w) 111 | cross_k (Tensor, optional): The dims is (n*144, c_in, 1, 1) 112 | cross_v (Tensor, optional): The dims is (n*c_in, 144, 1, 1) 113 | """ 114 | x = self.norm(x) 115 | x1 = F.conv2d( 116 | x, 117 | self.kv, 118 | bias=None, 119 | stride=1, 120 | padding=(3,0)) 121 | x1 = self._act_dn(x1) 122 | x1 = F.conv2d( 123 | x1, self.kv.transpose(1, 0), bias=None, stride=1, 124 | padding=(3,0)) 125 | x3 = F.conv2d( 126 | x, 127 | self.kv3, 128 | bias=None, 129 | stride=1, 130 | padding=(0,3)) 131 | x3 = self._act_dn(x3) 132 | x3 = F.conv2d( 133 | x3, self.kv3.transpose(1, 0), bias=None, stride=1,padding=(0,3)) 134 | x=x1+x3 135 | return x 136 | 137 | class CFBlock(BaseModule): 138 | """ 139 | The CFBlock implementation based on PaddlePaddle. 140 | Args: 141 | in_channels (int, optional): The input channels. 142 | out_channels (int, optional): The output channels. 143 | num_heads (int, optional): The num of heads in attention. Default: 8 144 | drop_rate (float, optional): The drop rate in MLP. Default:0. 145 | drop_path_rate (float, optional): The drop path rate in CFBlock. Default: 0.2 146 | """ 147 | 148 | def __init__(self, 149 | in_channels, 150 | out_channels, 151 | num_heads=8, 152 | drop_rate=0., 153 | drop_path_rate=0.): 154 | super(CFBlock,self).__init__() 155 | in_channels_l = in_channels 156 | out_channels_l = out_channels 157 | self.attn_l = ConvolutionalAttention( 158 | in_channels_l, 159 | out_channels_l, 160 | inter_channels=64, 161 | num_heads=num_heads) 162 | self.mlp_l = MLP(out_channels_l, drop_rate=drop_rate) 163 | self.drop_path = DropPath( 164 | drop_path_rate) if drop_path_rate > 0. else nn.Identity() 165 | 166 | def _init_weights_kaiming(self, m): 167 | if isinstance(m, nn.Linear): 168 | trunc_normal_init(m.weight, std=.02) 169 | if m.bias is not None: 170 | constant_init(m.bias, val=0) 171 | elif isinstance(m, (nn.SyncBatchNorm, nn.BatchNorm2d)): 172 | constant_init(m.weight, val=1.0) 173 | constant_init(m.bias, val=0) 174 | elif isinstance(m, nn.Conv2d): 175 | kaiming_init(m.weight) 176 | if m.bias is not None: 177 | constant_init(m.bias, val=0) 178 | 179 | def forward(self, x): 180 | x_res = x 181 | x = x_res + self.drop_path(self.attn_l(x)) 182 | x = x + self.drop_path(self.mlp_l(x)) 183 | return x -------------------------------------------------------------------------------- /即插即用卷积/CKGConv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch_geometric as pyg 4 | import torch_geometric.graphgym.register as register 5 | from torch_geometric.graphgym.config import cfg 6 | from torch_geometric.graphgym.models.gnn import GNNPreMP 7 | from torch_geometric.graphgym.models.layer import (new_layer_config, 8 | BatchNorm1dNode) 9 | from torch_geometric.graphgym.register import register_network 10 | from functools import partial 11 | 12 | 13 | 14 | 15 | class FeatureEncoder(torch.nn.Module): 16 | """ 17 | Encoding node and edge features 18 | 19 | Args: 20 | dim_in (int): Input feature dimension 21 | """ 22 | def __init__(self, dim_in): 23 | super(FeatureEncoder, self).__init__() 24 | self.dim_in = dim_in 25 | if cfg.dataset.node_encoder: 26 | # Encode integer node features via nn.Embeddings 27 | NodeEncoder = register.node_encoder_dict[ 28 | cfg.dataset.node_encoder_name] 29 | self.node_encoder = NodeEncoder(cfg.gnn.dim_inner) 30 | if cfg.dataset.node_encoder_bn: 31 | self.node_encoder_bn = BatchNorm1dNode( 32 | new_layer_config(cfg.gnn.dim_inner, -1, -1, has_act=False, 33 | has_bias=False, cfg=cfg)) 34 | # Update dim_in to reflect the new dimension fo the node features 35 | self.dim_in = cfg.gnn.dim_inner 36 | 37 | if cfg.dataset.edge_encoder: 38 | # Hard-limit max edge dim for PNA. 39 | if 'PNA' in cfg.gt.layer_type: 40 | cfg.gnn.dim_edge = min(128, cfg.gnn.dim_inner) 41 | else: 42 | cfg.gnn.dim_edge = cfg.gnn.dim_inner 43 | # Encode integer edge features via nn.Embeddings 44 | EdgeEncoder = register.edge_encoder_dict[ 45 | cfg.dataset.edge_encoder_name] 46 | self.edge_encoder = EdgeEncoder(cfg.gnn.dim_edge) 47 | if cfg.dataset.edge_encoder_bn: 48 | self.edge_encoder_bn = BatchNorm1dNode( 49 | new_layer_config(cfg.gnn.dim_edge, -1, -1, has_act=False, 50 | has_bias=False, cfg=cfg)) 51 | 52 | def forward(self, batch): 53 | for module in self.children(): 54 | batch = module(batch) 55 | return batch 56 | 57 | 58 | # class PosencEncoder(torch.nn.Module): 59 | class Stem(torch.nn.Module): 60 | """ 61 | Encoding node and edge Positional Encoding 62 | Args: 63 | dim_in (int): Input feature dimension 64 | """ 65 | def __init__(self, dim_in, **kwargs): 66 | super().__init__() 67 | 68 | 69 | node_stem = cfg.gt.stem.get('node_stem', None) 70 | if node_stem is None: # to be compatible with previous versions 71 | node_stem = cfg.gt.get('node_pe_encoder', None) 72 | 73 | if node_stem is None: node_stem='' 74 | node_stem = node_stem.split("+") 75 | self.node_stem = nn.Sequential(*[register.node_encoder_dict[enc](out_dim=cfg.gnn.dim_inner, 76 | ) 77 | for enc in node_stem if enc != '']) 78 | 79 | edge_stem = cfg.gt.stem.get('edge_stem', None) 80 | if edge_stem is None: # to be compatible with previous versions 81 | edge_stem = cfg.gt.get('edge_pe_encoder', None) 82 | 83 | if edge_stem is None: edge_stem='' 84 | edge_stem = edge_stem.split("+") 85 | self.edge_stem = nn.Sequential(*[register.edge_encoder_dict[enc](out_dim=cfg.gnn.dim_inner, 86 | ) 87 | for enc in edge_stem if enc != '']) 88 | 89 | def forward(self, batch): 90 | for module in self.children(): 91 | batch = module(batch) 92 | return batch 93 | 94 | 95 | @register_network('CKGConvNet') 96 | class CKGraphConvNet(torch.nn.Module): 97 | ''' 98 | The proposed Continuous Kernel Convolution Networks 99 | ''' 100 | 101 | def __init__(self, dim_in, dim_out): 102 | super().__init__() 103 | 104 | if cfg.gnn.dim_inner == -1: 105 | cfg.gnn.dim_inner = cfg.gt.dim_hidden 106 | 107 | self.feat_enc = FeatureEncoder(dim_in) 108 | dim_in = self.feat_enc.dim_in 109 | # self.pe_encoder = PosencEncoder(dim_in) 110 | self.stem = Stem(dim_in) 111 | 112 | # pre-backbone normalization 113 | pre_backbone_norm = cfg.gt.get('pre_backbone_norm', False) 114 | norm_fn = nn.Identity 115 | if cfg.gt.batch_norm: norm_fn = partial(nn.BatchNorm1d, momentum=cfg.gt.bn_momentum) 116 | if cfg.gt.layer_norm: norm_fn = nn.LayerNorm 117 | graph_norm = False 118 | if cfg.gt.get('graph_norm', False): 119 | norm_fn = pyg.nn.GraphNorm 120 | graph_norm = True 121 | 122 | if pre_backbone_norm: 123 | self.pre_backbone_norm = NormalizationLayer(norm_fn(cfg.gt.dim_hidden), graph_norm=graph_norm) 124 | else: 125 | self.pre_backbone_norm = nn.Identity() 126 | 127 | if cfg.gnn.layers_pre_mp > 0: 128 | self.pre_mp = GNNPreMP( 129 | dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp) 130 | dim_in = cfg.gnn.dim_inner 131 | 132 | assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \ 133 | "The inner and hidden dims must match." 134 | 135 | global_model_type = cfg.gt.get('layer_type', "CKGraphConvMLP") 136 | # global_model_type = "GritTransformer" 137 | 138 | ConvBlock = register.layer_dict.get(global_model_type) 139 | kernel_size = cfg.gt.get('kernel_size', -1) 140 | dilation = cfg.gt.get('dilation', 1) 141 | if isinstance(kernel_size, str): 142 | kernel_size = [int(i) for i in kernel_size.split(',')] 143 | else: 144 | kernel_size = [kernel_size] * cfg.gt.layers 145 | 146 | 147 | if isinstance(dilation, str): 148 | dilation = [int(i) for i in dilation.split(',')] 149 | elif isinstance(dilation, tuple): 150 | pass 151 | else: 152 | dilation = [dilation] * cfg.gt.layers 153 | 154 | 155 | layers = [] 156 | for l in range(cfg.gt.layers): 157 | layers.append(ConvBlock( 158 | in_dim=cfg.gt.dim_hidden, 159 | out_dim=cfg.gt.dim_hidden, 160 | num_heads=cfg.gt.n_heads, 161 | kernel_size=kernel_size[l], 162 | dilation=dilation[l], 163 | dropout=cfg.gt.dropout, 164 | act=cfg.gnn.act, 165 | attn_dropout=cfg.gt.attn_dropout, 166 | layer_norm=cfg.gt.layer_norm, 167 | batch_norm=cfg.gt.batch_norm, 168 | residual=True, 169 | norm_e=cfg.gt.attn.norm_e, 170 | O_e=cfg.gt.attn.O_e, 171 | cfg=cfg.gt, 172 | out_norm=l==cfg.gt.layers-1 173 | # log_attn_weights=cfg.train.mode == 'log-attn-weights', 174 | )) 175 | 176 | # if global_model_type == "Norm-Res-GritTransformer" or global_model_type == "PreNormGritTransformer": 177 | # layers.append(register.layer_dict["GeneralNormLayer"]\ 178 | # (dim=cfg.gt.dim_hidden, 179 | # layer_norm=cfg.gt.layer_norm, 180 | # batch_norm=cfg.gt.batch_norm, 181 | # cfg=cfg.gt 182 | # )) 183 | 184 | self.layers = torch.nn.Sequential(*layers) 185 | 186 | # pre-backbone normalization 187 | post_backbone_norm = cfg.gt.get('post_backbone_norm', False) 188 | if post_backbone_norm: 189 | self.post_backbone_norm = NormalizationLayer(norm_fn(cfg.gt.dim_hidden), graph_norm=graph_norm) 190 | else: 191 | self.post_backbone_norm = nn.Identity() 192 | 193 | GNNHead = register.head_dict[cfg.gnn.head] 194 | self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out) 195 | 196 | def forward(self, batch): 197 | 198 | for module in self.children(): 199 | batch = module(batch) 200 | 201 | return batch 202 | 203 | 204 | 205 | class NormalizationLayer(nn.Module): 206 | def __init__(self, norm_layer, graph_norm=False): 207 | super().__init__() 208 | self.norm_layer = norm_layer 209 | self.graph_norm = graph_norm 210 | 211 | def forward(self, batch): 212 | if self.graph_norm: 213 | batch.x = self.norm_layer(batch.x, batch.batch) 214 | 215 | batch.x = self.norm_layer(batch.x) 216 | return batch -------------------------------------------------------------------------------- /即插即用卷积/DCNv4.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Function 6 | from torch.nn.init import xavier_uniform_, constant_ 7 | 8 | """ 9 | 完整的代码实现可能涉及到一些C++/CUDA层的扩展,直接在Python中书写非常复杂。 10 | 由于完整实现通常会包含在setup.py或ops子目录中,以下是一个基于PyTorch框架的核心实现(假设未使用C++/CUDA扩展)。 11 | 具体的CUDA实现细节需要从官方仓库中提取。 12 | """ 13 | 14 | class DCNv4Function(Function): 15 | @staticmethod 16 | def forward(ctx, input, offset_mask, kernel_h, kernel_w, stride_h, stride_w, 17 | pad_h, pad_w, dilation_h, dilation_w, group, group_channels, 18 | offset_scale, im2col_step, remove_center): 19 | ctx.save_for_backward(input, offset_mask) 20 | ctx.kernel_h, ctx.kernel_w = kernel_h, kernel_w 21 | ctx.stride_h, ctx.stride_w = stride_h, stride_w 22 | ctx.pad_h, ctx.pad_w = pad_h, pad_w 23 | ctx.dilation_h, ctx.dilation_w = dilation_h, dilation_w 24 | ctx.group = group 25 | ctx.group_channels = group_channels 26 | ctx.offset_scale = offset_scale 27 | ctx.im2col_step = im2col_step 28 | ctx.remove_center = remove_center 29 | 30 | # Placeholder for forward computation (mocked as a linear operation) 31 | # Replace this with an efficient im2col + convolution implementation if needed 32 | output = torch.nn.functional.linear(input, offset_mask) 33 | return output 34 | 35 | @staticmethod 36 | def backward(ctx, grad_output): 37 | input, offset_mask = ctx.saved_tensors 38 | # Placeholder for backward computation 39 | grad_input = grad_offset_mask = None 40 | if ctx.needs_input_grad[0]: 41 | grad_input = torch.nn.functional.linear(grad_output, offset_mask.t()) 42 | if ctx.needs_input_grad[1]: 43 | grad_offset_mask = torch.nn.functional.linear(input.t(), grad_output) 44 | 45 | return (grad_input, grad_offset_mask, None, None, None, None, 46 | None, None, None, None, None, None, None, None, None) 47 | 48 | 49 | class CenterFeatureScaleModule(nn.Module): 50 | def forward(self, query, center_feature_scale_proj_weight, center_feature_scale_proj_bias): 51 | center_feature_scale = F.linear(query, weight=center_feature_scale_proj_weight, 52 | bias=center_feature_scale_proj_bias).sigmoid() 53 | return center_feature_scale 54 | 55 | 56 | class DCNv4(nn.Module): 57 | def __init__(self, channels=64, kernel_size=3, stride=1, pad=1, dilation=1, 58 | group=4, offset_scale=1.0, dw_kernel_size=None, center_feature_scale=False, 59 | remove_center=False, output_bias=True, without_pointwise=False, **kwargs): 60 | super().__init__() 61 | if channels % group != 0: 62 | raise ValueError(f'channels must be divisible by group, but got {channels} and {group}') 63 | self.offset_scale = offset_scale 64 | self.channels = channels 65 | self.kernel_size = kernel_size 66 | self.stride = stride 67 | self.dilation = dilation 68 | self.pad = pad 69 | self.group = group 70 | self.group_channels = channels // group 71 | self.offset_scale = offset_scale 72 | self.dw_kernel_size = dw_kernel_size 73 | self.center_feature_scale = center_feature_scale 74 | self.remove_center = int(remove_center) 75 | self.without_pointwise = without_pointwise 76 | 77 | self.K = group * (kernel_size * kernel_size - self.remove_center) 78 | if dw_kernel_size is not None: 79 | self.offset_mask_dw = nn.Conv2d(channels, channels, dw_kernel_size, stride=1, 80 | padding=(dw_kernel_size - 1) // 2, groups=channels) 81 | self.offset_mask = nn.Linear(channels, int(math.ceil((self.K * 3) / 8) * 8)) 82 | if not without_pointwise: 83 | self.value_proj = nn.Linear(channels, channels) 84 | self.output_proj = nn.Linear(channels, channels, bias=output_bias) 85 | self._reset_parameters() 86 | 87 | if center_feature_scale: 88 | self.center_feature_scale_proj_weight = nn.Parameter( 89 | torch.zeros((group, channels), dtype=torch.float)) 90 | self.center_feature_scale_proj_bias = nn.Parameter( 91 | torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group,)) 92 | self.center_feature_scale_module = CenterFeatureScaleModule() 93 | 94 | def _reset_parameters(self): 95 | constant_(self.offset_mask.weight.data, 0.) 96 | constant_(self.offset_mask.bias.data, 0.) 97 | if not self.without_pointwise: 98 | xavier_uniform_(self.value_proj.weight.data) 99 | constant_(self.value_proj.bias.data, 0.) 100 | xavier_uniform_(self.output_proj.weight.data) 101 | if self.output_proj.bias is not None: 102 | constant_(self.output_proj.bias.data, 0.) 103 | 104 | def forward(self, input, shape=None): 105 | N, L, C = input.shape 106 | if shape is not None: 107 | H, W = shape 108 | else: 109 | H, W = int(L**0.5), int(L**0.5) 110 | 111 | x = input 112 | if not self.without_pointwise: 113 | x = self.value_proj(x) 114 | x = x.reshape(N, H, W, -1) 115 | if self.dw_kernel_size is not None: 116 | offset_mask_input = self.offset_mask_dw(input.view(N, H, W, C).permute(0, 3, 1, 2)) 117 | offset_mask_input = offset_mask_input.permute(0, 2, 3, 1).view(N, L, C) 118 | else: 119 | offset_mask_input = input 120 | offset_mask = self.offset_mask(offset_mask_input).reshape(N, H, W, -1) 121 | 122 | x_proj = x 123 | x = DCNv4Function.apply( 124 | x, offset_mask, self.kernel_size, self.kernel_size, self.stride, self.stride, 125 | self.pad, self.pad, self.dilation, self.dilation, self.group, self.group_channels, 126 | self.offset_scale, 256, self.remove_center 127 | ) 128 | 129 | if self.center_feature_scale: 130 | center_feature_scale = self.center_feature_scale_module( 131 | x, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias) 132 | center_feature_scale = center_feature_scale[..., None].repeat( 133 | 1, 1, 1, 1, self.channels // self.group).flatten(-2) 134 | x = x * (1 - center_feature_scale) + x_proj * center_feature_scale 135 | 136 | x = x.view(N, L, -1) 137 | 138 | if not self.without_pointwise: 139 | x = self.output_proj(x) 140 | return x 141 | -------------------------------------------------------------------------------- /即插即用卷积/DEConv.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from einops.layers.torch import Rearrange 5 | 6 | 7 | class Conv2d_cd(nn.Module): 8 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, 9 | padding=1, dilation=1, groups=1, bias=False, theta=1.0): 10 | super(Conv2d_cd, self).__init__() 11 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, 12 | dilation=dilation, groups=groups, bias=bias) 13 | self.theta = theta 14 | 15 | def get_weight(self): 16 | conv_weight = self.conv.weight 17 | conv_shape = conv_weight.shape 18 | conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight) 19 | conv_weight_cd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0) 20 | conv_weight_cd[:, :, :] = conv_weight[:, :, :] 21 | conv_weight_cd[:, :, 4] = conv_weight[:, :, 4] - conv_weight[:, :, :].sum(2) 22 | conv_weight_cd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[3])( 23 | conv_weight_cd) 24 | return conv_weight_cd, self.conv.bias 25 | 26 | 27 | class Conv2d_ad(nn.Module): 28 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, 29 | padding=1, dilation=1, groups=1, bias=False, theta=1.0): 30 | super(Conv2d_ad, self).__init__() 31 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, 32 | dilation=dilation, groups=groups, bias=bias) 33 | self.theta = theta 34 | 35 | def get_weight(self): 36 | conv_weight = self.conv.weight 37 | conv_shape = conv_weight.shape 38 | conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight) 39 | conv_weight_ad = conv_weight - self.theta * conv_weight[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]] 40 | conv_weight_ad = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[3])( 41 | conv_weight_ad) 42 | return conv_weight_ad, self.conv.bias 43 | 44 | 45 | class Conv2d_rd(nn.Module): 46 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, 47 | padding=2, dilation=1, groups=1, bias=False, theta=1.0): 48 | 49 | super(Conv2d_rd, self).__init__() 50 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, 51 | dilation=dilation, groups=groups, bias=bias) 52 | self.theta = theta 53 | 54 | def forward(self, x): 55 | 56 | if math.fabs(self.theta - 0.0) < 1e-8: 57 | out_normal = self.conv(x) 58 | return out_normal 59 | else: 60 | conv_weight = self.conv.weight 61 | conv_shape = conv_weight.shape 62 | if conv_weight.is_cuda: 63 | conv_weight_rd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 5 * 5).fill_(0) 64 | else: 65 | conv_weight_rd = torch.zeros(conv_shape[0], conv_shape[1], 5 * 5) 66 | conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight) 67 | conv_weight_rd[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = conv_weight[:, :, 1:] 68 | conv_weight_rd[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -conv_weight[:, :, 1:] * self.theta 69 | conv_weight_rd[:, :, 12] = conv_weight[:, :, 0] * (1 - self.theta) 70 | conv_weight_rd = conv_weight_rd.view(conv_shape[0], conv_shape[1], 5, 5) 71 | out_diff = nn.functional.conv2d(input=x, weight=conv_weight_rd, bias=self.conv.bias, 72 | stride=self.conv.stride, padding=self.conv.padding, groups=self.conv.groups) 73 | 74 | return out_diff 75 | 76 | 77 | class Conv2d_hd(nn.Module): 78 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, 79 | padding=1, dilation=1, groups=1, bias=False, theta=1.0): 80 | super(Conv2d_hd, self).__init__() 81 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, 82 | dilation=dilation, groups=groups, bias=bias) 83 | 84 | def get_weight(self): 85 | conv_weight = self.conv.weight 86 | conv_shape = conv_weight.shape 87 | conv_weight_hd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0) 88 | conv_weight_hd[:, :, [0, 3, 6]] = conv_weight[:, :, :] 89 | conv_weight_hd[:, :, [2, 5, 8]] = -conv_weight[:, :, :] 90 | conv_weight_hd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[2])( 91 | conv_weight_hd) 92 | return conv_weight_hd, self.conv.bias 93 | 94 | 95 | class Conv2d_vd(nn.Module): 96 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, 97 | padding=1, dilation=1, groups=1, bias=False): 98 | super(Conv2d_vd, self).__init__() 99 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, 100 | dilation=dilation, groups=groups, bias=bias) 101 | 102 | def get_weight(self): 103 | conv_weight = self.conv.weight 104 | conv_shape = conv_weight.shape 105 | conv_weight_vd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0) 106 | conv_weight_vd[:, :, [0, 1, 2]] = conv_weight[:, :, :] 107 | conv_weight_vd[:, :, [6, 7, 8]] = -conv_weight[:, :, :] 108 | conv_weight_vd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[2])( 109 | conv_weight_vd) 110 | return conv_weight_vd, self.conv.bias 111 | 112 | 113 | class DEConv(nn.Module): 114 | def __init__(self, dim): 115 | super(DEConv, self).__init__() 116 | self.conv1_1 = Conv2d_cd(dim, dim, 3, bias=True) 117 | self.conv1_2 = Conv2d_hd(dim, dim, 3, bias=True) 118 | self.conv1_3 = Conv2d_vd(dim, dim, 3, bias=True) 119 | self.conv1_4 = Conv2d_ad(dim, dim, 3, bias=True) 120 | self.conv1_5 = nn.Conv2d(dim, dim, 3, padding=1, bias=True) 121 | 122 | def forward(self, x): 123 | w1, b1 = self.conv1_1.get_weight() 124 | w2, b2 = self.conv1_2.get_weight() 125 | w3, b3 = self.conv1_3.get_weight() 126 | w4, b4 = self.conv1_4.get_weight() 127 | w5, b5 = self.conv1_5.weight, self.conv1_5.bias 128 | 129 | w = w1 + w2 + w3 + w4 + w5 130 | b = b1 + b2 + b3 + b4 + b5 131 | res = nn.functional.conv2d(input=x, weight=w, bias=b, stride=1, padding=1, groups=1) 132 | 133 | return res -------------------------------------------------------------------------------- /即插即用卷积/DynamicConv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from timm.models.layers import CondConv2d 5 | 6 | class DynamicConv(nn.Module): 7 | """ 8 | Dynamic Conv layer 9 | 将额外的参数带入到网络 10 | """ 11 | 12 | def __init__(self, in_features, out_features, kernel_size=1, stride=1, padding='', dilation=1, 13 | groups=1, bias=False, num_experts=4): 14 | super().__init__() 15 | print('+++', num_experts) 16 | self.routing = nn.Linear(in_features, num_experts) 17 | self.cond_conv = CondConv2d(in_features, out_features, kernel_size, stride, padding, dilation, 18 | groups, bias, num_experts) 19 | 20 | def forward(self, x): 21 | pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) # CondConv routing 22 | routing_weights = torch.sigmoid(self.routing(pooled_inputs)) 23 | x = self.cond_conv(x, routing_weights) 24 | return x -------------------------------------------------------------------------------- /即插即用卷积/LDConv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from einops import rearrange 5 | 6 | class LDConv(nn.Module): 7 | def __init__(self, inc, outc, num_param, stride=1, bias=None): 8 | super(LDConv, self).__init__() 9 | self.num_param = num_param 10 | self.stride = stride 11 | self.conv = nn.Sequential(nn.Conv2d(inc, outc, kernel_size=(num_param, 1), stride=(num_param, 1), bias=bias), 12 | nn.BatchNorm2d(outc), 13 | nn.SiLU()) # the conv adds the BN and SiLU to compare original Conv in YOLOv5. 14 | self.p_conv = nn.Conv2d(inc, 2 * num_param, kernel_size=3, padding=1, stride=stride) 15 | nn.init.constant_(self.p_conv.weight, 0) 16 | self.p_conv.register_full_backward_hook(self._set_lr) 17 | self.register_buffer("p_n", self._get_p_n(N=self.num_param)) 18 | 19 | @staticmethod 20 | def _set_lr(module, grad_input, grad_output): 21 | grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input))) 22 | grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output))) 23 | 24 | def forward(self, x): 25 | # N is num_param. 26 | offset = self.p_conv(x) 27 | dtype = offset.data.type() 28 | N = offset.size(1) // 2 29 | # (b, 2N, h, w) 30 | p = self._get_p(offset, dtype) 31 | 32 | # (b, h, w, 2N) 33 | p = p.contiguous().permute(0, 2, 3, 1) 34 | q_lt = p.detach().floor() 35 | q_rb = q_lt + 1 36 | 37 | q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2) - 1), torch.clamp(q_lt[..., N:], 0, x.size(3) - 1)], 38 | dim=-1).long() 39 | q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2) - 1), torch.clamp(q_rb[..., N:], 0, x.size(3) - 1)], 40 | dim=-1).long() 41 | q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1) 42 | q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1) 43 | 44 | # clip p 45 | p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2) - 1), torch.clamp(p[..., N:], 0, x.size(3) - 1)], dim=-1) 46 | 47 | # bilinear kernel (b, h, w, N) 48 | g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:])) 49 | g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:])) 50 | g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:])) 51 | g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:])) 52 | 53 | # resampling the features based on the modified coordinates. 54 | x_q_lt = self._get_x_q(x, q_lt, N) 55 | x_q_rb = self._get_x_q(x, q_rb, N) 56 | x_q_lb = self._get_x_q(x, q_lb, N) 57 | x_q_rt = self._get_x_q(x, q_rt, N) 58 | 59 | # bilinear 60 | x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \ 61 | g_rb.unsqueeze(dim=1) * x_q_rb + \ 62 | g_lb.unsqueeze(dim=1) * x_q_lb + \ 63 | g_rt.unsqueeze(dim=1) * x_q_rt 64 | 65 | x_offset = self._reshape_x_offset(x_offset, self.num_param) 66 | out = self.conv(x_offset) 67 | 68 | return out 69 | 70 | # generating the inital sampled shapes for the LDConv with different sizes. 71 | def _get_p_n(self, N): 72 | base_int = round(math.sqrt(self.num_param)) 73 | row_number = self.num_param // base_int 74 | mod_number = self.num_param % base_int 75 | p_n_x, p_n_y = torch.meshgrid( 76 | torch.arange(0, row_number), 77 | torch.arange(0, base_int)) 78 | p_n_x = torch.flatten(p_n_x) 79 | p_n_y = torch.flatten(p_n_y) 80 | if mod_number > 0: 81 | mod_p_n_x, mod_p_n_y = torch.meshgrid( 82 | torch.arange(row_number, row_number + 1), 83 | torch.arange(0, mod_number)) 84 | 85 | mod_p_n_x = torch.flatten(mod_p_n_x) 86 | mod_p_n_y = torch.flatten(mod_p_n_y) 87 | p_n_x, p_n_y = torch.cat((p_n_x, mod_p_n_x)), torch.cat((p_n_y, mod_p_n_y)) 88 | p_n = torch.cat([p_n_x, p_n_y], 0) 89 | p_n = p_n.view(1, 2 * N, 1, 1) 90 | return p_n 91 | 92 | # no zero-padding 93 | def _get_p_0(self, h, w, N, dtype): 94 | p_0_x, p_0_y = torch.meshgrid( 95 | torch.arange(0, h * self.stride, self.stride), 96 | torch.arange(0, w * self.stride, self.stride)) 97 | 98 | p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1) 99 | p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1) 100 | p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype) 101 | 102 | return p_0 103 | 104 | def _get_p(self, offset, dtype): 105 | N, h, w = offset.size(1) // 2, offset.size(2), offset.size(3) 106 | 107 | # (1, 2N, 1, 1) 108 | # p_n = self._get_p_n(N, dtype) 109 | # (1, 2N, h, w) 110 | p_0 = self._get_p_0(h, w, N, dtype) 111 | p = p_0 + self.p_n + offset 112 | return p 113 | 114 | def _get_x_q(self, x, q, N): 115 | b, h, w, _ = q.size() 116 | padded_w = x.size(3) 117 | c = x.size(1) 118 | # (b, c, h*w) 119 | x = x.contiguous().view(b, c, -1) 120 | 121 | # (b, h, w, N) 122 | index = q[..., :N] * padded_w + q[..., N:] # offset_x*w + offset_y 123 | # (b, c, h*w*N) 124 | index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1) 125 | 126 | x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N) 127 | 128 | return x_offset 129 | 130 | # Stacking resampled features in the row direction. 131 | @staticmethod 132 | def _reshape_x_offset(x_offset, num_param): 133 | b, c, h, w, n = x_offset.size() 134 | # using Conv3d 135 | # x_offset = x_offset.permute(0,1,4,2,3), then Conv3d(c,c_out, kernel_size =(num_param,1,1),stride=(num_param,1,1),bias= False) 136 | # using 1 × 1 Conv 137 | # x_offset = x_offset.permute(0,1,4,2,3), then, x_offset.view(b,c×num_param,h,w) finally, Conv2d(c×num_param,c_out, kernel_size =1,stride=1,bias= False) 138 | # using the column conv as follow, then, Conv2d(inc, outc, kernel_size=(num_param, 1), stride=(num_param, 1), bias=bias) 139 | 140 | x_offset = rearrange(x_offset, 'b c h w n -> b c (h n) w') 141 | return x_offset -------------------------------------------------------------------------------- /即插即用卷积/MorphologyConv.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def fixed_padding(inputs, kernel_size, dilation): 8 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 9 | pad_total = kernel_size_effective - 1 10 | pad_beg = pad_total // 2 11 | pad_end = pad_total - pad_beg 12 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 13 | return padded_inputs 14 | 15 | 16 | class MorphologyConv(nn.Module): 17 | ''' 18 | Base class for morpholigical operators 19 | For now, only supports stride=1, dilation=1, kernel_size H==W, and padding='same'. 20 | ''' 21 | 22 | def __init__(self, in_channels, out_channels, kernel_size=5, soft_max=True, beta=15, type=None): 23 | ''' 24 | in_channels: scalar 25 | out_channels: scalar, the number of the morphological neure. 26 | kernel_size: scalar, the spatial size of the morphological neure. 27 | soft_max: bool, using the soft max rather the torch.max(), ref: Dense Morphological Networks: An Universal Function Approximator (Mondal et al. (2019)). 28 | beta: scalar, used by soft_max. 29 | type: str, dilation2d or erosion2d. 30 | ''' 31 | super(MorphologyConv, self).__init__() 32 | self.in_channels = in_channels 33 | self.out_channels = out_channels 34 | self.kernel_size = kernel_size 35 | self.soft_max = soft_max 36 | self.beta = beta 37 | self.type = type 38 | 39 | self.weight = nn.Parameter(torch.ones(out_channels, in_channels, kernel_size, kernel_size), requires_grad=True) 40 | self.unfold = nn.Unfold(kernel_size, dilation=1, padding=0, stride=1) 41 | 42 | def forward(self, x): 43 | ''' 44 | x: tensor of shape (B,C,H,W) 45 | ''' 46 | # padding 47 | x = fixed_padding(x, self.kernel_size, dilation=1) 48 | 49 | # unfold 50 | x = self.unfold(x) # (B, Cin*kH*kW, L), where L is the numbers of patches 51 | x = x.unsqueeze(1) # (B, 1, Cin*kH*kW, L) 52 | L = x.size(-1) 53 | L_sqrt = int(math.sqrt(L)) 54 | 55 | # erosion 56 | weight = self.weight.view(self.out_channels, -1) # (Cout, Cin*kH*kW) 57 | weight = weight.unsqueeze(0).unsqueeze(-1) # (1, Cout, Cin*kH*kW, 1) 58 | 59 | if self.type == 'erosion2d': 60 | x = weight - x # (B, Cout, Cin*kH*kW, L) 61 | elif self.type == 'dilation2d': 62 | x = weight + x # (B, Cout, Cin*kH*kW, L) 63 | else: 64 | raise ValueError 65 | 66 | if not self.soft_max: 67 | x, _ = torch.max(x, dim=2, keepdim=False) # (B, Cout, L) 68 | else: 69 | x = torch.logsumexp(x * self.beta, dim=2, keepdim=False) / self.beta # (B, Cout, L) 70 | 71 | if self.type == 'erosion2d': 72 | x = -1 * x 73 | 74 | # instead of fold, we use view to avoid copy 75 | x = x.view(-1, self.out_channels, L_sqrt, L_sqrt) # (B, Cout, L/2, L/2) 76 | 77 | return x 78 | 79 | 80 | class Dilation2d(MorphologyConv): 81 | def __init__(self, in_channels, out_channels, kernel_size=5, soft_max=True, beta=20): 82 | super(Dilation2d, self).__init__(in_channels, out_channels, kernel_size, soft_max, beta, 'dilation2d') 83 | 84 | class Erosion2d(MorphologyConv): 85 | def __init__(self, in_channels, out_channels, kernel_size=5, soft_max=True, beta=20): 86 | super(Erosion2d, self).__init__(in_channels, out_channels, kernel_size, soft_max, beta, 'erosion2d') -------------------------------------------------------------------------------- /即插即用卷积/PConv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | 5 | class Partial_conv3(nn.Module): 6 | 7 | def __init__(self, dim, n_div, forward): 8 | super().__init__() 9 | self.dim_conv3 = dim // n_div 10 | self.dim_untouched = dim - self.dim_conv3 11 | self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False) 12 | 13 | if forward == 'slicing': 14 | self.forward = self.forward_slicing 15 | elif forward == 'split_cat': 16 | self.forward = self.forward_split_cat 17 | else: 18 | raise NotImplementedError 19 | 20 | def forward_slicing(self, x: Tensor) -> Tensor: 21 | # only for inference 22 | x = x.clone() # !!! Keep the original input intact for the residual connection later 23 | x[:, :self.dim_conv3, :, :] = self.partial_conv3(x[:, :self.dim_conv3, :, :]) 24 | 25 | return x 26 | 27 | def forward_split_cat(self, x: Tensor) -> Tensor: 28 | # for training/inference 29 | x1, x2 = torch.split(x, [self.dim_conv3, self.dim_untouched], dim=1) 30 | x1 = self.partial_conv3(x1) 31 | x = torch.cat((x1, x2), 1) 32 | 33 | return x -------------------------------------------------------------------------------- /即插即用卷积/SCConv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | 6 | class GroupBatchnorm2d(nn.Module): 7 | def __init__(self, c_num: int, 8 | group_num: int = 16, 9 | eps: float = 1e-10 10 | ): 11 | super(GroupBatchnorm2d, self).__init__() 12 | assert c_num >= group_num 13 | self.group_num = group_num 14 | self.weight = nn.Parameter(torch.randn(c_num, 1, 1)) 15 | self.bias = nn.Parameter(torch.zeros(c_num, 1, 1)) 16 | self.eps = eps 17 | 18 | def forward(self, x): 19 | N, C, H, W = x.size() 20 | x = x.view(N, self.group_num, -1) 21 | mean = x.mean(dim=2, keepdim=True) 22 | std = x.std(dim=2, keepdim=True) 23 | x = (x - mean) / (std + self.eps) 24 | x = x.view(N, C, H, W) 25 | return x * self.weight + self.bias 26 | 27 | 28 | class SRU(nn.Module): 29 | def __init__(self, 30 | oup_channels: int, 31 | group_num: int = 16, 32 | gate_treshold: float = 0.5, 33 | torch_gn: bool = True 34 | ): 35 | super().__init__() 36 | 37 | self.gn = nn.GroupNorm(num_channels=oup_channels, num_groups=group_num) if torch_gn else GroupBatchnorm2d( 38 | c_num=oup_channels, group_num=group_num) 39 | self.gate_treshold = gate_treshold 40 | self.sigomid = nn.Sigmoid() 41 | 42 | def forward(self, x): 43 | gn_x = self.gn(x) 44 | w_gamma = self.gn.weight / sum(self.gn.weight) 45 | w_gamma = w_gamma.view(1, -1, 1, 1) 46 | reweigts = self.sigomid(gn_x * w_gamma) 47 | # Gate 48 | w1 = torch.where(reweigts > self.gate_treshold, torch.ones_like(reweigts), reweigts) # 大于门限值的设为1,否则保留原值 49 | w2 = torch.where(reweigts > self.gate_treshold, torch.zeros_like(reweigts), reweigts) # 大于门限值的设为0,否则保留原值 50 | x_1 = w1 * x 51 | x_2 = w2 * x 52 | y = self.reconstruct(x_1, x_2) 53 | return y 54 | 55 | def reconstruct(self, x_1, x_2): 56 | x_11, x_12 = torch.split(x_1, x_1.size(1) // 2, dim=1) 57 | x_21, x_22 = torch.split(x_2, x_2.size(1) // 2, dim=1) 58 | return torch.cat([x_11 + x_22, x_12 + x_21], dim=1) 59 | 60 | 61 | class CRU(nn.Module): 62 | ''' 63 | alpha: 0 0. else nn.Identity() 25 | 26 | def forward(self, x): 27 | input = x 28 | x = self.dwconv(x) 29 | x1, x2 = self.f1(x), self.f2(x) 30 | x = self.act(x1) * x2 31 | x = self.dwconv2(self.g(x)) 32 | x = input + self.drop_path(x) 33 | return x -------------------------------------------------------------------------------- /即插即用卷积/WTConv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import pywt 5 | 6 | def create_wavelet_filter(wave, in_size, out_size, type=torch.float): 7 | w = pywt.Wavelet(wave) 8 | dec_hi = torch.tensor(w.dec_hi[::-1], dtype=type) 9 | dec_lo = torch.tensor(w.dec_lo[::-1], dtype=type) 10 | dec_filters = torch.stack([dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1), 11 | dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1), 12 | dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1), 13 | dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)], dim=0) 14 | 15 | dec_filters = dec_filters[:, None].repeat(in_size, 1, 1, 1) 16 | 17 | rec_hi = torch.tensor(w.rec_hi[::-1], dtype=type).flip(dims=[0]) 18 | rec_lo = torch.tensor(w.rec_lo[::-1], dtype=type).flip(dims=[0]) 19 | rec_filters = torch.stack([rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1), 20 | rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1), 21 | rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1), 22 | rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1)], dim=0) 23 | 24 | rec_filters = rec_filters[:, None].repeat(out_size, 1, 1, 1) 25 | 26 | return dec_filters, rec_filters 27 | 28 | def wavelet_transform(x, filters): 29 | b, c, h, w = x.shape 30 | pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1) 31 | x = F.conv2d(x, filters, stride=2, groups=c, padding=pad) 32 | x = x.reshape(b, c, 4, h // 2, w // 2) 33 | return x 34 | 35 | 36 | def inverse_wavelet_transform(x, filters): 37 | b, c, _, h_half, w_half = x.shape 38 | pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1) 39 | x = x.reshape(b, c * 4, h_half, w_half) 40 | x = F.conv_transpose2d(x, filters, stride=2, groups=c, padding=pad) 41 | return x 42 | 43 | 44 | class _ScaleModule(nn.Module): 45 | def __init__(self, dims, init_scale=1.0, init_bias=0): 46 | super(_ScaleModule, self).__init__() 47 | self.dims = dims 48 | self.weight = nn.Parameter(torch.ones(*dims) * init_scale) 49 | self.bias = None 50 | 51 | def forward(self, x): 52 | return torch.mul(self.weight, x) 53 | 54 | 55 | class WTConv2d(nn.Module): 56 | def __init__(self, in_channels, out_channels, kernel_size=5, stride=1, bias=True, wt_levels=1, wt_type='db1'): 57 | super(WTConv2d, self).__init__() 58 | 59 | assert in_channels == out_channels 60 | 61 | self.in_channels = in_channels 62 | self.wt_levels = wt_levels 63 | self.stride = stride 64 | self.dilation = 1 65 | 66 | self.wt_filter, self.iwt_filter = create_wavelet_filter(wt_type, in_channels, in_channels, torch.float) 67 | self.wt_filter = nn.Parameter(self.wt_filter, requires_grad=False) 68 | self.iwt_filter = nn.Parameter(self.iwt_filter, requires_grad=False) 69 | 70 | self.base_conv = nn.Conv2d(in_channels, in_channels, kernel_size, padding='same', stride=1, dilation=1, groups=in_channels, bias=bias) 71 | self.base_scale = _ScaleModule([1 ,in_channels ,1 ,1]) 72 | 73 | self.wavelet_convs = nn.ModuleList( 74 | [nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size, padding='same', stride=1, dilation=1, groups=in_channels * 4, bias=False) for _ in range(self.wt_levels)] 75 | ) 76 | self.wavelet_scale = nn.ModuleList( 77 | [_ScaleModule([1 ,in_channels * 4 ,1 ,1], init_scale=0.1) for _ in range(self.wt_levels)] 78 | ) 79 | 80 | if self.stride > 1: 81 | self.do_stride = nn.AvgPool2d(kernel_size=1, stride=stride) 82 | else: 83 | self.do_stride = None 84 | 85 | def forward(self, x): 86 | 87 | x_ll_in_levels = [] 88 | x_h_in_levels = [] 89 | shapes_in_levels = [] 90 | 91 | curr_x_ll = x 92 | 93 | for i in range(self.wt_levels): 94 | curr_shape = curr_x_ll.shape 95 | shapes_in_levels.append(curr_shape) 96 | if (curr_shape[2] % 2 > 0) or (curr_shape[3] % 2 > 0): 97 | curr_pads = (0, curr_shape[3] % 2, 0, curr_shape[2] % 2) 98 | curr_x_ll = F.pad(curr_x_ll, curr_pads) 99 | 100 | curr_x = wavelet_transform(curr_x_ll, self.wt_filter) 101 | curr_x_ll = curr_x[: ,: ,0 ,: ,:] 102 | 103 | shape_x = curr_x.shape 104 | curr_x_tag = curr_x.reshape(shape_x[0], shape_x[1] * 4, shape_x[3], shape_x[4]) 105 | curr_x_tag = self.wavelet_scale[i](self.wavelet_convs[i](curr_x_tag)) 106 | curr_x_tag = curr_x_tag.reshape(shape_x) 107 | 108 | x_ll_in_levels.append(curr_x_tag[: ,: ,0 ,: ,:]) 109 | x_h_in_levels.append(curr_x_tag[: ,: ,1:4 ,: ,:]) 110 | 111 | next_x_ll = 0 112 | 113 | for i in range(self.wt_levels-1, -1, -1): 114 | curr_x_ll = x_ll_in_levels.pop() 115 | curr_x_h = x_h_in_levels.pop() 116 | curr_shape = shapes_in_levels.pop() 117 | 118 | curr_x_ll = curr_x_ll + next_x_ll 119 | 120 | curr_x = torch.cat([curr_x_ll.unsqueeze(2), curr_x_h], dim=2) 121 | next_x_ll = inverse_wavelet_transform(curr_x, self.iwt_filter) 122 | 123 | next_x_ll = next_x_ll[:, :, :curr_shape[2], :curr_shape[3]] 124 | 125 | x_tag = next_x_ll 126 | assert len(x_ll_in_levels) == 0 127 | 128 | x = self.base_scale(self.base_conv(x)) 129 | x = x + x_tag 130 | 131 | if self.do_stride is not None: 132 | x = self.do_stride(x) 133 | 134 | return x -------------------------------------------------------------------------------- /注意力模块/ASSA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 4 | from einops import rearrange, repeat 5 | 6 | 7 | def conv(in_channels, out_channels, kernel_size, bias=False, stride=1): 8 | return nn.Conv2d( 9 | in_channels, out_channels, kernel_size, 10 | padding=(kernel_size // 2), bias=bias, stride=stride) 11 | 12 | 13 | ######################################### 14 | class ConvBlock(nn.Module): 15 | def __init__(self, in_channel, out_channel, strides=1): 16 | super(ConvBlock, self).__init__() 17 | self.strides = strides 18 | self.in_channel = in_channel 19 | self.out_channel = out_channel 20 | self.block = nn.Sequential( 21 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=strides, padding=1), 22 | nn.LeakyReLU(inplace=True), 23 | nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=strides, padding=1), 24 | nn.LeakyReLU(inplace=True), 25 | ) 26 | self.conv11 = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=strides, padding=0) 27 | 28 | def forward(self, x): 29 | out1 = self.block(x) 30 | out2 = self.conv11(x) 31 | out = out1 + out2 32 | return out 33 | 34 | 35 | class LinearProjection(nn.Module): 36 | def __init__(self, dim, heads=8, dim_head=64, dropout=0., bias=True): 37 | super().__init__() 38 | inner_dim = dim_head * heads 39 | self.heads = heads 40 | self.to_q = nn.Linear(dim, inner_dim, bias=bias) 41 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=bias) 42 | self.dim = dim 43 | self.inner_dim = inner_dim 44 | 45 | def forward(self, x, attn_kv=None): 46 | B_, N, C = x.shape 47 | if attn_kv is not None: 48 | attn_kv = attn_kv.unsqueeze(0).repeat(B_, 1, 1) 49 | else: 50 | attn_kv = x 51 | N_kv = attn_kv.size(1) 52 | q = self.to_q(x).reshape(B_, N, 1, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) 53 | kv = self.to_kv(attn_kv).reshape(B_, N_kv, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) 54 | q = q[0] 55 | k, v = kv[0], kv[1] 56 | return q, k, v 57 | 58 | 59 | ######################################### 60 | ########### window-based self-attention ############# 61 | class WindowAttention(nn.Module): 62 | def __init__(self, dim, win_size, num_heads, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0., 63 | proj_drop=0.): 64 | 65 | super().__init__() 66 | self.dim = dim 67 | self.win_size = win_size # Wh, Ww 68 | self.num_heads = num_heads 69 | head_dim = dim // num_heads 70 | self.scale = qk_scale or head_dim ** -0.5 71 | 72 | # define a parameter table of relative position bias 73 | self.relative_position_bias_table = nn.Parameter( 74 | torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 75 | 76 | # get pair-wise relative position index for each token inside the window 77 | coords_h = torch.arange(self.win_size[0]) # [0,...,Wh-1] 78 | coords_w = torch.arange(self.win_size[1]) # [0,...,Ww-1] 79 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 80 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 81 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 82 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 83 | relative_coords[:, :, 0] += self.win_size[0] - 1 # shift to start from 0 84 | relative_coords[:, :, 1] += self.win_size[1] - 1 85 | relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1 86 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 87 | self.register_buffer("relative_position_index", relative_position_index) 88 | trunc_normal_(self.relative_position_bias_table, std=.02) 89 | 90 | if token_projection == 'linear': 91 | self.qkv = LinearProjection(dim, num_heads, dim // num_heads, bias=qkv_bias) 92 | else: 93 | raise Exception("Projection error!") 94 | 95 | self.token_projection = token_projection 96 | self.attn_drop = nn.Dropout(attn_drop) 97 | self.proj = nn.Linear(dim, dim) 98 | self.proj_drop = nn.Dropout(proj_drop) 99 | 100 | self.softmax = nn.Softmax(dim=-1) 101 | 102 | def forward(self, x, attn_kv=None, mask=None): 103 | B_, N, C = x.shape 104 | q, k, v = self.qkv(x, attn_kv) 105 | q = q * self.scale 106 | attn = (q @ k.transpose(-2, -1)) 107 | 108 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 109 | self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1) # Wh*Ww,Wh*Ww,nH 110 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 111 | ratio = attn.size(-1) // relative_position_bias.size(-1) 112 | relative_position_bias = repeat(relative_position_bias, 'nH l c -> nH l (c d)', d=ratio) 113 | 114 | attn = attn + relative_position_bias.unsqueeze(0) 115 | 116 | if mask is not None: 117 | nW = mask.shape[0] 118 | mask = repeat(mask, 'nW m n -> nW m (n d)', d=ratio) 119 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N * ratio) + mask.unsqueeze(1).unsqueeze(0) 120 | attn = attn.view(-1, self.num_heads, N, N * ratio) 121 | attn = self.softmax(attn) 122 | else: 123 | attn = self.softmax(attn) 124 | 125 | attn = self.attn_drop(attn) 126 | 127 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 128 | x = self.proj(x) 129 | x = self.proj_drop(x) 130 | return x 131 | 132 | def extra_repr(self) -> str: 133 | return f'dim={self.dim}, win_size={self.win_size}, num_heads={self.num_heads}' 134 | 135 | 136 | ########### window-based self-attention ############# 137 | class WindowAttention_sparse(nn.Module): 138 | def __init__(self, dim, win_size, num_heads, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0., 139 | proj_drop=0.): 140 | 141 | super().__init__() 142 | self.dim = dim 143 | self.win_size = win_size # Wh, Ww 144 | self.num_heads = num_heads 145 | head_dim = dim // num_heads 146 | self.scale = qk_scale or head_dim ** -0.5 147 | 148 | # define a parameter table of relative position bias 149 | self.relative_position_bias_table = nn.Parameter( 150 | torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 151 | 152 | # get pair-wise relative position index for each token inside the window 153 | coords_h = torch.arange(self.win_size[0]) # [0,...,Wh-1] 154 | coords_w = torch.arange(self.win_size[1]) # [0,...,Ww-1] 155 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 156 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 157 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 158 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 159 | relative_coords[:, :, 0] += self.win_size[0] - 1 # shift to start from 0 160 | relative_coords[:, :, 1] += self.win_size[1] - 1 161 | relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1 162 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 163 | self.register_buffer("relative_position_index", relative_position_index) 164 | trunc_normal_(self.relative_position_bias_table, std=.02) 165 | 166 | if token_projection == 'linear': 167 | self.qkv = LinearProjection(dim, num_heads, dim // num_heads, bias=qkv_bias) 168 | else: 169 | raise Exception("Projection error!") 170 | 171 | self.token_projection = token_projection 172 | self.attn_drop = nn.Dropout(attn_drop) 173 | self.proj = nn.Linear(dim, dim) 174 | self.proj_drop = nn.Dropout(proj_drop) 175 | 176 | self.softmax = nn.Softmax(dim=-1) 177 | self.relu = nn.ReLU() 178 | self.w = nn.Parameter(torch.ones(2)) 179 | 180 | def forward(self, x, attn_kv=None, mask=None): 181 | B_, N, C = x.shape 182 | q, k, v = self.qkv(x, attn_kv) 183 | q = q * self.scale 184 | attn = (q @ k.transpose(-2, -1)) 185 | 186 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 187 | self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1) # Wh*Ww,Wh*Ww,nH 188 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 189 | ratio = attn.size(-1) // relative_position_bias.size(-1) 190 | relative_position_bias = repeat(relative_position_bias, 'nH l c -> nH l (c d)', d=ratio) 191 | 192 | attn = attn + relative_position_bias.unsqueeze(0) 193 | 194 | if mask is not None: 195 | nW = mask.shape[0] 196 | mask = repeat(mask, 'nW m n -> nW m (n d)', d=ratio) 197 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N * ratio) + mask.unsqueeze(1).unsqueeze(0) 198 | attn = attn.view(-1, self.num_heads, N, N * ratio) 199 | attn0 = self.softmax(attn) 200 | attn1 = self.relu(attn) ** 2 # b,h,w,c 201 | else: 202 | attn0 = self.softmax(attn) 203 | attn1 = self.relu(attn) ** 2 204 | w1 = torch.exp(self.w[0]) / torch.sum(torch.exp(self.w)) 205 | w2 = torch.exp(self.w[1]) / torch.sum(torch.exp(self.w)) 206 | attn = attn0 * w1 + attn1 * w2 207 | attn = self.attn_drop(attn) 208 | 209 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 210 | x = self.proj(x) 211 | x = self.proj_drop(x) 212 | return x 213 | 214 | def extra_repr(self) -> str: 215 | return f'dim={self.dim}, win_size={self.win_size}, num_heads={self.num_heads}' 216 | 217 | 218 | ########### self-attention ############# 219 | class Attention(nn.Module): 220 | def __init__(self, dim, num_heads, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0., 221 | proj_drop=0.): 222 | 223 | super().__init__() 224 | self.dim = dim 225 | self.num_heads = num_heads 226 | head_dim = dim // num_heads 227 | self.scale = qk_scale or head_dim ** -0.5 228 | 229 | self.qkv = LinearProjection(dim, num_heads, dim // num_heads, bias=qkv_bias) 230 | 231 | self.token_projection = token_projection 232 | self.attn_drop = nn.Dropout(attn_drop) 233 | self.proj = nn.Linear(dim, dim) 234 | self.proj_drop = nn.Dropout(proj_drop) 235 | 236 | self.softmax = nn.Softmax(dim=-1) 237 | 238 | def forward(self, x, attn_kv=None, mask=None): 239 | B_, N, C = x.shape 240 | q, k, v = self.qkv(x, attn_kv) 241 | q = q * self.scale 242 | attn = (q @ k.transpose(-2, -1)) 243 | if mask is not None: 244 | nW = mask.shape[0] 245 | # mask = repeat(mask, 'nW m n -> nW m (n d)',d = ratio) 246 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 247 | attn = attn.view(-1, self.num_heads, N, N) 248 | attn = self.softmax(attn) 249 | else: 250 | attn = self.softmax(attn) 251 | 252 | attn = self.attn_drop(attn) 253 | 254 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 255 | x = self.proj(x) 256 | x = self.proj_drop(x) 257 | return x 258 | 259 | def extra_repr(self) -> str: 260 | return f'dim={self.dim}, num_heads={self.num_heads}' -------------------------------------------------------------------------------- /注意力模块/AgentAttention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from timm.models.layers import trunc_normal_ 4 | import numpy as np 5 | 6 | 7 | def img2windows(img, H_sp, W_sp): 8 | """ 9 | img: B C H W 10 | """ 11 | B, C, H, W = img.shape 12 | img_reshape = img.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp) 13 | img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp * W_sp, C) 14 | return img_perm 15 | 16 | def windows2img(img_splits_hw, H_sp, W_sp, H, W): 17 | """ 18 | img_splits_hw: B' H W C 19 | """ 20 | B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp)) 21 | 22 | img = img_splits_hw.view(B, H // H_sp, W // W_sp, H_sp, W_sp, -1) 23 | img = img.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 24 | return img 25 | 26 | 27 | class AgentAttention(nn.Module): 28 | def __init__(self, dim, resolution, idx, split_size=7, dim_out=None, num_heads=8, attn_drop=0., proj_drop=0., 29 | agent_num=49, **kwargs): 30 | super().__init__() 31 | self.dim = dim 32 | self.dim_out = dim_out or dim 33 | self.resolution = resolution 34 | self.split_size = split_size 35 | self.num_heads = num_heads 36 | head_dim = dim // num_heads 37 | self.agent_num = agent_num 38 | self.scale = head_dim ** -0.5 39 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 40 | # self.scale = qk_scale or head_dim ** -0.5 41 | if idx == -1: 42 | H_sp, W_sp = self.resolution, self.resolution 43 | elif idx == 0: 44 | H_sp, W_sp = self.resolution, self.split_size 45 | elif idx == 1: 46 | W_sp, H_sp = self.resolution, self.split_size 47 | else: 48 | print("ERROR MODE", idx) 49 | exit(0) 50 | self.H_sp = H_sp 51 | self.W_sp = W_sp 52 | self.get_v = nn.Conv2d(dim, dim, kernel_size=(3, 3), stride=(1, 1), padding=1, groups=dim) 53 | 54 | self.attn_drop = nn.Dropout(attn_drop) 55 | 56 | self.an_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7)) 57 | self.na_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7)) 58 | self.ah_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, H_sp, 1)) 59 | self.aw_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, 1, W_sp)) 60 | self.ha_bias = nn.Parameter(torch.zeros(1, num_heads, H_sp, 1, agent_num)) 61 | self.wa_bias = nn.Parameter(torch.zeros(1, num_heads, 1, W_sp, agent_num)) 62 | trunc_normal_(self.an_bias, std=.02) 63 | trunc_normal_(self.na_bias, std=.02) 64 | trunc_normal_(self.ah_bias, std=.02) 65 | trunc_normal_(self.aw_bias, std=.02) 66 | trunc_normal_(self.ha_bias, std=.02) 67 | trunc_normal_(self.wa_bias, std=.02) 68 | pool_size = int(agent_num ** 0.5) 69 | self.pool = nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size)) 70 | self.softmax = nn.Softmax(dim=-1) 71 | 72 | def im2cswin(self, x): 73 | B, N, C = x.shape 74 | H = W = int(np.sqrt(N)) 75 | x = x.transpose(-2, -1).contiguous().view(B, C, H, W) 76 | x = img2windows(x, self.H_sp, self.W_sp) 77 | # x = x.reshape(-1, self.H_sp * self.W_sp, C).contiguous() 78 | return x 79 | 80 | def get_lepe(self, x, func): 81 | B, N, C = x.shape 82 | H = W = int(np.sqrt(N)) 83 | x = x.transpose(-2, -1).contiguous().view(B, C, H, W) 84 | 85 | H_sp, W_sp = self.H_sp, self.W_sp 86 | x = x.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp) 87 | x = x.permute(0, 2, 4, 1, 3, 5).contiguous().reshape(-1, C, H_sp, W_sp) ### B', C, H', W' 88 | 89 | lepe = func(x) ### B', C, H', W' 90 | lepe = lepe.reshape(-1, C, H_sp * W_sp).permute(0, 2, 1).contiguous() 91 | 92 | x = x.reshape(-1, C, self.H_sp * self.W_sp).permute(0, 2, 1).contiguous() 93 | return x, lepe 94 | 95 | def forward(self, qkv): 96 | """ 97 | x: B L C 98 | """ 99 | q, k, v = qkv[0], qkv[1], qkv[2] 100 | 101 | ### Img2Window 102 | H = W = self.resolution 103 | B, L, C = q.shape 104 | assert L == H * W, "flatten img_tokens has wrong size" 105 | 106 | q = self.im2cswin(q) 107 | k = self.im2cswin(k) 108 | v, lepe = self.get_lepe(v, self.get_v) 109 | # q, k, v = (rearrange(x, "b h n c -> b n (h c)", h=self.num_heads) for x in [q, k, v]) 110 | 111 | b, n, c = q.shape 112 | h, w = self.H_sp, self.W_sp 113 | num_heads, head_dim = self.num_heads, self.dim // self.num_heads 114 | 115 | agent_tokens = self.pool(q.reshape(b, h, w, c).permute(0, 3, 1, 2)).reshape(b, c, -1).permute(0, 2, 1) 116 | q = q.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3) 117 | k = k.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3) 118 | v = v.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3) 119 | agent_tokens = agent_tokens.reshape(b, self.agent_num, num_heads, head_dim).permute(0, 2, 1, 3) 120 | 121 | position_bias1 = nn.functional.interpolate(self.an_bias, size=(self.H_sp, self.W_sp), mode='bilinear') 122 | position_bias1 = position_bias1.reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1) 123 | position_bias2 = (self.ah_bias + self.aw_bias).reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1) 124 | position_bias = position_bias1 + position_bias2 125 | agent_attn = self.softmax((agent_tokens * self.scale) @ k.transpose(-2, -1) + position_bias) 126 | agent_attn = self.attn_drop(agent_attn) 127 | agent_v = agent_attn @ v 128 | 129 | agent_bias1 = nn.functional.interpolate(self.na_bias, size=(self.H_sp, self.W_sp), mode='bilinear') 130 | agent_bias1 = agent_bias1.reshape(1, num_heads, self.agent_num, -1).permute(0, 1, 3, 2).repeat(b, 1, 1, 1) 131 | agent_bias2 = (self.ha_bias + self.wa_bias).reshape(1, num_heads, -1, self.agent_num).repeat(b, 1, 1, 1) 132 | agent_bias = agent_bias1 + agent_bias2 133 | q_attn = self.softmax((q * self.scale) @ agent_tokens.transpose(-2, -1) + agent_bias) 134 | q_attn = self.attn_drop(q_attn) 135 | x = q_attn @ agent_v 136 | 137 | x = x.transpose(1, 2).reshape(b, n, c) 138 | x = x + lepe 139 | x = windows2img(x, self.H_sp, self.W_sp, H, W).view(B, -1, C) 140 | 141 | return x -------------------------------------------------------------------------------- /注意力模块/CA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class h_sigmoid(nn.Module): 6 | def __init__(self, inplace=True): 7 | super(h_sigmoid, self).__init__() 8 | self.relu = nn.ReLU6(inplace=inplace) 9 | 10 | def forward(self, x): 11 | return self.relu(x + 3) / 6 12 | 13 | 14 | class h_swish(nn.Module): 15 | def __init__(self, inplace=True): 16 | super(h_swish, self).__init__() 17 | self.sigmoid = h_sigmoid(inplace=inplace) 18 | 19 | def forward(self, x): 20 | return x * self.sigmoid(x) 21 | 22 | 23 | class CoordAtt(nn.Module): 24 | def __init__(self, inp, oup, reduction=32): 25 | super(CoordAtt, self).__init__() 26 | self.pool_h = nn.AdaptiveAvgPool2d((None, 1)) 27 | self.pool_w = nn.AdaptiveAvgPool2d((1, None)) 28 | 29 | mip = max(8, inp // reduction) 30 | 31 | self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0) 32 | self.bn1 = nn.BatchNorm2d(mip) 33 | self.act = h_swish() 34 | 35 | self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0) 36 | self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0) 37 | 38 | def forward(self, x): 39 | identity = x 40 | 41 | n, c, h, w = x.size() 42 | x_h = self.pool_h(x) 43 | x_w = self.pool_w(x).permute(0, 1, 3, 2) 44 | 45 | y = torch.cat([x_h, x_w], dim=2) 46 | y = self.conv1(y) 47 | y = self.bn1(y) 48 | y = self.act(y) 49 | 50 | x_h, x_w = torch.split(y, [h, w], dim=2) 51 | x_w = x_w.permute(0, 1, 3, 2) 52 | 53 | a_h = self.conv_h(x_h).sigmoid() 54 | a_w = self.conv_w(x_w).sigmoid() 55 | 56 | out = identity * a_w * a_h 57 | 58 | return out -------------------------------------------------------------------------------- /注意力模块/CBAM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # Channel Attention Module 6 | class ChannelAttention(nn.Module): 7 | def __init__(self, in_planes, ratio=16): 8 | super(ChannelAttention, self).__init__() 9 | self.avg_pool = nn.AdaptiveAvgPool2d(1) # 全局平均池化 10 | self.max_pool = nn.AdaptiveMaxPool2d(1) # 全局最大池化 11 | 12 | # 使用1x1卷积代替全连接层,减少参数量,ratio用于降维 13 | self.fc = nn.Sequential( 14 | nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), 15 | nn.ReLU(), 16 | nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False) 17 | ) 18 | self.sigmoid = nn.Sigmoid() # 激活函数 19 | 20 | def forward(self, x): 21 | avg_out = self.fc(self.avg_pool(x)) # 全局平均池化分支 22 | max_out = self.fc(self.max_pool(x)) # 全局最大池化分支 23 | out = avg_out + max_out # 融合两个池化分支 24 | return self.sigmoid(out) # 返回通道权重 25 | 26 | 27 | # Spatial Attention Module 28 | class SpatialAttention(nn.Module): 29 | def __init__(self, kernel_size=7): 30 | super(SpatialAttention, self).__init__() 31 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False) 32 | self.sigmoid = nn.Sigmoid() 33 | 34 | def forward(self, x): 35 | avg_out = torch.mean(x, dim=1, keepdim=True) # 计算特征图的平均值 36 | max_out, _ = torch.max(x, dim=1, keepdim=True) # 计算特征图的最大值 37 | x = torch.cat([avg_out, max_out], dim=1) # 沿通道维度拼接 38 | x = self.conv1(x) # 通过卷积层 39 | return self.sigmoid(x) # 返回空间权重 40 | 41 | 42 | # CBAM Module 43 | class CBAM(nn.Module): 44 | def __init__(self, in_planes, ratio=16, kernel_size=7): 45 | super(CBAM, self).__init__() 46 | self.channel_attention = ChannelAttention(in_planes, ratio) # 通道注意力模块 47 | self.spatial_attention = SpatialAttention(kernel_size) # 空间注意力模块 48 | 49 | def forward(self, x): 50 | out = self.channel_attention(x) * x # 先应用通道注意力 51 | out = self.spatial_attention(out) * out # 再应用空间注意力 52 | return out # 返回结果 -------------------------------------------------------------------------------- /注意力模块/CGA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import itertools 4 | 5 | class Conv2d_BN(torch.nn.Sequential): 6 | def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, 7 | groups=1, bn_weight_init=1, resolution=-10000): 8 | super().__init__() 9 | self.add_module('c', torch.nn.Conv2d( 10 | a, b, ks, stride, pad, dilation, groups, bias=False)) 11 | self.add_module('bn', torch.nn.BatchNorm2d(b)) 12 | torch.nn.init.constant_(self.bn.weight, bn_weight_init) 13 | torch.nn.init.constant_(self.bn.bias, 0) 14 | 15 | @torch.no_grad() 16 | def fuse(self): 17 | c, bn = self._modules.values() 18 | w = bn.weight / (bn.running_var + bn.eps)**0.5 19 | w = c.weight * w[:, None, None, None] 20 | b = bn.bias - bn.running_mean * bn.weight / \ 21 | (bn.running_var + bn.eps)**0.5 22 | m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size( 23 | 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) 24 | m.weight.data.copy_(w) 25 | m.bias.data.copy_(b) 26 | return m 27 | 28 | 29 | class CascadedGroupAttention(torch.nn.Module): 30 | r""" Cascaded Group Attention. 31 | 32 | Args: 33 | dim (int): Number of input channels. 34 | key_dim (int): The dimension for query and key. 35 | num_heads (int): Number of attention heads. 36 | attn_ratio (int): Multiplier for the query dim for value dimension. 37 | resolution (int): Input resolution, correspond to the window size. 38 | kernels (List[int]): The kernel size of the dw conv on query. 39 | """ 40 | def __init__(self, dim, key_dim, num_heads=8, 41 | attn_ratio=4, 42 | resolution=14, 43 | kernels=[5, 5, 5, 5],): 44 | super().__init__() 45 | self.num_heads = num_heads 46 | self.scale = key_dim ** -0.5 47 | self.key_dim = key_dim 48 | self.d = int(attn_ratio * key_dim) 49 | self.attn_ratio = attn_ratio 50 | 51 | qkvs = [] 52 | dws = [] 53 | for i in range(num_heads): 54 | qkvs.append(Conv2d_BN(dim // (num_heads), self.key_dim * 2 + self.d, resolution=resolution)) 55 | dws.append(Conv2d_BN(self.key_dim, self.key_dim, kernels[i], 1, kernels[i]//2, groups=self.key_dim, resolution=resolution)) 56 | self.qkvs = torch.nn.ModuleList(qkvs) 57 | self.dws = torch.nn.ModuleList(dws) 58 | self.proj = torch.nn.Sequential(torch.nn.ReLU(), Conv2d_BN( 59 | self.d * num_heads, dim, bn_weight_init=0, resolution=resolution)) 60 | 61 | points = list(itertools.product(range(resolution), range(resolution))) 62 | N = len(points) 63 | attention_offsets = {} 64 | idxs = [] 65 | for p1 in points: 66 | for p2 in points: 67 | offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) 68 | if offset not in attention_offsets: 69 | attention_offsets[offset] = len(attention_offsets) 70 | idxs.append(attention_offsets[offset]) 71 | self.attention_biases = torch.nn.Parameter( 72 | torch.zeros(num_heads, len(attention_offsets))) 73 | self.register_buffer('attention_bias_idxs', 74 | torch.LongTensor(idxs).view(N, N)) 75 | 76 | @torch.no_grad() 77 | def train(self, mode=True): 78 | super().train(mode) 79 | if mode and hasattr(self, 'ab'): 80 | del self.ab 81 | else: 82 | self.ab = self.attention_biases[:, self.attention_bias_idxs] 83 | 84 | def forward(self, x): # x (B,C,H,W) 85 | B, C, H, W = x.shape 86 | trainingab = self.attention_biases[:, self.attention_bias_idxs] 87 | feats_in = x.chunk(len(self.qkvs), dim=1) 88 | feats_out = [] 89 | feat = feats_in[0] 90 | for i, qkv in enumerate(self.qkvs): 91 | if i > 0: # add the previous output to the input 92 | feat = feat + feats_in[i] 93 | feat = qkv(feat) 94 | q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.d], dim=1) # B, C/h, H, W 95 | q = self.dws[i](q) 96 | q, k, v = q.flatten(2), k.flatten(2), v.flatten(2) # B, C/h, N 97 | attn = ( 98 | (q.transpose(-2, -1) @ k) * self.scale 99 | + 100 | (trainingab[i] if self.training else self.ab[i]) 101 | ) 102 | attn = attn.softmax(dim=-1) # BNN 103 | feat = (v @ attn.transpose(-2, -1)).view(B, self.d, H, W) # BCHW 104 | feats_out.append(feat) 105 | x = self.proj(torch.cat(feats_out, 1)) 106 | return x -------------------------------------------------------------------------------- /注意力模块/CGLU.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class DWConv(nn.Module): 5 | def __init__(self, dim=768): 6 | super(DWConv, self).__init__() 7 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=True, groups=dim) 8 | 9 | def forward(self, x, H, W): 10 | B, N, C = x.shape 11 | x = x.transpose(1, 2).view(B, C, H, W).contiguous() 12 | x = self.dwconv(x) 13 | x = x.flatten(2).transpose(1, 2) 14 | 15 | return x 16 | 17 | 18 | class ConvolutionalGLU(nn.Module): 19 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 20 | super().__init__() 21 | out_features = out_features or in_features 22 | hidden_features = hidden_features or in_features 23 | hidden_features = int(2 * hidden_features / 3) 24 | self.fc1 = nn.Linear(in_features, hidden_features * 2) 25 | self.dwconv = DWConv(hidden_features) 26 | self.act = act_layer() 27 | self.fc2 = nn.Linear(hidden_features, out_features) 28 | self.drop = nn.Dropout(drop) 29 | 30 | def forward(self, x, H, W): 31 | x, v = self.fc1(x).chunk(2, dim=-1) 32 | x = self.act(self.dwconv(x, H, W)) * v 33 | x = self.drop(x) 34 | x = self.fc2(x) 35 | x = self.drop(x) 36 | return x -------------------------------------------------------------------------------- /注意力模块/DANet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch 4 | from torch.nn import Module, Sequential, Conv2d, ReLU,AdaptiveMaxPool2d, AdaptiveAvgPool2d, \ 5 | NLLLoss, BCELoss, CrossEntropyLoss, AvgPool2d, MaxPool2d, Parameter, Linear, Sigmoid, Softmax, Dropout, Embedding 6 | torch_ver = torch.__version__[:3] 7 | 8 | __all__ = ['PAM_Module', 'CAM_Module'] 9 | 10 | 11 | class PAM_Module(Module): 12 | """ Position attention module""" 13 | #Ref from SAGAN 14 | def __init__(self, in_dim): 15 | super(PAM_Module, self).__init__() 16 | self.chanel_in = in_dim 17 | 18 | self.query_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) 19 | self.key_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) 20 | self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 21 | self.gamma = Parameter(torch.zeros(1)) 22 | 23 | self.softmax = Softmax(dim=-1) 24 | def forward(self, x): 25 | """ 26 | inputs : 27 | x : input feature maps( B X C X H X W) 28 | returns : 29 | out : attention value + input feature 30 | attention: B X (HxW) X (HxW) 31 | """ 32 | m_batchsize, C, height, width = x.size() 33 | proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1) 34 | proj_key = self.key_conv(x).view(m_batchsize, -1, width*height) 35 | energy = torch.bmm(proj_query, proj_key) 36 | attention = self.softmax(energy) 37 | proj_value = self.value_conv(x).view(m_batchsize, -1, width*height) 38 | 39 | out = torch.bmm(proj_value, attention.permute(0, 2, 1)) 40 | out = out.view(m_batchsize, C, height, width) 41 | 42 | out = self.gamma*out + x 43 | return out 44 | 45 | 46 | class CAM_Module(Module): 47 | """ Channel attention module""" 48 | def __init__(self, in_dim): 49 | super(CAM_Module, self).__init__() 50 | self.chanel_in = in_dim 51 | 52 | 53 | self.gamma = Parameter(torch.zeros(1)) 54 | self.softmax = Softmax(dim=-1) 55 | def forward(self,x): 56 | """ 57 | inputs : 58 | x : input feature maps( B X C X H X W) 59 | returns : 60 | out : attention value + input feature 61 | attention: B X C X C 62 | """ 63 | m_batchsize, C, height, width = x.size() 64 | proj_query = x.view(m_batchsize, C, -1) 65 | proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1) 66 | energy = torch.bmm(proj_query, proj_key) 67 | energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy 68 | attention = self.softmax(energy_new) 69 | proj_value = x.view(m_batchsize, C, -1) 70 | 71 | out = torch.bmm(attention, proj_value) 72 | out = out.view(m_batchsize, C, height, width) 73 | 74 | out = self.gamma*out + x 75 | return out 76 | 77 | 78 | class DANet(Module): 79 | """ DANet module """ 80 | 81 | def __init__(self, in_channels, out_channels): 82 | super(DANet, self).__init__() 83 | 84 | # Shared convolutional layers 85 | self.conv = Sequential( 86 | Conv2d(in_channels, in_channels, kernel_size=3, padding=1), 87 | nn.BatchNorm2d(in_channels), 88 | ReLU(inplace=True) 89 | ) 90 | 91 | # Position and Channel attention modules 92 | self.position_attention = PAM_Module(in_channels) 93 | self.channel_attention = CAM_Module(in_channels) 94 | 95 | # Output convolution 96 | self.output_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 97 | 98 | def forward(self, x): 99 | x = self.conv(x) 100 | 101 | # Position and Channel Attention 102 | pos_out = self.position_attention(x) 103 | chn_out = self.channel_attention(x) 104 | 105 | # Fusion and output 106 | fusion = pos_out + chn_out 107 | out = self.output_conv(fusion) 108 | 109 | return out 110 | -------------------------------------------------------------------------------- /注意力模块/ECA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.parameter import Parameter 4 | 5 | class eca_layer(nn.Module): 6 | """Constructs a ECA module. 7 | 8 | Args: 9 | channel: Number of channels of the input feature map 10 | k_size: Adaptive selection of kernel size 11 | """ 12 | def __init__(self, channel, k_size=3): 13 | super(eca_layer, self).__init__() 14 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 15 | self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) 16 | self.sigmoid = nn.Sigmoid() 17 | 18 | def forward(self, x): 19 | # feature descriptor on the global spatial information 20 | y = self.avg_pool(x) 21 | 22 | # Two different branches of ECA module 23 | y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) 24 | 25 | # Multi-scale information fusion 26 | y = self.sigmoid(y) 27 | 28 | return x * y.expand_as(x) -------------------------------------------------------------------------------- /注意力模块/FcaNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | def get_freq_indices(method): 7 | assert method in ['top1','top2','top4','top8','top16','top32', 8 | 'bot1','bot2','bot4','bot8','bot16','bot32', 9 | 'low1','low2','low4','low8','low16','low32'] 10 | num_freq = int(method[3:]) 11 | if 'top' in method: 12 | all_top_indices_x = [0,0,6,0,0,1,1,4,5,1,3,0,0,0,3,2,4,6,3,5,5,2,6,5,5,3,3,4,2,2,6,1] 13 | all_top_indices_y = [0,1,0,5,2,0,2,0,0,6,0,4,6,3,5,2,6,3,3,3,5,1,1,2,4,2,1,1,3,0,5,3] 14 | mapper_x = all_top_indices_x[:num_freq] 15 | mapper_y = all_top_indices_y[:num_freq] 16 | elif 'low' in method: 17 | all_low_indices_x = [0,0,1,1,0,2,2,1,2,0,3,4,0,1,3,0,1,2,3,4,5,0,1,2,3,4,5,6,1,2,3,4] 18 | all_low_indices_y = [0,1,0,1,2,0,1,2,2,3,0,0,4,3,1,5,4,3,2,1,0,6,5,4,3,2,1,0,6,5,4,3] 19 | mapper_x = all_low_indices_x[:num_freq] 20 | mapper_y = all_low_indices_y[:num_freq] 21 | elif 'bot' in method: 22 | all_bot_indices_x = [6,1,3,3,2,4,1,2,4,4,5,1,4,6,2,5,6,1,6,2,2,4,3,3,5,5,6,2,5,5,3,6] 23 | all_bot_indices_y = [6,4,4,6,6,3,1,4,4,5,6,5,2,2,5,1,4,3,5,0,3,1,1,2,4,2,1,1,5,3,3,3] 24 | mapper_x = all_bot_indices_x[:num_freq] 25 | mapper_y = all_bot_indices_y[:num_freq] 26 | else: 27 | raise NotImplementedError 28 | return mapper_x, mapper_y 29 | 30 | 31 | class MultiSpectralDCTLayer(nn.Module): 32 | """ 33 | Generate dct filters 34 | """ 35 | 36 | def __init__(self, height, width, mapper_x, mapper_y, channel): 37 | super(MultiSpectralDCTLayer, self).__init__() 38 | 39 | assert len(mapper_x) == len(mapper_y) 40 | assert channel % len(mapper_x) == 0 41 | 42 | self.num_freq = len(mapper_x) 43 | 44 | # fixed DCT init 45 | self.register_buffer('weight', self.get_dct_filter(height, width, mapper_x, mapper_y, channel)) 46 | 47 | # fixed random init 48 | # self.register_buffer('weight', torch.rand(channel, height, width)) 49 | 50 | # learnable DCT init 51 | # self.register_parameter('weight', self.get_dct_filter(height, width, mapper_x, mapper_y, channel)) 52 | 53 | # learnable random init 54 | # self.register_parameter('weight', torch.rand(channel, height, width)) 55 | 56 | # num_freq, h, w 57 | 58 | def forward(self, x): 59 | assert len(x.shape) == 4, 'x must been 4 dimensions, but got ' + str(len(x.shape)) 60 | # n, c, h, w = x.shape 61 | 62 | x = x * self.weight 63 | 64 | result = torch.sum(x, dim=[2, 3]) 65 | return result 66 | 67 | def build_filter(self, pos, freq, POS): 68 | result = math.cos(math.pi * freq * (pos + 0.5) / POS) / math.sqrt(POS) 69 | if freq == 0: 70 | return result 71 | else: 72 | return result * math.sqrt(2) 73 | 74 | def get_dct_filter(self, tile_size_x, tile_size_y, mapper_x, mapper_y, channel): 75 | dct_filter = torch.zeros(channel, tile_size_x, tile_size_y) 76 | 77 | c_part = channel // len(mapper_x) 78 | 79 | for i, (u_x, v_y) in enumerate(zip(mapper_x, mapper_y)): 80 | for t_x in range(tile_size_x): 81 | for t_y in range(tile_size_y): 82 | dct_filter[i * c_part: (i + 1) * c_part, t_x, t_y] = self.build_filter(t_x, u_x, 83 | tile_size_x) * self.build_filter( 84 | t_y, v_y, tile_size_y) 85 | 86 | return dct_filter 87 | 88 | 89 | 90 | class MultiSpectralAttentionLayer(torch.nn.Module): 91 | def __init__(self, channel, dct_h, dct_w, reduction = 16, freq_sel_method = 'top16'): 92 | super(MultiSpectralAttentionLayer, self).__init__() 93 | self.reduction = reduction 94 | self.dct_h = dct_h 95 | self.dct_w = dct_w 96 | 97 | mapper_x, mapper_y = get_freq_indices(freq_sel_method) 98 | self.num_split = len(mapper_x) 99 | mapper_x = [temp_x * (dct_h // 7) for temp_x in mapper_x] 100 | mapper_y = [temp_y * (dct_w // 7) for temp_y in mapper_y] 101 | # make the frequencies in different sizes are identical to a 7x7 frequency space 102 | # eg, (2,2) in 14x14 is identical to (1,1) in 7x7 103 | 104 | self.dct_layer = MultiSpectralDCTLayer(dct_h, dct_w, mapper_x, mapper_y, channel) 105 | self.fc = nn.Sequential( 106 | nn.Linear(channel, channel // reduction, bias=False), 107 | nn.ReLU(inplace=True), 108 | nn.Linear(channel // reduction, channel, bias=False), 109 | nn.Sigmoid() 110 | ) 111 | 112 | def forward(self, x): 113 | n,c,h,w = x.shape 114 | x_pooled = x 115 | if h != self.dct_h or w != self.dct_w: 116 | x_pooled = torch.nn.functional.adaptive_avg_pool2d(x, (self.dct_h, self.dct_w)) 117 | # If you have concerns about one-line-change, don't worry. :) 118 | # In the ImageNet models, this line will never be triggered. 119 | # This is for compatibility in instance segmentation and object detection. 120 | y = self.dct_layer(x_pooled) 121 | 122 | y = self.fc(y).view(n, c, 1, 1) 123 | return x * y.expand_as(x) 124 | 125 | 126 | 127 | class FcaBasicBlock(nn.Module): 128 | expansion = 1 129 | 130 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 131 | base_width=64, dilation=1, norm_layer=None, 132 | *, reduction=16, ): 133 | global _mapper_x, _mapper_y 134 | super(FcaBasicBlock, self).__init__() 135 | # assert fea_h is not None 136 | # assert fea_w is not None 137 | c2wh = dict([(64,56), (128,28), (256,14) ,(512,7)]) 138 | self.planes = planes 139 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 140 | self.bn1 = nn.BatchNorm2d(planes) 141 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) 142 | self.bn2 = nn.BatchNorm2d(planes) 143 | self.relu = nn.ReLU(inplace=True) 144 | self.att = MultiSpectralAttentionLayer(planes, c2wh[planes], c2wh[planes], reduction=reduction, freq_sel_method = 'top16') 145 | self.downsample = downsample 146 | self.stride = stride 147 | 148 | def forward(self, x): 149 | residual = x 150 | 151 | out = self.conv1(x) 152 | out = self.bn1(out) 153 | out = self.relu(out) 154 | 155 | out = self.conv2(out) 156 | out = self.bn2(out) 157 | 158 | out = self.att(out) 159 | 160 | if self.downsample is not None: 161 | residual = self.downsample(x) 162 | 163 | out += residual 164 | out = self.relu(out) 165 | 166 | return out -------------------------------------------------------------------------------- /注意力模块/LGAG.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | 5 | import math 6 | from timm.models.layers import trunc_normal_tf_ 7 | from timm.models.helpers import named_apply 8 | 9 | 10 | # Other types of layers can go here (e.g., nn.Linear, etc.) 11 | def _init_weights(module, name, scheme=''): 12 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv3d): 13 | if scheme == 'normal': 14 | nn.init.normal_(module.weight, std=.02) 15 | if module.bias is not None: 16 | nn.init.zeros_(module.bias) 17 | elif scheme == 'trunc_normal': 18 | trunc_normal_tf_(module.weight, std=.02) 19 | if module.bias is not None: 20 | nn.init.zeros_(module.bias) 21 | elif scheme == 'xavier_normal': 22 | nn.init.xavier_normal_(module.weight) 23 | if module.bias is not None: 24 | nn.init.zeros_(module.bias) 25 | elif scheme == 'kaiming_normal': 26 | nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') 27 | if module.bias is not None: 28 | nn.init.zeros_(module.bias) 29 | else: 30 | # efficientnet like 31 | fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels 32 | fan_out //= module.groups 33 | nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out)) 34 | if module.bias is not None: 35 | nn.init.zeros_(module.bias) 36 | elif isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d): 37 | nn.init.constant_(module.weight, 1) 38 | nn.init.constant_(module.bias, 0) 39 | elif isinstance(module, nn.LayerNorm): 40 | nn.init.constant_(module.weight, 1) 41 | nn.init.constant_(module.bias, 0) 42 | 43 | def act_layer(act, inplace=False, neg_slope=0.2, n_prelu=1): 44 | # activation layer 45 | act = act.lower() 46 | if act == 'relu': 47 | layer = nn.ReLU(inplace) 48 | elif act == 'relu6': 49 | layer = nn.ReLU6(inplace) 50 | elif act == 'leakyrelu': 51 | layer = nn.LeakyReLU(neg_slope, inplace) 52 | elif act == 'prelu': 53 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) 54 | elif act == 'gelu': 55 | layer = nn.GELU() 56 | elif act == 'hswish': 57 | layer = nn.Hardswish(inplace) 58 | else: 59 | raise NotImplementedError('activation layer [%s] is not found' % act) 60 | return layer 61 | 62 | 63 | 64 | 65 | # Large-kernel grouped attention gate (LGAG) 66 | class LGAG(nn.Module): 67 | def __init__(self, F_g, F_l, F_int, kernel_size=3, groups=1, activation='relu'): 68 | super(LGAG, self).__init__() 69 | 70 | if kernel_size == 1: 71 | groups = 1 72 | self.W_g = nn.Sequential( 73 | nn.Conv2d(F_g, F_int, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, groups=groups, 74 | bias=True), 75 | nn.BatchNorm2d(F_int) 76 | ) 77 | self.W_x = nn.Sequential( 78 | nn.Conv2d(F_l, F_int, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, groups=groups, 79 | bias=True), 80 | nn.BatchNorm2d(F_int) 81 | ) 82 | self.psi = nn.Sequential( 83 | nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), 84 | nn.BatchNorm2d(1), 85 | nn.Sigmoid() 86 | ) 87 | self.activation = act_layer(activation, inplace=True) 88 | 89 | self.init_weights('normal') 90 | 91 | def init_weights(self, scheme=''): 92 | named_apply(partial(_init_weights, scheme=scheme), self) 93 | 94 | def forward(self, g, x): 95 | g1 = self.W_g(g) 96 | x1 = self.W_x(x) 97 | psi = self.activation(g1 + x1) 98 | psi = self.psi(psi) 99 | 100 | return x * psi -------------------------------------------------------------------------------- /注意力模块/MAB.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class LayerNorm(nn.Module): 6 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 7 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 8 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 9 | with shape (batch_size, channels, height, width). 10 | """ 11 | 12 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 13 | super().__init__() 14 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 15 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 16 | self.eps = eps 17 | self.data_format = data_format 18 | if self.data_format not in ["channels_last", "channels_first"]: 19 | raise NotImplementedError 20 | self.normalized_shape = (normalized_shape,) 21 | 22 | def forward(self, x): 23 | if self.data_format == "channels_last": 24 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 25 | elif self.data_format == "channels_first": 26 | u = x.mean(1, keepdim=True) 27 | s = (x - u).pow(2).mean(1, keepdim=True) 28 | x = (x - u) / torch.sqrt(s + self.eps) 29 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 30 | return x 31 | 32 | 33 | class SGAB(nn.Module): 34 | def __init__(self, n_feats, drop=0.0, k=2, squeeze_factor=15, attn='GLKA'): 35 | super().__init__() 36 | i_feats = n_feats * 2 37 | 38 | self.Conv1 = nn.Conv2d(n_feats, i_feats, 1, 1, 0) 39 | self.DWConv1 = nn.Conv2d(n_feats, n_feats, 7, 1, 7 // 2, groups=n_feats) 40 | self.Conv2 = nn.Conv2d(n_feats, n_feats, 1, 1, 0) 41 | 42 | self.norm = LayerNorm(n_feats, data_format='channels_first') 43 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True) 44 | 45 | def forward(self, x): 46 | shortcut = x.clone() 47 | 48 | # Ghost Expand 49 | x = self.Conv1(self.norm(x)) 50 | a, x = torch.chunk(x, 2, dim=1) 51 | x = x * self.DWConv1(a) 52 | x = self.Conv2(x) 53 | 54 | return x * self.scale + shortcut 55 | 56 | 57 | class GroupGLKA(nn.Module): 58 | def __init__(self, n_feats, k=2, squeeze_factor=15): 59 | super().__init__() 60 | i_feats = 2 * n_feats 61 | 62 | self.n_feats = n_feats 63 | self.i_feats = i_feats 64 | 65 | self.norm = LayerNorm(n_feats, data_format='channels_first') 66 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True) 67 | 68 | # Multiscale Large Kernel Attention 69 | self.LKA7 = nn.Sequential( 70 | nn.Conv2d(n_feats // 3, n_feats // 3, 7, 1, 7 // 2, groups=n_feats // 3), 71 | nn.Conv2d(n_feats // 3, n_feats // 3, 9, stride=1, padding=(9 // 2) * 4, groups=n_feats // 3, dilation=4), 72 | nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0)) 73 | self.LKA5 = nn.Sequential( 74 | nn.Conv2d(n_feats // 3, n_feats // 3, 5, 1, 5 // 2, groups=n_feats // 3), 75 | nn.Conv2d(n_feats // 3, n_feats // 3, 7, stride=1, padding=(7 // 2) * 3, groups=n_feats // 3, dilation=3), 76 | nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0)) 77 | self.LKA3 = nn.Sequential( 78 | nn.Conv2d(n_feats // 3, n_feats // 3, 3, 1, 1, groups=n_feats // 3), 79 | nn.Conv2d(n_feats // 3, n_feats // 3, 5, stride=1, padding=(5 // 2) * 2, groups=n_feats // 3, dilation=2), 80 | nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0)) 81 | 82 | self.X3 = nn.Conv2d(n_feats // 3, n_feats // 3, 3, 1, 1, groups=n_feats // 3) 83 | self.X5 = nn.Conv2d(n_feats // 3, n_feats // 3, 5, 1, 5 // 2, groups=n_feats // 3) 84 | self.X7 = nn.Conv2d(n_feats // 3, n_feats // 3, 7, 1, 7 // 2, groups=n_feats // 3) 85 | 86 | self.proj_first = nn.Sequential( 87 | nn.Conv2d(n_feats, i_feats, 1, 1, 0)) 88 | 89 | self.proj_last = nn.Sequential( 90 | nn.Conv2d(n_feats, n_feats, 1, 1, 0)) 91 | 92 | def forward(self, x, pre_attn=None, RAA=None): 93 | shortcut = x.clone() 94 | 95 | x = self.norm(x) 96 | 97 | x = self.proj_first(x) 98 | 99 | a, x = torch.chunk(x, 2, dim=1) 100 | 101 | a_1, a_2, a_3 = torch.chunk(a, 3, dim=1) 102 | 103 | a = torch.cat([self.LKA3(a_1) * self.X3(a_1), self.LKA5(a_2) * self.X5(a_2), self.LKA7(a_3) * self.X7(a_3)], 104 | dim=1) 105 | 106 | x = self.proj_last(x * a) * self.scale + shortcut 107 | 108 | return x 109 | 110 | 111 | class MAB(nn.Module): 112 | def __init__( 113 | self, n_feats): 114 | super().__init__() 115 | 116 | self.LKA = GroupGLKA(n_feats) 117 | 118 | self.LFE = SGAB(n_feats) 119 | 120 | def forward(self, x, pre_attn=None, RAA=None): 121 | # large kernel attention 122 | x = self.LKA(x) 123 | 124 | # local feature extraction 125 | x = self.LFE(x) 126 | 127 | return x -------------------------------------------------------------------------------- /注意力模块/MCA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | 5 | __all__ = ['MCALayer', 'MCAGate'] 6 | 7 | 8 | class StdPool(nn.Module): 9 | def __init__(self): 10 | super(StdPool, self).__init__() 11 | 12 | def forward(self, x): 13 | b, c, _, _ = x.size() 14 | 15 | std = x.view(b, c, -1).std(dim=2, keepdim=True) 16 | std = std.reshape(b, c, 1, 1) 17 | 18 | return std 19 | 20 | 21 | class MCAGate(nn.Module): 22 | def __init__(self, k_size, pool_types=['avg', 'std']): 23 | """Constructs a MCAGate module. 24 | Args: 25 | k_size: kernel size 26 | pool_types: pooling type. 'avg': average pooling, 'max': max pooling, 'std': standard deviation pooling. 27 | """ 28 | super(MCAGate, self).__init__() 29 | 30 | self.pools = nn.ModuleList([]) 31 | for pool_type in pool_types: 32 | if pool_type == 'avg': 33 | self.pools.append(nn.AdaptiveAvgPool2d(1)) 34 | elif pool_type == 'max': 35 | self.pools.append(nn.AdaptiveMaxPool2d(1)) 36 | elif pool_type == 'std': 37 | self.pools.append(StdPool()) 38 | else: 39 | raise NotImplementedError 40 | 41 | self.conv = nn.Conv2d(1, 1, kernel_size=(1, k_size), stride=1, padding=(0, (k_size - 1) // 2), bias=False) 42 | self.sigmoid = nn.Sigmoid() 43 | 44 | self.weight = nn.Parameter(torch.rand(2)) 45 | 46 | def forward(self, x): 47 | feats = [pool(x) for pool in self.pools] 48 | 49 | if len(feats) == 1: 50 | out = feats[0] 51 | elif len(feats) == 2: 52 | weight = torch.sigmoid(self.weight) 53 | out = 1 / 2 * (feats[0] + feats[1]) + weight[0] * feats[0] + weight[1] * feats[1] 54 | else: 55 | assert False, "Feature Extraction Exception!" 56 | 57 | out = out.permute(0, 3, 2, 1).contiguous() 58 | out = self.conv(out) 59 | out = out.permute(0, 3, 2, 1).contiguous() 60 | 61 | out = self.sigmoid(out) 62 | out = out.expand_as(x) 63 | 64 | return x * out 65 | 66 | 67 | class MCALayer(nn.Module): 68 | def __init__(self, inp, no_spatial=False): 69 | """Constructs a MCA module. 70 | Args: 71 | inp: Number of channels of the input feature maps 72 | no_spatial: whether to build channel dimension interactions 73 | """ 74 | super(MCALayer, self).__init__() 75 | 76 | lambd = 1.5 77 | gamma = 1 78 | temp = round(abs((math.log2(inp) - gamma) / lambd)) 79 | kernel = temp if temp % 2 else temp - 1 80 | 81 | self.h_cw = MCAGate(3) 82 | self.w_hc = MCAGate(3) 83 | self.no_spatial = no_spatial 84 | if not no_spatial: 85 | self.c_hw = MCAGate(kernel) 86 | 87 | def forward(self, x): 88 | x_h = x.permute(0, 2, 1, 3).contiguous() 89 | x_h = self.h_cw(x_h) 90 | x_h = x_h.permute(0, 2, 1, 3).contiguous() 91 | 92 | x_w = x.permute(0, 3, 2, 1).contiguous() 93 | x_w = self.w_hc(x_w) 94 | x_w = x_w.permute(0, 3, 2, 1).contiguous() 95 | 96 | if not self.no_spatial: 97 | x_c = self.c_hw(x) 98 | x_out = 1 / 3 * (x_c + x_h + x_w) 99 | else: 100 | x_out = 1 / 2 * (x_h + x_w) 101 | 102 | return x_out -------------------------------------------------------------------------------- /注意力模块/MCPA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class MCrossAttention(nn.Module): 5 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.1, proj_drop=0.1): 6 | super().__init__() 7 | self.num_heads = num_heads 8 | head_dim = dim // num_heads 9 | self.scale = qk_scale or head_dim ** -0.5 10 | 11 | self.wq = nn.Linear(head_dim, dim , bias=qkv_bias) 12 | self.wk = nn.Linear(head_dim, dim , bias=qkv_bias) 13 | self.wv = nn.Linear(head_dim, dim , bias=qkv_bias) 14 | # self.attn_drop = nn.Dropout(attn_drop) 15 | self.proj = nn.Linear(dim * num_heads, dim) 16 | self.proj_drop = nn.Dropout(proj_drop) 17 | 18 | def forward(self, x): 19 | 20 | B, N, C = x.shape 21 | q = self.wq(x[:, 0:1, ...].reshape(B, 1, self.num_heads, C // self.num_heads)).permute(0, 2, 1, 3) # B1C -> B1H(C/H) -> BH1(C/H) 22 | k = self.wk(x.reshape(B, N, self.num_heads, C // self.num_heads)).permute(0, 2, 1, 3) # BNC -> BNH(C/H) -> BHN(C/H) 23 | v = self.wv(x.reshape(B, N, self.num_heads, C // self.num_heads)).permute(0, 2, 1, 3) # BNC -> BNH(C/H) -> BHN(C/H) 24 | attn = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale 25 | # attn = (q @ k.transpose(-2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N 26 | attn = attn.softmax(dim=-1) 27 | # attn = self.attn_drop(attn) 28 | x = torch.einsum('bhij,bhjd->bhid', attn, v).transpose(1, 2) 29 | # x = (attn @ v).transpose(1, 2) 30 | x = x.reshape(B, 1, C * self.num_heads) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C 31 | x = self.proj(x) 32 | x = self.proj_drop(x) 33 | return x -------------------------------------------------------------------------------- /注意力模块/MEGA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | 6 | def gauss_kernel(channels=3, cuda=True): 7 | kernel = torch.tensor([[1., 4., 6., 4., 1], 8 | [4., 16., 24., 16., 4.], 9 | [6., 24., 36., 24., 6.], 10 | [4., 16., 24., 16., 4.], 11 | [1., 4., 6., 4., 1.]]) 12 | kernel /= 256. 13 | kernel = kernel.repeat(channels, 1, 1, 1) 14 | if cuda: 15 | kernel = kernel.cuda() 16 | return kernel 17 | 18 | 19 | def downsample(x): 20 | return x[:, :, ::2, ::2] 21 | 22 | 23 | def conv_gauss(img, kernel): 24 | img = F.pad(img, (2, 2, 2, 2), mode='reflect') 25 | out = F.conv2d(img, kernel, groups=img.shape[1]) 26 | return out 27 | 28 | 29 | def upsample(x, channels): 30 | cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3], device=x.device)], dim=3) 31 | cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3]) 32 | cc = cc.permute(0, 1, 3, 2) 33 | cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2, device=x.device)], dim=3) 34 | cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2) 35 | x_up = cc.permute(0, 1, 3, 2) 36 | return conv_gauss(x_up, 4 * gauss_kernel(channels)) 37 | 38 | 39 | def make_laplace(img, channels): 40 | filtered = conv_gauss(img, gauss_kernel(channels)) 41 | down = downsample(filtered) 42 | up = upsample(down, channels) 43 | if up.shape[2] != img.shape[2] or up.shape[3] != img.shape[3]: 44 | up = nn.functional.interpolate(up, size=(img.shape[2], img.shape[3])) 45 | diff = img - up 46 | return diff 47 | 48 | 49 | def make_laplace_pyramid(img, level, channels): 50 | current = img 51 | pyr = [] 52 | for _ in range(level): 53 | filtered = conv_gauss(current, gauss_kernel(channels)) 54 | down = downsample(filtered) 55 | up = upsample(down, channels) 56 | if up.shape[2] != current.shape[2] or up.shape[3] != current.shape[3]: 57 | up = nn.functional.interpolate(up, size=(current.shape[2], current.shape[3])) 58 | diff = current - up 59 | pyr.append(diff) 60 | current = down 61 | pyr.append(current) 62 | return pyr 63 | 64 | 65 | class ChannelGate(nn.Module): 66 | def __init__(self, gate_channels, reduction_ratio=16): 67 | super(ChannelGate, self).__init__() 68 | self.gate_channels = gate_channels 69 | self.mlp = nn.Sequential( 70 | nn.Flatten(), 71 | nn.Linear(gate_channels, gate_channels // reduction_ratio), 72 | nn.ReLU(), 73 | nn.Linear(gate_channels // reduction_ratio, gate_channels) 74 | ) 75 | 76 | def forward(self, x): 77 | avg_out = self.mlp(F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))) 78 | max_out = self.mlp(F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))) 79 | channel_att_sum = avg_out + max_out 80 | 81 | scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x) 82 | return x * scale 83 | 84 | 85 | class SpatialGate(nn.Module): 86 | def __init__(self): 87 | super(SpatialGate, self).__init__() 88 | kernel_size = 7 89 | self.spatial = nn.Conv2d(2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2) 90 | 91 | def forward(self, x): 92 | x_compress = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1) 93 | x_out = self.spatial(x_compress) 94 | scale = torch.sigmoid(x_out) # broadcasting 95 | return x * scale 96 | 97 | 98 | class CBAM(nn.Module): 99 | def __init__(self, gate_channels, reduction_ratio=16): 100 | super(CBAM, self).__init__() 101 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio) 102 | self.SpatialGate = SpatialGate() 103 | 104 | def forward(self, x): 105 | x_out = self.ChannelGate(x) 106 | x_out = self.SpatialGate(x_out) 107 | return x_out 108 | 109 | 110 | # Edge-Guided Attention Module 111 | class EGA(nn.Module): 112 | def __init__(self, in_channels): 113 | super(EGA, self).__init__() 114 | 115 | self.fusion_conv = nn.Sequential( 116 | nn.Conv2d(in_channels * 3, in_channels, 3, 1, 1), 117 | nn.BatchNorm2d(in_channels), 118 | nn.ReLU(inplace=True)) 119 | 120 | self.attention = nn.Sequential( 121 | nn.Conv2d(in_channels, 1, 3, 1, 1), 122 | nn.BatchNorm2d(1), 123 | nn.Sigmoid()) 124 | 125 | self.cbam = CBAM(in_channels) 126 | 127 | def forward(self, edge_feature, x, pred): 128 | residual = x 129 | xsize = x.size()[2:] 130 | pred = torch.sigmoid(pred) 131 | 132 | # reverse attention 133 | background_att = 1 - pred 134 | background_x = x * background_att 135 | 136 | # boudary attention 137 | edge_pred = make_laplace(pred, 1) 138 | pred_feature = x * edge_pred 139 | 140 | # high-frequency feature 141 | edge_input = F.interpolate(edge_feature, size=xsize, mode='bilinear', align_corners=True) 142 | input_feature = x * edge_input 143 | 144 | fusion_feature = torch.cat([background_x, pred_feature, input_feature], dim=1) 145 | fusion_feature = self.fusion_conv(fusion_feature) 146 | 147 | attention_map = self.attention(fusion_feature) 148 | fusion_feature = fusion_feature * attention_map 149 | 150 | out = fusion_feature + residual 151 | out = self.cbam(out) 152 | return out -------------------------------------------------------------------------------- /注意力模块/MLKA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LayerNorm(nn.Module): 7 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 8 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 9 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 10 | with shape (batch_size, channels, height, width). 11 | """ 12 | 13 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 14 | super().__init__() 15 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 16 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 17 | self.eps = eps 18 | self.data_format = data_format 19 | if self.data_format not in ["channels_last", "channels_first"]: 20 | raise NotImplementedError 21 | self.normalized_shape = (normalized_shape,) 22 | 23 | def forward(self, x): 24 | if self.data_format == "channels_last": 25 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 26 | elif self.data_format == "channels_first": 27 | u = x.mean(1, keepdim=True) 28 | s = (x - u).pow(2).mean(1, keepdim=True) 29 | x = (x - u) / torch.sqrt(s + self.eps) 30 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 31 | return x 32 | 33 | 34 | class MLKA_Ablation(nn.Module): 35 | def __init__(self, n_feats, k=2, squeeze_factor=15): 36 | super().__init__() 37 | i_feats = 2 * n_feats 38 | 39 | self.n_feats = n_feats 40 | self.i_feats = i_feats 41 | 42 | self.norm = LayerNorm(n_feats, data_format='channels_first') 43 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True) 44 | 45 | k = 2 46 | 47 | # Multiscale Large Kernel Attention 48 | self.LKA7 = nn.Sequential( 49 | nn.Conv2d(n_feats // k, n_feats // k, 7, 1, 7 // 2, groups=n_feats // k), 50 | nn.Conv2d(n_feats // k, n_feats // k, 9, stride=1, padding=(9 // 2) * 4, groups=n_feats // k, dilation=4), 51 | nn.Conv2d(n_feats // k, n_feats // k, 1, 1, 0)) 52 | self.LKA5 = nn.Sequential( 53 | nn.Conv2d(n_feats // k, n_feats // k, 5, 1, 5 // 2, groups=n_feats // k), 54 | nn.Conv2d(n_feats // k, n_feats // k, 7, stride=1, padding=(7 // 2) * 3, groups=n_feats // k, dilation=3), 55 | nn.Conv2d(n_feats // k, n_feats // k, 1, 1, 0)) 56 | '''self.LKA3 = nn.Sequential( 57 | nn.Conv2d(n_feats//k, n_feats//k, 3, 1, 1, groups= n_feats//k), 58 | nn.Conv2d(n_feats//k, n_feats//k, 5, stride=1, padding=(5//2)*2, groups=n_feats//k, dilation=2), 59 | nn.Conv2d(n_feats//k, n_feats//k, 1, 1, 0))''' 60 | 61 | # self.X3 = nn.Conv2d(n_feats//k, n_feats//k, 3, 1, 1, groups= n_feats//k) 62 | self.X5 = nn.Conv2d(n_feats // k, n_feats // k, 5, 1, 5 // 2, groups=n_feats // k) 63 | self.X7 = nn.Conv2d(n_feats // k, n_feats // k, 7, 1, 7 // 2, groups=n_feats // k) 64 | 65 | self.proj_first = nn.Sequential( 66 | nn.Conv2d(n_feats, i_feats, 1, 1, 0)) 67 | 68 | self.proj_last = nn.Sequential( 69 | nn.Conv2d(n_feats, n_feats, 1, 1, 0)) 70 | 71 | def forward(self, x, pre_attn=None, RAA=None): 72 | shortcut = x.clone() 73 | 74 | x = self.norm(x) 75 | 76 | x = self.proj_first(x) 77 | 78 | a, x = torch.chunk(x, 2, dim=1) 79 | 80 | # u_1, u_2, u_3= torch.chunk(u, 3, dim=1) 81 | a_1, a_2 = torch.chunk(a, 2, dim=1) 82 | 83 | a = torch.cat([self.LKA7(a_1) * self.X7(a_1), self.LKA5(a_2) * self.X5(a_2)], dim=1) 84 | 85 | x = self.proj_last(x * a) * self.scale + shortcut 86 | 87 | return x -------------------------------------------------------------------------------- /注意力模块/MSPA.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch 4 | 5 | 6 | def conv3x3(in_planes, out_planes, stride=1): 7 | """3x3 convolution with padding""" 8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 9 | 10 | 11 | def conv1x1(in_planes, out_planes, stride=1): 12 | """1x1 convolution""" 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 14 | 15 | 16 | def convdilated(in_planes, out_planes, kSize=3, stride=1, dilation=1): 17 | """3x3 convolution with dilation""" 18 | padding = int((kSize - 1) / 2) * dilation 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=kSize, stride=stride, padding=padding, 20 | dilation=dilation, bias=False) 21 | 22 | 23 | class SPRModule(nn.Module): 24 | def __init__(self, channels, reduction=16): 25 | super(SPRModule, self).__init__() 26 | 27 | self.avg_pool1 = nn.AdaptiveAvgPool2d(1) 28 | self.avg_pool2 = nn.AdaptiveAvgPool2d(2) 29 | 30 | self.fc1 = nn.Conv2d(channels * 5, channels//reduction, kernel_size=1, padding=0) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.fc2 = nn.Conv2d(channels//reduction, channels, kernel_size=1, padding=0) 33 | self.sigmoid = nn.Sigmoid() 34 | 35 | def forward(self, x): 36 | 37 | out1 = self.avg_pool1(x).view(x.size(0), -1, 1, 1) 38 | out2 = self.avg_pool2(x).view(x.size(0), -1, 1, 1) 39 | out = torch.cat((out1, out2), 1) 40 | 41 | out = self.fc1(out) 42 | out = self.relu(out) 43 | out = self.fc2(out) 44 | weight = self.sigmoid(out) 45 | 46 | return weight 47 | 48 | 49 | class MSAModule(nn.Module): 50 | def __init__(self, inplanes, scale=3, stride=1, stype='normal'): 51 | """ Constructor 52 | Args: 53 | inplanes: input channel dimensionality. 54 | scale: number of scale. 55 | stride: conv stride. 56 | stype: 'normal': normal set. 'stage': first block of a new stage. 57 | """ 58 | super(MSAModule, self).__init__() 59 | 60 | self.width = inplanes 61 | self.nums = scale 62 | self.stride = stride 63 | assert stype in ['stage', 'normal'], 'One of these is suppported (stage or normal)' 64 | self.stype = stype 65 | 66 | self.convs = nn.ModuleList([]) 67 | self.bns = nn.ModuleList([]) 68 | 69 | for i in range(self.nums): 70 | if self.stype == 'stage' and self.stride != 1: 71 | self.convs.append(convdilated(self.width, self.width, stride=stride, dilation=int(i + 1))) 72 | else: 73 | self.convs.append(conv3x3(self.width, self.width, stride)) 74 | 75 | self.bns.append(nn.BatchNorm2d(self.width)) 76 | 77 | self.attention = SPRModule(self.width) 78 | 79 | self.softmax = nn.Softmax(dim=1) 80 | 81 | def forward(self, x): 82 | batch_size = x.shape[0] 83 | 84 | spx = torch.split(x, self.width, 1) 85 | for i in range(self.nums): 86 | if i == 0 or (self.stype == 'stage' and self.stride != 1): 87 | sp = spx[i] 88 | else: 89 | sp = sp + spx[i] 90 | sp = self.convs[i](sp) 91 | sp = self.bns[i](sp) 92 | 93 | if i == 0: 94 | out = sp 95 | else: 96 | out = torch.cat((out, sp), 1) 97 | 98 | feats = out 99 | feats = feats.view(batch_size, self.nums, self.width, feats.shape[2], feats.shape[3]) 100 | 101 | sp_inp = torch.split(out, self.width, 1) 102 | 103 | attn_weight = [] 104 | for inp in sp_inp: 105 | attn_weight.append(self.attention(inp)) 106 | 107 | attn_weight = torch.cat(attn_weight, dim=1) 108 | attn_vectors = attn_weight.view(batch_size, self.nums, self.width, 1, 1) 109 | attn_vectors = self.softmax(attn_vectors) 110 | feats_weight = feats * attn_vectors 111 | 112 | for i in range(self.nums): 113 | x_attn_weight = feats_weight[:, i, :, :, :] 114 | if i == 0: 115 | out = x_attn_weight 116 | else: 117 | out = torch.cat((out, x_attn_weight), 1) 118 | 119 | return out 120 | 121 | 122 | class MSPABlock(nn.Module): 123 | expansion = 4 124 | 125 | def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=30, scale=3, 126 | norm_layer=None, stype='normal'): 127 | """ Constructor 128 | Args: 129 | inplanes: input channel dimensionality. 130 | planes: output channel dimensionality. 131 | stride: conv stride. 132 | downsample: None when stride = 1. 133 | baseWidth: basic width of conv3x3. 134 | scale: number of scale. 135 | norm_layer: regularization layer. 136 | stype: 'normal': normal set. 'stage': first block of a new stage. 137 | """ 138 | super(MSPABlock, self).__init__() 139 | if norm_layer is None: 140 | norm_layer = nn.BatchNorm2d 141 | width = int(math.floor(planes * (baseWidth / 64.0))) 142 | 143 | self.conv1 = conv1x1(inplanes, width * scale) 144 | self.bn1 = norm_layer(width * scale) 145 | 146 | self.conv2 = MSAModule(width, scale=scale, stride=stride, stype=stype) 147 | self.bn2 = norm_layer(width * scale) 148 | 149 | self.conv3 = conv1x1(width * scale, planes * self.expansion) 150 | self.bn3 = norm_layer(planes * self.expansion) 151 | self.relu = nn.ReLU(inplace=True) 152 | 153 | self.downsample = downsample 154 | 155 | def forward(self, x): 156 | identity = x 157 | 158 | out = self.conv1(x) 159 | out = self.bn1(out) 160 | out = self.relu(out) 161 | 162 | out = self.conv2(out) 163 | out = self.bn2(out) 164 | out = self.relu(out) 165 | 166 | out = self.conv3(out) 167 | out = self.bn3(out) 168 | 169 | if self.downsample is not None: 170 | identity = self.downsample(x) 171 | 172 | out += identity 173 | out = self.relu(out) 174 | 175 | return out 176 | 177 | 178 | class MSPANet(nn.Module): 179 | def __init__(self, block, layers, num_classes=1000, baseWidth=30, scale=3, norm_layer=None): 180 | super(MSPANet, self).__init__() 181 | if norm_layer is None: 182 | norm_layer = nn.BatchNorm2d 183 | self._norm_layer = norm_layer 184 | 185 | self.inplanes = 64 186 | self.baseWidth = baseWidth 187 | self.scale = scale 188 | 189 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 190 | self.bn1 = norm_layer(self.inplanes) 191 | self.relu = nn.ReLU(inplace=True) 192 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 193 | 194 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1) 195 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 196 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 197 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 198 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 199 | 200 | self.fc = nn.Linear(512 * block.expansion, num_classes) 201 | 202 | # weight initialization 203 | self._initialize_weights() 204 | 205 | def _initialize_weights(self): 206 | for m in self.modules(): 207 | if isinstance(m, nn.Conv2d): 208 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 209 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 210 | nn.init.constant_(m.weight, 1) 211 | nn.init.constant_(m.bias, 0) 212 | 213 | def _make_layer(self, block, planes, blocks, stride=1): 214 | norm_layer = self._norm_layer 215 | downsample = None 216 | if stride != 1 or self.inplanes != planes * block.expansion: 217 | downsample = nn.Sequential( 218 | conv1x1(self.inplanes, planes * block.expansion, stride), 219 | norm_layer(planes * block.expansion), 220 | ) 221 | 222 | layers = [] 223 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 224 | baseWidth=self.baseWidth, scale=self.scale, norm_layer=norm_layer, stype='stage')) 225 | self.inplanes = planes * block.expansion 226 | for i in range(1, blocks): 227 | layers.append(block(self.inplanes, planes, baseWidth=self.baseWidth, scale=self.scale, 228 | norm_layer=norm_layer)) 229 | 230 | return nn.Sequential(*layers) 231 | 232 | def forward(self, x): 233 | x = self.conv1(x) 234 | x = self.bn1(x) 235 | x = self.relu(x) 236 | x = self.maxpool(x) 237 | 238 | x = self.layer1(x) 239 | x = self.layer2(x) 240 | x = self.layer3(x) 241 | x = self.layer4(x) 242 | 243 | x = self.avgpool(x) 244 | x = x.view(x.size(0), -1) 245 | 246 | x = self.fc(x) 247 | 248 | return x -------------------------------------------------------------------------------- /注意力模块/NonLocal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class NonLocalBlock(nn.Module): 6 | def __init__(self, in_channels, inter_channels=None, dimension=2, sub_sample=False, bn_layer=True): 7 | """ 8 | :param in_channels: Number of channels in the input feature map. 9 | :param inter_channels: Number of channels in the intermediate feature map. 10 | If None, it will be set to in_channels // 2. 11 | :param dimension: The spatial dimensions of the input. 12 | 2 for 2D (image), 3 for 3D (video). 13 | :param sub_sample: Whether to apply subsampling to reduce computation. 14 | :param bn_layer: Whether to add BatchNorm layer after the Non-Local operation. 15 | """ 16 | super(NonLocalBlock, self).__init__() 17 | 18 | assert dimension in [1, 2, 3], "Only 1D, 2D, and 3D inputs are supported." 19 | self.dimension = dimension 20 | self.sub_sample = sub_sample 21 | 22 | # Determine the dimension for 1D, 2D, or 3D 23 | if dimension == 3: 24 | self.pool = nn.MaxPool3d(kernel_size=(1, 2, 2)) 25 | elif dimension == 2: 26 | self.pool = nn.MaxPool2d(kernel_size=(2, 2)) 27 | else: 28 | self.pool = nn.MaxPool1d(kernel_size=2) 29 | 30 | # Define the intermediate channels size 31 | if inter_channels is None: 32 | inter_channels = in_channels // 2 33 | if inter_channels == 0: 34 | inter_channels = 1 35 | 36 | self.g = nn.Conv2d(in_channels=in_channels, out_channels=inter_channels, kernel_size=1, stride=1, padding=0) 37 | 38 | # theta and phi will be used to compute similarity (dot product) 39 | self.theta = nn.Conv2d(in_channels=in_channels, out_channels=inter_channels, kernel_size=1, stride=1, padding=0) 40 | self.phi = nn.Conv2d(in_channels=in_channels, out_channels=inter_channels, kernel_size=1, stride=1, padding=0) 41 | 42 | # W will transform the output of NonLocal block back to the original dimension 43 | self.W = nn.Conv2d(in_channels=inter_channels, out_channels=in_channels, kernel_size=1, stride=1, padding=0) 44 | 45 | # Optionally, include a BatchNorm layer after W 46 | if bn_layer: 47 | self.W = nn.Sequential( 48 | self.W, 49 | nn.BatchNorm2d(in_channels) 50 | ) 51 | 52 | # Optional subsampling 53 | if sub_sample: 54 | self.g = nn.Sequential(self.g, self.pool) 55 | self.phi = nn.Sequential(self.phi, self.pool) 56 | 57 | def forward(self, x): 58 | """ 59 | :param x: Input feature map of shape (N, C, H, W) where 60 | N is the batch size, 61 | C is the number of channels, 62 | H is the height, 63 | W is the width. 64 | :return: Output feature map of the same shape as input. 65 | """ 66 | batch_size, C, H, W = x.size() 67 | 68 | # Apply transformations: theta, phi, and g 69 | g_x = self.g(x).view(batch_size, -1, H * W) # g(x) shape: (N, C', H*W) 70 | g_x = g_x.permute(0, 2, 1) # g(x) shape: (N, H*W, C') 71 | 72 | theta_x = self.theta(x).view(batch_size, -1, H * W) # theta(x) shape: (N, C', H*W) 73 | phi_x = self.phi(x).view(batch_size, -1, H * W) # phi(x) shape: (N, C', H*W) 74 | phi_x = phi_x.permute(0, 2, 1) # phi(x) shape: (N, H*W, C') 75 | 76 | # Compute similarity: theta_x * phi_x^T (matrix multiplication) 77 | f = torch.matmul(theta_x, phi_x) # shape: (N, C', C') 78 | f_div_C = F.softmax(f, dim=-1) # Apply softmax to normalize similarity 79 | 80 | # Apply attention map to g(x) 81 | y = torch.matmul(f_div_C, g_x) # shape: (N, C', H*W) 82 | y = y.permute(0, 2, 1).contiguous() # Reshape: (N, H*W, C') 83 | y = y.view(batch_size, C // 2, H, W) # Reshape: (N, C', H, W) 84 | 85 | # Transform the output back to original input dimension with W 86 | W_y = self.W(y) 87 | 88 | # Residual connection: adding input x to the output 89 | z = W_y + x 90 | 91 | return z 92 | 93 | 94 | """ 95 | Example of use, inserted into the middle layer of ResNet 96 | 97 | from torchvision.models import resnet50 98 | 99 | class ResNetWithNonLocal(nn.Module): 100 | def __init__(self): 101 | super(ResNetWithNonLocal, self).__init__() 102 | self.resnet = resnet50(pretrained=True) 103 | self.nonlocal_block = NonLocalBlock(in_channels=256) 104 | 105 | def forward(self, x): 106 | x = self.resnet.layer1(x) 107 | x = self.nonlocal_block(x) 108 | x = self.resnet.layer2(x) 109 | return x 110 | """ -------------------------------------------------------------------------------- /注意力模块/SA-Net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | 5 | 6 | class sa_layer(nn.Module): 7 | """Constructs a Channel Spatial Group module. 8 | 9 | Args: 10 | k_size: Adaptive selection of kernel size 11 | """ 12 | 13 | def __init__(self, channel, groups=64): 14 | super(sa_layer, self).__init__() 15 | self.groups = groups 16 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 17 | self.cweight = Parameter(torch.zeros(1, channel // (2 * groups), 1, 1)) 18 | self.cbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1)) 19 | self.sweight = Parameter(torch.zeros(1, channel // (2 * groups), 1, 1)) 20 | self.sbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1)) 21 | 22 | self.sigmoid = nn.Sigmoid() 23 | self.gn = nn.GroupNorm(channel // (2 * groups), channel // (2 * groups)) 24 | 25 | @staticmethod 26 | def channel_shuffle(x, groups): 27 | b, c, h, w = x.shape 28 | 29 | x = x.reshape(b, groups, -1, h, w) 30 | x = x.permute(0, 2, 1, 3, 4) 31 | 32 | # flatten 33 | x = x.reshape(b, -1, h, w) 34 | 35 | return x 36 | 37 | def forward(self, x): 38 | b, c, h, w = x.shape 39 | 40 | x = x.reshape(b * self.groups, -1, h, w) 41 | x_0, x_1 = x.chunk(2, dim=1) 42 | 43 | # channel attention 44 | xn = self.avg_pool(x_0) 45 | xn = self.cweight * xn + self.cbias 46 | xn = x_0 * self.sigmoid(xn) 47 | 48 | # spatial attention 49 | xs = self.gn(x_1) 50 | xs = self.sweight * xs + self.sbias 51 | xs = x_1 * self.sigmoid(xs) 52 | 53 | # concatenate along channel axis 54 | out = torch.cat([xn, xs], dim=1) 55 | out = out.reshape(b, -1, h, w) 56 | 57 | out = self.channel_shuffle(out, 2) 58 | return out -------------------------------------------------------------------------------- /注意力模块/SENet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SEBlock(nn.Module): 6 | def __init__(self, in_channels, reduction=16): 7 | super(SEBlock, self).__init__() 8 | 9 | # 全局平均池化 10 | self.global_avg_pool = nn.AdaptiveAvgPool2d(1) 11 | 12 | # 两个全连接层 13 | self.fc1 = nn.Linear(in_channels, in_channels // reduction, bias=False) 14 | self.relu = nn.ReLU(inplace=True) 15 | self.fc2 = nn.Linear(in_channels // reduction, in_channels, bias=False) 16 | 17 | # Sigmoid 激活函数用于生成权重 18 | self.sigmoid = nn.Sigmoid() 19 | 20 | def forward(self, x): 21 | batch_size, channels, _, _ = x.size() 22 | 23 | # 全局平均池化 24 | y = self.global_avg_pool(x).view(batch_size, channels) 25 | 26 | # 通过全连接层并生成注意力权重 27 | y = self.fc1(y) 28 | y = self.relu(y) 29 | y = self.fc2(y) 30 | y = self.sigmoid(y).view(batch_size, channels, 1, 1) 31 | 32 | # 将权重与输入特征相乘 33 | return x * y.expand_as(x) 34 | 35 | 36 | class SENet(nn.Module): 37 | def __init__(self, in_channels, reduction=16): 38 | super(SENet, self).__init__() 39 | self.se_block = SEBlock(in_channels, reduction) 40 | 41 | def forward(self, x): 42 | return self.se_block(x) -------------------------------------------------------------------------------- /注意力模块/SLAB.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | 6 | class SimplifiedLinearAttention(nn.Module): 7 | def __init__(self, dim, num_patches, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, 8 | focusing_factor=3, kernel_size=5, norm_layer=nn.LayerNorm): 9 | super().__init__() 10 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 11 | 12 | self.dim = dim 13 | self.num_heads = num_heads 14 | head_dim = dim // num_heads 15 | 16 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 17 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 18 | self.attn_drop = nn.Dropout(attn_drop) 19 | self.proj = nn.Linear(dim, dim) 20 | self.proj_drop = nn.Dropout(proj_drop) 21 | 22 | self.sr_ratio = sr_ratio 23 | if sr_ratio > 1: 24 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 25 | self.norm = norm_layer(dim) 26 | 27 | self.focusing_factor = focusing_factor 28 | self.dwc = nn.Conv2d(in_channels=head_dim, out_channels=head_dim, kernel_size=kernel_size, 29 | groups=head_dim, padding=kernel_size // 2) 30 | self.positional_encoding = nn.Parameter(torch.zeros(size=(1, num_patches // (sr_ratio * sr_ratio), dim))) 31 | print('Linear Attention sr_ratio{} f{} kernel{}'. 32 | format(sr_ratio, focusing_factor, kernel_size)) 33 | 34 | def forward(self, x, H, W): 35 | B, N, C = x.shape 36 | q = self.q(x) 37 | 38 | if self.sr_ratio > 1: 39 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 40 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 41 | x_ = self.norm(x_) 42 | kv = self.kv(x_).reshape(B, -1, 2, C).permute(2, 0, 1, 3) 43 | else: 44 | kv = self.kv(x).reshape(B, -1, 2, C).permute(2, 0, 1, 3) 45 | k, v = kv[0], kv[1] 46 | 47 | k = k + self.positional_encoding 48 | kernel_function = nn.ReLU() 49 | q = kernel_function(q) 50 | k = kernel_function(k) 51 | 52 | q, k, v = (rearrange(x, "b n (h c) -> (b h) n c", h=self.num_heads) for x in [q, k, v]) 53 | i, j, c, d = q.shape[-2], k.shape[-2], k.shape[-1], v.shape[-1] 54 | 55 | z = 1 / (torch.einsum("b i c, b c -> b i", q, k.sum(dim=1)) + 1e-6) 56 | if i * j * (c + d) > c * d * (i + j): 57 | kv = torch.einsum("b j c, b j d -> b c d", k, v) 58 | x = torch.einsum("b i c, b c d, b i -> b i d", q, kv, z) 59 | else: 60 | qk = torch.einsum("b i c, b j c -> b i j", q, k) 61 | x = torch.einsum("b i j, b j d, b i -> b i d", qk, v, z) 62 | 63 | if self.sr_ratio > 1: 64 | v = nn.functional.interpolate(v.permute(0, 2, 1), size=x.shape[1], mode='linear').permute(0, 2, 1) 65 | num = int(v.shape[1] ** 0.5) 66 | feature_map = rearrange(v, "b (w h) c -> b c w h", w=num, h=num) 67 | feature_map = rearrange(self.dwc(feature_map), "b c w h -> b (w h) c") 68 | x = x + feature_map 69 | x = rearrange(x, "(b h) n c -> b n (h c)", h=self.num_heads) 70 | 71 | x = self.proj(x) 72 | x = self.proj_drop(x) 73 | 74 | return x 75 | -------------------------------------------------------------------------------- /注意力模块/SRA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class SRA(nn.Module): 5 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 6 | super().__init__() 7 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 8 | 9 | self.dim = dim 10 | self.num_heads = num_heads 11 | head_dim = dim // num_heads 12 | self.scale = qk_scale or head_dim ** -0.5 13 | 14 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 15 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 16 | self.attn_drop = nn.Dropout(attn_drop) 17 | self.proj = nn.Linear(dim, dim) 18 | self.proj_drop = nn.Dropout(proj_drop) 19 | 20 | self.sr_ratio = sr_ratio 21 | if sr_ratio > 1: 22 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 23 | self.norm = nn.LayerNorm(dim) 24 | 25 | def forward(self, x, H, W): 26 | B, N, C = x.shape 27 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 28 | 29 | if self.sr_ratio > 1: 30 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 31 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 32 | x_ = self.norm(x_) 33 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 34 | else: 35 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 36 | k, v = kv[0], kv[1] 37 | 38 | attn = (q @ k.transpose(-2, -1)) * self.scale 39 | attn = attn.softmax(dim=-1) 40 | attn = self.attn_drop(attn) 41 | 42 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 43 | x = self.proj(x) 44 | x = self.proj_drop(x) 45 | 46 | return x -------------------------------------------------------------------------------- /注意力模块/SimAM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class SimAM(nn.Module): 5 | def __init__(self, lambda_=0.0001): 6 | """ 7 | SimAM attention module. 8 | :param lambda_: A coefficient lambda in the equation. 9 | """ 10 | super(SimAM, self).__init__() 11 | self.lambda_ = lambda_ 12 | 13 | def forward(self, X): 14 | """ 15 | Forward pass for SimAM. 16 | :param X: Input feature map of shape (N, C, H, W) 17 | :return: Output feature map with attention applied, same shape as input (N, C, H, W) 18 | """ 19 | # Calculate the spatial size minus 1 for normalization 20 | n = X.shape[2] * X.shape[3] - 1 21 | 22 | # Calculate the square of (X - mean(X)) 23 | d = (X - X.mean(dim=[2, 3], keepdim=True)).pow(2) 24 | 25 | # Channel variance (d.sum() / n) 26 | v = d.sum(dim=[2, 3], keepdim=True) / n 27 | 28 | # Calculate E_inv which contains importance of X 29 | E_inv = d / (4 * (v + self.lambda_)) + 0.5 30 | 31 | # Return attended features using sigmoid activation 32 | return X * torch.sigmoid(E_inv) 33 | -------------------------------------------------------------------------------- /特征提取or融合or对齐模块/AFF.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class AFF(nn.Module): 6 | def __init__(self, channels=64, reduction=4): 7 | """ 8 | Attentional Feature Fusion (AFF) module. 9 | :param channels: Number of input channels. 10 | :param reduction: Reduction ratio for intermediate channels in attention layers. Default is 4. 11 | """ 12 | super(AFF, self).__init__() 13 | intermediate_channels = channels // reduction 14 | 15 | # Local attention layer 16 | self.local_attention = nn.Sequential( 17 | nn.Conv2d(channels, intermediate_channels, kernel_size=1, stride=1, padding=0), 18 | nn.BatchNorm2d(intermediate_channels), 19 | nn.ReLU(inplace=True), 20 | nn.Conv2d(intermediate_channels, channels, kernel_size=1, stride=1, padding=0), 21 | nn.BatchNorm2d(channels), 22 | ) 23 | 24 | # Global attention layer 25 | self.global_attention = nn.Sequential( 26 | nn.AdaptiveAvgPool2d(1), 27 | nn.Conv2d(channels, intermediate_channels, kernel_size=1, stride=1, padding=0), 28 | nn.BatchNorm2d(intermediate_channels), 29 | nn.ReLU(inplace=True), 30 | nn.Conv2d(intermediate_channels, channels, kernel_size=1, stride=1, padding=0), 31 | nn.BatchNorm2d(channels), 32 | ) 33 | 34 | # Activation function for attention weight 35 | self.sigmoid = nn.Sigmoid() 36 | 37 | def forward(self, input_feature, residual_feature): 38 | """ 39 | Forward pass for the AFF module. 40 | :param input_feature: First input feature map (tensor of shape N x C x H x W). 41 | :param residual_feature: Second input feature map (tensor of shape N x C x H x W). 42 | :return: Output feature map with fused attention applied (same shape as input). 43 | """ 44 | # Initial fusion by element-wise addition 45 | combined_feature = input_feature + residual_feature 46 | 47 | # Compute local and global attention 48 | local_attention = self.local_attention(combined_feature) 49 | global_attention = self.global_attention(combined_feature) 50 | 51 | # Sum local and global attention, then apply sigmoid for attention weight 52 | attention_weight = self.sigmoid(local_attention + global_attention) 53 | 54 | # Weighted combination of input and residual features 55 | output_feature = 2 * input_feature * attention_weight + 2 * residual_feature * (1 - attention_weight) 56 | 57 | return output_feature 58 | 59 | 60 | """ 61 | 使用示例 62 | 63 | # 假设输入特征图 input_feature 和 residual_feature 的通道数为 64,尺寸为 32x32 64 | input_feature = torch.randn(1, 64, 32, 32) # (N, C, H, W) 65 | residual_feature = torch.randn(1, 64, 32, 32) 66 | 67 | # 初始化 AFF 模块,指定输入通道数为 64,reduction 比例为 4 68 | aff_module = AFF(channels=64, reduction=4) 69 | 70 | # 计算 AFF 模块的输出 71 | output_feature = aff_module(input_feature, residual_feature) 72 | 73 | # 输出特征图的形状,应该与输入特征图一致 74 | print(output_feature.shape) # 输出: torch.Size([1, 64, 32, 32]) 75 | 76 | """ -------------------------------------------------------------------------------- /特征提取or融合or对齐模块/CAFM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | 5 | 6 | class CAFM(nn.Module): 7 | def __init__(self, dim, num_heads, bias): 8 | super(CAFM, self).__init__() 9 | self.num_heads = num_heads 10 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) 11 | 12 | self.qkv = nn.Conv3d(dim, dim * 3, kernel_size=(1, 1, 1), bias=bias) 13 | self.qkv_dwconv = nn.Conv3d(dim * 3, dim * 3, kernel_size=(3, 3, 3), stride=1, padding=1, groups=dim * 3, 14 | bias=bias) 15 | self.project_out = nn.Conv3d(dim, dim, kernel_size=(1, 1, 1), bias=bias) 16 | self.fc = nn.Conv3d(3 * self.num_heads, 9, kernel_size=(1, 1, 1), bias=True) 17 | 18 | self.dep_conv = nn.Conv3d(9 * dim // self.num_heads, dim, kernel_size=(3, 3, 3), bias=True, 19 | groups=dim // self.num_heads, padding=1) 20 | 21 | def forward(self, x): 22 | b, c, h, w = x.shape 23 | x = x.unsqueeze(2) 24 | qkv = self.qkv_dwconv(self.qkv(x)) 25 | qkv = qkv.squeeze(2) 26 | f_conv = qkv.permute(0, 2, 3, 1) 27 | f_all = qkv.reshape(f_conv.shape[0], h * w, 3 * self.num_heads, -1).permute(0, 2, 1, 3) 28 | f_all = self.fc(f_all.unsqueeze(2)) 29 | f_all = f_all.squeeze(2) 30 | 31 | # local conv 32 | f_conv = f_all.permute(0, 3, 1, 2).reshape(x.shape[0], 9 * x.shape[1] // self.num_heads, h, w) 33 | f_conv = f_conv.unsqueeze(2) 34 | out_conv = self.dep_conv(f_conv) # B, C, H, W 35 | out_conv = out_conv.squeeze(2) 36 | 37 | # global SA 38 | q, k, v = qkv.chunk(3, dim=1) 39 | 40 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 41 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 42 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 43 | 44 | q = torch.nn.functional.normalize(q, dim=-1) 45 | k = torch.nn.functional.normalize(k, dim=-1) 46 | 47 | attn = (q @ k.transpose(-2, -1)) * self.temperature 48 | attn = attn.softmax(dim=-1) 49 | 50 | out = (attn @ v) 51 | 52 | out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) 53 | out = out.unsqueeze(2) 54 | out = self.project_out(out) 55 | out = out.squeeze(2) 56 | output = out + out_conv 57 | 58 | return output -------------------------------------------------------------------------------- /特征提取or融合or对齐模块/CCFF.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def get_activation(act: str, inpace: bool = True): 7 | '''get activation 8 | ''' 9 | act = act.lower() 10 | 11 | if act == 'silu': 12 | m = nn.SiLU() 13 | 14 | elif act == 'relu': 15 | m = nn.ReLU() 16 | 17 | elif act == 'leaky_relu': 18 | m = nn.LeakyReLU() 19 | 20 | elif act == 'silu': 21 | m = nn.SiLU() 22 | 23 | elif act == 'gelu': 24 | m = nn.GELU() 25 | 26 | elif act is None: 27 | m = nn.Identity() 28 | 29 | elif isinstance(act, nn.Module): 30 | m = act 31 | 32 | else: 33 | raise RuntimeError('') 34 | 35 | if hasattr(m, 'inplace'): 36 | m.inplace = inpace 37 | 38 | return m 39 | 40 | class ConvNormLayer(nn.Module): 41 | def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None): 42 | super().__init__() 43 | self.conv = nn.Conv2d( 44 | ch_in, 45 | ch_out, 46 | kernel_size, 47 | stride, 48 | padding=(kernel_size - 1) // 2 if padding is None else padding, 49 | bias=bias) 50 | self.norm = nn.BatchNorm2d(ch_out) 51 | self.act = nn.Identity() if act is None else get_activation(act) 52 | 53 | def forward(self, x): 54 | return self.act(self.norm(self.conv(x))) 55 | 56 | 57 | class RepVggBlock(nn.Module): 58 | def __init__(self, ch_in, ch_out, act='relu'): 59 | super().__init__() 60 | self.ch_in = ch_in 61 | self.ch_out = ch_out 62 | self.conv1 = ConvNormLayer(ch_in, ch_out, 3, 1, padding=1, act=None) 63 | self.conv2 = ConvNormLayer(ch_in, ch_out, 1, 1, padding=0, act=None) 64 | self.act = nn.Identity() if act is None else get_activation(act) 65 | 66 | def forward(self, x): 67 | if hasattr(self, 'conv'): 68 | y = self.conv(x) 69 | else: 70 | y = self.conv1(x) + self.conv2(x) 71 | 72 | return self.act(y) 73 | 74 | def convert_to_deploy(self): 75 | if not hasattr(self, 'conv'): 76 | self.conv = nn.Conv2d(self.ch_in, self.ch_out, 3, 1, padding=1) 77 | 78 | kernel, bias = self.get_equivalent_kernel_bias() 79 | self.conv.weight.data = kernel 80 | self.conv.bias.data = bias 81 | # self.__delattr__('conv1') 82 | # self.__delattr__('conv2') 83 | 84 | def get_equivalent_kernel_bias(self): 85 | kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1) 86 | kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2) 87 | 88 | return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1), bias3x3 + bias1x1 89 | 90 | def _pad_1x1_to_3x3_tensor(self, kernel1x1): 91 | if kernel1x1 is None: 92 | return 0 93 | else: 94 | return F.pad(kernel1x1, [1, 1, 1, 1]) 95 | 96 | def _fuse_bn_tensor(self, branch: ConvNormLayer): 97 | if branch is None: 98 | return 0, 0 99 | kernel = branch.conv.weight 100 | running_mean = branch.norm.running_mean 101 | running_var = branch.norm.running_var 102 | gamma = branch.norm.weight 103 | beta = branch.norm.bias 104 | eps = branch.norm.eps 105 | std = (running_var + eps).sqrt() 106 | t = (gamma / std).reshape(-1, 1, 1, 1) 107 | return kernel * t, beta - running_mean * gamma / std 108 | 109 | 110 | class CCFF(nn.Module): 111 | def __init__(self, 112 | in_channels, 113 | out_channels, 114 | num_blocks=3, 115 | expansion=1.0, 116 | bias=None, 117 | act="silu"): 118 | super(CCFF, self).__init__() 119 | hidden_channels = int(out_channels * expansion) 120 | self.conv1 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act) 121 | self.conv2 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act) 122 | self.bottlenecks = nn.Sequential(*[ 123 | RepVggBlock(hidden_channels, hidden_channels, act=act) for _ in range(num_blocks) 124 | ]) 125 | if hidden_channels != out_channels: 126 | self.conv3 = ConvNormLayer(hidden_channels, out_channels, 1, 1, bias=bias, act=act) 127 | else: 128 | self.conv3 = nn.Identity() 129 | 130 | def forward(self, x): 131 | x_1 = self.conv1(x) 132 | x_1 = self.bottlenecks(x_1) 133 | x_2 = self.conv2(x) 134 | return self.conv3(x_1 + x_2) -------------------------------------------------------------------------------- /特征提取or融合or对齐模块/CGAFusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from einops.layers.torch import Rearrange 4 | 5 | 6 | class SpatialAttention(nn.Module): 7 | def __init__(self): 8 | super(SpatialAttention, self).__init__() 9 | self.sa = nn.Conv2d(2, 1, 7, padding=3, padding_mode='reflect' ,bias=True) 10 | 11 | def forward(self, x): 12 | x_avg = torch.mean(x, dim=1, keepdim=True) 13 | x_max, _ = torch.max(x, dim=1, keepdim=True) 14 | x2 = torch.concat([x_avg, x_max], dim=1) 15 | sattn = self.sa(x2) 16 | return sattn 17 | 18 | 19 | class ChannelAttention(nn.Module): 20 | def __init__(self, dim, reduction=8): 21 | super(ChannelAttention, self).__init__() 22 | self.gap = nn.AdaptiveAvgPool2d(1) 23 | self.ca = nn.Sequential( 24 | nn.Conv2d(dim, dim // reduction, 1, padding=0, bias=True), 25 | nn.ReLU(inplace=True), 26 | nn.Conv2d(dim // reduction, dim, 1, padding=0, bias=True), 27 | ) 28 | 29 | def forward(self, x): 30 | x_gap = self.gap(x) 31 | cattn = self.ca(x_gap) 32 | return cattn 33 | 34 | 35 | class PixelAttention(nn.Module): 36 | def __init__(self, dim): 37 | super(PixelAttention, self).__init__() 38 | self.pa2 = nn.Conv2d(2 * dim, dim, 7, padding=3, padding_mode='reflect', groups=dim, bias=True) 39 | self.sigmoid = nn.Sigmoid() 40 | 41 | def forward(self, x, pattn1): 42 | B, C, H, W = x.shape 43 | x = x.unsqueeze(dim=2) # B, C, 1, H, W 44 | pattn1 = pattn1.unsqueeze(dim=2) # B, C, 1, H, W 45 | x2 = torch.cat([x, pattn1], dim=2) # B, C, 2, H, W 46 | x2 = Rearrange('b c t h w -> b (c t) h w')(x2) 47 | pattn2 = self.pa2(x2) 48 | pattn2 = self.sigmoid(pattn2) 49 | return pattn2 50 | 51 | 52 | class CGAFusion(nn.Module): 53 | def __init__(self, dim, reduction=8): 54 | super(CGAFusion, self).__init__() 55 | self.sa = SpatialAttention() 56 | self.ca = ChannelAttention(dim, reduction) 57 | self.pa = PixelAttention(dim) 58 | self.conv = nn.Conv2d(dim, dim, 1, bias=True) 59 | self.sigmoid = nn.Sigmoid() 60 | 61 | def forward(self, x, y): 62 | initial = x + y 63 | cattn = self.ca(initial) 64 | sattn = self.sa(initial) 65 | pattn1 = sattn + cattn 66 | pattn2 = self.sigmoid(self.pa(initial, pattn1)) 67 | result = initial + pattn2 * x + (1 - pattn2) * y 68 | result = self.conv(result) 69 | return result -------------------------------------------------------------------------------- /特征提取or融合or对齐模块/CSAM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.distributions as td 4 | 5 | 6 | def custom_max(x,dim,keepdim=True): 7 | temp_x=x 8 | for i in dim: 9 | temp_x=torch.max(temp_x,dim=i,keepdim=True)[0] 10 | if not keepdim: 11 | temp_x=temp_x.squeeze() 12 | return temp_x 13 | 14 | class PositionalAttentionModule(nn.Module): 15 | def __init__(self): 16 | super(PositionalAttentionModule,self).__init__() 17 | self.conv=nn.Conv2d(in_channels=2,out_channels=1,kernel_size=(7,7),padding=3) 18 | def forward(self,x): 19 | max_x=custom_max(x,dim=(0,1),keepdim=True) 20 | avg_x=torch.mean(x,dim=(0,1),keepdim=True) 21 | att=torch.cat((max_x,avg_x),dim=1) 22 | att=self.conv(att) 23 | att=torch.sigmoid(att) 24 | return x*att 25 | 26 | class SemanticAttentionModule(nn.Module): 27 | def __init__(self,in_features,reduction_rate=16): 28 | super(SemanticAttentionModule,self).__init__() 29 | self.linear=[] 30 | self.linear.append(nn.Linear(in_features=in_features,out_features=in_features//reduction_rate)) 31 | self.linear.append(nn.ReLU()) 32 | self.linear.append(nn.Linear(in_features=in_features//reduction_rate,out_features=in_features)) 33 | self.linear=nn.Sequential(*self.linear) 34 | def forward(self,x): 35 | max_x=custom_max(x,dim=(0,2,3),keepdim=False).unsqueeze(0) 36 | avg_x=torch.mean(x,dim=(0,2,3),keepdim=False).unsqueeze(0) 37 | max_x=self.linear(max_x) 38 | avg_x=self.linear(avg_x) 39 | att=max_x+avg_x 40 | att=torch.sigmoid(att).unsqueeze(-1).unsqueeze(-1) 41 | return x*att 42 | 43 | class SliceAttentionModule(nn.Module): 44 | def __init__(self,in_features,rate=4,uncertainty=True,rank=5): 45 | super(SliceAttentionModule,self).__init__() 46 | self.uncertainty=uncertainty 47 | self.rank=rank 48 | self.linear=[] 49 | self.linear.append(nn.Linear(in_features=in_features,out_features=int(in_features*rate))) 50 | self.linear.append(nn.ReLU()) 51 | self.linear.append(nn.Linear(in_features=int(in_features*rate),out_features=in_features)) 52 | self.linear=nn.Sequential(*self.linear) 53 | if uncertainty: 54 | self.non_linear=nn.ReLU() 55 | self.mean=nn.Linear(in_features=in_features,out_features=in_features) 56 | self.log_diag=nn.Linear(in_features=in_features,out_features=in_features) 57 | self.factor=nn.Linear(in_features=in_features,out_features=in_features*rank) 58 | def forward(self,x): 59 | max_x=custom_max(x,dim=(1,2,3),keepdim=False).unsqueeze(0) 60 | avg_x=torch.mean(x,dim=(1,2,3),keepdim=False).unsqueeze(0) 61 | max_x=self.linear(max_x) 62 | avg_x=self.linear(avg_x) 63 | att=max_x+avg_x 64 | if self.uncertainty: 65 | temp=self.non_linear(att) 66 | mean=self.mean(temp) 67 | diag=self.log_diag(temp).exp() 68 | factor=self.factor(temp) 69 | factor=factor.view(1,-1,self.rank) 70 | dist=td.LowRankMultivariateNormal(loc=mean,cov_factor=factor,cov_diag=diag) 71 | att=dist.sample() 72 | att=torch.sigmoid(att).squeeze().unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 73 | return x*att 74 | 75 | 76 | class CSAM(nn.Module): 77 | def __init__(self,num_slices,num_channels,semantic=True,positional=True,slice=True,uncertainty=True,rank=5): 78 | super(CSAM,self).__init__() 79 | self.semantic=semantic 80 | self.positional=positional 81 | self.slice=slice 82 | if semantic: 83 | self.semantic_att=SemanticAttentionModule(num_channels) 84 | if positional: 85 | self.positional_att=PositionalAttentionModule() 86 | if slice: 87 | self.slice_att=SliceAttentionModule(num_slices,uncertainty=uncertainty,rank=rank) 88 | def forward(self,x): 89 | if self.semantic: 90 | x=self.semantic_att(x) 91 | if self.positional: 92 | x=self.positional_att(x) 93 | if self.slice: 94 | x=self.slice_att(x) 95 | return x -------------------------------------------------------------------------------- /特征提取or融合or对齐模块/FARM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import pytorch_lightning as pl 6 | from torchvision.ops import DeformConv2d 7 | from pytorch_lightning import seed_everything 8 | from einops import rearrange 9 | import numbers 10 | 11 | seed_everything(13) 12 | 13 | 14 | ########################################################################## 15 | ## Layer Norm 16 | 17 | def to_3d(x): 18 | return rearrange(x, 'b c h w -> b (h w) c') 19 | 20 | 21 | def to_4d(x, h, w): 22 | return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 23 | 24 | 25 | class BiasFree_LayerNorm(pl.LightningModule): 26 | def __init__(self, normalized_shape): 27 | super(BiasFree_LayerNorm, self).__init__() 28 | if isinstance(normalized_shape, numbers.Integral): 29 | normalized_shape = (normalized_shape,) 30 | normalized_shape = torch.Size(normalized_shape) 31 | 32 | assert len(normalized_shape) == 1 33 | 34 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 35 | self.normalized_shape = normalized_shape 36 | 37 | def forward(self, x): 38 | sigma = x.var(-1, keepdim=True, unbiased=False) 39 | return x / torch.sqrt(sigma + 1e-5) * self.weight 40 | 41 | 42 | class WithBias_LayerNorm(pl.LightningModule): 43 | def __init__(self, normalized_shape): 44 | super(WithBias_LayerNorm, self).__init__() 45 | if isinstance(normalized_shape, numbers.Integral): 46 | normalized_shape = (normalized_shape,) 47 | normalized_shape = torch.Size(normalized_shape) 48 | 49 | assert len(normalized_shape) == 1 50 | 51 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 52 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 53 | self.normalized_shape = normalized_shape 54 | 55 | def forward(self, x): 56 | mu = x.mean(-1, keepdim=True) 57 | sigma = x.var(-1, keepdim=True, unbiased=False) 58 | return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias 59 | 60 | 61 | class LayerNorm(pl.LightningModule): 62 | def __init__(self, dim, LayerNorm_type): 63 | super(LayerNorm, self).__init__() 64 | if LayerNorm_type == 'BiasFree': 65 | self.body = BiasFree_LayerNorm(dim) 66 | else: 67 | self.body = WithBias_LayerNorm(dim) 68 | 69 | def forward(self, x): 70 | h, w = x.shape[-2:] 71 | return to_4d(self.body(to_3d(x)), h, w) 72 | 73 | 74 | ########################################################################## 75 | ## Gated-Dconv Feed-Forward Network (GDFN) 76 | class FeedForward(pl.LightningModule): 77 | def __init__(self, dim, ffn_expansion_factor, bias): 78 | super(FeedForward, self).__init__() 79 | 80 | hidden_features = int(dim * ffn_expansion_factor) 81 | 82 | self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias) 83 | 84 | self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1, 85 | groups=hidden_features * 2, bias=bias) 86 | 87 | self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) 88 | 89 | def forward(self, x): 90 | x = self.project_in(x) 91 | x1, x2 = self.dwconv(x).chunk(2, dim=1) 92 | x = F.gelu(x1) * x2 93 | x = self.project_out(x) 94 | return x 95 | 96 | 97 | ########################################################################## 98 | ## Multi-DConv Head Transposed Self-Attention (MDTA) 99 | class Attention(pl.LightningModule): 100 | def __init__(self, dim, num_heads, stride, bias): 101 | super(Attention, self).__init__() 102 | self.num_heads = num_heads 103 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) 104 | 105 | self.stride = stride 106 | self.qk = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=bias) 107 | self.qk_dwconv = nn.Conv2d(dim * 2, dim * 2, kernel_size=3, stride=self.stride, padding=1, groups=dim * 2, 108 | bias=bias) 109 | 110 | self.v = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) 111 | self.v_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias) 112 | 113 | self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) 114 | 115 | def forward(self, x): 116 | b, c, h, w = x.shape 117 | 118 | qk = self.qk_dwconv(self.qk(x)) 119 | q, k = qk.chunk(2, dim=1) 120 | 121 | v = self.v_dwconv(self.v(x)) 122 | 123 | b, f, h1, w1 = q.size() 124 | 125 | q = rearrange(q, 'b (head c) h1 w1 -> b head c (h1 w1)', head=self.num_heads) 126 | k = rearrange(k, 'b (head c) h1 w1 -> b head c (h1 w1)', head=self.num_heads) 127 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 128 | 129 | q = torch.nn.functional.normalize(q, dim=-1) 130 | k = torch.nn.functional.normalize(k, dim=-1) 131 | 132 | attn = (q @ k.transpose(-2, -1)) * self.temperature 133 | attn = attn.softmax(dim=-1) 134 | 135 | out = (attn @ v) 136 | 137 | out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) 138 | 139 | out = self.project_out(out) 140 | return out 141 | 142 | 143 | ########################################################################## 144 | ######### Burst Feature Attention ######################################## 145 | 146 | ########################################################################## 147 | class BFA(pl.LightningModule): 148 | def __init__(self, dim, num_heads, stride, ffn_expansion_factor, bias, LayerNorm_type): 149 | super(BFA, self).__init__() 150 | 151 | self.norm1 = LayerNorm(dim, LayerNorm_type) 152 | self.attn = Attention(dim, num_heads, stride, bias) 153 | self.norm2 = LayerNorm(dim, LayerNorm_type) 154 | self.ffn = FeedForward(dim, ffn_expansion_factor, bias) 155 | 156 | def forward(self, x): 157 | x = x + self.attn(self.norm1(x)) 158 | x = x + self.ffn(self.norm2(x)) 159 | 160 | return x 161 | 162 | 163 | ########################################################################## 164 | ## Overlapped image patch embedding with 3x3 Conv 165 | class OverlapPatchEmbed(pl.LightningModule): 166 | def __init__(self, in_c=3, embed_dim=48, bias=False): 167 | super(OverlapPatchEmbed, self).__init__() 168 | 169 | self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) 170 | 171 | def forward(self, x): 172 | # print("Inside patch embed:::", inp_enc_level1.size()) 173 | x = self.proj(x) 174 | 175 | return x 176 | 177 | 178 | class ref_back_projection(pl.LightningModule): 179 | def __init__(self, in_channels, stride): 180 | super(ref_back_projection, self).__init__() 181 | 182 | bias = False 183 | 184 | self.feat_fusion = nn.Sequential(nn.Conv2d(in_channels * 2, in_channels, 3, stride=1, padding=1), nn.GELU()) 185 | self.feat_expand = nn.Sequential(nn.Conv2d(in_channels, in_channels * 2, 3, stride=1, padding=1), nn.GELU()) 186 | self.diff_fusion = nn.Sequential(nn.Conv2d(in_channels * 2, in_channels, 3, stride=1, padding=1), nn.GELU()) 187 | 188 | self.encoder1 = nn.Sequential(*[ 189 | BFA(dim=in_channels * 2, num_heads=1, stride=stride, ffn_expansion_factor=2.66, bias=bias, 190 | LayerNorm_type='WithBias') for i in range(2)]) 191 | 192 | def forward(self, x): 193 | B, f, H, W = x.size() 194 | 195 | ref = x[0].unsqueeze(0) 196 | ref = torch.repeat_interleave(ref, B, dim=0) 197 | feat = self.encoder1(torch.cat([ref, x], dim=1)) 198 | 199 | fused_feat = self.feat_fusion(feat) 200 | exp_feat = self.feat_expand(fused_feat) 201 | 202 | residual = exp_feat - feat 203 | residual = self.diff_fusion(residual) 204 | 205 | fused_feat = fused_feat + residual 206 | 207 | return fused_feat 208 | 209 | 210 | class alignment(pl.LightningModule): 211 | def __init__(self, dim=48, memory=False, stride=1, type='group_conv'): 212 | 213 | super(alignment, self).__init__() 214 | 215 | act = nn.GELU() 216 | bias = False 217 | 218 | kernel_size = 3 219 | padding = kernel_size // 2 220 | deform_groups = 8 221 | out_channels = deform_groups * 3 * (kernel_size ** 2) 222 | 223 | self.offset_conv = nn.Conv2d(dim, out_channels, kernel_size, stride=1, padding=padding, bias=bias) 224 | self.deform = DeformConv2d(dim, dim, kernel_size, padding=2, groups=deform_groups, dilation=2) 225 | self.back_projection = ref_back_projection(dim, stride=1) 226 | 227 | self.bottleneck = nn.Sequential(nn.Conv2d(dim * 2, dim, kernel_size=3, padding=1, bias=bias), act) 228 | 229 | if memory == True: 230 | self.bottleneck_o = nn.Sequential(nn.Conv2d(dim * 2, dim, kernel_size=3, padding=1, bias=bias), act) 231 | 232 | def offset_gen(self, x): 233 | 234 | o1, o2, mask = torch.chunk(x, 3, dim=1) 235 | offset = torch.cat((o1, o2), dim=1) 236 | mask = torch.sigmoid(mask) 237 | 238 | return offset, mask 239 | 240 | def forward(self, x, prev_offset_feat=None): 241 | 242 | B, f, H, W = x.size() 243 | ref = x[0].unsqueeze(0) 244 | ref = torch.repeat_interleave(ref, B, dim=0) 245 | 246 | offset_feat = self.bottleneck(torch.cat([ref, x], dim=1)) 247 | 248 | if not prev_offset_feat == None: 249 | offset_feat = self.bottleneck_o(torch.cat([prev_offset_feat, offset_feat], dim=1)) 250 | 251 | offset, mask = self.offset_gen(self.offset_conv(offset_feat)) 252 | 253 | aligned_feat = self.deform(x, offset, mask) 254 | aligned_feat[0] = x[0].unsqueeze(0) 255 | 256 | aligned_feat = self.back_projection(aligned_feat) 257 | 258 | # return aligned_feat, offset_feat 259 | return aligned_feat 260 | 261 | -------------------------------------------------------------------------------- /特征提取or融合or对齐模块/FCA.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class Mix(nn.Module): 7 | def __init__(self, m=-0.80): 8 | super(Mix, self).__init__() 9 | w = torch.nn.Parameter(torch.FloatTensor([m]), requires_grad=True) 10 | w = torch.nn.Parameter(w, requires_grad=True) 11 | self.w = w 12 | self.mix_block = nn.Sigmoid() 13 | 14 | def forward(self, fea1, fea2): 15 | mix_factor = self.mix_block(self.w) 16 | out = fea1 * mix_factor.expand_as(fea1) + fea2 * (1 - mix_factor.expand_as(fea2)) 17 | return out 18 | 19 | 20 | class FCA(nn.Module): 21 | def __init__(self,channel,b=1, gamma=2): 22 | super(FCA, self).__init__() 23 | self.avg_pool = nn.AdaptiveAvgPool2d(1)#全局平均池化 24 | #一维卷积 25 | t = int(abs((math.log(channel, 2) + b) / gamma)) 26 | k = t if t % 2 else t + 1 27 | self.conv1 = nn.Conv1d(1, 1, kernel_size=k, padding=int(k / 2), bias=False) 28 | self.fc = nn.Conv2d(channel, channel, 1, padding=0, bias=True) 29 | self.sigmoid = nn.Sigmoid() 30 | self.mix = Mix() 31 | 32 | 33 | def forward(self, input): 34 | x = self.avg_pool(input) 35 | x1 = self.conv1(x.squeeze(-1).transpose(-1, -2)).transpose(-1, -2)#(1,64,1) 36 | x2 = self.fc(x).squeeze(-1).transpose(-1, -2)#(1,1,64) 37 | out1 = torch.sum(torch.matmul(x1,x2),dim=1).unsqueeze(-1).unsqueeze(-1)#(1,64,1,1) 38 | #x1 = x1.transpose(-1, -2).unsqueeze(-1) 39 | out1 = self.sigmoid(out1) 40 | out2 = torch.sum(torch.matmul(x2.transpose(-1, -2),x1.transpose(-1, -2)),dim=1).unsqueeze(-1).unsqueeze(-1) 41 | 42 | #out2 = self.fc(x) 43 | out2 = self.sigmoid(out2) 44 | out = self.mix(out1,out2) 45 | out = self.conv1(out.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) 46 | out = self.sigmoid(out) 47 | 48 | return input*out -------------------------------------------------------------------------------- /特征提取or融合or对齐模块/GLSA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from mmcv.cnn import constant_init, kaiming_init 5 | 6 | 7 | class BasicConv2d(nn.Module): 8 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 9 | super(BasicConv2d, self).__init__() 10 | 11 | self.conv = nn.Conv2d(in_planes, out_planes, 12 | kernel_size=kernel_size, stride=stride, 13 | padding=padding, dilation=dilation, bias=False) 14 | self.bn = nn.BatchNorm2d(out_planes) 15 | self.relu = nn.ReLU(inplace=True) 16 | 17 | def forward(self, x): 18 | x = self.conv(x) 19 | x = self.bn(x) 20 | x = self.relu(x) 21 | return x 22 | 23 | 24 | class Block(nn.Sequential): 25 | def __init__(self, input_num, num1, num2, dilation_rate, drop_out, bn_start=True, norm_layer=nn.BatchNorm2d): 26 | super(Block, self).__init__() 27 | if bn_start: 28 | self.add_module('norm1', norm_layer(input_num)), 29 | 30 | self.add_module('relu1', nn.ReLU(inplace=True)), 31 | self.add_module('conv1', nn.Conv2d(in_channels=input_num, out_channels=num1, kernel_size=1)), 32 | 33 | self.add_module('norm2', norm_layer(num1)), 34 | self.add_module('relu2', nn.ReLU(inplace=True)), 35 | self.add_module('conv2', nn.Conv2d(in_channels=num1, out_channels=num2, kernel_size=3, 36 | dilation=dilation_rate, padding=dilation_rate)), 37 | self.drop_rate = drop_out 38 | 39 | def forward(self, _input): 40 | feature = super(Block, self).forward(_input) 41 | if self.drop_rate > 0: 42 | feature = F.dropout2d(feature, p=self.drop_rate, training=self.training) 43 | return feature 44 | 45 | 46 | def Upsample(x, size, align_corners=False): 47 | """ 48 | Wrapper Around the Upsample Call 49 | """ 50 | return nn.functional.interpolate(x, size=size, mode='bilinear', align_corners=align_corners) 51 | 52 | 53 | def last_zero_init(m): 54 | if isinstance(m, nn.Sequential): 55 | constant_init(m[-1], val=0) 56 | else: 57 | constant_init(m, val=0) 58 | 59 | 60 | class ContextBlock(nn.Module): 61 | 62 | def __init__(self, 63 | inplanes, 64 | ratio, 65 | pooling_type='att', 66 | fusion_types=('channel_mul',)): 67 | super(ContextBlock, self).__init__() 68 | assert pooling_type in ['avg', 'att'] 69 | assert isinstance(fusion_types, (list, tuple)) 70 | valid_fusion_types = ['channel_add', 'channel_mul'] 71 | assert all([f in valid_fusion_types for f in fusion_types]) 72 | assert len(fusion_types) > 0, 'at least one fusion should be used' 73 | self.inplanes = inplanes 74 | self.ratio = ratio 75 | self.planes = int(inplanes * ratio) 76 | self.pooling_type = pooling_type 77 | self.fusion_types = fusion_types 78 | if pooling_type == 'att': 79 | self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1) 80 | self.softmax = nn.Softmax(dim=2) 81 | else: 82 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 83 | if 'channel_add' in fusion_types: 84 | self.channel_add_conv = nn.Sequential( 85 | nn.Conv2d(self.inplanes, self.planes, kernel_size=1), 86 | nn.LayerNorm([self.planes, 1, 1]), 87 | nn.ReLU(inplace=True), # yapf: disable 88 | nn.Conv2d(self.planes, self.inplanes, kernel_size=1)) 89 | else: 90 | self.channel_add_conv = None 91 | if 'channel_mul' in fusion_types: 92 | self.channel_mul_conv = nn.Sequential( 93 | nn.Conv2d(self.inplanes, self.planes, kernel_size=1), 94 | nn.LayerNorm([self.planes, 1, 1]), 95 | nn.ReLU(inplace=True), # yapf: disable 96 | nn.Conv2d(self.planes, self.inplanes, kernel_size=1)) 97 | else: 98 | self.channel_mul_conv = None 99 | self.reset_parameters() 100 | 101 | def reset_parameters(self): 102 | if self.pooling_type == 'att': 103 | kaiming_init(self.conv_mask, mode='fan_in') 104 | self.conv_mask.inited = True 105 | 106 | if self.channel_add_conv is not None: 107 | last_zero_init(self.channel_add_conv) 108 | if self.channel_mul_conv is not None: 109 | last_zero_init(self.channel_mul_conv) 110 | 111 | def spatial_pool(self, x): 112 | batch, channel, height, width = x.size() 113 | if self.pooling_type == 'att': 114 | input_x = x 115 | # [N, C, H * W] 116 | input_x = input_x.view(batch, channel, height * width) 117 | # [N, 1, C, H * W] 118 | input_x = input_x.unsqueeze(1) 119 | # [N, 1, H, W] 120 | context_mask = self.conv_mask(x) 121 | # [N, 1, H * W] 122 | context_mask = context_mask.view(batch, 1, height * width) 123 | # [N, 1, H * W] 124 | context_mask = self.softmax(context_mask) 125 | # [N, 1, H * W, 1] 126 | context_mask = context_mask.unsqueeze(-1) 127 | # [N, 1, C, 1] 128 | context = torch.matmul(input_x, context_mask) 129 | # [N, C, 1, 1] 130 | context = context.view(batch, channel, 1, 1) 131 | else: 132 | # [N, C, 1, 1] 133 | context = self.avg_pool(x) 134 | 135 | return context 136 | 137 | def forward(self, x): 138 | # [N, C, 1, 1] 139 | context = self.spatial_pool(x) 140 | 141 | out = x 142 | if self.channel_mul_conv is not None: 143 | # [N, C, 1, 1] 144 | channel_mul_term = torch.sigmoid(self.channel_mul_conv(context)) 145 | out = out + out * channel_mul_term 146 | if self.channel_add_conv is not None: 147 | # [N, C, 1, 1] 148 | channel_add_term = self.channel_add_conv(context) 149 | out = out + channel_add_term 150 | 151 | return out 152 | 153 | 154 | class ChannelAttention(nn.Module): 155 | def __init__(self, in_planes, ratio=16): 156 | super(ChannelAttention, self).__init__() 157 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 158 | self.max_pool = nn.AdaptiveMaxPool2d(1) 159 | 160 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 161 | self.relu1 = nn.ReLU() 162 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 163 | 164 | self.sigmoid = nn.Sigmoid() 165 | 166 | def forward(self, x): 167 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 168 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 169 | out = avg_out + max_out 170 | return self.sigmoid(out) 171 | 172 | 173 | class SpatialAttention(nn.Module): 174 | def __init__(self, kernel_size=7): 175 | super(SpatialAttention, self).__init__() 176 | 177 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 178 | padding = 3 if kernel_size == 7 else 1 179 | 180 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 181 | self.sigmoid = nn.Sigmoid() 182 | 183 | def forward(self, x): 184 | avg_out = torch.mean(x, dim=1, keepdim=True) 185 | max_out, _ = torch.max(x, dim=1, keepdim=True) 186 | x = torch.cat([avg_out, max_out], dim=1) 187 | x = self.conv1(x) 188 | return self.sigmoid(x) 189 | 190 | 191 | class ConvBranch(nn.Module): 192 | def __init__(self, in_features, hidden_features=None, out_features=None): 193 | super().__init__() 194 | hidden_features = hidden_features or in_features 195 | out_features = out_features or in_features 196 | self.conv1 = nn.Sequential( 197 | nn.Conv2d(in_features, hidden_features, 1, bias=False), 198 | nn.BatchNorm2d(hidden_features), 199 | nn.ReLU(inplace=True) 200 | ) 201 | self.conv2 = nn.Sequential( 202 | nn.Conv2d(hidden_features, hidden_features, 3, padding=1, groups=hidden_features, bias=False), 203 | nn.BatchNorm2d(hidden_features), 204 | nn.ReLU(inplace=True) 205 | ) 206 | self.conv3 = nn.Sequential( 207 | nn.Conv2d(hidden_features, hidden_features, 1, bias=False), 208 | nn.BatchNorm2d(hidden_features), 209 | nn.ReLU(inplace=True) 210 | ) 211 | self.conv4 = nn.Sequential( 212 | nn.Conv2d(hidden_features, hidden_features, 3, padding=1, groups=hidden_features, bias=False), 213 | nn.BatchNorm2d(hidden_features), 214 | nn.ReLU(inplace=True) 215 | ) 216 | self.conv5 = nn.Sequential( 217 | nn.Conv2d(hidden_features, hidden_features, 1, bias=False), 218 | nn.BatchNorm2d(hidden_features), 219 | nn.SiLU(inplace=True) 220 | ) 221 | self.conv6 = nn.Sequential( 222 | nn.Conv2d(hidden_features, hidden_features, 3, padding=1, groups=hidden_features, bias=False), 223 | nn.BatchNorm2d(hidden_features), 224 | nn.ReLU(inplace=True) 225 | ) 226 | self.conv7 = nn.Sequential( 227 | nn.Conv2d(hidden_features, out_features, 1, bias=False), 228 | nn.ReLU(inplace=True) 229 | ) 230 | self.ca = ChannelAttention(64) 231 | self.sa = SpatialAttention() 232 | self.sigmoid_spatial = nn.Sigmoid() 233 | 234 | def forward(self, x): 235 | res1 = x 236 | res2 = x 237 | x = self.conv1(x) 238 | x = x + self.conv2(x) 239 | x = self.conv3(x) 240 | x = x + self.conv4(x) 241 | x = self.conv5(x) 242 | x = x + self.conv6(x) 243 | x = self.conv7(x) 244 | x_mask = self.sigmoid_spatial(x) 245 | res1 = res1 * x_mask 246 | return res2 + res1 247 | 248 | 249 | class GLSA(nn.Module): 250 | 251 | def __init__(self, input_dim=512, embed_dim=32, k_s=3): 252 | super().__init__() 253 | 254 | self.conv1_1 = BasicConv2d(embed_dim * 2, embed_dim, 1) 255 | self.conv1_1_1 = BasicConv2d(input_dim // 2, embed_dim, 1) 256 | self.local_11conv = nn.Conv2d(input_dim // 2, embed_dim, 1) 257 | self.global_11conv = nn.Conv2d(input_dim // 2, embed_dim, 1) 258 | self.GlobelBlock = ContextBlock(inplanes=embed_dim, ratio=2) 259 | self.local = ConvBranch(in_features=embed_dim, hidden_features=embed_dim, out_features=embed_dim) 260 | 261 | def forward(self, x): 262 | b, c, h, w = x.size() 263 | x_0, x_1 = x.chunk(2, dim=1) 264 | 265 | # local block 266 | local = self.local(self.local_11conv(x_0)) 267 | 268 | # Globel block 269 | Globel = self.GlobelBlock(self.global_11conv(x_1)) 270 | 271 | # concat Globel + local 272 | x = torch.cat([local, Globel], dim=1) 273 | x = self.conv1_1(x) 274 | 275 | return x -------------------------------------------------------------------------------- /特征提取or融合or对齐模块/LSK.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class LSKblock(nn.Module): 5 | def __init__(self, dim): 6 | super().__init__() 7 | self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) 8 | self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3) 9 | self.conv1 = nn.Conv2d(dim, dim//2, 1) 10 | self.conv2 = nn.Conv2d(dim, dim//2, 1) 11 | self.conv_squeeze = nn.Conv2d(2, 2, 7, padding=3) 12 | self.conv = nn.Conv2d(dim//2, dim, 1) 13 | 14 | def forward(self, x): 15 | attn1 = self.conv0(x) 16 | attn2 = self.conv_spatial(attn1) 17 | 18 | attn1 = self.conv1(attn1) 19 | attn2 = self.conv2(attn2) 20 | 21 | attn = torch.cat([attn1, attn2], dim=1) 22 | avg_attn = torch.mean(attn, dim=1, keepdim=True) 23 | max_attn, _ = torch.max(attn, dim=1, keepdim=True) 24 | agg = torch.cat([avg_attn, max_attn], dim=1) 25 | sig = self.conv_squeeze(agg).sigmoid() 26 | attn = attn1 * sig[: ,0 ,: ,:].unsqueeze(1) + attn2 * sig[: ,1 ,: ,:].unsqueeze(1) 27 | attn = self.conv(attn) 28 | return x * attn -------------------------------------------------------------------------------- /特征提取or融合or对齐模块/MFII.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class eca_layer(nn.Module): 6 | """Constructs a ECA module. 7 | Args: 8 | channel: Number of channels of the input feature map 9 | k_size: Adaptive selection of kernel size 10 | source: https://github.com/BangguWu/ECANet 11 | """ 12 | def __init__(self, channel, k_size=3): 13 | super(eca_layer, self).__init__() 14 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 15 | self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) 16 | self.sigmoid = nn.Sigmoid() 17 | 18 | def forward(self, x): 19 | # x: input features with shape [b, c, h, w] 20 | b, c, h, w = x.size() 21 | 22 | # feature descriptor on the global spatial information 23 | y = self.avg_pool(x) 24 | 25 | # Two different branches of ECA module 26 | y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) 27 | 28 | # Multi-scale information fusion 29 | y = self.sigmoid(y) 30 | 31 | return x * y.expand_as(x) 32 | 33 | 34 | class SELayer(nn.Module): 35 | def __init__(self, channel, reduction=16): 36 | super(SELayer, self).__init__() 37 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 38 | self.fc = nn.Sequential( 39 | nn.Linear(channel, channel // reduction, bias=False), 40 | nn.ReLU(inplace=True), 41 | nn.Linear(channel // reduction, channel, bias=False), 42 | nn.Sigmoid() 43 | ) 44 | 45 | def forward(self, x): 46 | b, c, _, _ = x.size() 47 | y = self.avg_pool(x).view(b, c) 48 | y = self.fc(y).view(b, c, 1, 1) 49 | return x * y.expand_as(x) 50 | 51 | 52 | 53 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 54 | """3x3 convolution with padding""" 55 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 56 | padding=dilation, groups=groups, bias=False, dilation=dilation) 57 | 58 | def conv1x1(in_planes, out_planes, stride=1): 59 | """1x1 convolution""" 60 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 61 | 62 | class MFII_BasicBlock(nn.Module): 63 | expansion = 1 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None, 66 | rla_channel=32, SE=False, ECA_size=None, groups=1, 67 | base_width=64, dilation=1, norm_layer=None, reduction=16): 68 | super(MFII_BasicBlock, self).__init__() 69 | if norm_layer is None: 70 | norm_layer = nn.BatchNorm2d 71 | 72 | # if groups != 1 or base_width != 64: 73 | # raise ValueError('BasicBlock only supports groups=1 and base_width=64') 74 | if dilation > 1: 75 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 76 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 77 | self.conv1 = conv3x3(inplanes + rla_channel, planes, stride) 78 | self.bn1 = norm_layer(planes) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.conv2 = conv3x3(planes, planes) 81 | self.bn2 = norm_layer(planes) 82 | self.downsample = downsample 83 | self.stride = stride 84 | 85 | self.averagePooling = None 86 | if downsample is not None and stride != 1: 87 | self.averagePooling = nn.AvgPool2d((2, 2), stride=(2, 2)) 88 | 89 | self.se = None 90 | if SE: 91 | self.se = SELayer(planes * self.expansion, reduction) 92 | 93 | self.eca = None 94 | if ECA_size != None: 95 | self.eca = eca_layer(planes * self.expansion, int(ECA_size)) 96 | 97 | def forward(self, x, h): 98 | identity = x 99 | 100 | x = torch.cat((x, h), dim=1) # [8, 96, 56, 56] 101 | 102 | out = self.conv1(x) # [8, 64, 56, 56] 103 | out = self.bn1(out) # [8, 64, 56, 56] 104 | out = self.relu(out) 105 | 106 | out = self.conv2(out) # [8, 64, 56, 56] 107 | out = self.bn2(out) 108 | 109 | if self.se != None: 110 | out = self.se(out) 111 | 112 | if self.eca != None: 113 | out = self.eca(out) 114 | 115 | y = out 116 | 117 | if self.downsample is not None: 118 | identity = self.downsample(identity) 119 | if self.averagePooling is not None: 120 | h = self.averagePooling(h) 121 | 122 | out += identity 123 | out = self.relu(out) 124 | 125 | return out, y, h 126 | 127 | 128 | class MFII_BasicBlock_half(nn.Module): 129 | expansion = 1 130 | 131 | def __init__(self, inplanes, planes, stride=1, downsample=None, 132 | rla_channel=16, SE=False, ECA_size=None, groups=1, 133 | base_width=64, dilation=1, norm_layer=None, reduction=16): 134 | super(MFII_BasicBlock_half, self).__init__() 135 | if norm_layer is None: 136 | norm_layer = nn.BatchNorm2d 137 | 138 | # if groups != 1 or base_width != 64: 139 | # raise ValueError('BasicBlock only supports groups=1 and base_width=64') 140 | if dilation > 1: 141 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 142 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 143 | self.conv1 = conv3x3(inplanes + rla_channel, planes, stride) 144 | self.bn1 = norm_layer(planes) 145 | self.relu = nn.ReLU(inplace=True) 146 | self.conv2 = conv3x3(planes, planes) 147 | self.bn2 = norm_layer(planes) 148 | self.downsample = downsample 149 | self.stride = stride 150 | 151 | self.averagePooling = None 152 | if downsample is not None and stride != 1: 153 | self.averagePooling = nn.AvgPool2d((2, 2), stride=(2, 2)) 154 | 155 | self.se = None 156 | if SE: 157 | self.se = SELayer(planes * self.expansion, reduction) 158 | 159 | self.eca = None 160 | if ECA_size != None: 161 | self.eca = eca_layer(planes * self.expansion, int(ECA_size)) 162 | 163 | def forward(self, x, h): 164 | identity = x 165 | 166 | x = torch.cat((x, h), dim=1) # [8, 96, 56, 56] 167 | 168 | out = self.conv1(x) # [8, 64, 56, 56] 169 | out = self.bn1(out) # [8, 64, 56, 56] 170 | out = self.relu(out) 171 | 172 | out = self.conv2(out) # [8, 64, 56, 56] 173 | out = self.bn2(out) 174 | 175 | if self.se != None: 176 | out = self.se(out) 177 | 178 | if self.eca != None: 179 | out = self.eca(out) 180 | 181 | y = out 182 | 183 | if self.downsample is not None: 184 | identity = self.downsample(identity) 185 | if self.averagePooling is not None: 186 | h = self.averagePooling(h) 187 | 188 | out += identity 189 | out = self.relu(out) 190 | 191 | return out, y, h -------------------------------------------------------------------------------- /特征提取or融合or对齐模块/PSFM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BBasicConv2d(nn.Module): 6 | def __init__( 7 | self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False, 8 | ): 9 | super(BBasicConv2d, self).__init__() 10 | 11 | self.basicconv = nn.Sequential( 12 | nn.Conv2d( 13 | in_planes, 14 | out_planes, 15 | kernel_size=kernel_size, 16 | stride=stride, 17 | padding=padding, 18 | dilation=dilation, 19 | groups=groups, 20 | bias=bias, 21 | ), 22 | nn.BatchNorm2d(out_planes), 23 | nn.ReLU(inplace=True), 24 | ) 25 | 26 | def forward(self, x): 27 | return self.basicconv(x) 28 | 29 | 30 | class DenseLayer(nn.Module): 31 | def __init__(self, in_C, out_C, down_factor=4, k=4): 32 | """ 33 | 更像是DenseNet的Block,从而构造特征内的密集连接 34 | """ 35 | super(DenseLayer, self).__init__() 36 | self.k = k 37 | self.down_factor = down_factor 38 | mid_C = out_C // self.down_factor 39 | 40 | self.down = nn.Conv2d(in_C, mid_C, 1) 41 | 42 | self.denseblock = nn.ModuleList() 43 | for i in range(1, self.k + 1): 44 | self.denseblock.append(BBasicConv2d(mid_C * i, mid_C, 3, 1, 1)) 45 | 46 | self.fuse = BBasicConv2d(in_C + mid_C, out_C, kernel_size=3, stride=1, padding=1) 47 | 48 | def forward(self, in_feat): 49 | down_feats = self.down(in_feat) 50 | # print(down_feats.shape) 51 | # print(self.denseblock) 52 | out_feats = [] 53 | for i in self.denseblock: 54 | # print(self.denseblock) 55 | feats = i(torch.cat((*out_feats, down_feats), dim=1)) 56 | # print(feats.shape) 57 | out_feats.append(feats) 58 | 59 | feats = torch.cat((in_feat, feats), dim=1) 60 | return self.fuse(feats) 61 | 62 | 63 | class GEFM(nn.Module): 64 | def __init__(self, in_C, out_C): 65 | super(GEFM, self).__init__() 66 | self.RGB_K = BBasicConv2d(out_C, out_C, 3, 1, 1) 67 | self.RGB_V = BBasicConv2d(out_C, out_C, 3, 1, 1) 68 | self.Q = BBasicConv2d(in_C, out_C, 3, 1, 1) 69 | self.INF_K = BBasicConv2d(out_C, out_C, 3, 1, 1) 70 | self.INF_V = BBasicConv2d(out_C, out_C, 3, 1, 1) 71 | self.Second_reduce = BBasicConv2d(in_C, out_C, 3, 1, 1) 72 | self.gamma1 = nn.Parameter(torch.zeros(1)) 73 | self.gamma2 = nn.Parameter(torch.zeros(1)) 74 | self.softmax = nn.Softmax(dim=-1) 75 | 76 | def forward(self, x, y): 77 | Q = self.Q(torch.cat([x, y], dim=1)) 78 | RGB_K = self.RGB_K(x) 79 | RGB_V = self.RGB_V(x) 80 | m_batchsize, C, height, width = RGB_V.size() 81 | RGB_V = RGB_V.view(m_batchsize, -1, width * height) 82 | RGB_K = RGB_K.view(m_batchsize, -1, width * height).permute(0, 2, 1) 83 | RGB_Q = Q.view(m_batchsize, -1, width * height) 84 | RGB_mask = torch.bmm(RGB_K, RGB_Q) 85 | RGB_mask = self.softmax(RGB_mask) 86 | RGB_refine = torch.bmm(RGB_V, RGB_mask.permute(0, 2, 1)) 87 | RGB_refine = RGB_refine.view(m_batchsize, -1, height, width) 88 | RGB_refine = self.gamma1 * RGB_refine + y 89 | 90 | INF_K = self.INF_K(y) 91 | INF_V = self.INF_V(y) 92 | INF_V = INF_V.view(m_batchsize, -1, width * height) 93 | INF_K = INF_K.view(m_batchsize, -1, width * height).permute(0, 2, 1) 94 | INF_Q = Q.view(m_batchsize, -1, width * height) 95 | INF_mask = torch.bmm(INF_K, INF_Q) 96 | INF_mask = self.softmax(INF_mask) 97 | INF_refine = torch.bmm(INF_V, INF_mask.permute(0, 2, 1)) 98 | INF_refine = INF_refine.view(m_batchsize, -1, height, width) 99 | INF_refine = self.gamma2 * INF_refine + x 100 | 101 | out = self.Second_reduce(torch.cat([RGB_refine, INF_refine], dim=1)) 102 | return out 103 | 104 | 105 | class PSFM(nn.Module): 106 | def __init__(self, in_C, out_C, cat_C): 107 | super(PSFM, self).__init__() 108 | self.RGBobj = DenseLayer(in_C, out_C) 109 | self.Infobj = DenseLayer(in_C, out_C) 110 | self.obj_fuse = GEFM(cat_C, out_C) 111 | 112 | def forward(self, rgb, depth): 113 | rgb_sum = self.RGBobj(rgb) 114 | Inf_sum = self.Infobj(depth) 115 | out = self.obj_fuse(rgb_sum, Inf_sum) 116 | return out -------------------------------------------------------------------------------- /特征提取or融合or对齐模块/SMFA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from feature_show import feature_show 5 | 6 | 7 | class DMlp(nn.Module): 8 | def __init__(self, dim, growth_rate=2.0): 9 | super().__init__() 10 | hidden_dim = int(dim * growth_rate) 11 | self.conv_0 = nn.Sequential( 12 | nn.Conv2d(dim,hidden_dim,3,1,1,groups=dim), 13 | nn.Conv2d(hidden_dim,hidden_dim,1,1,0) 14 | ) 15 | self.act = nn.GELU() 16 | self.conv_1 = nn.Conv2d(hidden_dim, dim, 1, 1, 0) 17 | 18 | def forward(self, x): 19 | x = self.conv_0(x) 20 | x = self.act(x) 21 | x = self.conv_1(x) 22 | return x 23 | 24 | class SMFA(nn.Module): 25 | def __init__(self, dim=36, id = 0): 26 | super(SMFA, self).__init__() 27 | 28 | self.id = id 29 | 30 | self.linear_0 = nn.Conv2d(dim,dim*2,1,1,0) 31 | self.linear_1 = nn.Conv2d(dim,dim,1,1,0) 32 | self.linear_2 = nn.Conv2d(dim,dim,1,1,0) 33 | 34 | self.lde = DMlp(dim,2) 35 | 36 | self.dw_conv = nn.Conv2d(dim,dim,3,1,1,groups=dim) 37 | 38 | self.gelu = nn.GELU() 39 | self.down_scale = 8 40 | 41 | self.alpha = nn.Parameter(torch.ones((1,dim,1,1))) 42 | self.belt = nn.Parameter(torch.zeros((1,dim,1,1))) 43 | 44 | def forward(self, f): 45 | feature_show(f, f"{self.id}_input") 46 | _,_,h,w = f.shape 47 | 48 | y, x = self.linear_0(f).chunk(2, dim=1) 49 | x_s = self.dw_conv(F.adaptive_max_pool2d(x, (h // self.down_scale, w // self.down_scale))) 50 | x_v = torch.var(x, dim=(-2,-1), keepdim=True) 51 | x_l = x * F.interpolate(self.gelu(self.linear_1(x_s * self.alpha + x_v * self.belt)), size=(h,w), mode='nearest') 52 | y_d = self.lde(y) 53 | out = self.linear_2(x_l + y_d) 54 | 55 | feature_show(x_l, f"{self.id}_easa") 56 | feature_show(y_d, f"{self.id}_lde") 57 | feature_show(out, f"{self.id}_output") 58 | 59 | return out -------------------------------------------------------------------------------- /特征提取or融合or对齐模块/SSFF.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def autopad(k, p=None, d=1): # kernel, padding, dilation 7 | # Pad to 'same' shape outputs 8 | if d > 1: 9 | k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size 10 | if p is None: 11 | p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad 12 | return p 13 | 14 | class Conv(nn.Module): 15 | # Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation) 16 | default_act = nn.SiLU() # default activation 17 | 18 | def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True): 19 | super().__init__() 20 | self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False) 21 | self.bn = nn.BatchNorm2d(c2) 22 | self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity() 23 | 24 | def forward(self, x): 25 | return self.act(self.bn(self.conv(x))) 26 | 27 | def forward_fuse(self, x): 28 | return self.act(self.conv(x)) 29 | 30 | class Zoom_cat(nn.Module): 31 | def __init__(self, in_dim): 32 | super().__init__() 33 | #self.conv_l_post_down = Conv(in_dim, 2*in_dim, 3, 1, 1) 34 | 35 | def forward(self, x): 36 | """l,m,s表示大中小三个尺度,最终会被整合到m这个尺度上""" 37 | l, m, s = x[0], x[1], x[2] 38 | tgt_size = m.shape[2:] 39 | l = F.adaptive_max_pool2d(l, tgt_size) + F.adaptive_avg_pool2d(l, tgt_size) 40 | #l = self.conv_l_post_down(l) 41 | # m = self.conv_m(m) 42 | # s = self.conv_s_pre_up(s) 43 | s = F.interpolate(s, m.shape[2:], mode='nearest') 44 | # s = self.conv_s_post_up(s) 45 | lms = torch.cat([l, m, s], dim=1) 46 | return lms 47 | 48 | 49 | class SSFF(nn.Module): 50 | def __init__(self, channel): 51 | super(SSFF, self).__init__() 52 | self.conv1 = Conv(512, channel,1) 53 | self.conv2 = Conv(1024, channel,1) 54 | self.conv3d = nn.Conv3d(channel,channel,kernel_size=(1,1,1)) 55 | self.bn = nn.BatchNorm3d(channel) 56 | self.act = nn.LeakyReLU(0.1) 57 | self.pool_3d = nn.MaxPool3d(kernel_size=(3,1,1)) 58 | 59 | def forward(self, x): 60 | p3, p4, p5 = x[0],x[1],x[2] 61 | p4_2 = self.conv1(p4) 62 | p4_2 = F.interpolate(p4_2, p3.size()[2:], mode='nearest') 63 | p5_2 = self.conv2(p5) 64 | p5_2 = F.interpolate(p5_2, p3.size()[2:], mode='nearest') 65 | p3_3d = torch.unsqueeze(p3, -3) 66 | p4_3d = torch.unsqueeze(p4_2, -3) 67 | p5_3d = torch.unsqueeze(p5_2, -3) 68 | combine = torch.cat([p3_3d,p4_3d,p5_3d],dim = 2) 69 | conv_3d = self.conv3d(combine) 70 | bn = self.bn(conv_3d) 71 | act = self.act(bn) 72 | x = self.pool_3d(act) 73 | x = torch.squeeze(x, 2) 74 | return x --------------------------------------------------------------------------------