├── .idea
├── .gitignore
├── DeepLearningPlugAndPlayModule.iml
├── deployment.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── README.md
├── 上采样
└── DySample.py
├── 即插即用卷积
├── ARConv.py
├── CAMixer.py
├── CFBConv.py
├── CKGConv.py
├── DCNv4.py
├── DEConv.py
├── DynamicConv.py
├── LDConv.py
├── MorphologyConv.py
├── PConv.py
├── SCConv.py
├── StarConv.py
└── WTConv.py
├── 注意力模块
├── AGF.py
├── ASSA.py
├── AgentAttention.py
├── CA.py
├── CBAM.py
├── CGA.py
├── CGLU.py
├── DANet.py
├── DAT.py
├── ECA.py
├── FcaNet.py
├── LGAG.py
├── MAB.py
├── MCA.py
├── MCPA.py
├── MEGA.py
├── MLKA.py
├── MSPA.py
├── NonLocal.py
├── SA-Net.py
├── SENet.py
├── SLAB.py
├── SRA.py
└── SimAM.py
└── 特征提取or融合or对齐模块
├── AFF.py
├── CAFM.py
├── CCFF.py
├── CCMF.py
├── CGAFusion.py
├── CSAM.py
├── FARM.py
├── FCA.py
├── GLSA.py
├── LSK.py
├── MFII.py
├── PSFM.py
├── SMFA.py
└── SSFF.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # 默认忽略的文件
2 | /shelf/
3 | /workspace.xml
4 | # 基于编辑器的 HTTP 客户端请求
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/DeepLearningPlugAndPlayModule.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/上采样/DySample.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def normal_init(module, mean=0, std=1, bias=0):
7 | if hasattr(module, 'weight') and module.weight is not None:
8 | nn.init.normal_(module.weight, mean, std)
9 | if hasattr(module, 'bias') and module.bias is not None:
10 | nn.init.constant_(module.bias, bias)
11 |
12 |
13 | def constant_init(module, val, bias=0):
14 | if hasattr(module, 'weight') and module.weight is not None:
15 | nn.init.constant_(module.weight, val)
16 | if hasattr(module, 'bias') and module.bias is not None:
17 | nn.init.constant_(module.bias, bias)
18 |
19 |
20 | class DySample(nn.Module):
21 | def __init__(self, in_channels, scale=2, style='lp', groups=4, dyscope=False):
22 | super().__init__()
23 | self.scale = scale
24 | self.style = style
25 | self.groups = groups
26 | assert style in ['lp', 'pl']
27 | if style == 'pl':
28 | assert in_channels >= scale ** 2 and in_channels % scale ** 2 == 0
29 | assert in_channels >= groups and in_channels % groups == 0
30 |
31 | if style == 'pl':
32 | in_channels = in_channels // scale ** 2
33 | out_channels = 2 * groups
34 | else:
35 | out_channels = 2 * groups * scale ** 2
36 |
37 | self.offset = nn.Conv2d(in_channels, out_channels, 1)
38 | normal_init(self.offset, std=0.001)
39 | if dyscope:
40 | self.scope = nn.Conv2d(in_channels, out_channels, 1, bias=False)
41 | constant_init(self.scope, val=0.)
42 |
43 | self.register_buffer('init_pos', self._init_pos())
44 |
45 | def _init_pos(self):
46 | h = torch.arange((-self.scale + 1) / 2, (self.scale - 1) / 2 + 1) / self.scale
47 | return torch.stack(torch.meshgrid([h, h])).transpose(1, 2).repeat(1, self.groups, 1).reshape(1, -1, 1, 1)
48 |
49 | def sample(self, x, offset):
50 | B, _, H, W = offset.shape
51 | offset = offset.view(B, 2, -1, H, W)
52 | coords_h = torch.arange(H) + 0.5
53 | coords_w = torch.arange(W) + 0.5
54 | coords = torch.stack(torch.meshgrid([coords_w, coords_h])
55 | ).transpose(1, 2).unsqueeze(1).unsqueeze(0).type(x.dtype).to(x.device)
56 | normalizer = torch.tensor([W, H], dtype=x.dtype, device=x.device).view(1, 2, 1, 1, 1)
57 | coords = 2 * (coords + offset) / normalizer - 1
58 | coords = F.pixel_shuffle(coords.view(B, -1, H, W), self.scale).view(
59 | B, 2, -1, self.scale * H, self.scale * W).permute(0, 2, 3, 4, 1).contiguous().flatten(0, 1)
60 | return F.grid_sample(x.reshape(B * self.groups, -1, H, W), coords, mode='bilinear',
61 | align_corners=False, padding_mode="border").view(B, -1, self.scale * H, self.scale * W)
62 |
63 | def forward_lp(self, x):
64 | if hasattr(self, 'scope'):
65 | offset = self.offset(x) * self.scope(x).sigmoid() * 0.5 + self.init_pos
66 | else:
67 | offset = self.offset(x) * 0.25 + self.init_pos
68 | return self.sample(x, offset)
69 |
70 | def forward_pl(self, x):
71 | x_ = F.pixel_shuffle(x, self.scale)
72 | if hasattr(self, 'scope'):
73 | offset = F.pixel_unshuffle(self.offset(x_) * self.scope(x_).sigmoid(), self.scale) * 0.5 + self.init_pos
74 | else:
75 | offset = F.pixel_unshuffle(self.offset(x_), self.scale) * 0.25 + self.init_pos
76 | return self.sample(x, offset)
77 |
78 | def forward(self, x):
79 | if self.style == 'pl':
80 | return self.forward_pl(x)
81 | return self.forward_lp(x)
--------------------------------------------------------------------------------
/即插即用卷积/ARConv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class ARConv(nn.Module):
6 | def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, l_max=9, w_max=9, flag=False, modulation=True):
7 | super(ARConv, self).__init__()
8 | self.lmax = l_max
9 | self.wmax = w_max
10 | self.inc = inc
11 | self.outc = outc
12 | self.kernel_size = kernel_size
13 | self.padding = padding
14 | self.stride = stride
15 | self.zero_padding = nn.ZeroPad2d(padding)
16 | self.flag = flag
17 | self.modulation = modulation
18 | self.i_list = [33, 35, 53, 37, 73, 55, 57, 75, 77]
19 | self.convs = nn.ModuleList(
20 | [
21 | nn.Conv2d(inc, outc, kernel_size=(i // 10, i % 10), stride=(i // 10, i % 10), padding=0)
22 | for i in self.i_list
23 | ]
24 | )
25 | self.m_conv = nn.Sequential(
26 | nn.Conv2d(inc, outc, kernel_size=3, padding=1, stride=stride),
27 | nn.LeakyReLU(),
28 | nn.Dropout2d(0.3),
29 | nn.Conv2d(outc, outc, kernel_size=3, padding=1, stride=stride),
30 | nn.LeakyReLU(),
31 | nn.Dropout2d(0.3),
32 | nn.Conv2d(outc, outc, kernel_size=3, padding=1, stride=stride),
33 | nn.Tanh()
34 | )
35 | self.b_conv = nn.Sequential(
36 | nn.Conv2d(inc, outc, kernel_size=3, padding=1, stride=stride),
37 | nn.LeakyReLU(),
38 | nn.Dropout2d(0.3),
39 | nn.Conv2d(outc, outc, kernel_size=3, padding=1, stride=stride),
40 | nn.LeakyReLU(),
41 | nn.Dropout2d(0.3),
42 | nn.Conv2d(outc, outc, kernel_size=3, padding=1, stride=stride)
43 | )
44 | self.p_conv = nn.Sequential(
45 | nn.Conv2d(inc, inc, kernel_size=3, padding=1, stride=stride),
46 | nn.BatchNorm2d(inc),
47 | nn.LeakyReLU(),
48 | nn.Dropout2d(0),
49 | nn.Conv2d(inc, inc, kernel_size=3, padding=1, stride=stride),
50 | nn.BatchNorm2d(inc),
51 | nn.LeakyReLU(),
52 | )
53 | self.l_conv = nn.Sequential(
54 | nn.Conv2d(inc, 1, kernel_size=3, padding=1, stride=stride),
55 | nn.BatchNorm2d(1),
56 | nn.LeakyReLU(),
57 | nn.Dropout2d(0),
58 | nn.Conv2d(1, 1, 1),
59 | nn.BatchNorm2d(1),
60 | nn.Sigmoid()
61 | )
62 | self.w_conv = nn.Sequential(
63 | nn.Conv2d(inc, 1, kernel_size=3, padding=1, stride=stride),
64 | nn.BatchNorm2d(1),
65 | nn.LeakyReLU(),
66 | nn.Dropout2d(0),
67 | nn.Conv2d(1, 1, 1),
68 | nn.BatchNorm2d(1),
69 | nn.Sigmoid()
70 | )
71 | self.dropout1 = nn.Dropout(0.3)
72 | self.dropout2 = nn.Dropout2d(0.3)
73 | self.hook_handles = []
74 | self.hook_handles.append(self.m_conv[0].register_full_backward_hook(self._set_lr))
75 | self.hook_handles.append(self.m_conv[1].register_full_backward_hook(self._set_lr))
76 | self.hook_handles.append(self.b_conv[0].register_full_backward_hook(self._set_lr))
77 | self.hook_handles.append(self.b_conv[1].register_full_backward_hook(self._set_lr))
78 | self.hook_handles.append(self.p_conv[0].register_full_backward_hook(self._set_lr))
79 | self.hook_handles.append(self.p_conv[1].register_full_backward_hook(self._set_lr))
80 | self.hook_handles.append(self.l_conv[0].register_full_backward_hook(self._set_lr))
81 | self.hook_handles.append(self.l_conv[1].register_full_backward_hook(self._set_lr))
82 | self.hook_handles.append(self.w_conv[0].register_full_backward_hook(self._set_lr))
83 | self.hook_handles.append(self.w_conv[1].register_full_backward_hook(self._set_lr))
84 |
85 | self.reserved_NXY = nn.Parameter(torch.tensor([3, 3], dtype=torch.int32), requires_grad=False)
86 |
87 | @staticmethod
88 | def _set_lr(module, grad_input, grad_output):
89 | grad_input = tuple(g * 0.1 if g is not None else None for g in grad_input)
90 | grad_output = tuple(g * 0.1 if g is not None else None for g in grad_output)
91 | return grad_input
92 |
93 | def remove_hooks(self):
94 | for handle in self.hook_handles:
95 | handle.remove() # 移除钩子函数
96 | self.hook_handles.clear() # 清空句柄列表
97 |
98 | def forward(self, x, epoch, hw_range):
99 | assert isinstance(hw_range, list) and len(
100 | hw_range) == 2, "hw_range should be a list with 2 elements, represent the range of h w"
101 | scale = hw_range[1] // 9
102 | if hw_range[0] == 1 and hw_range[1] == 3:
103 | scale = 1
104 | m = self.m_conv(x)
105 | bias = self.b_conv(x)
106 | offset = self.p_conv(x * 100)
107 | l = self.l_conv(offset) * (hw_range[1] - 1) + 1 # b, 1, h, w
108 | w = self.w_conv(offset) * (hw_range[1] - 1) + 1 # b, 1, h, w
109 | if epoch <= 100:
110 | mean_l = l.mean(dim=0).mean(dim=1).mean(dim=1)
111 | mean_w = w.mean(dim=0).mean(dim=1).mean(dim=1)
112 | N_X = int(mean_l // scale)
113 | N_Y = int(mean_w // scale)
114 |
115 | def phi(x):
116 | if x % 2 == 0:
117 | x -= 1
118 | return x
119 |
120 | N_X, N_Y = phi(N_X), phi(N_Y)
121 | N_X, N_Y = max(N_X, 3), max(N_Y, 3)
122 | N_X, N_Y = min(N_X, 7), min(N_Y, 7)
123 | if epoch == 100:
124 | self.reserved_NXY = self.reserved_NXY = nn.Parameter(
125 | torch.tensor([N_X, N_Y], dtype=torch.int32, device=x.device),
126 | requires_grad=False
127 | )
128 | else:
129 | N_X = self.reserved_NXY[0]
130 | N_Y = self.reserved_NXY[1]
131 |
132 | N = N_X * N_Y
133 | # print(N_X, N_Y)
134 | l = l.repeat([1, N, 1, 1])
135 | w = w.repeat([1, N, 1, 1])
136 | offset = torch.cat((l, w), dim=1)
137 | dtype = offset.data.type()
138 | if self.padding:
139 | x = self.zero_padding(x)
140 | p = self._get_p(offset, dtype, N_X, N_Y) # (b, 2*N, h, w)
141 | p = p.contiguous().permute(0, 2, 3, 1) # (b, h, w, 2*N)
142 | q_lt = p.detach().floor()
143 | q_rb = q_lt + 1
144 | q_lt = torch.cat(
145 | [
146 | torch.clamp(q_lt[..., :N], 0, x.size(2) - 1),
147 | torch.clamp(q_lt[..., N:], 0, x.size(3) - 1),
148 | ],
149 | dim=-1,
150 | ).long()
151 | q_rb = torch.cat(
152 | [
153 | torch.clamp(q_rb[..., :N], 0, x.size(2) - 1),
154 | torch.clamp(q_rb[..., N:], 0, x.size(3) - 1),
155 | ],
156 | dim=-1,
157 | ).long()
158 | q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)
159 | q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)
160 | # clip p
161 | p = torch.cat(
162 | [
163 | torch.clamp(p[..., :N], 0, x.size(2) - 1),
164 | torch.clamp(p[..., N:], 0, x.size(3) - 1),
165 | ],
166 | dim=-1,
167 | )
168 | # bilinear kernel (b, h, w, N)
169 | g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (
170 | 1 + (q_lt[..., N:].type_as(p) - p[..., N:])
171 | )
172 | g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (
173 | 1 - (q_rb[..., N:].type_as(p) - p[..., N:])
174 | )
175 | g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (
176 | 1 - (q_lb[..., N:].type_as(p) - p[..., N:])
177 | )
178 | g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (
179 | 1 + (q_rt[..., N:].type_as(p) - p[..., N:])
180 | )
181 | # (b, c, h, w, N)
182 | x_q_lt = self._get_x_q(x, q_lt, N)
183 | x_q_rb = self._get_x_q(x, q_rb, N)
184 | x_q_lb = self._get_x_q(x, q_lb, N)
185 | x_q_rt = self._get_x_q(x, q_rt, N)
186 | # (b, c, h, w, N)
187 | x_offset = (
188 | g_lt.unsqueeze(dim=1) * x_q_lt
189 | + g_rb.unsqueeze(dim=1) * x_q_rb
190 | + g_lb.unsqueeze(dim=1) * x_q_lb
191 | + g_rt.unsqueeze(dim=1) * x_q_rt
192 | )
193 | x_offset = self._reshape_x_offset(x_offset, N_X, N_Y)
194 | x_offset = self.dropout2(x_offset)
195 | x_offset = self.convs[self.i_list.index(N_X * 10 + N_Y)](x_offset)
196 | out = x_offset * m + bias
197 | return out
198 |
199 | def _get_p_n(self, N, dtype, n_x, n_y):
200 | p_n_x, p_n_y = torch.meshgrid(
201 | torch.arange(-(n_x - 1) // 2, (n_x - 1) // 2 + 1),
202 | torch.arange(-(n_y - 1) // 2, (n_y - 1) // 2 + 1),
203 | )
204 | p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)
205 | p_n = p_n.view(1, 2 * N, 1, 1).type(dtype)
206 | return p_n
207 |
208 | def _get_p_0(self, h, w, N, dtype):
209 | p_0_x, p_0_y = torch.meshgrid(
210 | torch.arange(1, h * self.stride + 1, self.stride),
211 | torch.arange(1, w * self.stride + 1, self.stride),
212 | )
213 | p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
214 | p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
215 | p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)
216 | return p_0
217 |
218 | def _get_p(self, offset, dtype, n_x, n_y):
219 | N, h, w = offset.size(1) // 2, offset.size(2), offset.size(3)
220 | L, W = offset.split([N, N], dim=1)
221 | L = L / n_x
222 | W = W / n_y
223 | offsett = torch.cat([L, W], dim=1)
224 | p_n = self._get_p_n(N, dtype, n_x, n_y)
225 | p_n = p_n.repeat([1, 1, h, w])
226 | p_0 = self._get_p_0(h, w, N, dtype)
227 | p = p_0 + offsett * p_n
228 | return p
229 |
230 | def _get_x_q(self, x, q, N):
231 | b, h, w, _ = q.size()
232 | padded_w = x.size(3)
233 | c = x.size(1)
234 | x = x.contiguous().view(b, c, -1)
235 | index = q[..., :N] * padded_w + q[..., N:]
236 | index = (
237 | index.contiguous()
238 | .unsqueeze(dim=1)
239 | .expand(-1, c, -1, -1, -1)
240 | .contiguous()
241 | .view(b, c, -1)
242 | )
243 | x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)
244 | return x_offset
245 |
246 | @staticmethod
247 | def _reshape_x_offset(x_offset, n_x, n_y):
248 | b, c, h, w, N = x_offset.size()
249 | x_offset = torch.cat([x_offset[..., s:s + n_y].contiguous().view(b, c, h, w * n_y) for s in range(0, N, n_y)],
250 | dim=-1)
251 | x_offset = x_offset.contiguous().view(b, c, h * n_x, w * n_y)
252 | return x_offset
--------------------------------------------------------------------------------
/即插即用卷积/CFBConv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from mmcv.cnn import (Conv2d,ConvModule)
5 | from mmcv.runner import BaseModule
6 |
7 | from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,trunc_normal_init,normal_init)
8 | from timm.models.layers import DropPath
9 | from mmseg.models.decode_heads.decode_head import BaseDecodeHead
10 | from mmseg.ops import resize
11 |
12 | #BN->Conv->GELU->drop->Conv2->drop
13 | class MLP(BaseModule):
14 | def __init__(self,
15 | in_channels,
16 | hidden_channels=None,
17 | out_channels=None,
18 | drop_rate=0.):
19 | super(MLP,self).__init__()
20 | hidden_channels = hidden_channels or in_channels
21 | out_channels = out_channels or in_channels
22 | self.norm = nn.SyncBatchNorm(in_channels, eps=1e-06)
23 | self.conv1 = nn.Conv2d(in_channels, hidden_channels, 3, 1, 1)
24 | self.act = nn.GELU()
25 | self.conv2 = nn.Conv2d(hidden_channels, out_channels, 3, 1, 1)
26 | self.drop = nn.Dropout(drop_rate)
27 |
28 | self.apply(self._init_weights)
29 |
30 | def _init_weights(self, m):
31 | if isinstance(m, nn.Linear):
32 | trunc_normal_init(m.weight, std=.02)
33 | if m.bias is not None:
34 | constant_init(m.bias, val=0)
35 | elif isinstance(m, (nn.SyncBatchNorm, nn.BatchNorm2d)):
36 | constant_init(m.weight, val=1.0)
37 | constant_init(m.bias, val=0)
38 | elif isinstance(m, nn.Conv2d):
39 | kaiming_init(m.weight)
40 | if m.bias is not None:
41 | constant_init(m.bias, val=0)
42 |
43 | def forward(self, x):
44 | x = self.norm(x)
45 | x = self.conv1(x)
46 | x = self.act(x)
47 | x = self.drop(x)
48 | x = self.conv2(x)
49 | x = self.drop(x)
50 | return x
51 |
52 | class ConvolutionalAttention(BaseModule):
53 | """
54 | The ConvolutionalAttention implementation
55 | Args:
56 | in_channels (int, optional): The input channels.
57 | inter_channels (int, optional): The channels of intermediate feature.
58 | out_channels (int, optional): The output channels.
59 | num_heads (int, optional): The num of heads in attention. Default: 8
60 | """
61 |
62 | def __init__(self,
63 | in_channels,
64 | out_channels,
65 | inter_channels,
66 | num_heads=8):
67 | super(ConvolutionalAttention,self).__init__()
68 | assert out_channels % num_heads == 0, \
69 | "out_channels ({}) should be be a multiple of num_heads ({})".format(out_channels, num_heads)
70 | self.in_channels = in_channels
71 | self.out_channels = out_channels
72 | self.inter_channels = inter_channels
73 | self.num_heads = num_heads
74 | self.norm = nn.SyncBatchNorm(in_channels)
75 |
76 | self.kv =nn.Parameter(torch.zeros(inter_channels, in_channels, 7, 1))
77 | self.kv3 =nn.Parameter(torch.zeros(inter_channels, in_channels, 1, 7))
78 | trunc_normal_init(self.kv, std=0.001)
79 | trunc_normal_init(self.kv3, std=0.001)
80 |
81 | self.apply(self._init_weights)
82 |
83 | def _init_weights(self, m):
84 | if isinstance(m, nn.Linear):
85 | trunc_normal_init(m.weight, std=.001)
86 | if m.bias is not None:
87 | constant_init(m.bias, val=0.)
88 | elif isinstance(m, (nn.SyncBatchNorm, nn.BatchNorm2d)):
89 | constant_init(m.weight, val=1.)
90 | constant_init(m.bias, val=.0)
91 | elif isinstance(m, nn.Conv2d):
92 | trunc_normal_init(m.weight, std=.001)
93 | if m.bias is not None:
94 | constant_init(m.bias, val=0.)
95 |
96 |
97 | def _act_dn(self, x):
98 | x_shape = x.shape # n,c_inter,h,w
99 | h, w = x_shape[2], x_shape[3]
100 | x = x.reshape(
101 | [x_shape[0], self.num_heads, self.inter_channels // self.num_heads, -1]) #n,c_inter,h,w -> n,heads,c_inner//heads,hw
102 | x = F.softmax(x, dim=3)
103 | x = x / (torch.sum(x, dim =2, keepdim=True) + 1e-06)
104 | x = x.reshape([x_shape[0], self.inter_channels, h, w])
105 | return x
106 |
107 | def forward(self, x):
108 | """
109 | Args:
110 | x (Tensor): The input tensor. (n,c,h,w)
111 | cross_k (Tensor, optional): The dims is (n*144, c_in, 1, 1)
112 | cross_v (Tensor, optional): The dims is (n*c_in, 144, 1, 1)
113 | """
114 | x = self.norm(x)
115 | x1 = F.conv2d(
116 | x,
117 | self.kv,
118 | bias=None,
119 | stride=1,
120 | padding=(3,0))
121 | x1 = self._act_dn(x1)
122 | x1 = F.conv2d(
123 | x1, self.kv.transpose(1, 0), bias=None, stride=1,
124 | padding=(3,0))
125 | x3 = F.conv2d(
126 | x,
127 | self.kv3,
128 | bias=None,
129 | stride=1,
130 | padding=(0,3))
131 | x3 = self._act_dn(x3)
132 | x3 = F.conv2d(
133 | x3, self.kv3.transpose(1, 0), bias=None, stride=1,padding=(0,3))
134 | x=x1+x3
135 | return x
136 |
137 | class CFBlock(BaseModule):
138 | """
139 | The CFBlock implementation based on PaddlePaddle.
140 | Args:
141 | in_channels (int, optional): The input channels.
142 | out_channels (int, optional): The output channels.
143 | num_heads (int, optional): The num of heads in attention. Default: 8
144 | drop_rate (float, optional): The drop rate in MLP. Default:0.
145 | drop_path_rate (float, optional): The drop path rate in CFBlock. Default: 0.2
146 | """
147 |
148 | def __init__(self,
149 | in_channels,
150 | out_channels,
151 | num_heads=8,
152 | drop_rate=0.,
153 | drop_path_rate=0.):
154 | super(CFBlock,self).__init__()
155 | in_channels_l = in_channels
156 | out_channels_l = out_channels
157 | self.attn_l = ConvolutionalAttention(
158 | in_channels_l,
159 | out_channels_l,
160 | inter_channels=64,
161 | num_heads=num_heads)
162 | self.mlp_l = MLP(out_channels_l, drop_rate=drop_rate)
163 | self.drop_path = DropPath(
164 | drop_path_rate) if drop_path_rate > 0. else nn.Identity()
165 |
166 | def _init_weights_kaiming(self, m):
167 | if isinstance(m, nn.Linear):
168 | trunc_normal_init(m.weight, std=.02)
169 | if m.bias is not None:
170 | constant_init(m.bias, val=0)
171 | elif isinstance(m, (nn.SyncBatchNorm, nn.BatchNorm2d)):
172 | constant_init(m.weight, val=1.0)
173 | constant_init(m.bias, val=0)
174 | elif isinstance(m, nn.Conv2d):
175 | kaiming_init(m.weight)
176 | if m.bias is not None:
177 | constant_init(m.bias, val=0)
178 |
179 | def forward(self, x):
180 | x_res = x
181 | x = x_res + self.drop_path(self.attn_l(x))
182 | x = x + self.drop_path(self.mlp_l(x))
183 | return x
--------------------------------------------------------------------------------
/即插即用卷积/CKGConv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch_geometric as pyg
4 | import torch_geometric.graphgym.register as register
5 | from torch_geometric.graphgym.config import cfg
6 | from torch_geometric.graphgym.models.gnn import GNNPreMP
7 | from torch_geometric.graphgym.models.layer import (new_layer_config,
8 | BatchNorm1dNode)
9 | from torch_geometric.graphgym.register import register_network
10 | from functools import partial
11 |
12 |
13 |
14 |
15 | class FeatureEncoder(torch.nn.Module):
16 | """
17 | Encoding node and edge features
18 |
19 | Args:
20 | dim_in (int): Input feature dimension
21 | """
22 | def __init__(self, dim_in):
23 | super(FeatureEncoder, self).__init__()
24 | self.dim_in = dim_in
25 | if cfg.dataset.node_encoder:
26 | # Encode integer node features via nn.Embeddings
27 | NodeEncoder = register.node_encoder_dict[
28 | cfg.dataset.node_encoder_name]
29 | self.node_encoder = NodeEncoder(cfg.gnn.dim_inner)
30 | if cfg.dataset.node_encoder_bn:
31 | self.node_encoder_bn = BatchNorm1dNode(
32 | new_layer_config(cfg.gnn.dim_inner, -1, -1, has_act=False,
33 | has_bias=False, cfg=cfg))
34 | # Update dim_in to reflect the new dimension fo the node features
35 | self.dim_in = cfg.gnn.dim_inner
36 |
37 | if cfg.dataset.edge_encoder:
38 | # Hard-limit max edge dim for PNA.
39 | if 'PNA' in cfg.gt.layer_type:
40 | cfg.gnn.dim_edge = min(128, cfg.gnn.dim_inner)
41 | else:
42 | cfg.gnn.dim_edge = cfg.gnn.dim_inner
43 | # Encode integer edge features via nn.Embeddings
44 | EdgeEncoder = register.edge_encoder_dict[
45 | cfg.dataset.edge_encoder_name]
46 | self.edge_encoder = EdgeEncoder(cfg.gnn.dim_edge)
47 | if cfg.dataset.edge_encoder_bn:
48 | self.edge_encoder_bn = BatchNorm1dNode(
49 | new_layer_config(cfg.gnn.dim_edge, -1, -1, has_act=False,
50 | has_bias=False, cfg=cfg))
51 |
52 | def forward(self, batch):
53 | for module in self.children():
54 | batch = module(batch)
55 | return batch
56 |
57 |
58 | # class PosencEncoder(torch.nn.Module):
59 | class Stem(torch.nn.Module):
60 | """
61 | Encoding node and edge Positional Encoding
62 | Args:
63 | dim_in (int): Input feature dimension
64 | """
65 | def __init__(self, dim_in, **kwargs):
66 | super().__init__()
67 |
68 |
69 | node_stem = cfg.gt.stem.get('node_stem', None)
70 | if node_stem is None: # to be compatible with previous versions
71 | node_stem = cfg.gt.get('node_pe_encoder', None)
72 |
73 | if node_stem is None: node_stem=''
74 | node_stem = node_stem.split("+")
75 | self.node_stem = nn.Sequential(*[register.node_encoder_dict[enc](out_dim=cfg.gnn.dim_inner,
76 | )
77 | for enc in node_stem if enc != ''])
78 |
79 | edge_stem = cfg.gt.stem.get('edge_stem', None)
80 | if edge_stem is None: # to be compatible with previous versions
81 | edge_stem = cfg.gt.get('edge_pe_encoder', None)
82 |
83 | if edge_stem is None: edge_stem=''
84 | edge_stem = edge_stem.split("+")
85 | self.edge_stem = nn.Sequential(*[register.edge_encoder_dict[enc](out_dim=cfg.gnn.dim_inner,
86 | )
87 | for enc in edge_stem if enc != ''])
88 |
89 | def forward(self, batch):
90 | for module in self.children():
91 | batch = module(batch)
92 | return batch
93 |
94 |
95 | @register_network('CKGConvNet')
96 | class CKGraphConvNet(torch.nn.Module):
97 | '''
98 | The proposed Continuous Kernel Convolution Networks
99 | '''
100 |
101 | def __init__(self, dim_in, dim_out):
102 | super().__init__()
103 |
104 | if cfg.gnn.dim_inner == -1:
105 | cfg.gnn.dim_inner = cfg.gt.dim_hidden
106 |
107 | self.feat_enc = FeatureEncoder(dim_in)
108 | dim_in = self.feat_enc.dim_in
109 | # self.pe_encoder = PosencEncoder(dim_in)
110 | self.stem = Stem(dim_in)
111 |
112 | # pre-backbone normalization
113 | pre_backbone_norm = cfg.gt.get('pre_backbone_norm', False)
114 | norm_fn = nn.Identity
115 | if cfg.gt.batch_norm: norm_fn = partial(nn.BatchNorm1d, momentum=cfg.gt.bn_momentum)
116 | if cfg.gt.layer_norm: norm_fn = nn.LayerNorm
117 | graph_norm = False
118 | if cfg.gt.get('graph_norm', False):
119 | norm_fn = pyg.nn.GraphNorm
120 | graph_norm = True
121 |
122 | if pre_backbone_norm:
123 | self.pre_backbone_norm = NormalizationLayer(norm_fn(cfg.gt.dim_hidden), graph_norm=graph_norm)
124 | else:
125 | self.pre_backbone_norm = nn.Identity()
126 |
127 | if cfg.gnn.layers_pre_mp > 0:
128 | self.pre_mp = GNNPreMP(
129 | dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp)
130 | dim_in = cfg.gnn.dim_inner
131 |
132 | assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \
133 | "The inner and hidden dims must match."
134 |
135 | global_model_type = cfg.gt.get('layer_type', "CKGraphConvMLP")
136 | # global_model_type = "GritTransformer"
137 |
138 | ConvBlock = register.layer_dict.get(global_model_type)
139 | kernel_size = cfg.gt.get('kernel_size', -1)
140 | dilation = cfg.gt.get('dilation', 1)
141 | if isinstance(kernel_size, str):
142 | kernel_size = [int(i) for i in kernel_size.split(',')]
143 | else:
144 | kernel_size = [kernel_size] * cfg.gt.layers
145 |
146 |
147 | if isinstance(dilation, str):
148 | dilation = [int(i) for i in dilation.split(',')]
149 | elif isinstance(dilation, tuple):
150 | pass
151 | else:
152 | dilation = [dilation] * cfg.gt.layers
153 |
154 |
155 | layers = []
156 | for l in range(cfg.gt.layers):
157 | layers.append(ConvBlock(
158 | in_dim=cfg.gt.dim_hidden,
159 | out_dim=cfg.gt.dim_hidden,
160 | num_heads=cfg.gt.n_heads,
161 | kernel_size=kernel_size[l],
162 | dilation=dilation[l],
163 | dropout=cfg.gt.dropout,
164 | act=cfg.gnn.act,
165 | attn_dropout=cfg.gt.attn_dropout,
166 | layer_norm=cfg.gt.layer_norm,
167 | batch_norm=cfg.gt.batch_norm,
168 | residual=True,
169 | norm_e=cfg.gt.attn.norm_e,
170 | O_e=cfg.gt.attn.O_e,
171 | cfg=cfg.gt,
172 | out_norm=l==cfg.gt.layers-1
173 | # log_attn_weights=cfg.train.mode == 'log-attn-weights',
174 | ))
175 |
176 | # if global_model_type == "Norm-Res-GritTransformer" or global_model_type == "PreNormGritTransformer":
177 | # layers.append(register.layer_dict["GeneralNormLayer"]\
178 | # (dim=cfg.gt.dim_hidden,
179 | # layer_norm=cfg.gt.layer_norm,
180 | # batch_norm=cfg.gt.batch_norm,
181 | # cfg=cfg.gt
182 | # ))
183 |
184 | self.layers = torch.nn.Sequential(*layers)
185 |
186 | # pre-backbone normalization
187 | post_backbone_norm = cfg.gt.get('post_backbone_norm', False)
188 | if post_backbone_norm:
189 | self.post_backbone_norm = NormalizationLayer(norm_fn(cfg.gt.dim_hidden), graph_norm=graph_norm)
190 | else:
191 | self.post_backbone_norm = nn.Identity()
192 |
193 | GNNHead = register.head_dict[cfg.gnn.head]
194 | self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out)
195 |
196 | def forward(self, batch):
197 |
198 | for module in self.children():
199 | batch = module(batch)
200 |
201 | return batch
202 |
203 |
204 |
205 | class NormalizationLayer(nn.Module):
206 | def __init__(self, norm_layer, graph_norm=False):
207 | super().__init__()
208 | self.norm_layer = norm_layer
209 | self.graph_norm = graph_norm
210 |
211 | def forward(self, batch):
212 | if self.graph_norm:
213 | batch.x = self.norm_layer(batch.x, batch.batch)
214 |
215 | batch.x = self.norm_layer(batch.x)
216 | return batch
--------------------------------------------------------------------------------
/即插即用卷积/DCNv4.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as F
5 | from torch.autograd import Function
6 | from torch.nn.init import xavier_uniform_, constant_
7 |
8 | """
9 | 完整的代码实现可能涉及到一些C++/CUDA层的扩展,直接在Python中书写非常复杂。
10 | 由于完整实现通常会包含在setup.py或ops子目录中,以下是一个基于PyTorch框架的核心实现(假设未使用C++/CUDA扩展)。
11 | 具体的CUDA实现细节需要从官方仓库中提取。
12 | """
13 |
14 | class DCNv4Function(Function):
15 | @staticmethod
16 | def forward(ctx, input, offset_mask, kernel_h, kernel_w, stride_h, stride_w,
17 | pad_h, pad_w, dilation_h, dilation_w, group, group_channels,
18 | offset_scale, im2col_step, remove_center):
19 | ctx.save_for_backward(input, offset_mask)
20 | ctx.kernel_h, ctx.kernel_w = kernel_h, kernel_w
21 | ctx.stride_h, ctx.stride_w = stride_h, stride_w
22 | ctx.pad_h, ctx.pad_w = pad_h, pad_w
23 | ctx.dilation_h, ctx.dilation_w = dilation_h, dilation_w
24 | ctx.group = group
25 | ctx.group_channels = group_channels
26 | ctx.offset_scale = offset_scale
27 | ctx.im2col_step = im2col_step
28 | ctx.remove_center = remove_center
29 |
30 | # Placeholder for forward computation (mocked as a linear operation)
31 | # Replace this with an efficient im2col + convolution implementation if needed
32 | output = torch.nn.functional.linear(input, offset_mask)
33 | return output
34 |
35 | @staticmethod
36 | def backward(ctx, grad_output):
37 | input, offset_mask = ctx.saved_tensors
38 | # Placeholder for backward computation
39 | grad_input = grad_offset_mask = None
40 | if ctx.needs_input_grad[0]:
41 | grad_input = torch.nn.functional.linear(grad_output, offset_mask.t())
42 | if ctx.needs_input_grad[1]:
43 | grad_offset_mask = torch.nn.functional.linear(input.t(), grad_output)
44 |
45 | return (grad_input, grad_offset_mask, None, None, None, None,
46 | None, None, None, None, None, None, None, None, None)
47 |
48 |
49 | class CenterFeatureScaleModule(nn.Module):
50 | def forward(self, query, center_feature_scale_proj_weight, center_feature_scale_proj_bias):
51 | center_feature_scale = F.linear(query, weight=center_feature_scale_proj_weight,
52 | bias=center_feature_scale_proj_bias).sigmoid()
53 | return center_feature_scale
54 |
55 |
56 | class DCNv4(nn.Module):
57 | def __init__(self, channels=64, kernel_size=3, stride=1, pad=1, dilation=1,
58 | group=4, offset_scale=1.0, dw_kernel_size=None, center_feature_scale=False,
59 | remove_center=False, output_bias=True, without_pointwise=False, **kwargs):
60 | super().__init__()
61 | if channels % group != 0:
62 | raise ValueError(f'channels must be divisible by group, but got {channels} and {group}')
63 | self.offset_scale = offset_scale
64 | self.channels = channels
65 | self.kernel_size = kernel_size
66 | self.stride = stride
67 | self.dilation = dilation
68 | self.pad = pad
69 | self.group = group
70 | self.group_channels = channels // group
71 | self.offset_scale = offset_scale
72 | self.dw_kernel_size = dw_kernel_size
73 | self.center_feature_scale = center_feature_scale
74 | self.remove_center = int(remove_center)
75 | self.without_pointwise = without_pointwise
76 |
77 | self.K = group * (kernel_size * kernel_size - self.remove_center)
78 | if dw_kernel_size is not None:
79 | self.offset_mask_dw = nn.Conv2d(channels, channels, dw_kernel_size, stride=1,
80 | padding=(dw_kernel_size - 1) // 2, groups=channels)
81 | self.offset_mask = nn.Linear(channels, int(math.ceil((self.K * 3) / 8) * 8))
82 | if not without_pointwise:
83 | self.value_proj = nn.Linear(channels, channels)
84 | self.output_proj = nn.Linear(channels, channels, bias=output_bias)
85 | self._reset_parameters()
86 |
87 | if center_feature_scale:
88 | self.center_feature_scale_proj_weight = nn.Parameter(
89 | torch.zeros((group, channels), dtype=torch.float))
90 | self.center_feature_scale_proj_bias = nn.Parameter(
91 | torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group,))
92 | self.center_feature_scale_module = CenterFeatureScaleModule()
93 |
94 | def _reset_parameters(self):
95 | constant_(self.offset_mask.weight.data, 0.)
96 | constant_(self.offset_mask.bias.data, 0.)
97 | if not self.without_pointwise:
98 | xavier_uniform_(self.value_proj.weight.data)
99 | constant_(self.value_proj.bias.data, 0.)
100 | xavier_uniform_(self.output_proj.weight.data)
101 | if self.output_proj.bias is not None:
102 | constant_(self.output_proj.bias.data, 0.)
103 |
104 | def forward(self, input, shape=None):
105 | N, L, C = input.shape
106 | if shape is not None:
107 | H, W = shape
108 | else:
109 | H, W = int(L**0.5), int(L**0.5)
110 |
111 | x = input
112 | if not self.without_pointwise:
113 | x = self.value_proj(x)
114 | x = x.reshape(N, H, W, -1)
115 | if self.dw_kernel_size is not None:
116 | offset_mask_input = self.offset_mask_dw(input.view(N, H, W, C).permute(0, 3, 1, 2))
117 | offset_mask_input = offset_mask_input.permute(0, 2, 3, 1).view(N, L, C)
118 | else:
119 | offset_mask_input = input
120 | offset_mask = self.offset_mask(offset_mask_input).reshape(N, H, W, -1)
121 |
122 | x_proj = x
123 | x = DCNv4Function.apply(
124 | x, offset_mask, self.kernel_size, self.kernel_size, self.stride, self.stride,
125 | self.pad, self.pad, self.dilation, self.dilation, self.group, self.group_channels,
126 | self.offset_scale, 256, self.remove_center
127 | )
128 |
129 | if self.center_feature_scale:
130 | center_feature_scale = self.center_feature_scale_module(
131 | x, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias)
132 | center_feature_scale = center_feature_scale[..., None].repeat(
133 | 1, 1, 1, 1, self.channels // self.group).flatten(-2)
134 | x = x * (1 - center_feature_scale) + x_proj * center_feature_scale
135 |
136 | x = x.view(N, L, -1)
137 |
138 | if not self.without_pointwise:
139 | x = self.output_proj(x)
140 | return x
141 |
--------------------------------------------------------------------------------
/即插即用卷积/DEConv.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | from einops.layers.torch import Rearrange
5 |
6 |
7 | class Conv2d_cd(nn.Module):
8 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
9 | padding=1, dilation=1, groups=1, bias=False, theta=1.0):
10 | super(Conv2d_cd, self).__init__()
11 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
12 | dilation=dilation, groups=groups, bias=bias)
13 | self.theta = theta
14 |
15 | def get_weight(self):
16 | conv_weight = self.conv.weight
17 | conv_shape = conv_weight.shape
18 | conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight)
19 | conv_weight_cd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
20 | conv_weight_cd[:, :, :] = conv_weight[:, :, :]
21 | conv_weight_cd[:, :, 4] = conv_weight[:, :, 4] - conv_weight[:, :, :].sum(2)
22 | conv_weight_cd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[3])(
23 | conv_weight_cd)
24 | return conv_weight_cd, self.conv.bias
25 |
26 |
27 | class Conv2d_ad(nn.Module):
28 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
29 | padding=1, dilation=1, groups=1, bias=False, theta=1.0):
30 | super(Conv2d_ad, self).__init__()
31 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
32 | dilation=dilation, groups=groups, bias=bias)
33 | self.theta = theta
34 |
35 | def get_weight(self):
36 | conv_weight = self.conv.weight
37 | conv_shape = conv_weight.shape
38 | conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight)
39 | conv_weight_ad = conv_weight - self.theta * conv_weight[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]
40 | conv_weight_ad = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[3])(
41 | conv_weight_ad)
42 | return conv_weight_ad, self.conv.bias
43 |
44 |
45 | class Conv2d_rd(nn.Module):
46 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
47 | padding=2, dilation=1, groups=1, bias=False, theta=1.0):
48 |
49 | super(Conv2d_rd, self).__init__()
50 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
51 | dilation=dilation, groups=groups, bias=bias)
52 | self.theta = theta
53 |
54 | def forward(self, x):
55 |
56 | if math.fabs(self.theta - 0.0) < 1e-8:
57 | out_normal = self.conv(x)
58 | return out_normal
59 | else:
60 | conv_weight = self.conv.weight
61 | conv_shape = conv_weight.shape
62 | if conv_weight.is_cuda:
63 | conv_weight_rd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 5 * 5).fill_(0)
64 | else:
65 | conv_weight_rd = torch.zeros(conv_shape[0], conv_shape[1], 5 * 5)
66 | conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight)
67 | conv_weight_rd[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = conv_weight[:, :, 1:]
68 | conv_weight_rd[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -conv_weight[:, :, 1:] * self.theta
69 | conv_weight_rd[:, :, 12] = conv_weight[:, :, 0] * (1 - self.theta)
70 | conv_weight_rd = conv_weight_rd.view(conv_shape[0], conv_shape[1], 5, 5)
71 | out_diff = nn.functional.conv2d(input=x, weight=conv_weight_rd, bias=self.conv.bias,
72 | stride=self.conv.stride, padding=self.conv.padding, groups=self.conv.groups)
73 |
74 | return out_diff
75 |
76 |
77 | class Conv2d_hd(nn.Module):
78 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
79 | padding=1, dilation=1, groups=1, bias=False, theta=1.0):
80 | super(Conv2d_hd, self).__init__()
81 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
82 | dilation=dilation, groups=groups, bias=bias)
83 |
84 | def get_weight(self):
85 | conv_weight = self.conv.weight
86 | conv_shape = conv_weight.shape
87 | conv_weight_hd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
88 | conv_weight_hd[:, :, [0, 3, 6]] = conv_weight[:, :, :]
89 | conv_weight_hd[:, :, [2, 5, 8]] = -conv_weight[:, :, :]
90 | conv_weight_hd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[2])(
91 | conv_weight_hd)
92 | return conv_weight_hd, self.conv.bias
93 |
94 |
95 | class Conv2d_vd(nn.Module):
96 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
97 | padding=1, dilation=1, groups=1, bias=False):
98 | super(Conv2d_vd, self).__init__()
99 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
100 | dilation=dilation, groups=groups, bias=bias)
101 |
102 | def get_weight(self):
103 | conv_weight = self.conv.weight
104 | conv_shape = conv_weight.shape
105 | conv_weight_vd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
106 | conv_weight_vd[:, :, [0, 1, 2]] = conv_weight[:, :, :]
107 | conv_weight_vd[:, :, [6, 7, 8]] = -conv_weight[:, :, :]
108 | conv_weight_vd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[2])(
109 | conv_weight_vd)
110 | return conv_weight_vd, self.conv.bias
111 |
112 |
113 | class DEConv(nn.Module):
114 | def __init__(self, dim):
115 | super(DEConv, self).__init__()
116 | self.conv1_1 = Conv2d_cd(dim, dim, 3, bias=True)
117 | self.conv1_2 = Conv2d_hd(dim, dim, 3, bias=True)
118 | self.conv1_3 = Conv2d_vd(dim, dim, 3, bias=True)
119 | self.conv1_4 = Conv2d_ad(dim, dim, 3, bias=True)
120 | self.conv1_5 = nn.Conv2d(dim, dim, 3, padding=1, bias=True)
121 |
122 | def forward(self, x):
123 | w1, b1 = self.conv1_1.get_weight()
124 | w2, b2 = self.conv1_2.get_weight()
125 | w3, b3 = self.conv1_3.get_weight()
126 | w4, b4 = self.conv1_4.get_weight()
127 | w5, b5 = self.conv1_5.weight, self.conv1_5.bias
128 |
129 | w = w1 + w2 + w3 + w4 + w5
130 | b = b1 + b2 + b3 + b4 + b5
131 | res = nn.functional.conv2d(input=x, weight=w, bias=b, stride=1, padding=1, groups=1)
132 |
133 | return res
--------------------------------------------------------------------------------
/即插即用卷积/DynamicConv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | from timm.models.layers import CondConv2d
5 |
6 | class DynamicConv(nn.Module):
7 | """
8 | Dynamic Conv layer
9 | 将额外的参数带入到网络
10 | """
11 |
12 | def __init__(self, in_features, out_features, kernel_size=1, stride=1, padding='', dilation=1,
13 | groups=1, bias=False, num_experts=4):
14 | super().__init__()
15 | print('+++', num_experts)
16 | self.routing = nn.Linear(in_features, num_experts)
17 | self.cond_conv = CondConv2d(in_features, out_features, kernel_size, stride, padding, dilation,
18 | groups, bias, num_experts)
19 |
20 | def forward(self, x):
21 | pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) # CondConv routing
22 | routing_weights = torch.sigmoid(self.routing(pooled_inputs))
23 | x = self.cond_conv(x, routing_weights)
24 | return x
--------------------------------------------------------------------------------
/即插即用卷积/LDConv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | from einops import rearrange
5 |
6 | class LDConv(nn.Module):
7 | def __init__(self, inc, outc, num_param, stride=1, bias=None):
8 | super(LDConv, self).__init__()
9 | self.num_param = num_param
10 | self.stride = stride
11 | self.conv = nn.Sequential(nn.Conv2d(inc, outc, kernel_size=(num_param, 1), stride=(num_param, 1), bias=bias),
12 | nn.BatchNorm2d(outc),
13 | nn.SiLU()) # the conv adds the BN and SiLU to compare original Conv in YOLOv5.
14 | self.p_conv = nn.Conv2d(inc, 2 * num_param, kernel_size=3, padding=1, stride=stride)
15 | nn.init.constant_(self.p_conv.weight, 0)
16 | self.p_conv.register_full_backward_hook(self._set_lr)
17 | self.register_buffer("p_n", self._get_p_n(N=self.num_param))
18 |
19 | @staticmethod
20 | def _set_lr(module, grad_input, grad_output):
21 | grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input)))
22 | grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output)))
23 |
24 | def forward(self, x):
25 | # N is num_param.
26 | offset = self.p_conv(x)
27 | dtype = offset.data.type()
28 | N = offset.size(1) // 2
29 | # (b, 2N, h, w)
30 | p = self._get_p(offset, dtype)
31 |
32 | # (b, h, w, 2N)
33 | p = p.contiguous().permute(0, 2, 3, 1)
34 | q_lt = p.detach().floor()
35 | q_rb = q_lt + 1
36 |
37 | q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2) - 1), torch.clamp(q_lt[..., N:], 0, x.size(3) - 1)],
38 | dim=-1).long()
39 | q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2) - 1), torch.clamp(q_rb[..., N:], 0, x.size(3) - 1)],
40 | dim=-1).long()
41 | q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)
42 | q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)
43 |
44 | # clip p
45 | p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2) - 1), torch.clamp(p[..., N:], 0, x.size(3) - 1)], dim=-1)
46 |
47 | # bilinear kernel (b, h, w, N)
48 | g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
49 | g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
50 | g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
51 | g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))
52 |
53 | # resampling the features based on the modified coordinates.
54 | x_q_lt = self._get_x_q(x, q_lt, N)
55 | x_q_rb = self._get_x_q(x, q_rb, N)
56 | x_q_lb = self._get_x_q(x, q_lb, N)
57 | x_q_rt = self._get_x_q(x, q_rt, N)
58 |
59 | # bilinear
60 | x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
61 | g_rb.unsqueeze(dim=1) * x_q_rb + \
62 | g_lb.unsqueeze(dim=1) * x_q_lb + \
63 | g_rt.unsqueeze(dim=1) * x_q_rt
64 |
65 | x_offset = self._reshape_x_offset(x_offset, self.num_param)
66 | out = self.conv(x_offset)
67 |
68 | return out
69 |
70 | # generating the inital sampled shapes for the LDConv with different sizes.
71 | def _get_p_n(self, N):
72 | base_int = round(math.sqrt(self.num_param))
73 | row_number = self.num_param // base_int
74 | mod_number = self.num_param % base_int
75 | p_n_x, p_n_y = torch.meshgrid(
76 | torch.arange(0, row_number),
77 | torch.arange(0, base_int))
78 | p_n_x = torch.flatten(p_n_x)
79 | p_n_y = torch.flatten(p_n_y)
80 | if mod_number > 0:
81 | mod_p_n_x, mod_p_n_y = torch.meshgrid(
82 | torch.arange(row_number, row_number + 1),
83 | torch.arange(0, mod_number))
84 |
85 | mod_p_n_x = torch.flatten(mod_p_n_x)
86 | mod_p_n_y = torch.flatten(mod_p_n_y)
87 | p_n_x, p_n_y = torch.cat((p_n_x, mod_p_n_x)), torch.cat((p_n_y, mod_p_n_y))
88 | p_n = torch.cat([p_n_x, p_n_y], 0)
89 | p_n = p_n.view(1, 2 * N, 1, 1)
90 | return p_n
91 |
92 | # no zero-padding
93 | def _get_p_0(self, h, w, N, dtype):
94 | p_0_x, p_0_y = torch.meshgrid(
95 | torch.arange(0, h * self.stride, self.stride),
96 | torch.arange(0, w * self.stride, self.stride))
97 |
98 | p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
99 | p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
100 | p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)
101 |
102 | return p_0
103 |
104 | def _get_p(self, offset, dtype):
105 | N, h, w = offset.size(1) // 2, offset.size(2), offset.size(3)
106 |
107 | # (1, 2N, 1, 1)
108 | # p_n = self._get_p_n(N, dtype)
109 | # (1, 2N, h, w)
110 | p_0 = self._get_p_0(h, w, N, dtype)
111 | p = p_0 + self.p_n + offset
112 | return p
113 |
114 | def _get_x_q(self, x, q, N):
115 | b, h, w, _ = q.size()
116 | padded_w = x.size(3)
117 | c = x.size(1)
118 | # (b, c, h*w)
119 | x = x.contiguous().view(b, c, -1)
120 |
121 | # (b, h, w, N)
122 | index = q[..., :N] * padded_w + q[..., N:] # offset_x*w + offset_y
123 | # (b, c, h*w*N)
124 | index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)
125 |
126 | x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)
127 |
128 | return x_offset
129 |
130 | # Stacking resampled features in the row direction.
131 | @staticmethod
132 | def _reshape_x_offset(x_offset, num_param):
133 | b, c, h, w, n = x_offset.size()
134 | # using Conv3d
135 | # x_offset = x_offset.permute(0,1,4,2,3), then Conv3d(c,c_out, kernel_size =(num_param,1,1),stride=(num_param,1,1),bias= False)
136 | # using 1 × 1 Conv
137 | # x_offset = x_offset.permute(0,1,4,2,3), then, x_offset.view(b,c×num_param,h,w) finally, Conv2d(c×num_param,c_out, kernel_size =1,stride=1,bias= False)
138 | # using the column conv as follow, then, Conv2d(inc, outc, kernel_size=(num_param, 1), stride=(num_param, 1), bias=bias)
139 |
140 | x_offset = rearrange(x_offset, 'b c h w n -> b c (h n) w')
141 | return x_offset
--------------------------------------------------------------------------------
/即插即用卷积/MorphologyConv.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | def fixed_padding(inputs, kernel_size, dilation):
8 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
9 | pad_total = kernel_size_effective - 1
10 | pad_beg = pad_total // 2
11 | pad_end = pad_total - pad_beg
12 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
13 | return padded_inputs
14 |
15 |
16 | class MorphologyConv(nn.Module):
17 | '''
18 | Base class for morpholigical operators
19 | For now, only supports stride=1, dilation=1, kernel_size H==W, and padding='same'.
20 | '''
21 |
22 | def __init__(self, in_channels, out_channels, kernel_size=5, soft_max=True, beta=15, type=None):
23 | '''
24 | in_channels: scalar
25 | out_channels: scalar, the number of the morphological neure.
26 | kernel_size: scalar, the spatial size of the morphological neure.
27 | soft_max: bool, using the soft max rather the torch.max(), ref: Dense Morphological Networks: An Universal Function Approximator (Mondal et al. (2019)).
28 | beta: scalar, used by soft_max.
29 | type: str, dilation2d or erosion2d.
30 | '''
31 | super(MorphologyConv, self).__init__()
32 | self.in_channels = in_channels
33 | self.out_channels = out_channels
34 | self.kernel_size = kernel_size
35 | self.soft_max = soft_max
36 | self.beta = beta
37 | self.type = type
38 |
39 | self.weight = nn.Parameter(torch.ones(out_channels, in_channels, kernel_size, kernel_size), requires_grad=True)
40 | self.unfold = nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)
41 |
42 | def forward(self, x):
43 | '''
44 | x: tensor of shape (B,C,H,W)
45 | '''
46 | # padding
47 | x = fixed_padding(x, self.kernel_size, dilation=1)
48 |
49 | # unfold
50 | x = self.unfold(x) # (B, Cin*kH*kW, L), where L is the numbers of patches
51 | x = x.unsqueeze(1) # (B, 1, Cin*kH*kW, L)
52 | L = x.size(-1)
53 | L_sqrt = int(math.sqrt(L))
54 |
55 | # erosion
56 | weight = self.weight.view(self.out_channels, -1) # (Cout, Cin*kH*kW)
57 | weight = weight.unsqueeze(0).unsqueeze(-1) # (1, Cout, Cin*kH*kW, 1)
58 |
59 | if self.type == 'erosion2d':
60 | x = weight - x # (B, Cout, Cin*kH*kW, L)
61 | elif self.type == 'dilation2d':
62 | x = weight + x # (B, Cout, Cin*kH*kW, L)
63 | else:
64 | raise ValueError
65 |
66 | if not self.soft_max:
67 | x, _ = torch.max(x, dim=2, keepdim=False) # (B, Cout, L)
68 | else:
69 | x = torch.logsumexp(x * self.beta, dim=2, keepdim=False) / self.beta # (B, Cout, L)
70 |
71 | if self.type == 'erosion2d':
72 | x = -1 * x
73 |
74 | # instead of fold, we use view to avoid copy
75 | x = x.view(-1, self.out_channels, L_sqrt, L_sqrt) # (B, Cout, L/2, L/2)
76 |
77 | return x
78 |
79 |
80 | class Dilation2d(MorphologyConv):
81 | def __init__(self, in_channels, out_channels, kernel_size=5, soft_max=True, beta=20):
82 | super(Dilation2d, self).__init__(in_channels, out_channels, kernel_size, soft_max, beta, 'dilation2d')
83 |
84 | class Erosion2d(MorphologyConv):
85 | def __init__(self, in_channels, out_channels, kernel_size=5, soft_max=True, beta=20):
86 | super(Erosion2d, self).__init__(in_channels, out_channels, kernel_size, soft_max, beta, 'erosion2d')
--------------------------------------------------------------------------------
/即插即用卷积/PConv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch import Tensor
4 |
5 | class Partial_conv3(nn.Module):
6 |
7 | def __init__(self, dim, n_div, forward):
8 | super().__init__()
9 | self.dim_conv3 = dim // n_div
10 | self.dim_untouched = dim - self.dim_conv3
11 | self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False)
12 |
13 | if forward == 'slicing':
14 | self.forward = self.forward_slicing
15 | elif forward == 'split_cat':
16 | self.forward = self.forward_split_cat
17 | else:
18 | raise NotImplementedError
19 |
20 | def forward_slicing(self, x: Tensor) -> Tensor:
21 | # only for inference
22 | x = x.clone() # !!! Keep the original input intact for the residual connection later
23 | x[:, :self.dim_conv3, :, :] = self.partial_conv3(x[:, :self.dim_conv3, :, :])
24 |
25 | return x
26 |
27 | def forward_split_cat(self, x: Tensor) -> Tensor:
28 | # for training/inference
29 | x1, x2 = torch.split(x, [self.dim_conv3, self.dim_untouched], dim=1)
30 | x1 = self.partial_conv3(x1)
31 | x = torch.cat((x1, x2), 1)
32 |
33 | return x
--------------------------------------------------------------------------------
/即插即用卷积/SCConv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 |
5 |
6 | class GroupBatchnorm2d(nn.Module):
7 | def __init__(self, c_num: int,
8 | group_num: int = 16,
9 | eps: float = 1e-10
10 | ):
11 | super(GroupBatchnorm2d, self).__init__()
12 | assert c_num >= group_num
13 | self.group_num = group_num
14 | self.weight = nn.Parameter(torch.randn(c_num, 1, 1))
15 | self.bias = nn.Parameter(torch.zeros(c_num, 1, 1))
16 | self.eps = eps
17 |
18 | def forward(self, x):
19 | N, C, H, W = x.size()
20 | x = x.view(N, self.group_num, -1)
21 | mean = x.mean(dim=2, keepdim=True)
22 | std = x.std(dim=2, keepdim=True)
23 | x = (x - mean) / (std + self.eps)
24 | x = x.view(N, C, H, W)
25 | return x * self.weight + self.bias
26 |
27 |
28 | class SRU(nn.Module):
29 | def __init__(self,
30 | oup_channels: int,
31 | group_num: int = 16,
32 | gate_treshold: float = 0.5,
33 | torch_gn: bool = True
34 | ):
35 | super().__init__()
36 |
37 | self.gn = nn.GroupNorm(num_channels=oup_channels, num_groups=group_num) if torch_gn else GroupBatchnorm2d(
38 | c_num=oup_channels, group_num=group_num)
39 | self.gate_treshold = gate_treshold
40 | self.sigomid = nn.Sigmoid()
41 |
42 | def forward(self, x):
43 | gn_x = self.gn(x)
44 | w_gamma = self.gn.weight / sum(self.gn.weight)
45 | w_gamma = w_gamma.view(1, -1, 1, 1)
46 | reweigts = self.sigomid(gn_x * w_gamma)
47 | # Gate
48 | w1 = torch.where(reweigts > self.gate_treshold, torch.ones_like(reweigts), reweigts) # 大于门限值的设为1,否则保留原值
49 | w2 = torch.where(reweigts > self.gate_treshold, torch.zeros_like(reweigts), reweigts) # 大于门限值的设为0,否则保留原值
50 | x_1 = w1 * x
51 | x_2 = w2 * x
52 | y = self.reconstruct(x_1, x_2)
53 | return y
54 |
55 | def reconstruct(self, x_1, x_2):
56 | x_11, x_12 = torch.split(x_1, x_1.size(1) // 2, dim=1)
57 | x_21, x_22 = torch.split(x_2, x_2.size(1) // 2, dim=1)
58 | return torch.cat([x_11 + x_22, x_12 + x_21], dim=1)
59 |
60 |
61 | class CRU(nn.Module):
62 | '''
63 | alpha: 0 0. else nn.Identity()
25 |
26 | def forward(self, x):
27 | input = x
28 | x = self.dwconv(x)
29 | x1, x2 = self.f1(x), self.f2(x)
30 | x = self.act(x1) * x2
31 | x = self.dwconv2(self.g(x))
32 | x = input + self.drop_path(x)
33 | return x
--------------------------------------------------------------------------------
/即插即用卷积/WTConv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import pywt
5 |
6 | def create_wavelet_filter(wave, in_size, out_size, type=torch.float):
7 | w = pywt.Wavelet(wave)
8 | dec_hi = torch.tensor(w.dec_hi[::-1], dtype=type)
9 | dec_lo = torch.tensor(w.dec_lo[::-1], dtype=type)
10 | dec_filters = torch.stack([dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1),
11 | dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1),
12 | dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1),
13 | dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)], dim=0)
14 |
15 | dec_filters = dec_filters[:, None].repeat(in_size, 1, 1, 1)
16 |
17 | rec_hi = torch.tensor(w.rec_hi[::-1], dtype=type).flip(dims=[0])
18 | rec_lo = torch.tensor(w.rec_lo[::-1], dtype=type).flip(dims=[0])
19 | rec_filters = torch.stack([rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1),
20 | rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1),
21 | rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1),
22 | rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1)], dim=0)
23 |
24 | rec_filters = rec_filters[:, None].repeat(out_size, 1, 1, 1)
25 |
26 | return dec_filters, rec_filters
27 |
28 | def wavelet_transform(x, filters):
29 | b, c, h, w = x.shape
30 | pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)
31 | x = F.conv2d(x, filters, stride=2, groups=c, padding=pad)
32 | x = x.reshape(b, c, 4, h // 2, w // 2)
33 | return x
34 |
35 |
36 | def inverse_wavelet_transform(x, filters):
37 | b, c, _, h_half, w_half = x.shape
38 | pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)
39 | x = x.reshape(b, c * 4, h_half, w_half)
40 | x = F.conv_transpose2d(x, filters, stride=2, groups=c, padding=pad)
41 | return x
42 |
43 |
44 | class _ScaleModule(nn.Module):
45 | def __init__(self, dims, init_scale=1.0, init_bias=0):
46 | super(_ScaleModule, self).__init__()
47 | self.dims = dims
48 | self.weight = nn.Parameter(torch.ones(*dims) * init_scale)
49 | self.bias = None
50 |
51 | def forward(self, x):
52 | return torch.mul(self.weight, x)
53 |
54 |
55 | class WTConv2d(nn.Module):
56 | def __init__(self, in_channels, out_channels, kernel_size=5, stride=1, bias=True, wt_levels=1, wt_type='db1'):
57 | super(WTConv2d, self).__init__()
58 |
59 | assert in_channels == out_channels
60 |
61 | self.in_channels = in_channels
62 | self.wt_levels = wt_levels
63 | self.stride = stride
64 | self.dilation = 1
65 |
66 | self.wt_filter, self.iwt_filter = create_wavelet_filter(wt_type, in_channels, in_channels, torch.float)
67 | self.wt_filter = nn.Parameter(self.wt_filter, requires_grad=False)
68 | self.iwt_filter = nn.Parameter(self.iwt_filter, requires_grad=False)
69 |
70 | self.base_conv = nn.Conv2d(in_channels, in_channels, kernel_size, padding='same', stride=1, dilation=1, groups=in_channels, bias=bias)
71 | self.base_scale = _ScaleModule([1 ,in_channels ,1 ,1])
72 |
73 | self.wavelet_convs = nn.ModuleList(
74 | [nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size, padding='same', stride=1, dilation=1, groups=in_channels * 4, bias=False) for _ in range(self.wt_levels)]
75 | )
76 | self.wavelet_scale = nn.ModuleList(
77 | [_ScaleModule([1 ,in_channels * 4 ,1 ,1], init_scale=0.1) for _ in range(self.wt_levels)]
78 | )
79 |
80 | if self.stride > 1:
81 | self.do_stride = nn.AvgPool2d(kernel_size=1, stride=stride)
82 | else:
83 | self.do_stride = None
84 |
85 | def forward(self, x):
86 |
87 | x_ll_in_levels = []
88 | x_h_in_levels = []
89 | shapes_in_levels = []
90 |
91 | curr_x_ll = x
92 |
93 | for i in range(self.wt_levels):
94 | curr_shape = curr_x_ll.shape
95 | shapes_in_levels.append(curr_shape)
96 | if (curr_shape[2] % 2 > 0) or (curr_shape[3] % 2 > 0):
97 | curr_pads = (0, curr_shape[3] % 2, 0, curr_shape[2] % 2)
98 | curr_x_ll = F.pad(curr_x_ll, curr_pads)
99 |
100 | curr_x = wavelet_transform(curr_x_ll, self.wt_filter)
101 | curr_x_ll = curr_x[: ,: ,0 ,: ,:]
102 |
103 | shape_x = curr_x.shape
104 | curr_x_tag = curr_x.reshape(shape_x[0], shape_x[1] * 4, shape_x[3], shape_x[4])
105 | curr_x_tag = self.wavelet_scale[i](self.wavelet_convs[i](curr_x_tag))
106 | curr_x_tag = curr_x_tag.reshape(shape_x)
107 |
108 | x_ll_in_levels.append(curr_x_tag[: ,: ,0 ,: ,:])
109 | x_h_in_levels.append(curr_x_tag[: ,: ,1:4 ,: ,:])
110 |
111 | next_x_ll = 0
112 |
113 | for i in range(self.wt_levels-1, -1, -1):
114 | curr_x_ll = x_ll_in_levels.pop()
115 | curr_x_h = x_h_in_levels.pop()
116 | curr_shape = shapes_in_levels.pop()
117 |
118 | curr_x_ll = curr_x_ll + next_x_ll
119 |
120 | curr_x = torch.cat([curr_x_ll.unsqueeze(2), curr_x_h], dim=2)
121 | next_x_ll = inverse_wavelet_transform(curr_x, self.iwt_filter)
122 |
123 | next_x_ll = next_x_ll[:, :, :curr_shape[2], :curr_shape[3]]
124 |
125 | x_tag = next_x_ll
126 | assert len(x_ll_in_levels) == 0
127 |
128 | x = self.base_scale(self.base_conv(x))
129 | x = x + x_tag
130 |
131 | if self.do_stride is not None:
132 | x = self.do_stride(x)
133 |
134 | return x
--------------------------------------------------------------------------------
/注意力模块/ASSA.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
4 | from einops import rearrange, repeat
5 |
6 |
7 | def conv(in_channels, out_channels, kernel_size, bias=False, stride=1):
8 | return nn.Conv2d(
9 | in_channels, out_channels, kernel_size,
10 | padding=(kernel_size // 2), bias=bias, stride=stride)
11 |
12 |
13 | #########################################
14 | class ConvBlock(nn.Module):
15 | def __init__(self, in_channel, out_channel, strides=1):
16 | super(ConvBlock, self).__init__()
17 | self.strides = strides
18 | self.in_channel = in_channel
19 | self.out_channel = out_channel
20 | self.block = nn.Sequential(
21 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=strides, padding=1),
22 | nn.LeakyReLU(inplace=True),
23 | nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=strides, padding=1),
24 | nn.LeakyReLU(inplace=True),
25 | )
26 | self.conv11 = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=strides, padding=0)
27 |
28 | def forward(self, x):
29 | out1 = self.block(x)
30 | out2 = self.conv11(x)
31 | out = out1 + out2
32 | return out
33 |
34 |
35 | class LinearProjection(nn.Module):
36 | def __init__(self, dim, heads=8, dim_head=64, dropout=0., bias=True):
37 | super().__init__()
38 | inner_dim = dim_head * heads
39 | self.heads = heads
40 | self.to_q = nn.Linear(dim, inner_dim, bias=bias)
41 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=bias)
42 | self.dim = dim
43 | self.inner_dim = inner_dim
44 |
45 | def forward(self, x, attn_kv=None):
46 | B_, N, C = x.shape
47 | if attn_kv is not None:
48 | attn_kv = attn_kv.unsqueeze(0).repeat(B_, 1, 1)
49 | else:
50 | attn_kv = x
51 | N_kv = attn_kv.size(1)
52 | q = self.to_q(x).reshape(B_, N, 1, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
53 | kv = self.to_kv(attn_kv).reshape(B_, N_kv, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
54 | q = q[0]
55 | k, v = kv[0], kv[1]
56 | return q, k, v
57 |
58 |
59 | #########################################
60 | ########### window-based self-attention #############
61 | class WindowAttention(nn.Module):
62 | def __init__(self, dim, win_size, num_heads, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0.,
63 | proj_drop=0.):
64 |
65 | super().__init__()
66 | self.dim = dim
67 | self.win_size = win_size # Wh, Ww
68 | self.num_heads = num_heads
69 | head_dim = dim // num_heads
70 | self.scale = qk_scale or head_dim ** -0.5
71 |
72 | # define a parameter table of relative position bias
73 | self.relative_position_bias_table = nn.Parameter(
74 | torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
75 |
76 | # get pair-wise relative position index for each token inside the window
77 | coords_h = torch.arange(self.win_size[0]) # [0,...,Wh-1]
78 | coords_w = torch.arange(self.win_size[1]) # [0,...,Ww-1]
79 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
80 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
81 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
82 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
83 | relative_coords[:, :, 0] += self.win_size[0] - 1 # shift to start from 0
84 | relative_coords[:, :, 1] += self.win_size[1] - 1
85 | relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1
86 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
87 | self.register_buffer("relative_position_index", relative_position_index)
88 | trunc_normal_(self.relative_position_bias_table, std=.02)
89 |
90 | if token_projection == 'linear':
91 | self.qkv = LinearProjection(dim, num_heads, dim // num_heads, bias=qkv_bias)
92 | else:
93 | raise Exception("Projection error!")
94 |
95 | self.token_projection = token_projection
96 | self.attn_drop = nn.Dropout(attn_drop)
97 | self.proj = nn.Linear(dim, dim)
98 | self.proj_drop = nn.Dropout(proj_drop)
99 |
100 | self.softmax = nn.Softmax(dim=-1)
101 |
102 | def forward(self, x, attn_kv=None, mask=None):
103 | B_, N, C = x.shape
104 | q, k, v = self.qkv(x, attn_kv)
105 | q = q * self.scale
106 | attn = (q @ k.transpose(-2, -1))
107 |
108 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
109 | self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1) # Wh*Ww,Wh*Ww,nH
110 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
111 | ratio = attn.size(-1) // relative_position_bias.size(-1)
112 | relative_position_bias = repeat(relative_position_bias, 'nH l c -> nH l (c d)', d=ratio)
113 |
114 | attn = attn + relative_position_bias.unsqueeze(0)
115 |
116 | if mask is not None:
117 | nW = mask.shape[0]
118 | mask = repeat(mask, 'nW m n -> nW m (n d)', d=ratio)
119 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N * ratio) + mask.unsqueeze(1).unsqueeze(0)
120 | attn = attn.view(-1, self.num_heads, N, N * ratio)
121 | attn = self.softmax(attn)
122 | else:
123 | attn = self.softmax(attn)
124 |
125 | attn = self.attn_drop(attn)
126 |
127 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
128 | x = self.proj(x)
129 | x = self.proj_drop(x)
130 | return x
131 |
132 | def extra_repr(self) -> str:
133 | return f'dim={self.dim}, win_size={self.win_size}, num_heads={self.num_heads}'
134 |
135 |
136 | ########### window-based self-attention #############
137 | class WindowAttention_sparse(nn.Module):
138 | def __init__(self, dim, win_size, num_heads, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0.,
139 | proj_drop=0.):
140 |
141 | super().__init__()
142 | self.dim = dim
143 | self.win_size = win_size # Wh, Ww
144 | self.num_heads = num_heads
145 | head_dim = dim // num_heads
146 | self.scale = qk_scale or head_dim ** -0.5
147 |
148 | # define a parameter table of relative position bias
149 | self.relative_position_bias_table = nn.Parameter(
150 | torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
151 |
152 | # get pair-wise relative position index for each token inside the window
153 | coords_h = torch.arange(self.win_size[0]) # [0,...,Wh-1]
154 | coords_w = torch.arange(self.win_size[1]) # [0,...,Ww-1]
155 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
156 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
157 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
158 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
159 | relative_coords[:, :, 0] += self.win_size[0] - 1 # shift to start from 0
160 | relative_coords[:, :, 1] += self.win_size[1] - 1
161 | relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1
162 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
163 | self.register_buffer("relative_position_index", relative_position_index)
164 | trunc_normal_(self.relative_position_bias_table, std=.02)
165 |
166 | if token_projection == 'linear':
167 | self.qkv = LinearProjection(dim, num_heads, dim // num_heads, bias=qkv_bias)
168 | else:
169 | raise Exception("Projection error!")
170 |
171 | self.token_projection = token_projection
172 | self.attn_drop = nn.Dropout(attn_drop)
173 | self.proj = nn.Linear(dim, dim)
174 | self.proj_drop = nn.Dropout(proj_drop)
175 |
176 | self.softmax = nn.Softmax(dim=-1)
177 | self.relu = nn.ReLU()
178 | self.w = nn.Parameter(torch.ones(2))
179 |
180 | def forward(self, x, attn_kv=None, mask=None):
181 | B_, N, C = x.shape
182 | q, k, v = self.qkv(x, attn_kv)
183 | q = q * self.scale
184 | attn = (q @ k.transpose(-2, -1))
185 |
186 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
187 | self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1) # Wh*Ww,Wh*Ww,nH
188 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
189 | ratio = attn.size(-1) // relative_position_bias.size(-1)
190 | relative_position_bias = repeat(relative_position_bias, 'nH l c -> nH l (c d)', d=ratio)
191 |
192 | attn = attn + relative_position_bias.unsqueeze(0)
193 |
194 | if mask is not None:
195 | nW = mask.shape[0]
196 | mask = repeat(mask, 'nW m n -> nW m (n d)', d=ratio)
197 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N * ratio) + mask.unsqueeze(1).unsqueeze(0)
198 | attn = attn.view(-1, self.num_heads, N, N * ratio)
199 | attn0 = self.softmax(attn)
200 | attn1 = self.relu(attn) ** 2 # b,h,w,c
201 | else:
202 | attn0 = self.softmax(attn)
203 | attn1 = self.relu(attn) ** 2
204 | w1 = torch.exp(self.w[0]) / torch.sum(torch.exp(self.w))
205 | w2 = torch.exp(self.w[1]) / torch.sum(torch.exp(self.w))
206 | attn = attn0 * w1 + attn1 * w2
207 | attn = self.attn_drop(attn)
208 |
209 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
210 | x = self.proj(x)
211 | x = self.proj_drop(x)
212 | return x
213 |
214 | def extra_repr(self) -> str:
215 | return f'dim={self.dim}, win_size={self.win_size}, num_heads={self.num_heads}'
216 |
217 |
218 | ########### self-attention #############
219 | class Attention(nn.Module):
220 | def __init__(self, dim, num_heads, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0.,
221 | proj_drop=0.):
222 |
223 | super().__init__()
224 | self.dim = dim
225 | self.num_heads = num_heads
226 | head_dim = dim // num_heads
227 | self.scale = qk_scale or head_dim ** -0.5
228 |
229 | self.qkv = LinearProjection(dim, num_heads, dim // num_heads, bias=qkv_bias)
230 |
231 | self.token_projection = token_projection
232 | self.attn_drop = nn.Dropout(attn_drop)
233 | self.proj = nn.Linear(dim, dim)
234 | self.proj_drop = nn.Dropout(proj_drop)
235 |
236 | self.softmax = nn.Softmax(dim=-1)
237 |
238 | def forward(self, x, attn_kv=None, mask=None):
239 | B_, N, C = x.shape
240 | q, k, v = self.qkv(x, attn_kv)
241 | q = q * self.scale
242 | attn = (q @ k.transpose(-2, -1))
243 | if mask is not None:
244 | nW = mask.shape[0]
245 | # mask = repeat(mask, 'nW m n -> nW m (n d)',d = ratio)
246 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
247 | attn = attn.view(-1, self.num_heads, N, N)
248 | attn = self.softmax(attn)
249 | else:
250 | attn = self.softmax(attn)
251 |
252 | attn = self.attn_drop(attn)
253 |
254 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
255 | x = self.proj(x)
256 | x = self.proj_drop(x)
257 | return x
258 |
259 | def extra_repr(self) -> str:
260 | return f'dim={self.dim}, num_heads={self.num_heads}'
--------------------------------------------------------------------------------
/注意力模块/AgentAttention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from timm.models.layers import trunc_normal_
4 | import numpy as np
5 |
6 |
7 | def img2windows(img, H_sp, W_sp):
8 | """
9 | img: B C H W
10 | """
11 | B, C, H, W = img.shape
12 | img_reshape = img.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp)
13 | img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp * W_sp, C)
14 | return img_perm
15 |
16 | def windows2img(img_splits_hw, H_sp, W_sp, H, W):
17 | """
18 | img_splits_hw: B' H W C
19 | """
20 | B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp))
21 |
22 | img = img_splits_hw.view(B, H // H_sp, W // W_sp, H_sp, W_sp, -1)
23 | img = img.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
24 | return img
25 |
26 |
27 | class AgentAttention(nn.Module):
28 | def __init__(self, dim, resolution, idx, split_size=7, dim_out=None, num_heads=8, attn_drop=0., proj_drop=0.,
29 | agent_num=49, **kwargs):
30 | super().__init__()
31 | self.dim = dim
32 | self.dim_out = dim_out or dim
33 | self.resolution = resolution
34 | self.split_size = split_size
35 | self.num_heads = num_heads
36 | head_dim = dim // num_heads
37 | self.agent_num = agent_num
38 | self.scale = head_dim ** -0.5
39 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
40 | # self.scale = qk_scale or head_dim ** -0.5
41 | if idx == -1:
42 | H_sp, W_sp = self.resolution, self.resolution
43 | elif idx == 0:
44 | H_sp, W_sp = self.resolution, self.split_size
45 | elif idx == 1:
46 | W_sp, H_sp = self.resolution, self.split_size
47 | else:
48 | print("ERROR MODE", idx)
49 | exit(0)
50 | self.H_sp = H_sp
51 | self.W_sp = W_sp
52 | self.get_v = nn.Conv2d(dim, dim, kernel_size=(3, 3), stride=(1, 1), padding=1, groups=dim)
53 |
54 | self.attn_drop = nn.Dropout(attn_drop)
55 |
56 | self.an_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7))
57 | self.na_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7))
58 | self.ah_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, H_sp, 1))
59 | self.aw_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, 1, W_sp))
60 | self.ha_bias = nn.Parameter(torch.zeros(1, num_heads, H_sp, 1, agent_num))
61 | self.wa_bias = nn.Parameter(torch.zeros(1, num_heads, 1, W_sp, agent_num))
62 | trunc_normal_(self.an_bias, std=.02)
63 | trunc_normal_(self.na_bias, std=.02)
64 | trunc_normal_(self.ah_bias, std=.02)
65 | trunc_normal_(self.aw_bias, std=.02)
66 | trunc_normal_(self.ha_bias, std=.02)
67 | trunc_normal_(self.wa_bias, std=.02)
68 | pool_size = int(agent_num ** 0.5)
69 | self.pool = nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size))
70 | self.softmax = nn.Softmax(dim=-1)
71 |
72 | def im2cswin(self, x):
73 | B, N, C = x.shape
74 | H = W = int(np.sqrt(N))
75 | x = x.transpose(-2, -1).contiguous().view(B, C, H, W)
76 | x = img2windows(x, self.H_sp, self.W_sp)
77 | # x = x.reshape(-1, self.H_sp * self.W_sp, C).contiguous()
78 | return x
79 |
80 | def get_lepe(self, x, func):
81 | B, N, C = x.shape
82 | H = W = int(np.sqrt(N))
83 | x = x.transpose(-2, -1).contiguous().view(B, C, H, W)
84 |
85 | H_sp, W_sp = self.H_sp, self.W_sp
86 | x = x.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp)
87 | x = x.permute(0, 2, 4, 1, 3, 5).contiguous().reshape(-1, C, H_sp, W_sp) ### B', C, H', W'
88 |
89 | lepe = func(x) ### B', C, H', W'
90 | lepe = lepe.reshape(-1, C, H_sp * W_sp).permute(0, 2, 1).contiguous()
91 |
92 | x = x.reshape(-1, C, self.H_sp * self.W_sp).permute(0, 2, 1).contiguous()
93 | return x, lepe
94 |
95 | def forward(self, qkv):
96 | """
97 | x: B L C
98 | """
99 | q, k, v = qkv[0], qkv[1], qkv[2]
100 |
101 | ### Img2Window
102 | H = W = self.resolution
103 | B, L, C = q.shape
104 | assert L == H * W, "flatten img_tokens has wrong size"
105 |
106 | q = self.im2cswin(q)
107 | k = self.im2cswin(k)
108 | v, lepe = self.get_lepe(v, self.get_v)
109 | # q, k, v = (rearrange(x, "b h n c -> b n (h c)", h=self.num_heads) for x in [q, k, v])
110 |
111 | b, n, c = q.shape
112 | h, w = self.H_sp, self.W_sp
113 | num_heads, head_dim = self.num_heads, self.dim // self.num_heads
114 |
115 | agent_tokens = self.pool(q.reshape(b, h, w, c).permute(0, 3, 1, 2)).reshape(b, c, -1).permute(0, 2, 1)
116 | q = q.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
117 | k = k.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
118 | v = v.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
119 | agent_tokens = agent_tokens.reshape(b, self.agent_num, num_heads, head_dim).permute(0, 2, 1, 3)
120 |
121 | position_bias1 = nn.functional.interpolate(self.an_bias, size=(self.H_sp, self.W_sp), mode='bilinear')
122 | position_bias1 = position_bias1.reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1)
123 | position_bias2 = (self.ah_bias + self.aw_bias).reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1)
124 | position_bias = position_bias1 + position_bias2
125 | agent_attn = self.softmax((agent_tokens * self.scale) @ k.transpose(-2, -1) + position_bias)
126 | agent_attn = self.attn_drop(agent_attn)
127 | agent_v = agent_attn @ v
128 |
129 | agent_bias1 = nn.functional.interpolate(self.na_bias, size=(self.H_sp, self.W_sp), mode='bilinear')
130 | agent_bias1 = agent_bias1.reshape(1, num_heads, self.agent_num, -1).permute(0, 1, 3, 2).repeat(b, 1, 1, 1)
131 | agent_bias2 = (self.ha_bias + self.wa_bias).reshape(1, num_heads, -1, self.agent_num).repeat(b, 1, 1, 1)
132 | agent_bias = agent_bias1 + agent_bias2
133 | q_attn = self.softmax((q * self.scale) @ agent_tokens.transpose(-2, -1) + agent_bias)
134 | q_attn = self.attn_drop(q_attn)
135 | x = q_attn @ agent_v
136 |
137 | x = x.transpose(1, 2).reshape(b, n, c)
138 | x = x + lepe
139 | x = windows2img(x, self.H_sp, self.W_sp, H, W).view(B, -1, C)
140 |
141 | return x
--------------------------------------------------------------------------------
/注意力模块/CA.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class h_sigmoid(nn.Module):
6 | def __init__(self, inplace=True):
7 | super(h_sigmoid, self).__init__()
8 | self.relu = nn.ReLU6(inplace=inplace)
9 |
10 | def forward(self, x):
11 | return self.relu(x + 3) / 6
12 |
13 |
14 | class h_swish(nn.Module):
15 | def __init__(self, inplace=True):
16 | super(h_swish, self).__init__()
17 | self.sigmoid = h_sigmoid(inplace=inplace)
18 |
19 | def forward(self, x):
20 | return x * self.sigmoid(x)
21 |
22 |
23 | class CoordAtt(nn.Module):
24 | def __init__(self, inp, oup, reduction=32):
25 | super(CoordAtt, self).__init__()
26 | self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
27 | self.pool_w = nn.AdaptiveAvgPool2d((1, None))
28 |
29 | mip = max(8, inp // reduction)
30 |
31 | self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
32 | self.bn1 = nn.BatchNorm2d(mip)
33 | self.act = h_swish()
34 |
35 | self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
36 | self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
37 |
38 | def forward(self, x):
39 | identity = x
40 |
41 | n, c, h, w = x.size()
42 | x_h = self.pool_h(x)
43 | x_w = self.pool_w(x).permute(0, 1, 3, 2)
44 |
45 | y = torch.cat([x_h, x_w], dim=2)
46 | y = self.conv1(y)
47 | y = self.bn1(y)
48 | y = self.act(y)
49 |
50 | x_h, x_w = torch.split(y, [h, w], dim=2)
51 | x_w = x_w.permute(0, 1, 3, 2)
52 |
53 | a_h = self.conv_h(x_h).sigmoid()
54 | a_w = self.conv_w(x_w).sigmoid()
55 |
56 | out = identity * a_w * a_h
57 |
58 | return out
--------------------------------------------------------------------------------
/注意力模块/CBAM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | # Channel Attention Module
6 | class ChannelAttention(nn.Module):
7 | def __init__(self, in_planes, ratio=16):
8 | super(ChannelAttention, self).__init__()
9 | self.avg_pool = nn.AdaptiveAvgPool2d(1) # 全局平均池化
10 | self.max_pool = nn.AdaptiveMaxPool2d(1) # 全局最大池化
11 |
12 | # 使用1x1卷积代替全连接层,减少参数量,ratio用于降维
13 | self.fc = nn.Sequential(
14 | nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
15 | nn.ReLU(),
16 | nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
17 | )
18 | self.sigmoid = nn.Sigmoid() # 激活函数
19 |
20 | def forward(self, x):
21 | avg_out = self.fc(self.avg_pool(x)) # 全局平均池化分支
22 | max_out = self.fc(self.max_pool(x)) # 全局最大池化分支
23 | out = avg_out + max_out # 融合两个池化分支
24 | return self.sigmoid(out) # 返回通道权重
25 |
26 |
27 | # Spatial Attention Module
28 | class SpatialAttention(nn.Module):
29 | def __init__(self, kernel_size=7):
30 | super(SpatialAttention, self).__init__()
31 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
32 | self.sigmoid = nn.Sigmoid()
33 |
34 | def forward(self, x):
35 | avg_out = torch.mean(x, dim=1, keepdim=True) # 计算特征图的平均值
36 | max_out, _ = torch.max(x, dim=1, keepdim=True) # 计算特征图的最大值
37 | x = torch.cat([avg_out, max_out], dim=1) # 沿通道维度拼接
38 | x = self.conv1(x) # 通过卷积层
39 | return self.sigmoid(x) # 返回空间权重
40 |
41 |
42 | # CBAM Module
43 | class CBAM(nn.Module):
44 | def __init__(self, in_planes, ratio=16, kernel_size=7):
45 | super(CBAM, self).__init__()
46 | self.channel_attention = ChannelAttention(in_planes, ratio) # 通道注意力模块
47 | self.spatial_attention = SpatialAttention(kernel_size) # 空间注意力模块
48 |
49 | def forward(self, x):
50 | out = self.channel_attention(x) * x # 先应用通道注意力
51 | out = self.spatial_attention(out) * out # 再应用空间注意力
52 | return out # 返回结果
--------------------------------------------------------------------------------
/注意力模块/CGA.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import itertools
4 |
5 | class Conv2d_BN(torch.nn.Sequential):
6 | def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
7 | groups=1, bn_weight_init=1, resolution=-10000):
8 | super().__init__()
9 | self.add_module('c', torch.nn.Conv2d(
10 | a, b, ks, stride, pad, dilation, groups, bias=False))
11 | self.add_module('bn', torch.nn.BatchNorm2d(b))
12 | torch.nn.init.constant_(self.bn.weight, bn_weight_init)
13 | torch.nn.init.constant_(self.bn.bias, 0)
14 |
15 | @torch.no_grad()
16 | def fuse(self):
17 | c, bn = self._modules.values()
18 | w = bn.weight / (bn.running_var + bn.eps)**0.5
19 | w = c.weight * w[:, None, None, None]
20 | b = bn.bias - bn.running_mean * bn.weight / \
21 | (bn.running_var + bn.eps)**0.5
22 | m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
23 | 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
24 | m.weight.data.copy_(w)
25 | m.bias.data.copy_(b)
26 | return m
27 |
28 |
29 | class CascadedGroupAttention(torch.nn.Module):
30 | r""" Cascaded Group Attention.
31 |
32 | Args:
33 | dim (int): Number of input channels.
34 | key_dim (int): The dimension for query and key.
35 | num_heads (int): Number of attention heads.
36 | attn_ratio (int): Multiplier for the query dim for value dimension.
37 | resolution (int): Input resolution, correspond to the window size.
38 | kernels (List[int]): The kernel size of the dw conv on query.
39 | """
40 | def __init__(self, dim, key_dim, num_heads=8,
41 | attn_ratio=4,
42 | resolution=14,
43 | kernels=[5, 5, 5, 5],):
44 | super().__init__()
45 | self.num_heads = num_heads
46 | self.scale = key_dim ** -0.5
47 | self.key_dim = key_dim
48 | self.d = int(attn_ratio * key_dim)
49 | self.attn_ratio = attn_ratio
50 |
51 | qkvs = []
52 | dws = []
53 | for i in range(num_heads):
54 | qkvs.append(Conv2d_BN(dim // (num_heads), self.key_dim * 2 + self.d, resolution=resolution))
55 | dws.append(Conv2d_BN(self.key_dim, self.key_dim, kernels[i], 1, kernels[i]//2, groups=self.key_dim, resolution=resolution))
56 | self.qkvs = torch.nn.ModuleList(qkvs)
57 | self.dws = torch.nn.ModuleList(dws)
58 | self.proj = torch.nn.Sequential(torch.nn.ReLU(), Conv2d_BN(
59 | self.d * num_heads, dim, bn_weight_init=0, resolution=resolution))
60 |
61 | points = list(itertools.product(range(resolution), range(resolution)))
62 | N = len(points)
63 | attention_offsets = {}
64 | idxs = []
65 | for p1 in points:
66 | for p2 in points:
67 | offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
68 | if offset not in attention_offsets:
69 | attention_offsets[offset] = len(attention_offsets)
70 | idxs.append(attention_offsets[offset])
71 | self.attention_biases = torch.nn.Parameter(
72 | torch.zeros(num_heads, len(attention_offsets)))
73 | self.register_buffer('attention_bias_idxs',
74 | torch.LongTensor(idxs).view(N, N))
75 |
76 | @torch.no_grad()
77 | def train(self, mode=True):
78 | super().train(mode)
79 | if mode and hasattr(self, 'ab'):
80 | del self.ab
81 | else:
82 | self.ab = self.attention_biases[:, self.attention_bias_idxs]
83 |
84 | def forward(self, x): # x (B,C,H,W)
85 | B, C, H, W = x.shape
86 | trainingab = self.attention_biases[:, self.attention_bias_idxs]
87 | feats_in = x.chunk(len(self.qkvs), dim=1)
88 | feats_out = []
89 | feat = feats_in[0]
90 | for i, qkv in enumerate(self.qkvs):
91 | if i > 0: # add the previous output to the input
92 | feat = feat + feats_in[i]
93 | feat = qkv(feat)
94 | q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.d], dim=1) # B, C/h, H, W
95 | q = self.dws[i](q)
96 | q, k, v = q.flatten(2), k.flatten(2), v.flatten(2) # B, C/h, N
97 | attn = (
98 | (q.transpose(-2, -1) @ k) * self.scale
99 | +
100 | (trainingab[i] if self.training else self.ab[i])
101 | )
102 | attn = attn.softmax(dim=-1) # BNN
103 | feat = (v @ attn.transpose(-2, -1)).view(B, self.d, H, W) # BCHW
104 | feats_out.append(feat)
105 | x = self.proj(torch.cat(feats_out, 1))
106 | return x
--------------------------------------------------------------------------------
/注意力模块/CGLU.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class DWConv(nn.Module):
5 | def __init__(self, dim=768):
6 | super(DWConv, self).__init__()
7 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=True, groups=dim)
8 |
9 | def forward(self, x, H, W):
10 | B, N, C = x.shape
11 | x = x.transpose(1, 2).view(B, C, H, W).contiguous()
12 | x = self.dwconv(x)
13 | x = x.flatten(2).transpose(1, 2)
14 |
15 | return x
16 |
17 |
18 | class ConvolutionalGLU(nn.Module):
19 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
20 | super().__init__()
21 | out_features = out_features or in_features
22 | hidden_features = hidden_features or in_features
23 | hidden_features = int(2 * hidden_features / 3)
24 | self.fc1 = nn.Linear(in_features, hidden_features * 2)
25 | self.dwconv = DWConv(hidden_features)
26 | self.act = act_layer()
27 | self.fc2 = nn.Linear(hidden_features, out_features)
28 | self.drop = nn.Dropout(drop)
29 |
30 | def forward(self, x, H, W):
31 | x, v = self.fc1(x).chunk(2, dim=-1)
32 | x = self.act(self.dwconv(x, H, W)) * v
33 | x = self.drop(x)
34 | x = self.fc2(x)
35 | x = self.drop(x)
36 | return x
--------------------------------------------------------------------------------
/注意力模块/DANet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch.nn as nn
3 | import torch
4 | from torch.nn import Module, Sequential, Conv2d, ReLU,AdaptiveMaxPool2d, AdaptiveAvgPool2d, \
5 | NLLLoss, BCELoss, CrossEntropyLoss, AvgPool2d, MaxPool2d, Parameter, Linear, Sigmoid, Softmax, Dropout, Embedding
6 | torch_ver = torch.__version__[:3]
7 |
8 | __all__ = ['PAM_Module', 'CAM_Module']
9 |
10 |
11 | class PAM_Module(Module):
12 | """ Position attention module"""
13 | #Ref from SAGAN
14 | def __init__(self, in_dim):
15 | super(PAM_Module, self).__init__()
16 | self.chanel_in = in_dim
17 |
18 | self.query_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
19 | self.key_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
20 | self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
21 | self.gamma = Parameter(torch.zeros(1))
22 |
23 | self.softmax = Softmax(dim=-1)
24 | def forward(self, x):
25 | """
26 | inputs :
27 | x : input feature maps( B X C X H X W)
28 | returns :
29 | out : attention value + input feature
30 | attention: B X (HxW) X (HxW)
31 | """
32 | m_batchsize, C, height, width = x.size()
33 | proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
34 | proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)
35 | energy = torch.bmm(proj_query, proj_key)
36 | attention = self.softmax(energy)
37 | proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)
38 |
39 | out = torch.bmm(proj_value, attention.permute(0, 2, 1))
40 | out = out.view(m_batchsize, C, height, width)
41 |
42 | out = self.gamma*out + x
43 | return out
44 |
45 |
46 | class CAM_Module(Module):
47 | """ Channel attention module"""
48 | def __init__(self, in_dim):
49 | super(CAM_Module, self).__init__()
50 | self.chanel_in = in_dim
51 |
52 |
53 | self.gamma = Parameter(torch.zeros(1))
54 | self.softmax = Softmax(dim=-1)
55 | def forward(self,x):
56 | """
57 | inputs :
58 | x : input feature maps( B X C X H X W)
59 | returns :
60 | out : attention value + input feature
61 | attention: B X C X C
62 | """
63 | m_batchsize, C, height, width = x.size()
64 | proj_query = x.view(m_batchsize, C, -1)
65 | proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
66 | energy = torch.bmm(proj_query, proj_key)
67 | energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
68 | attention = self.softmax(energy_new)
69 | proj_value = x.view(m_batchsize, C, -1)
70 |
71 | out = torch.bmm(attention, proj_value)
72 | out = out.view(m_batchsize, C, height, width)
73 |
74 | out = self.gamma*out + x
75 | return out
76 |
77 |
78 | class DANet(Module):
79 | """ DANet module """
80 |
81 | def __init__(self, in_channels, out_channels):
82 | super(DANet, self).__init__()
83 |
84 | # Shared convolutional layers
85 | self.conv = Sequential(
86 | Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
87 | nn.BatchNorm2d(in_channels),
88 | ReLU(inplace=True)
89 | )
90 |
91 | # Position and Channel attention modules
92 | self.position_attention = PAM_Module(in_channels)
93 | self.channel_attention = CAM_Module(in_channels)
94 |
95 | # Output convolution
96 | self.output_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
97 |
98 | def forward(self, x):
99 | x = self.conv(x)
100 |
101 | # Position and Channel Attention
102 | pos_out = self.position_attention(x)
103 | chn_out = self.channel_attention(x)
104 |
105 | # Fusion and output
106 | fusion = pos_out + chn_out
107 | out = self.output_conv(fusion)
108 |
109 | return out
110 |
--------------------------------------------------------------------------------
/注意力模块/ECA.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn.parameter import Parameter
4 |
5 | class eca_layer(nn.Module):
6 | """Constructs a ECA module.
7 |
8 | Args:
9 | channel: Number of channels of the input feature map
10 | k_size: Adaptive selection of kernel size
11 | """
12 | def __init__(self, channel, k_size=3):
13 | super(eca_layer, self).__init__()
14 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
15 | self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
16 | self.sigmoid = nn.Sigmoid()
17 |
18 | def forward(self, x):
19 | # feature descriptor on the global spatial information
20 | y = self.avg_pool(x)
21 |
22 | # Two different branches of ECA module
23 | y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
24 |
25 | # Multi-scale information fusion
26 | y = self.sigmoid(y)
27 |
28 | return x * y.expand_as(x)
--------------------------------------------------------------------------------
/注意力模块/FcaNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 |
5 |
6 | def get_freq_indices(method):
7 | assert method in ['top1','top2','top4','top8','top16','top32',
8 | 'bot1','bot2','bot4','bot8','bot16','bot32',
9 | 'low1','low2','low4','low8','low16','low32']
10 | num_freq = int(method[3:])
11 | if 'top' in method:
12 | all_top_indices_x = [0,0,6,0,0,1,1,4,5,1,3,0,0,0,3,2,4,6,3,5,5,2,6,5,5,3,3,4,2,2,6,1]
13 | all_top_indices_y = [0,1,0,5,2,0,2,0,0,6,0,4,6,3,5,2,6,3,3,3,5,1,1,2,4,2,1,1,3,0,5,3]
14 | mapper_x = all_top_indices_x[:num_freq]
15 | mapper_y = all_top_indices_y[:num_freq]
16 | elif 'low' in method:
17 | all_low_indices_x = [0,0,1,1,0,2,2,1,2,0,3,4,0,1,3,0,1,2,3,4,5,0,1,2,3,4,5,6,1,2,3,4]
18 | all_low_indices_y = [0,1,0,1,2,0,1,2,2,3,0,0,4,3,1,5,4,3,2,1,0,6,5,4,3,2,1,0,6,5,4,3]
19 | mapper_x = all_low_indices_x[:num_freq]
20 | mapper_y = all_low_indices_y[:num_freq]
21 | elif 'bot' in method:
22 | all_bot_indices_x = [6,1,3,3,2,4,1,2,4,4,5,1,4,6,2,5,6,1,6,2,2,4,3,3,5,5,6,2,5,5,3,6]
23 | all_bot_indices_y = [6,4,4,6,6,3,1,4,4,5,6,5,2,2,5,1,4,3,5,0,3,1,1,2,4,2,1,1,5,3,3,3]
24 | mapper_x = all_bot_indices_x[:num_freq]
25 | mapper_y = all_bot_indices_y[:num_freq]
26 | else:
27 | raise NotImplementedError
28 | return mapper_x, mapper_y
29 |
30 |
31 | class MultiSpectralDCTLayer(nn.Module):
32 | """
33 | Generate dct filters
34 | """
35 |
36 | def __init__(self, height, width, mapper_x, mapper_y, channel):
37 | super(MultiSpectralDCTLayer, self).__init__()
38 |
39 | assert len(mapper_x) == len(mapper_y)
40 | assert channel % len(mapper_x) == 0
41 |
42 | self.num_freq = len(mapper_x)
43 |
44 | # fixed DCT init
45 | self.register_buffer('weight', self.get_dct_filter(height, width, mapper_x, mapper_y, channel))
46 |
47 | # fixed random init
48 | # self.register_buffer('weight', torch.rand(channel, height, width))
49 |
50 | # learnable DCT init
51 | # self.register_parameter('weight', self.get_dct_filter(height, width, mapper_x, mapper_y, channel))
52 |
53 | # learnable random init
54 | # self.register_parameter('weight', torch.rand(channel, height, width))
55 |
56 | # num_freq, h, w
57 |
58 | def forward(self, x):
59 | assert len(x.shape) == 4, 'x must been 4 dimensions, but got ' + str(len(x.shape))
60 | # n, c, h, w = x.shape
61 |
62 | x = x * self.weight
63 |
64 | result = torch.sum(x, dim=[2, 3])
65 | return result
66 |
67 | def build_filter(self, pos, freq, POS):
68 | result = math.cos(math.pi * freq * (pos + 0.5) / POS) / math.sqrt(POS)
69 | if freq == 0:
70 | return result
71 | else:
72 | return result * math.sqrt(2)
73 |
74 | def get_dct_filter(self, tile_size_x, tile_size_y, mapper_x, mapper_y, channel):
75 | dct_filter = torch.zeros(channel, tile_size_x, tile_size_y)
76 |
77 | c_part = channel // len(mapper_x)
78 |
79 | for i, (u_x, v_y) in enumerate(zip(mapper_x, mapper_y)):
80 | for t_x in range(tile_size_x):
81 | for t_y in range(tile_size_y):
82 | dct_filter[i * c_part: (i + 1) * c_part, t_x, t_y] = self.build_filter(t_x, u_x,
83 | tile_size_x) * self.build_filter(
84 | t_y, v_y, tile_size_y)
85 |
86 | return dct_filter
87 |
88 |
89 |
90 | class MultiSpectralAttentionLayer(torch.nn.Module):
91 | def __init__(self, channel, dct_h, dct_w, reduction = 16, freq_sel_method = 'top16'):
92 | super(MultiSpectralAttentionLayer, self).__init__()
93 | self.reduction = reduction
94 | self.dct_h = dct_h
95 | self.dct_w = dct_w
96 |
97 | mapper_x, mapper_y = get_freq_indices(freq_sel_method)
98 | self.num_split = len(mapper_x)
99 | mapper_x = [temp_x * (dct_h // 7) for temp_x in mapper_x]
100 | mapper_y = [temp_y * (dct_w // 7) for temp_y in mapper_y]
101 | # make the frequencies in different sizes are identical to a 7x7 frequency space
102 | # eg, (2,2) in 14x14 is identical to (1,1) in 7x7
103 |
104 | self.dct_layer = MultiSpectralDCTLayer(dct_h, dct_w, mapper_x, mapper_y, channel)
105 | self.fc = nn.Sequential(
106 | nn.Linear(channel, channel // reduction, bias=False),
107 | nn.ReLU(inplace=True),
108 | nn.Linear(channel // reduction, channel, bias=False),
109 | nn.Sigmoid()
110 | )
111 |
112 | def forward(self, x):
113 | n,c,h,w = x.shape
114 | x_pooled = x
115 | if h != self.dct_h or w != self.dct_w:
116 | x_pooled = torch.nn.functional.adaptive_avg_pool2d(x, (self.dct_h, self.dct_w))
117 | # If you have concerns about one-line-change, don't worry. :)
118 | # In the ImageNet models, this line will never be triggered.
119 | # This is for compatibility in instance segmentation and object detection.
120 | y = self.dct_layer(x_pooled)
121 |
122 | y = self.fc(y).view(n, c, 1, 1)
123 | return x * y.expand_as(x)
124 |
125 |
126 |
127 | class FcaBasicBlock(nn.Module):
128 | expansion = 1
129 |
130 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
131 | base_width=64, dilation=1, norm_layer=None,
132 | *, reduction=16, ):
133 | global _mapper_x, _mapper_y
134 | super(FcaBasicBlock, self).__init__()
135 | # assert fea_h is not None
136 | # assert fea_w is not None
137 | c2wh = dict([(64,56), (128,28), (256,14) ,(512,7)])
138 | self.planes = planes
139 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
140 | self.bn1 = nn.BatchNorm2d(planes)
141 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
142 | self.bn2 = nn.BatchNorm2d(planes)
143 | self.relu = nn.ReLU(inplace=True)
144 | self.att = MultiSpectralAttentionLayer(planes, c2wh[planes], c2wh[planes], reduction=reduction, freq_sel_method = 'top16')
145 | self.downsample = downsample
146 | self.stride = stride
147 |
148 | def forward(self, x):
149 | residual = x
150 |
151 | out = self.conv1(x)
152 | out = self.bn1(out)
153 | out = self.relu(out)
154 |
155 | out = self.conv2(out)
156 | out = self.bn2(out)
157 |
158 | out = self.att(out)
159 |
160 | if self.downsample is not None:
161 | residual = self.downsample(x)
162 |
163 | out += residual
164 | out = self.relu(out)
165 |
166 | return out
--------------------------------------------------------------------------------
/注意力模块/LGAG.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from functools import partial
4 |
5 | import math
6 | from timm.models.layers import trunc_normal_tf_
7 | from timm.models.helpers import named_apply
8 |
9 |
10 | # Other types of layers can go here (e.g., nn.Linear, etc.)
11 | def _init_weights(module, name, scheme=''):
12 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv3d):
13 | if scheme == 'normal':
14 | nn.init.normal_(module.weight, std=.02)
15 | if module.bias is not None:
16 | nn.init.zeros_(module.bias)
17 | elif scheme == 'trunc_normal':
18 | trunc_normal_tf_(module.weight, std=.02)
19 | if module.bias is not None:
20 | nn.init.zeros_(module.bias)
21 | elif scheme == 'xavier_normal':
22 | nn.init.xavier_normal_(module.weight)
23 | if module.bias is not None:
24 | nn.init.zeros_(module.bias)
25 | elif scheme == 'kaiming_normal':
26 | nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
27 | if module.bias is not None:
28 | nn.init.zeros_(module.bias)
29 | else:
30 | # efficientnet like
31 | fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
32 | fan_out //= module.groups
33 | nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out))
34 | if module.bias is not None:
35 | nn.init.zeros_(module.bias)
36 | elif isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d):
37 | nn.init.constant_(module.weight, 1)
38 | nn.init.constant_(module.bias, 0)
39 | elif isinstance(module, nn.LayerNorm):
40 | nn.init.constant_(module.weight, 1)
41 | nn.init.constant_(module.bias, 0)
42 |
43 | def act_layer(act, inplace=False, neg_slope=0.2, n_prelu=1):
44 | # activation layer
45 | act = act.lower()
46 | if act == 'relu':
47 | layer = nn.ReLU(inplace)
48 | elif act == 'relu6':
49 | layer = nn.ReLU6(inplace)
50 | elif act == 'leakyrelu':
51 | layer = nn.LeakyReLU(neg_slope, inplace)
52 | elif act == 'prelu':
53 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
54 | elif act == 'gelu':
55 | layer = nn.GELU()
56 | elif act == 'hswish':
57 | layer = nn.Hardswish(inplace)
58 | else:
59 | raise NotImplementedError('activation layer [%s] is not found' % act)
60 | return layer
61 |
62 |
63 |
64 |
65 | # Large-kernel grouped attention gate (LGAG)
66 | class LGAG(nn.Module):
67 | def __init__(self, F_g, F_l, F_int, kernel_size=3, groups=1, activation='relu'):
68 | super(LGAG, self).__init__()
69 |
70 | if kernel_size == 1:
71 | groups = 1
72 | self.W_g = nn.Sequential(
73 | nn.Conv2d(F_g, F_int, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, groups=groups,
74 | bias=True),
75 | nn.BatchNorm2d(F_int)
76 | )
77 | self.W_x = nn.Sequential(
78 | nn.Conv2d(F_l, F_int, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, groups=groups,
79 | bias=True),
80 | nn.BatchNorm2d(F_int)
81 | )
82 | self.psi = nn.Sequential(
83 | nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
84 | nn.BatchNorm2d(1),
85 | nn.Sigmoid()
86 | )
87 | self.activation = act_layer(activation, inplace=True)
88 |
89 | self.init_weights('normal')
90 |
91 | def init_weights(self, scheme=''):
92 | named_apply(partial(_init_weights, scheme=scheme), self)
93 |
94 | def forward(self, g, x):
95 | g1 = self.W_g(g)
96 | x1 = self.W_x(x)
97 | psi = self.activation(g1 + x1)
98 | psi = self.psi(psi)
99 |
100 | return x * psi
--------------------------------------------------------------------------------
/注意力模块/MAB.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class LayerNorm(nn.Module):
6 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
7 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
8 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs
9 | with shape (batch_size, channels, height, width).
10 | """
11 |
12 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
13 | super().__init__()
14 | self.weight = nn.Parameter(torch.ones(normalized_shape))
15 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
16 | self.eps = eps
17 | self.data_format = data_format
18 | if self.data_format not in ["channels_last", "channels_first"]:
19 | raise NotImplementedError
20 | self.normalized_shape = (normalized_shape,)
21 |
22 | def forward(self, x):
23 | if self.data_format == "channels_last":
24 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
25 | elif self.data_format == "channels_first":
26 | u = x.mean(1, keepdim=True)
27 | s = (x - u).pow(2).mean(1, keepdim=True)
28 | x = (x - u) / torch.sqrt(s + self.eps)
29 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
30 | return x
31 |
32 |
33 | class SGAB(nn.Module):
34 | def __init__(self, n_feats, drop=0.0, k=2, squeeze_factor=15, attn='GLKA'):
35 | super().__init__()
36 | i_feats = n_feats * 2
37 |
38 | self.Conv1 = nn.Conv2d(n_feats, i_feats, 1, 1, 0)
39 | self.DWConv1 = nn.Conv2d(n_feats, n_feats, 7, 1, 7 // 2, groups=n_feats)
40 | self.Conv2 = nn.Conv2d(n_feats, n_feats, 1, 1, 0)
41 |
42 | self.norm = LayerNorm(n_feats, data_format='channels_first')
43 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)
44 |
45 | def forward(self, x):
46 | shortcut = x.clone()
47 |
48 | # Ghost Expand
49 | x = self.Conv1(self.norm(x))
50 | a, x = torch.chunk(x, 2, dim=1)
51 | x = x * self.DWConv1(a)
52 | x = self.Conv2(x)
53 |
54 | return x * self.scale + shortcut
55 |
56 |
57 | class GroupGLKA(nn.Module):
58 | def __init__(self, n_feats, k=2, squeeze_factor=15):
59 | super().__init__()
60 | i_feats = 2 * n_feats
61 |
62 | self.n_feats = n_feats
63 | self.i_feats = i_feats
64 |
65 | self.norm = LayerNorm(n_feats, data_format='channels_first')
66 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)
67 |
68 | # Multiscale Large Kernel Attention
69 | self.LKA7 = nn.Sequential(
70 | nn.Conv2d(n_feats // 3, n_feats // 3, 7, 1, 7 // 2, groups=n_feats // 3),
71 | nn.Conv2d(n_feats // 3, n_feats // 3, 9, stride=1, padding=(9 // 2) * 4, groups=n_feats // 3, dilation=4),
72 | nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))
73 | self.LKA5 = nn.Sequential(
74 | nn.Conv2d(n_feats // 3, n_feats // 3, 5, 1, 5 // 2, groups=n_feats // 3),
75 | nn.Conv2d(n_feats // 3, n_feats // 3, 7, stride=1, padding=(7 // 2) * 3, groups=n_feats // 3, dilation=3),
76 | nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))
77 | self.LKA3 = nn.Sequential(
78 | nn.Conv2d(n_feats // 3, n_feats // 3, 3, 1, 1, groups=n_feats // 3),
79 | nn.Conv2d(n_feats // 3, n_feats // 3, 5, stride=1, padding=(5 // 2) * 2, groups=n_feats // 3, dilation=2),
80 | nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))
81 |
82 | self.X3 = nn.Conv2d(n_feats // 3, n_feats // 3, 3, 1, 1, groups=n_feats // 3)
83 | self.X5 = nn.Conv2d(n_feats // 3, n_feats // 3, 5, 1, 5 // 2, groups=n_feats // 3)
84 | self.X7 = nn.Conv2d(n_feats // 3, n_feats // 3, 7, 1, 7 // 2, groups=n_feats // 3)
85 |
86 | self.proj_first = nn.Sequential(
87 | nn.Conv2d(n_feats, i_feats, 1, 1, 0))
88 |
89 | self.proj_last = nn.Sequential(
90 | nn.Conv2d(n_feats, n_feats, 1, 1, 0))
91 |
92 | def forward(self, x, pre_attn=None, RAA=None):
93 | shortcut = x.clone()
94 |
95 | x = self.norm(x)
96 |
97 | x = self.proj_first(x)
98 |
99 | a, x = torch.chunk(x, 2, dim=1)
100 |
101 | a_1, a_2, a_3 = torch.chunk(a, 3, dim=1)
102 |
103 | a = torch.cat([self.LKA3(a_1) * self.X3(a_1), self.LKA5(a_2) * self.X5(a_2), self.LKA7(a_3) * self.X7(a_3)],
104 | dim=1)
105 |
106 | x = self.proj_last(x * a) * self.scale + shortcut
107 |
108 | return x
109 |
110 |
111 | class MAB(nn.Module):
112 | def __init__(
113 | self, n_feats):
114 | super().__init__()
115 |
116 | self.LKA = GroupGLKA(n_feats)
117 |
118 | self.LFE = SGAB(n_feats)
119 |
120 | def forward(self, x, pre_attn=None, RAA=None):
121 | # large kernel attention
122 | x = self.LKA(x)
123 |
124 | # local feature extraction
125 | x = self.LFE(x)
126 |
127 | return x
--------------------------------------------------------------------------------
/注意力模块/MCA.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import math
4 |
5 | __all__ = ['MCALayer', 'MCAGate']
6 |
7 |
8 | class StdPool(nn.Module):
9 | def __init__(self):
10 | super(StdPool, self).__init__()
11 |
12 | def forward(self, x):
13 | b, c, _, _ = x.size()
14 |
15 | std = x.view(b, c, -1).std(dim=2, keepdim=True)
16 | std = std.reshape(b, c, 1, 1)
17 |
18 | return std
19 |
20 |
21 | class MCAGate(nn.Module):
22 | def __init__(self, k_size, pool_types=['avg', 'std']):
23 | """Constructs a MCAGate module.
24 | Args:
25 | k_size: kernel size
26 | pool_types: pooling type. 'avg': average pooling, 'max': max pooling, 'std': standard deviation pooling.
27 | """
28 | super(MCAGate, self).__init__()
29 |
30 | self.pools = nn.ModuleList([])
31 | for pool_type in pool_types:
32 | if pool_type == 'avg':
33 | self.pools.append(nn.AdaptiveAvgPool2d(1))
34 | elif pool_type == 'max':
35 | self.pools.append(nn.AdaptiveMaxPool2d(1))
36 | elif pool_type == 'std':
37 | self.pools.append(StdPool())
38 | else:
39 | raise NotImplementedError
40 |
41 | self.conv = nn.Conv2d(1, 1, kernel_size=(1, k_size), stride=1, padding=(0, (k_size - 1) // 2), bias=False)
42 | self.sigmoid = nn.Sigmoid()
43 |
44 | self.weight = nn.Parameter(torch.rand(2))
45 |
46 | def forward(self, x):
47 | feats = [pool(x) for pool in self.pools]
48 |
49 | if len(feats) == 1:
50 | out = feats[0]
51 | elif len(feats) == 2:
52 | weight = torch.sigmoid(self.weight)
53 | out = 1 / 2 * (feats[0] + feats[1]) + weight[0] * feats[0] + weight[1] * feats[1]
54 | else:
55 | assert False, "Feature Extraction Exception!"
56 |
57 | out = out.permute(0, 3, 2, 1).contiguous()
58 | out = self.conv(out)
59 | out = out.permute(0, 3, 2, 1).contiguous()
60 |
61 | out = self.sigmoid(out)
62 | out = out.expand_as(x)
63 |
64 | return x * out
65 |
66 |
67 | class MCALayer(nn.Module):
68 | def __init__(self, inp, no_spatial=False):
69 | """Constructs a MCA module.
70 | Args:
71 | inp: Number of channels of the input feature maps
72 | no_spatial: whether to build channel dimension interactions
73 | """
74 | super(MCALayer, self).__init__()
75 |
76 | lambd = 1.5
77 | gamma = 1
78 | temp = round(abs((math.log2(inp) - gamma) / lambd))
79 | kernel = temp if temp % 2 else temp - 1
80 |
81 | self.h_cw = MCAGate(3)
82 | self.w_hc = MCAGate(3)
83 | self.no_spatial = no_spatial
84 | if not no_spatial:
85 | self.c_hw = MCAGate(kernel)
86 |
87 | def forward(self, x):
88 | x_h = x.permute(0, 2, 1, 3).contiguous()
89 | x_h = self.h_cw(x_h)
90 | x_h = x_h.permute(0, 2, 1, 3).contiguous()
91 |
92 | x_w = x.permute(0, 3, 2, 1).contiguous()
93 | x_w = self.w_hc(x_w)
94 | x_w = x_w.permute(0, 3, 2, 1).contiguous()
95 |
96 | if not self.no_spatial:
97 | x_c = self.c_hw(x)
98 | x_out = 1 / 3 * (x_c + x_h + x_w)
99 | else:
100 | x_out = 1 / 2 * (x_h + x_w)
101 |
102 | return x_out
--------------------------------------------------------------------------------
/注意力模块/MCPA.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | class MCrossAttention(nn.Module):
5 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.1, proj_drop=0.1):
6 | super().__init__()
7 | self.num_heads = num_heads
8 | head_dim = dim // num_heads
9 | self.scale = qk_scale or head_dim ** -0.5
10 |
11 | self.wq = nn.Linear(head_dim, dim , bias=qkv_bias)
12 | self.wk = nn.Linear(head_dim, dim , bias=qkv_bias)
13 | self.wv = nn.Linear(head_dim, dim , bias=qkv_bias)
14 | # self.attn_drop = nn.Dropout(attn_drop)
15 | self.proj = nn.Linear(dim * num_heads, dim)
16 | self.proj_drop = nn.Dropout(proj_drop)
17 |
18 | def forward(self, x):
19 |
20 | B, N, C = x.shape
21 | q = self.wq(x[:, 0:1, ...].reshape(B, 1, self.num_heads, C // self.num_heads)).permute(0, 2, 1, 3) # B1C -> B1H(C/H) -> BH1(C/H)
22 | k = self.wk(x.reshape(B, N, self.num_heads, C // self.num_heads)).permute(0, 2, 1, 3) # BNC -> BNH(C/H) -> BHN(C/H)
23 | v = self.wv(x.reshape(B, N, self.num_heads, C // self.num_heads)).permute(0, 2, 1, 3) # BNC -> BNH(C/H) -> BHN(C/H)
24 | attn = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
25 | # attn = (q @ k.transpose(-2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N
26 | attn = attn.softmax(dim=-1)
27 | # attn = self.attn_drop(attn)
28 | x = torch.einsum('bhij,bhjd->bhid', attn, v).transpose(1, 2)
29 | # x = (attn @ v).transpose(1, 2)
30 | x = x.reshape(B, 1, C * self.num_heads) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C
31 | x = self.proj(x)
32 | x = self.proj_drop(x)
33 | return x
--------------------------------------------------------------------------------
/注意力模块/MEGA.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 |
5 |
6 | def gauss_kernel(channels=3, cuda=True):
7 | kernel = torch.tensor([[1., 4., 6., 4., 1],
8 | [4., 16., 24., 16., 4.],
9 | [6., 24., 36., 24., 6.],
10 | [4., 16., 24., 16., 4.],
11 | [1., 4., 6., 4., 1.]])
12 | kernel /= 256.
13 | kernel = kernel.repeat(channels, 1, 1, 1)
14 | if cuda:
15 | kernel = kernel.cuda()
16 | return kernel
17 |
18 |
19 | def downsample(x):
20 | return x[:, :, ::2, ::2]
21 |
22 |
23 | def conv_gauss(img, kernel):
24 | img = F.pad(img, (2, 2, 2, 2), mode='reflect')
25 | out = F.conv2d(img, kernel, groups=img.shape[1])
26 | return out
27 |
28 |
29 | def upsample(x, channels):
30 | cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3], device=x.device)], dim=3)
31 | cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3])
32 | cc = cc.permute(0, 1, 3, 2)
33 | cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2, device=x.device)], dim=3)
34 | cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2)
35 | x_up = cc.permute(0, 1, 3, 2)
36 | return conv_gauss(x_up, 4 * gauss_kernel(channels))
37 |
38 |
39 | def make_laplace(img, channels):
40 | filtered = conv_gauss(img, gauss_kernel(channels))
41 | down = downsample(filtered)
42 | up = upsample(down, channels)
43 | if up.shape[2] != img.shape[2] or up.shape[3] != img.shape[3]:
44 | up = nn.functional.interpolate(up, size=(img.shape[2], img.shape[3]))
45 | diff = img - up
46 | return diff
47 |
48 |
49 | def make_laplace_pyramid(img, level, channels):
50 | current = img
51 | pyr = []
52 | for _ in range(level):
53 | filtered = conv_gauss(current, gauss_kernel(channels))
54 | down = downsample(filtered)
55 | up = upsample(down, channels)
56 | if up.shape[2] != current.shape[2] or up.shape[3] != current.shape[3]:
57 | up = nn.functional.interpolate(up, size=(current.shape[2], current.shape[3]))
58 | diff = current - up
59 | pyr.append(diff)
60 | current = down
61 | pyr.append(current)
62 | return pyr
63 |
64 |
65 | class ChannelGate(nn.Module):
66 | def __init__(self, gate_channels, reduction_ratio=16):
67 | super(ChannelGate, self).__init__()
68 | self.gate_channels = gate_channels
69 | self.mlp = nn.Sequential(
70 | nn.Flatten(),
71 | nn.Linear(gate_channels, gate_channels // reduction_ratio),
72 | nn.ReLU(),
73 | nn.Linear(gate_channels // reduction_ratio, gate_channels)
74 | )
75 |
76 | def forward(self, x):
77 | avg_out = self.mlp(F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))))
78 | max_out = self.mlp(F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))))
79 | channel_att_sum = avg_out + max_out
80 |
81 | scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
82 | return x * scale
83 |
84 |
85 | class SpatialGate(nn.Module):
86 | def __init__(self):
87 | super(SpatialGate, self).__init__()
88 | kernel_size = 7
89 | self.spatial = nn.Conv2d(2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2)
90 |
91 | def forward(self, x):
92 | x_compress = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
93 | x_out = self.spatial(x_compress)
94 | scale = torch.sigmoid(x_out) # broadcasting
95 | return x * scale
96 |
97 |
98 | class CBAM(nn.Module):
99 | def __init__(self, gate_channels, reduction_ratio=16):
100 | super(CBAM, self).__init__()
101 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio)
102 | self.SpatialGate = SpatialGate()
103 |
104 | def forward(self, x):
105 | x_out = self.ChannelGate(x)
106 | x_out = self.SpatialGate(x_out)
107 | return x_out
108 |
109 |
110 | # Edge-Guided Attention Module
111 | class EGA(nn.Module):
112 | def __init__(self, in_channels):
113 | super(EGA, self).__init__()
114 |
115 | self.fusion_conv = nn.Sequential(
116 | nn.Conv2d(in_channels * 3, in_channels, 3, 1, 1),
117 | nn.BatchNorm2d(in_channels),
118 | nn.ReLU(inplace=True))
119 |
120 | self.attention = nn.Sequential(
121 | nn.Conv2d(in_channels, 1, 3, 1, 1),
122 | nn.BatchNorm2d(1),
123 | nn.Sigmoid())
124 |
125 | self.cbam = CBAM(in_channels)
126 |
127 | def forward(self, edge_feature, x, pred):
128 | residual = x
129 | xsize = x.size()[2:]
130 | pred = torch.sigmoid(pred)
131 |
132 | # reverse attention
133 | background_att = 1 - pred
134 | background_x = x * background_att
135 |
136 | # boudary attention
137 | edge_pred = make_laplace(pred, 1)
138 | pred_feature = x * edge_pred
139 |
140 | # high-frequency feature
141 | edge_input = F.interpolate(edge_feature, size=xsize, mode='bilinear', align_corners=True)
142 | input_feature = x * edge_input
143 |
144 | fusion_feature = torch.cat([background_x, pred_feature, input_feature], dim=1)
145 | fusion_feature = self.fusion_conv(fusion_feature)
146 |
147 | attention_map = self.attention(fusion_feature)
148 | fusion_feature = fusion_feature * attention_map
149 |
150 | out = fusion_feature + residual
151 | out = self.cbam(out)
152 | return out
--------------------------------------------------------------------------------
/注意力模块/MLKA.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class LayerNorm(nn.Module):
7 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
8 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
9 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs
10 | with shape (batch_size, channels, height, width).
11 | """
12 |
13 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
14 | super().__init__()
15 | self.weight = nn.Parameter(torch.ones(normalized_shape))
16 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
17 | self.eps = eps
18 | self.data_format = data_format
19 | if self.data_format not in ["channels_last", "channels_first"]:
20 | raise NotImplementedError
21 | self.normalized_shape = (normalized_shape,)
22 |
23 | def forward(self, x):
24 | if self.data_format == "channels_last":
25 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
26 | elif self.data_format == "channels_first":
27 | u = x.mean(1, keepdim=True)
28 | s = (x - u).pow(2).mean(1, keepdim=True)
29 | x = (x - u) / torch.sqrt(s + self.eps)
30 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
31 | return x
32 |
33 |
34 | class MLKA_Ablation(nn.Module):
35 | def __init__(self, n_feats, k=2, squeeze_factor=15):
36 | super().__init__()
37 | i_feats = 2 * n_feats
38 |
39 | self.n_feats = n_feats
40 | self.i_feats = i_feats
41 |
42 | self.norm = LayerNorm(n_feats, data_format='channels_first')
43 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)
44 |
45 | k = 2
46 |
47 | # Multiscale Large Kernel Attention
48 | self.LKA7 = nn.Sequential(
49 | nn.Conv2d(n_feats // k, n_feats // k, 7, 1, 7 // 2, groups=n_feats // k),
50 | nn.Conv2d(n_feats // k, n_feats // k, 9, stride=1, padding=(9 // 2) * 4, groups=n_feats // k, dilation=4),
51 | nn.Conv2d(n_feats // k, n_feats // k, 1, 1, 0))
52 | self.LKA5 = nn.Sequential(
53 | nn.Conv2d(n_feats // k, n_feats // k, 5, 1, 5 // 2, groups=n_feats // k),
54 | nn.Conv2d(n_feats // k, n_feats // k, 7, stride=1, padding=(7 // 2) * 3, groups=n_feats // k, dilation=3),
55 | nn.Conv2d(n_feats // k, n_feats // k, 1, 1, 0))
56 | '''self.LKA3 = nn.Sequential(
57 | nn.Conv2d(n_feats//k, n_feats//k, 3, 1, 1, groups= n_feats//k),
58 | nn.Conv2d(n_feats//k, n_feats//k, 5, stride=1, padding=(5//2)*2, groups=n_feats//k, dilation=2),
59 | nn.Conv2d(n_feats//k, n_feats//k, 1, 1, 0))'''
60 |
61 | # self.X3 = nn.Conv2d(n_feats//k, n_feats//k, 3, 1, 1, groups= n_feats//k)
62 | self.X5 = nn.Conv2d(n_feats // k, n_feats // k, 5, 1, 5 // 2, groups=n_feats // k)
63 | self.X7 = nn.Conv2d(n_feats // k, n_feats // k, 7, 1, 7 // 2, groups=n_feats // k)
64 |
65 | self.proj_first = nn.Sequential(
66 | nn.Conv2d(n_feats, i_feats, 1, 1, 0))
67 |
68 | self.proj_last = nn.Sequential(
69 | nn.Conv2d(n_feats, n_feats, 1, 1, 0))
70 |
71 | def forward(self, x, pre_attn=None, RAA=None):
72 | shortcut = x.clone()
73 |
74 | x = self.norm(x)
75 |
76 | x = self.proj_first(x)
77 |
78 | a, x = torch.chunk(x, 2, dim=1)
79 |
80 | # u_1, u_2, u_3= torch.chunk(u, 3, dim=1)
81 | a_1, a_2 = torch.chunk(a, 2, dim=1)
82 |
83 | a = torch.cat([self.LKA7(a_1) * self.X7(a_1), self.LKA5(a_2) * self.X5(a_2)], dim=1)
84 |
85 | x = self.proj_last(x * a) * self.scale + shortcut
86 |
87 | return x
--------------------------------------------------------------------------------
/注意力模块/MSPA.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch
4 |
5 |
6 | def conv3x3(in_planes, out_planes, stride=1):
7 | """3x3 convolution with padding"""
8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
9 |
10 |
11 | def conv1x1(in_planes, out_planes, stride=1):
12 | """1x1 convolution"""
13 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
14 |
15 |
16 | def convdilated(in_planes, out_planes, kSize=3, stride=1, dilation=1):
17 | """3x3 convolution with dilation"""
18 | padding = int((kSize - 1) / 2) * dilation
19 | return nn.Conv2d(in_planes, out_planes, kernel_size=kSize, stride=stride, padding=padding,
20 | dilation=dilation, bias=False)
21 |
22 |
23 | class SPRModule(nn.Module):
24 | def __init__(self, channels, reduction=16):
25 | super(SPRModule, self).__init__()
26 |
27 | self.avg_pool1 = nn.AdaptiveAvgPool2d(1)
28 | self.avg_pool2 = nn.AdaptiveAvgPool2d(2)
29 |
30 | self.fc1 = nn.Conv2d(channels * 5, channels//reduction, kernel_size=1, padding=0)
31 | self.relu = nn.ReLU(inplace=True)
32 | self.fc2 = nn.Conv2d(channels//reduction, channels, kernel_size=1, padding=0)
33 | self.sigmoid = nn.Sigmoid()
34 |
35 | def forward(self, x):
36 |
37 | out1 = self.avg_pool1(x).view(x.size(0), -1, 1, 1)
38 | out2 = self.avg_pool2(x).view(x.size(0), -1, 1, 1)
39 | out = torch.cat((out1, out2), 1)
40 |
41 | out = self.fc1(out)
42 | out = self.relu(out)
43 | out = self.fc2(out)
44 | weight = self.sigmoid(out)
45 |
46 | return weight
47 |
48 |
49 | class MSAModule(nn.Module):
50 | def __init__(self, inplanes, scale=3, stride=1, stype='normal'):
51 | """ Constructor
52 | Args:
53 | inplanes: input channel dimensionality.
54 | scale: number of scale.
55 | stride: conv stride.
56 | stype: 'normal': normal set. 'stage': first block of a new stage.
57 | """
58 | super(MSAModule, self).__init__()
59 |
60 | self.width = inplanes
61 | self.nums = scale
62 | self.stride = stride
63 | assert stype in ['stage', 'normal'], 'One of these is suppported (stage or normal)'
64 | self.stype = stype
65 |
66 | self.convs = nn.ModuleList([])
67 | self.bns = nn.ModuleList([])
68 |
69 | for i in range(self.nums):
70 | if self.stype == 'stage' and self.stride != 1:
71 | self.convs.append(convdilated(self.width, self.width, stride=stride, dilation=int(i + 1)))
72 | else:
73 | self.convs.append(conv3x3(self.width, self.width, stride))
74 |
75 | self.bns.append(nn.BatchNorm2d(self.width))
76 |
77 | self.attention = SPRModule(self.width)
78 |
79 | self.softmax = nn.Softmax(dim=1)
80 |
81 | def forward(self, x):
82 | batch_size = x.shape[0]
83 |
84 | spx = torch.split(x, self.width, 1)
85 | for i in range(self.nums):
86 | if i == 0 or (self.stype == 'stage' and self.stride != 1):
87 | sp = spx[i]
88 | else:
89 | sp = sp + spx[i]
90 | sp = self.convs[i](sp)
91 | sp = self.bns[i](sp)
92 |
93 | if i == 0:
94 | out = sp
95 | else:
96 | out = torch.cat((out, sp), 1)
97 |
98 | feats = out
99 | feats = feats.view(batch_size, self.nums, self.width, feats.shape[2], feats.shape[3])
100 |
101 | sp_inp = torch.split(out, self.width, 1)
102 |
103 | attn_weight = []
104 | for inp in sp_inp:
105 | attn_weight.append(self.attention(inp))
106 |
107 | attn_weight = torch.cat(attn_weight, dim=1)
108 | attn_vectors = attn_weight.view(batch_size, self.nums, self.width, 1, 1)
109 | attn_vectors = self.softmax(attn_vectors)
110 | feats_weight = feats * attn_vectors
111 |
112 | for i in range(self.nums):
113 | x_attn_weight = feats_weight[:, i, :, :, :]
114 | if i == 0:
115 | out = x_attn_weight
116 | else:
117 | out = torch.cat((out, x_attn_weight), 1)
118 |
119 | return out
120 |
121 |
122 | class MSPABlock(nn.Module):
123 | expansion = 4
124 |
125 | def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=30, scale=3,
126 | norm_layer=None, stype='normal'):
127 | """ Constructor
128 | Args:
129 | inplanes: input channel dimensionality.
130 | planes: output channel dimensionality.
131 | stride: conv stride.
132 | downsample: None when stride = 1.
133 | baseWidth: basic width of conv3x3.
134 | scale: number of scale.
135 | norm_layer: regularization layer.
136 | stype: 'normal': normal set. 'stage': first block of a new stage.
137 | """
138 | super(MSPABlock, self).__init__()
139 | if norm_layer is None:
140 | norm_layer = nn.BatchNorm2d
141 | width = int(math.floor(planes * (baseWidth / 64.0)))
142 |
143 | self.conv1 = conv1x1(inplanes, width * scale)
144 | self.bn1 = norm_layer(width * scale)
145 |
146 | self.conv2 = MSAModule(width, scale=scale, stride=stride, stype=stype)
147 | self.bn2 = norm_layer(width * scale)
148 |
149 | self.conv3 = conv1x1(width * scale, planes * self.expansion)
150 | self.bn3 = norm_layer(planes * self.expansion)
151 | self.relu = nn.ReLU(inplace=True)
152 |
153 | self.downsample = downsample
154 |
155 | def forward(self, x):
156 | identity = x
157 |
158 | out = self.conv1(x)
159 | out = self.bn1(out)
160 | out = self.relu(out)
161 |
162 | out = self.conv2(out)
163 | out = self.bn2(out)
164 | out = self.relu(out)
165 |
166 | out = self.conv3(out)
167 | out = self.bn3(out)
168 |
169 | if self.downsample is not None:
170 | identity = self.downsample(x)
171 |
172 | out += identity
173 | out = self.relu(out)
174 |
175 | return out
176 |
177 |
178 | class MSPANet(nn.Module):
179 | def __init__(self, block, layers, num_classes=1000, baseWidth=30, scale=3, norm_layer=None):
180 | super(MSPANet, self).__init__()
181 | if norm_layer is None:
182 | norm_layer = nn.BatchNorm2d
183 | self._norm_layer = norm_layer
184 |
185 | self.inplanes = 64
186 | self.baseWidth = baseWidth
187 | self.scale = scale
188 |
189 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
190 | self.bn1 = norm_layer(self.inplanes)
191 | self.relu = nn.ReLU(inplace=True)
192 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
193 |
194 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
195 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
196 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
197 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
198 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
199 |
200 | self.fc = nn.Linear(512 * block.expansion, num_classes)
201 |
202 | # weight initialization
203 | self._initialize_weights()
204 |
205 | def _initialize_weights(self):
206 | for m in self.modules():
207 | if isinstance(m, nn.Conv2d):
208 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
209 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
210 | nn.init.constant_(m.weight, 1)
211 | nn.init.constant_(m.bias, 0)
212 |
213 | def _make_layer(self, block, planes, blocks, stride=1):
214 | norm_layer = self._norm_layer
215 | downsample = None
216 | if stride != 1 or self.inplanes != planes * block.expansion:
217 | downsample = nn.Sequential(
218 | conv1x1(self.inplanes, planes * block.expansion, stride),
219 | norm_layer(planes * block.expansion),
220 | )
221 |
222 | layers = []
223 | layers.append(block(self.inplanes, planes, stride, downsample=downsample,
224 | baseWidth=self.baseWidth, scale=self.scale, norm_layer=norm_layer, stype='stage'))
225 | self.inplanes = planes * block.expansion
226 | for i in range(1, blocks):
227 | layers.append(block(self.inplanes, planes, baseWidth=self.baseWidth, scale=self.scale,
228 | norm_layer=norm_layer))
229 |
230 | return nn.Sequential(*layers)
231 |
232 | def forward(self, x):
233 | x = self.conv1(x)
234 | x = self.bn1(x)
235 | x = self.relu(x)
236 | x = self.maxpool(x)
237 |
238 | x = self.layer1(x)
239 | x = self.layer2(x)
240 | x = self.layer3(x)
241 | x = self.layer4(x)
242 |
243 | x = self.avgpool(x)
244 | x = x.view(x.size(0), -1)
245 |
246 | x = self.fc(x)
247 |
248 | return x
--------------------------------------------------------------------------------
/注意力模块/NonLocal.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class NonLocalBlock(nn.Module):
6 | def __init__(self, in_channels, inter_channels=None, dimension=2, sub_sample=False, bn_layer=True):
7 | """
8 | :param in_channels: Number of channels in the input feature map.
9 | :param inter_channels: Number of channels in the intermediate feature map.
10 | If None, it will be set to in_channels // 2.
11 | :param dimension: The spatial dimensions of the input.
12 | 2 for 2D (image), 3 for 3D (video).
13 | :param sub_sample: Whether to apply subsampling to reduce computation.
14 | :param bn_layer: Whether to add BatchNorm layer after the Non-Local operation.
15 | """
16 | super(NonLocalBlock, self).__init__()
17 |
18 | assert dimension in [1, 2, 3], "Only 1D, 2D, and 3D inputs are supported."
19 | self.dimension = dimension
20 | self.sub_sample = sub_sample
21 |
22 | # Determine the dimension for 1D, 2D, or 3D
23 | if dimension == 3:
24 | self.pool = nn.MaxPool3d(kernel_size=(1, 2, 2))
25 | elif dimension == 2:
26 | self.pool = nn.MaxPool2d(kernel_size=(2, 2))
27 | else:
28 | self.pool = nn.MaxPool1d(kernel_size=2)
29 |
30 | # Define the intermediate channels size
31 | if inter_channels is None:
32 | inter_channels = in_channels // 2
33 | if inter_channels == 0:
34 | inter_channels = 1
35 |
36 | self.g = nn.Conv2d(in_channels=in_channels, out_channels=inter_channels, kernel_size=1, stride=1, padding=0)
37 |
38 | # theta and phi will be used to compute similarity (dot product)
39 | self.theta = nn.Conv2d(in_channels=in_channels, out_channels=inter_channels, kernel_size=1, stride=1, padding=0)
40 | self.phi = nn.Conv2d(in_channels=in_channels, out_channels=inter_channels, kernel_size=1, stride=1, padding=0)
41 |
42 | # W will transform the output of NonLocal block back to the original dimension
43 | self.W = nn.Conv2d(in_channels=inter_channels, out_channels=in_channels, kernel_size=1, stride=1, padding=0)
44 |
45 | # Optionally, include a BatchNorm layer after W
46 | if bn_layer:
47 | self.W = nn.Sequential(
48 | self.W,
49 | nn.BatchNorm2d(in_channels)
50 | )
51 |
52 | # Optional subsampling
53 | if sub_sample:
54 | self.g = nn.Sequential(self.g, self.pool)
55 | self.phi = nn.Sequential(self.phi, self.pool)
56 |
57 | def forward(self, x):
58 | """
59 | :param x: Input feature map of shape (N, C, H, W) where
60 | N is the batch size,
61 | C is the number of channels,
62 | H is the height,
63 | W is the width.
64 | :return: Output feature map of the same shape as input.
65 | """
66 | batch_size, C, H, W = x.size()
67 |
68 | # Apply transformations: theta, phi, and g
69 | g_x = self.g(x).view(batch_size, -1, H * W) # g(x) shape: (N, C', H*W)
70 | g_x = g_x.permute(0, 2, 1) # g(x) shape: (N, H*W, C')
71 |
72 | theta_x = self.theta(x).view(batch_size, -1, H * W) # theta(x) shape: (N, C', H*W)
73 | phi_x = self.phi(x).view(batch_size, -1, H * W) # phi(x) shape: (N, C', H*W)
74 | phi_x = phi_x.permute(0, 2, 1) # phi(x) shape: (N, H*W, C')
75 |
76 | # Compute similarity: theta_x * phi_x^T (matrix multiplication)
77 | f = torch.matmul(theta_x, phi_x) # shape: (N, C', C')
78 | f_div_C = F.softmax(f, dim=-1) # Apply softmax to normalize similarity
79 |
80 | # Apply attention map to g(x)
81 | y = torch.matmul(f_div_C, g_x) # shape: (N, C', H*W)
82 | y = y.permute(0, 2, 1).contiguous() # Reshape: (N, H*W, C')
83 | y = y.view(batch_size, C // 2, H, W) # Reshape: (N, C', H, W)
84 |
85 | # Transform the output back to original input dimension with W
86 | W_y = self.W(y)
87 |
88 | # Residual connection: adding input x to the output
89 | z = W_y + x
90 |
91 | return z
92 |
93 |
94 | """
95 | Example of use, inserted into the middle layer of ResNet
96 |
97 | from torchvision.models import resnet50
98 |
99 | class ResNetWithNonLocal(nn.Module):
100 | def __init__(self):
101 | super(ResNetWithNonLocal, self).__init__()
102 | self.resnet = resnet50(pretrained=True)
103 | self.nonlocal_block = NonLocalBlock(in_channels=256)
104 |
105 | def forward(self, x):
106 | x = self.resnet.layer1(x)
107 | x = self.nonlocal_block(x)
108 | x = self.resnet.layer2(x)
109 | return x
110 | """
--------------------------------------------------------------------------------
/注意力模块/SA-Net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.parameter import Parameter
4 |
5 |
6 | class sa_layer(nn.Module):
7 | """Constructs a Channel Spatial Group module.
8 |
9 | Args:
10 | k_size: Adaptive selection of kernel size
11 | """
12 |
13 | def __init__(self, channel, groups=64):
14 | super(sa_layer, self).__init__()
15 | self.groups = groups
16 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
17 | self.cweight = Parameter(torch.zeros(1, channel // (2 * groups), 1, 1))
18 | self.cbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))
19 | self.sweight = Parameter(torch.zeros(1, channel // (2 * groups), 1, 1))
20 | self.sbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))
21 |
22 | self.sigmoid = nn.Sigmoid()
23 | self.gn = nn.GroupNorm(channel // (2 * groups), channel // (2 * groups))
24 |
25 | @staticmethod
26 | def channel_shuffle(x, groups):
27 | b, c, h, w = x.shape
28 |
29 | x = x.reshape(b, groups, -1, h, w)
30 | x = x.permute(0, 2, 1, 3, 4)
31 |
32 | # flatten
33 | x = x.reshape(b, -1, h, w)
34 |
35 | return x
36 |
37 | def forward(self, x):
38 | b, c, h, w = x.shape
39 |
40 | x = x.reshape(b * self.groups, -1, h, w)
41 | x_0, x_1 = x.chunk(2, dim=1)
42 |
43 | # channel attention
44 | xn = self.avg_pool(x_0)
45 | xn = self.cweight * xn + self.cbias
46 | xn = x_0 * self.sigmoid(xn)
47 |
48 | # spatial attention
49 | xs = self.gn(x_1)
50 | xs = self.sweight * xs + self.sbias
51 | xs = x_1 * self.sigmoid(xs)
52 |
53 | # concatenate along channel axis
54 | out = torch.cat([xn, xs], dim=1)
55 | out = out.reshape(b, -1, h, w)
56 |
57 | out = self.channel_shuffle(out, 2)
58 | return out
--------------------------------------------------------------------------------
/注意力模块/SENet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class SEBlock(nn.Module):
6 | def __init__(self, in_channels, reduction=16):
7 | super(SEBlock, self).__init__()
8 |
9 | # 全局平均池化
10 | self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
11 |
12 | # 两个全连接层
13 | self.fc1 = nn.Linear(in_channels, in_channels // reduction, bias=False)
14 | self.relu = nn.ReLU(inplace=True)
15 | self.fc2 = nn.Linear(in_channels // reduction, in_channels, bias=False)
16 |
17 | # Sigmoid 激活函数用于生成权重
18 | self.sigmoid = nn.Sigmoid()
19 |
20 | def forward(self, x):
21 | batch_size, channels, _, _ = x.size()
22 |
23 | # 全局平均池化
24 | y = self.global_avg_pool(x).view(batch_size, channels)
25 |
26 | # 通过全连接层并生成注意力权重
27 | y = self.fc1(y)
28 | y = self.relu(y)
29 | y = self.fc2(y)
30 | y = self.sigmoid(y).view(batch_size, channels, 1, 1)
31 |
32 | # 将权重与输入特征相乘
33 | return x * y.expand_as(x)
34 |
35 |
36 | class SENet(nn.Module):
37 | def __init__(self, in_channels, reduction=16):
38 | super(SENet, self).__init__()
39 | self.se_block = SEBlock(in_channels, reduction)
40 |
41 | def forward(self, x):
42 | return self.se_block(x)
--------------------------------------------------------------------------------
/注意力模块/SLAB.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from einops import rearrange
5 |
6 | class SimplifiedLinearAttention(nn.Module):
7 | def __init__(self, dim, num_patches, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1,
8 | focusing_factor=3, kernel_size=5, norm_layer=nn.LayerNorm):
9 | super().__init__()
10 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
11 |
12 | self.dim = dim
13 | self.num_heads = num_heads
14 | head_dim = dim // num_heads
15 |
16 | self.q = nn.Linear(dim, dim, bias=qkv_bias)
17 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
18 | self.attn_drop = nn.Dropout(attn_drop)
19 | self.proj = nn.Linear(dim, dim)
20 | self.proj_drop = nn.Dropout(proj_drop)
21 |
22 | self.sr_ratio = sr_ratio
23 | if sr_ratio > 1:
24 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
25 | self.norm = norm_layer(dim)
26 |
27 | self.focusing_factor = focusing_factor
28 | self.dwc = nn.Conv2d(in_channels=head_dim, out_channels=head_dim, kernel_size=kernel_size,
29 | groups=head_dim, padding=kernel_size // 2)
30 | self.positional_encoding = nn.Parameter(torch.zeros(size=(1, num_patches // (sr_ratio * sr_ratio), dim)))
31 | print('Linear Attention sr_ratio{} f{} kernel{}'.
32 | format(sr_ratio, focusing_factor, kernel_size))
33 |
34 | def forward(self, x, H, W):
35 | B, N, C = x.shape
36 | q = self.q(x)
37 |
38 | if self.sr_ratio > 1:
39 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
40 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
41 | x_ = self.norm(x_)
42 | kv = self.kv(x_).reshape(B, -1, 2, C).permute(2, 0, 1, 3)
43 | else:
44 | kv = self.kv(x).reshape(B, -1, 2, C).permute(2, 0, 1, 3)
45 | k, v = kv[0], kv[1]
46 |
47 | k = k + self.positional_encoding
48 | kernel_function = nn.ReLU()
49 | q = kernel_function(q)
50 | k = kernel_function(k)
51 |
52 | q, k, v = (rearrange(x, "b n (h c) -> (b h) n c", h=self.num_heads) for x in [q, k, v])
53 | i, j, c, d = q.shape[-2], k.shape[-2], k.shape[-1], v.shape[-1]
54 |
55 | z = 1 / (torch.einsum("b i c, b c -> b i", q, k.sum(dim=1)) + 1e-6)
56 | if i * j * (c + d) > c * d * (i + j):
57 | kv = torch.einsum("b j c, b j d -> b c d", k, v)
58 | x = torch.einsum("b i c, b c d, b i -> b i d", q, kv, z)
59 | else:
60 | qk = torch.einsum("b i c, b j c -> b i j", q, k)
61 | x = torch.einsum("b i j, b j d, b i -> b i d", qk, v, z)
62 |
63 | if self.sr_ratio > 1:
64 | v = nn.functional.interpolate(v.permute(0, 2, 1), size=x.shape[1], mode='linear').permute(0, 2, 1)
65 | num = int(v.shape[1] ** 0.5)
66 | feature_map = rearrange(v, "b (w h) c -> b c w h", w=num, h=num)
67 | feature_map = rearrange(self.dwc(feature_map), "b c w h -> b (w h) c")
68 | x = x + feature_map
69 | x = rearrange(x, "(b h) n c -> b n (h c)", h=self.num_heads)
70 |
71 | x = self.proj(x)
72 | x = self.proj_drop(x)
73 |
74 | return x
75 |
--------------------------------------------------------------------------------
/注意力模块/SRA.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class SRA(nn.Module):
5 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
6 | super().__init__()
7 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
8 |
9 | self.dim = dim
10 | self.num_heads = num_heads
11 | head_dim = dim // num_heads
12 | self.scale = qk_scale or head_dim ** -0.5
13 |
14 | self.q = nn.Linear(dim, dim, bias=qkv_bias)
15 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
16 | self.attn_drop = nn.Dropout(attn_drop)
17 | self.proj = nn.Linear(dim, dim)
18 | self.proj_drop = nn.Dropout(proj_drop)
19 |
20 | self.sr_ratio = sr_ratio
21 | if sr_ratio > 1:
22 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
23 | self.norm = nn.LayerNorm(dim)
24 |
25 | def forward(self, x, H, W):
26 | B, N, C = x.shape
27 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
28 |
29 | if self.sr_ratio > 1:
30 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
31 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
32 | x_ = self.norm(x_)
33 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
34 | else:
35 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
36 | k, v = kv[0], kv[1]
37 |
38 | attn = (q @ k.transpose(-2, -1)) * self.scale
39 | attn = attn.softmax(dim=-1)
40 | attn = self.attn_drop(attn)
41 |
42 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
43 | x = self.proj(x)
44 | x = self.proj_drop(x)
45 |
46 | return x
--------------------------------------------------------------------------------
/注意力模块/SimAM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class SimAM(nn.Module):
5 | def __init__(self, lambda_=0.0001):
6 | """
7 | SimAM attention module.
8 | :param lambda_: A coefficient lambda in the equation.
9 | """
10 | super(SimAM, self).__init__()
11 | self.lambda_ = lambda_
12 |
13 | def forward(self, X):
14 | """
15 | Forward pass for SimAM.
16 | :param X: Input feature map of shape (N, C, H, W)
17 | :return: Output feature map with attention applied, same shape as input (N, C, H, W)
18 | """
19 | # Calculate the spatial size minus 1 for normalization
20 | n = X.shape[2] * X.shape[3] - 1
21 |
22 | # Calculate the square of (X - mean(X))
23 | d = (X - X.mean(dim=[2, 3], keepdim=True)).pow(2)
24 |
25 | # Channel variance (d.sum() / n)
26 | v = d.sum(dim=[2, 3], keepdim=True) / n
27 |
28 | # Calculate E_inv which contains importance of X
29 | E_inv = d / (4 * (v + self.lambda_)) + 0.5
30 |
31 | # Return attended features using sigmoid activation
32 | return X * torch.sigmoid(E_inv)
33 |
--------------------------------------------------------------------------------
/特征提取or融合or对齐模块/AFF.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class AFF(nn.Module):
6 | def __init__(self, channels=64, reduction=4):
7 | """
8 | Attentional Feature Fusion (AFF) module.
9 | :param channels: Number of input channels.
10 | :param reduction: Reduction ratio for intermediate channels in attention layers. Default is 4.
11 | """
12 | super(AFF, self).__init__()
13 | intermediate_channels = channels // reduction
14 |
15 | # Local attention layer
16 | self.local_attention = nn.Sequential(
17 | nn.Conv2d(channels, intermediate_channels, kernel_size=1, stride=1, padding=0),
18 | nn.BatchNorm2d(intermediate_channels),
19 | nn.ReLU(inplace=True),
20 | nn.Conv2d(intermediate_channels, channels, kernel_size=1, stride=1, padding=0),
21 | nn.BatchNorm2d(channels),
22 | )
23 |
24 | # Global attention layer
25 | self.global_attention = nn.Sequential(
26 | nn.AdaptiveAvgPool2d(1),
27 | nn.Conv2d(channels, intermediate_channels, kernel_size=1, stride=1, padding=0),
28 | nn.BatchNorm2d(intermediate_channels),
29 | nn.ReLU(inplace=True),
30 | nn.Conv2d(intermediate_channels, channels, kernel_size=1, stride=1, padding=0),
31 | nn.BatchNorm2d(channels),
32 | )
33 |
34 | # Activation function for attention weight
35 | self.sigmoid = nn.Sigmoid()
36 |
37 | def forward(self, input_feature, residual_feature):
38 | """
39 | Forward pass for the AFF module.
40 | :param input_feature: First input feature map (tensor of shape N x C x H x W).
41 | :param residual_feature: Second input feature map (tensor of shape N x C x H x W).
42 | :return: Output feature map with fused attention applied (same shape as input).
43 | """
44 | # Initial fusion by element-wise addition
45 | combined_feature = input_feature + residual_feature
46 |
47 | # Compute local and global attention
48 | local_attention = self.local_attention(combined_feature)
49 | global_attention = self.global_attention(combined_feature)
50 |
51 | # Sum local and global attention, then apply sigmoid for attention weight
52 | attention_weight = self.sigmoid(local_attention + global_attention)
53 |
54 | # Weighted combination of input and residual features
55 | output_feature = 2 * input_feature * attention_weight + 2 * residual_feature * (1 - attention_weight)
56 |
57 | return output_feature
58 |
59 |
60 | """
61 | 使用示例
62 |
63 | # 假设输入特征图 input_feature 和 residual_feature 的通道数为 64,尺寸为 32x32
64 | input_feature = torch.randn(1, 64, 32, 32) # (N, C, H, W)
65 | residual_feature = torch.randn(1, 64, 32, 32)
66 |
67 | # 初始化 AFF 模块,指定输入通道数为 64,reduction 比例为 4
68 | aff_module = AFF(channels=64, reduction=4)
69 |
70 | # 计算 AFF 模块的输出
71 | output_feature = aff_module(input_feature, residual_feature)
72 |
73 | # 输出特征图的形状,应该与输入特征图一致
74 | print(output_feature.shape) # 输出: torch.Size([1, 64, 32, 32])
75 |
76 | """
--------------------------------------------------------------------------------
/特征提取or融合or对齐模块/CAFM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from einops import rearrange
4 |
5 |
6 | class CAFM(nn.Module):
7 | def __init__(self, dim, num_heads, bias):
8 | super(CAFM, self).__init__()
9 | self.num_heads = num_heads
10 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
11 |
12 | self.qkv = nn.Conv3d(dim, dim * 3, kernel_size=(1, 1, 1), bias=bias)
13 | self.qkv_dwconv = nn.Conv3d(dim * 3, dim * 3, kernel_size=(3, 3, 3), stride=1, padding=1, groups=dim * 3,
14 | bias=bias)
15 | self.project_out = nn.Conv3d(dim, dim, kernel_size=(1, 1, 1), bias=bias)
16 | self.fc = nn.Conv3d(3 * self.num_heads, 9, kernel_size=(1, 1, 1), bias=True)
17 |
18 | self.dep_conv = nn.Conv3d(9 * dim // self.num_heads, dim, kernel_size=(3, 3, 3), bias=True,
19 | groups=dim // self.num_heads, padding=1)
20 |
21 | def forward(self, x):
22 | b, c, h, w = x.shape
23 | x = x.unsqueeze(2)
24 | qkv = self.qkv_dwconv(self.qkv(x))
25 | qkv = qkv.squeeze(2)
26 | f_conv = qkv.permute(0, 2, 3, 1)
27 | f_all = qkv.reshape(f_conv.shape[0], h * w, 3 * self.num_heads, -1).permute(0, 2, 1, 3)
28 | f_all = self.fc(f_all.unsqueeze(2))
29 | f_all = f_all.squeeze(2)
30 |
31 | # local conv
32 | f_conv = f_all.permute(0, 3, 1, 2).reshape(x.shape[0], 9 * x.shape[1] // self.num_heads, h, w)
33 | f_conv = f_conv.unsqueeze(2)
34 | out_conv = self.dep_conv(f_conv) # B, C, H, W
35 | out_conv = out_conv.squeeze(2)
36 |
37 | # global SA
38 | q, k, v = qkv.chunk(3, dim=1)
39 |
40 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
41 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
42 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
43 |
44 | q = torch.nn.functional.normalize(q, dim=-1)
45 | k = torch.nn.functional.normalize(k, dim=-1)
46 |
47 | attn = (q @ k.transpose(-2, -1)) * self.temperature
48 | attn = attn.softmax(dim=-1)
49 |
50 | out = (attn @ v)
51 |
52 | out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
53 | out = out.unsqueeze(2)
54 | out = self.project_out(out)
55 | out = out.squeeze(2)
56 | output = out + out_conv
57 |
58 | return output
--------------------------------------------------------------------------------
/特征提取or融合or对齐模块/CCFF.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def get_activation(act: str, inpace: bool = True):
7 | '''get activation
8 | '''
9 | act = act.lower()
10 |
11 | if act == 'silu':
12 | m = nn.SiLU()
13 |
14 | elif act == 'relu':
15 | m = nn.ReLU()
16 |
17 | elif act == 'leaky_relu':
18 | m = nn.LeakyReLU()
19 |
20 | elif act == 'silu':
21 | m = nn.SiLU()
22 |
23 | elif act == 'gelu':
24 | m = nn.GELU()
25 |
26 | elif act is None:
27 | m = nn.Identity()
28 |
29 | elif isinstance(act, nn.Module):
30 | m = act
31 |
32 | else:
33 | raise RuntimeError('')
34 |
35 | if hasattr(m, 'inplace'):
36 | m.inplace = inpace
37 |
38 | return m
39 |
40 | class ConvNormLayer(nn.Module):
41 | def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None):
42 | super().__init__()
43 | self.conv = nn.Conv2d(
44 | ch_in,
45 | ch_out,
46 | kernel_size,
47 | stride,
48 | padding=(kernel_size - 1) // 2 if padding is None else padding,
49 | bias=bias)
50 | self.norm = nn.BatchNorm2d(ch_out)
51 | self.act = nn.Identity() if act is None else get_activation(act)
52 |
53 | def forward(self, x):
54 | return self.act(self.norm(self.conv(x)))
55 |
56 |
57 | class RepVggBlock(nn.Module):
58 | def __init__(self, ch_in, ch_out, act='relu'):
59 | super().__init__()
60 | self.ch_in = ch_in
61 | self.ch_out = ch_out
62 | self.conv1 = ConvNormLayer(ch_in, ch_out, 3, 1, padding=1, act=None)
63 | self.conv2 = ConvNormLayer(ch_in, ch_out, 1, 1, padding=0, act=None)
64 | self.act = nn.Identity() if act is None else get_activation(act)
65 |
66 | def forward(self, x):
67 | if hasattr(self, 'conv'):
68 | y = self.conv(x)
69 | else:
70 | y = self.conv1(x) + self.conv2(x)
71 |
72 | return self.act(y)
73 |
74 | def convert_to_deploy(self):
75 | if not hasattr(self, 'conv'):
76 | self.conv = nn.Conv2d(self.ch_in, self.ch_out, 3, 1, padding=1)
77 |
78 | kernel, bias = self.get_equivalent_kernel_bias()
79 | self.conv.weight.data = kernel
80 | self.conv.bias.data = bias
81 | # self.__delattr__('conv1')
82 | # self.__delattr__('conv2')
83 |
84 | def get_equivalent_kernel_bias(self):
85 | kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
86 | kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
87 |
88 | return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1), bias3x3 + bias1x1
89 |
90 | def _pad_1x1_to_3x3_tensor(self, kernel1x1):
91 | if kernel1x1 is None:
92 | return 0
93 | else:
94 | return F.pad(kernel1x1, [1, 1, 1, 1])
95 |
96 | def _fuse_bn_tensor(self, branch: ConvNormLayer):
97 | if branch is None:
98 | return 0, 0
99 | kernel = branch.conv.weight
100 | running_mean = branch.norm.running_mean
101 | running_var = branch.norm.running_var
102 | gamma = branch.norm.weight
103 | beta = branch.norm.bias
104 | eps = branch.norm.eps
105 | std = (running_var + eps).sqrt()
106 | t = (gamma / std).reshape(-1, 1, 1, 1)
107 | return kernel * t, beta - running_mean * gamma / std
108 |
109 |
110 | class CCFF(nn.Module):
111 | def __init__(self,
112 | in_channels,
113 | out_channels,
114 | num_blocks=3,
115 | expansion=1.0,
116 | bias=None,
117 | act="silu"):
118 | super(CCFF, self).__init__()
119 | hidden_channels = int(out_channels * expansion)
120 | self.conv1 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act)
121 | self.conv2 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act)
122 | self.bottlenecks = nn.Sequential(*[
123 | RepVggBlock(hidden_channels, hidden_channels, act=act) for _ in range(num_blocks)
124 | ])
125 | if hidden_channels != out_channels:
126 | self.conv3 = ConvNormLayer(hidden_channels, out_channels, 1, 1, bias=bias, act=act)
127 | else:
128 | self.conv3 = nn.Identity()
129 |
130 | def forward(self, x):
131 | x_1 = self.conv1(x)
132 | x_1 = self.bottlenecks(x_1)
133 | x_2 = self.conv2(x)
134 | return self.conv3(x_1 + x_2)
--------------------------------------------------------------------------------
/特征提取or融合or对齐模块/CGAFusion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from einops.layers.torch import Rearrange
4 |
5 |
6 | class SpatialAttention(nn.Module):
7 | def __init__(self):
8 | super(SpatialAttention, self).__init__()
9 | self.sa = nn.Conv2d(2, 1, 7, padding=3, padding_mode='reflect' ,bias=True)
10 |
11 | def forward(self, x):
12 | x_avg = torch.mean(x, dim=1, keepdim=True)
13 | x_max, _ = torch.max(x, dim=1, keepdim=True)
14 | x2 = torch.concat([x_avg, x_max], dim=1)
15 | sattn = self.sa(x2)
16 | return sattn
17 |
18 |
19 | class ChannelAttention(nn.Module):
20 | def __init__(self, dim, reduction=8):
21 | super(ChannelAttention, self).__init__()
22 | self.gap = nn.AdaptiveAvgPool2d(1)
23 | self.ca = nn.Sequential(
24 | nn.Conv2d(dim, dim // reduction, 1, padding=0, bias=True),
25 | nn.ReLU(inplace=True),
26 | nn.Conv2d(dim // reduction, dim, 1, padding=0, bias=True),
27 | )
28 |
29 | def forward(self, x):
30 | x_gap = self.gap(x)
31 | cattn = self.ca(x_gap)
32 | return cattn
33 |
34 |
35 | class PixelAttention(nn.Module):
36 | def __init__(self, dim):
37 | super(PixelAttention, self).__init__()
38 | self.pa2 = nn.Conv2d(2 * dim, dim, 7, padding=3, padding_mode='reflect', groups=dim, bias=True)
39 | self.sigmoid = nn.Sigmoid()
40 |
41 | def forward(self, x, pattn1):
42 | B, C, H, W = x.shape
43 | x = x.unsqueeze(dim=2) # B, C, 1, H, W
44 | pattn1 = pattn1.unsqueeze(dim=2) # B, C, 1, H, W
45 | x2 = torch.cat([x, pattn1], dim=2) # B, C, 2, H, W
46 | x2 = Rearrange('b c t h w -> b (c t) h w')(x2)
47 | pattn2 = self.pa2(x2)
48 | pattn2 = self.sigmoid(pattn2)
49 | return pattn2
50 |
51 |
52 | class CGAFusion(nn.Module):
53 | def __init__(self, dim, reduction=8):
54 | super(CGAFusion, self).__init__()
55 | self.sa = SpatialAttention()
56 | self.ca = ChannelAttention(dim, reduction)
57 | self.pa = PixelAttention(dim)
58 | self.conv = nn.Conv2d(dim, dim, 1, bias=True)
59 | self.sigmoid = nn.Sigmoid()
60 |
61 | def forward(self, x, y):
62 | initial = x + y
63 | cattn = self.ca(initial)
64 | sattn = self.sa(initial)
65 | pattn1 = sattn + cattn
66 | pattn2 = self.sigmoid(self.pa(initial, pattn1))
67 | result = initial + pattn2 * x + (1 - pattn2) * y
68 | result = self.conv(result)
69 | return result
--------------------------------------------------------------------------------
/特征提取or融合or对齐模块/CSAM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.distributions as td
4 |
5 |
6 | def custom_max(x,dim,keepdim=True):
7 | temp_x=x
8 | for i in dim:
9 | temp_x=torch.max(temp_x,dim=i,keepdim=True)[0]
10 | if not keepdim:
11 | temp_x=temp_x.squeeze()
12 | return temp_x
13 |
14 | class PositionalAttentionModule(nn.Module):
15 | def __init__(self):
16 | super(PositionalAttentionModule,self).__init__()
17 | self.conv=nn.Conv2d(in_channels=2,out_channels=1,kernel_size=(7,7),padding=3)
18 | def forward(self,x):
19 | max_x=custom_max(x,dim=(0,1),keepdim=True)
20 | avg_x=torch.mean(x,dim=(0,1),keepdim=True)
21 | att=torch.cat((max_x,avg_x),dim=1)
22 | att=self.conv(att)
23 | att=torch.sigmoid(att)
24 | return x*att
25 |
26 | class SemanticAttentionModule(nn.Module):
27 | def __init__(self,in_features,reduction_rate=16):
28 | super(SemanticAttentionModule,self).__init__()
29 | self.linear=[]
30 | self.linear.append(nn.Linear(in_features=in_features,out_features=in_features//reduction_rate))
31 | self.linear.append(nn.ReLU())
32 | self.linear.append(nn.Linear(in_features=in_features//reduction_rate,out_features=in_features))
33 | self.linear=nn.Sequential(*self.linear)
34 | def forward(self,x):
35 | max_x=custom_max(x,dim=(0,2,3),keepdim=False).unsqueeze(0)
36 | avg_x=torch.mean(x,dim=(0,2,3),keepdim=False).unsqueeze(0)
37 | max_x=self.linear(max_x)
38 | avg_x=self.linear(avg_x)
39 | att=max_x+avg_x
40 | att=torch.sigmoid(att).unsqueeze(-1).unsqueeze(-1)
41 | return x*att
42 |
43 | class SliceAttentionModule(nn.Module):
44 | def __init__(self,in_features,rate=4,uncertainty=True,rank=5):
45 | super(SliceAttentionModule,self).__init__()
46 | self.uncertainty=uncertainty
47 | self.rank=rank
48 | self.linear=[]
49 | self.linear.append(nn.Linear(in_features=in_features,out_features=int(in_features*rate)))
50 | self.linear.append(nn.ReLU())
51 | self.linear.append(nn.Linear(in_features=int(in_features*rate),out_features=in_features))
52 | self.linear=nn.Sequential(*self.linear)
53 | if uncertainty:
54 | self.non_linear=nn.ReLU()
55 | self.mean=nn.Linear(in_features=in_features,out_features=in_features)
56 | self.log_diag=nn.Linear(in_features=in_features,out_features=in_features)
57 | self.factor=nn.Linear(in_features=in_features,out_features=in_features*rank)
58 | def forward(self,x):
59 | max_x=custom_max(x,dim=(1,2,3),keepdim=False).unsqueeze(0)
60 | avg_x=torch.mean(x,dim=(1,2,3),keepdim=False).unsqueeze(0)
61 | max_x=self.linear(max_x)
62 | avg_x=self.linear(avg_x)
63 | att=max_x+avg_x
64 | if self.uncertainty:
65 | temp=self.non_linear(att)
66 | mean=self.mean(temp)
67 | diag=self.log_diag(temp).exp()
68 | factor=self.factor(temp)
69 | factor=factor.view(1,-1,self.rank)
70 | dist=td.LowRankMultivariateNormal(loc=mean,cov_factor=factor,cov_diag=diag)
71 | att=dist.sample()
72 | att=torch.sigmoid(att).squeeze().unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
73 | return x*att
74 |
75 |
76 | class CSAM(nn.Module):
77 | def __init__(self,num_slices,num_channels,semantic=True,positional=True,slice=True,uncertainty=True,rank=5):
78 | super(CSAM,self).__init__()
79 | self.semantic=semantic
80 | self.positional=positional
81 | self.slice=slice
82 | if semantic:
83 | self.semantic_att=SemanticAttentionModule(num_channels)
84 | if positional:
85 | self.positional_att=PositionalAttentionModule()
86 | if slice:
87 | self.slice_att=SliceAttentionModule(num_slices,uncertainty=uncertainty,rank=rank)
88 | def forward(self,x):
89 | if self.semantic:
90 | x=self.semantic_att(x)
91 | if self.positional:
92 | x=self.positional_att(x)
93 | if self.slice:
94 | x=self.slice_att(x)
95 | return x
--------------------------------------------------------------------------------
/特征提取or融合or对齐模块/FARM.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import pytorch_lightning as pl
6 | from torchvision.ops import DeformConv2d
7 | from pytorch_lightning import seed_everything
8 | from einops import rearrange
9 | import numbers
10 |
11 | seed_everything(13)
12 |
13 |
14 | ##########################################################################
15 | ## Layer Norm
16 |
17 | def to_3d(x):
18 | return rearrange(x, 'b c h w -> b (h w) c')
19 |
20 |
21 | def to_4d(x, h, w):
22 | return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
23 |
24 |
25 | class BiasFree_LayerNorm(pl.LightningModule):
26 | def __init__(self, normalized_shape):
27 | super(BiasFree_LayerNorm, self).__init__()
28 | if isinstance(normalized_shape, numbers.Integral):
29 | normalized_shape = (normalized_shape,)
30 | normalized_shape = torch.Size(normalized_shape)
31 |
32 | assert len(normalized_shape) == 1
33 |
34 | self.weight = nn.Parameter(torch.ones(normalized_shape))
35 | self.normalized_shape = normalized_shape
36 |
37 | def forward(self, x):
38 | sigma = x.var(-1, keepdim=True, unbiased=False)
39 | return x / torch.sqrt(sigma + 1e-5) * self.weight
40 |
41 |
42 | class WithBias_LayerNorm(pl.LightningModule):
43 | def __init__(self, normalized_shape):
44 | super(WithBias_LayerNorm, self).__init__()
45 | if isinstance(normalized_shape, numbers.Integral):
46 | normalized_shape = (normalized_shape,)
47 | normalized_shape = torch.Size(normalized_shape)
48 |
49 | assert len(normalized_shape) == 1
50 |
51 | self.weight = nn.Parameter(torch.ones(normalized_shape))
52 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
53 | self.normalized_shape = normalized_shape
54 |
55 | def forward(self, x):
56 | mu = x.mean(-1, keepdim=True)
57 | sigma = x.var(-1, keepdim=True, unbiased=False)
58 | return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
59 |
60 |
61 | class LayerNorm(pl.LightningModule):
62 | def __init__(self, dim, LayerNorm_type):
63 | super(LayerNorm, self).__init__()
64 | if LayerNorm_type == 'BiasFree':
65 | self.body = BiasFree_LayerNorm(dim)
66 | else:
67 | self.body = WithBias_LayerNorm(dim)
68 |
69 | def forward(self, x):
70 | h, w = x.shape[-2:]
71 | return to_4d(self.body(to_3d(x)), h, w)
72 |
73 |
74 | ##########################################################################
75 | ## Gated-Dconv Feed-Forward Network (GDFN)
76 | class FeedForward(pl.LightningModule):
77 | def __init__(self, dim, ffn_expansion_factor, bias):
78 | super(FeedForward, self).__init__()
79 |
80 | hidden_features = int(dim * ffn_expansion_factor)
81 |
82 | self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
83 |
84 | self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1,
85 | groups=hidden_features * 2, bias=bias)
86 |
87 | self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
88 |
89 | def forward(self, x):
90 | x = self.project_in(x)
91 | x1, x2 = self.dwconv(x).chunk(2, dim=1)
92 | x = F.gelu(x1) * x2
93 | x = self.project_out(x)
94 | return x
95 |
96 |
97 | ##########################################################################
98 | ## Multi-DConv Head Transposed Self-Attention (MDTA)
99 | class Attention(pl.LightningModule):
100 | def __init__(self, dim, num_heads, stride, bias):
101 | super(Attention, self).__init__()
102 | self.num_heads = num_heads
103 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
104 |
105 | self.stride = stride
106 | self.qk = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=bias)
107 | self.qk_dwconv = nn.Conv2d(dim * 2, dim * 2, kernel_size=3, stride=self.stride, padding=1, groups=dim * 2,
108 | bias=bias)
109 |
110 | self.v = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
111 | self.v_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias)
112 |
113 | self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
114 |
115 | def forward(self, x):
116 | b, c, h, w = x.shape
117 |
118 | qk = self.qk_dwconv(self.qk(x))
119 | q, k = qk.chunk(2, dim=1)
120 |
121 | v = self.v_dwconv(self.v(x))
122 |
123 | b, f, h1, w1 = q.size()
124 |
125 | q = rearrange(q, 'b (head c) h1 w1 -> b head c (h1 w1)', head=self.num_heads)
126 | k = rearrange(k, 'b (head c) h1 w1 -> b head c (h1 w1)', head=self.num_heads)
127 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
128 |
129 | q = torch.nn.functional.normalize(q, dim=-1)
130 | k = torch.nn.functional.normalize(k, dim=-1)
131 |
132 | attn = (q @ k.transpose(-2, -1)) * self.temperature
133 | attn = attn.softmax(dim=-1)
134 |
135 | out = (attn @ v)
136 |
137 | out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
138 |
139 | out = self.project_out(out)
140 | return out
141 |
142 |
143 | ##########################################################################
144 | ######### Burst Feature Attention ########################################
145 |
146 | ##########################################################################
147 | class BFA(pl.LightningModule):
148 | def __init__(self, dim, num_heads, stride, ffn_expansion_factor, bias, LayerNorm_type):
149 | super(BFA, self).__init__()
150 |
151 | self.norm1 = LayerNorm(dim, LayerNorm_type)
152 | self.attn = Attention(dim, num_heads, stride, bias)
153 | self.norm2 = LayerNorm(dim, LayerNorm_type)
154 | self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
155 |
156 | def forward(self, x):
157 | x = x + self.attn(self.norm1(x))
158 | x = x + self.ffn(self.norm2(x))
159 |
160 | return x
161 |
162 |
163 | ##########################################################################
164 | ## Overlapped image patch embedding with 3x3 Conv
165 | class OverlapPatchEmbed(pl.LightningModule):
166 | def __init__(self, in_c=3, embed_dim=48, bias=False):
167 | super(OverlapPatchEmbed, self).__init__()
168 |
169 | self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
170 |
171 | def forward(self, x):
172 | # print("Inside patch embed:::", inp_enc_level1.size())
173 | x = self.proj(x)
174 |
175 | return x
176 |
177 |
178 | class ref_back_projection(pl.LightningModule):
179 | def __init__(self, in_channels, stride):
180 | super(ref_back_projection, self).__init__()
181 |
182 | bias = False
183 |
184 | self.feat_fusion = nn.Sequential(nn.Conv2d(in_channels * 2, in_channels, 3, stride=1, padding=1), nn.GELU())
185 | self.feat_expand = nn.Sequential(nn.Conv2d(in_channels, in_channels * 2, 3, stride=1, padding=1), nn.GELU())
186 | self.diff_fusion = nn.Sequential(nn.Conv2d(in_channels * 2, in_channels, 3, stride=1, padding=1), nn.GELU())
187 |
188 | self.encoder1 = nn.Sequential(*[
189 | BFA(dim=in_channels * 2, num_heads=1, stride=stride, ffn_expansion_factor=2.66, bias=bias,
190 | LayerNorm_type='WithBias') for i in range(2)])
191 |
192 | def forward(self, x):
193 | B, f, H, W = x.size()
194 |
195 | ref = x[0].unsqueeze(0)
196 | ref = torch.repeat_interleave(ref, B, dim=0)
197 | feat = self.encoder1(torch.cat([ref, x], dim=1))
198 |
199 | fused_feat = self.feat_fusion(feat)
200 | exp_feat = self.feat_expand(fused_feat)
201 |
202 | residual = exp_feat - feat
203 | residual = self.diff_fusion(residual)
204 |
205 | fused_feat = fused_feat + residual
206 |
207 | return fused_feat
208 |
209 |
210 | class alignment(pl.LightningModule):
211 | def __init__(self, dim=48, memory=False, stride=1, type='group_conv'):
212 |
213 | super(alignment, self).__init__()
214 |
215 | act = nn.GELU()
216 | bias = False
217 |
218 | kernel_size = 3
219 | padding = kernel_size // 2
220 | deform_groups = 8
221 | out_channels = deform_groups * 3 * (kernel_size ** 2)
222 |
223 | self.offset_conv = nn.Conv2d(dim, out_channels, kernel_size, stride=1, padding=padding, bias=bias)
224 | self.deform = DeformConv2d(dim, dim, kernel_size, padding=2, groups=deform_groups, dilation=2)
225 | self.back_projection = ref_back_projection(dim, stride=1)
226 |
227 | self.bottleneck = nn.Sequential(nn.Conv2d(dim * 2, dim, kernel_size=3, padding=1, bias=bias), act)
228 |
229 | if memory == True:
230 | self.bottleneck_o = nn.Sequential(nn.Conv2d(dim * 2, dim, kernel_size=3, padding=1, bias=bias), act)
231 |
232 | def offset_gen(self, x):
233 |
234 | o1, o2, mask = torch.chunk(x, 3, dim=1)
235 | offset = torch.cat((o1, o2), dim=1)
236 | mask = torch.sigmoid(mask)
237 |
238 | return offset, mask
239 |
240 | def forward(self, x, prev_offset_feat=None):
241 |
242 | B, f, H, W = x.size()
243 | ref = x[0].unsqueeze(0)
244 | ref = torch.repeat_interleave(ref, B, dim=0)
245 |
246 | offset_feat = self.bottleneck(torch.cat([ref, x], dim=1))
247 |
248 | if not prev_offset_feat == None:
249 | offset_feat = self.bottleneck_o(torch.cat([prev_offset_feat, offset_feat], dim=1))
250 |
251 | offset, mask = self.offset_gen(self.offset_conv(offset_feat))
252 |
253 | aligned_feat = self.deform(x, offset, mask)
254 | aligned_feat[0] = x[0].unsqueeze(0)
255 |
256 | aligned_feat = self.back_projection(aligned_feat)
257 |
258 | # return aligned_feat, offset_feat
259 | return aligned_feat
260 |
261 |
--------------------------------------------------------------------------------
/特征提取or融合or对齐模块/FCA.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 |
5 |
6 | class Mix(nn.Module):
7 | def __init__(self, m=-0.80):
8 | super(Mix, self).__init__()
9 | w = torch.nn.Parameter(torch.FloatTensor([m]), requires_grad=True)
10 | w = torch.nn.Parameter(w, requires_grad=True)
11 | self.w = w
12 | self.mix_block = nn.Sigmoid()
13 |
14 | def forward(self, fea1, fea2):
15 | mix_factor = self.mix_block(self.w)
16 | out = fea1 * mix_factor.expand_as(fea1) + fea2 * (1 - mix_factor.expand_as(fea2))
17 | return out
18 |
19 |
20 | class FCA(nn.Module):
21 | def __init__(self,channel,b=1, gamma=2):
22 | super(FCA, self).__init__()
23 | self.avg_pool = nn.AdaptiveAvgPool2d(1)#全局平均池化
24 | #一维卷积
25 | t = int(abs((math.log(channel, 2) + b) / gamma))
26 | k = t if t % 2 else t + 1
27 | self.conv1 = nn.Conv1d(1, 1, kernel_size=k, padding=int(k / 2), bias=False)
28 | self.fc = nn.Conv2d(channel, channel, 1, padding=0, bias=True)
29 | self.sigmoid = nn.Sigmoid()
30 | self.mix = Mix()
31 |
32 |
33 | def forward(self, input):
34 | x = self.avg_pool(input)
35 | x1 = self.conv1(x.squeeze(-1).transpose(-1, -2)).transpose(-1, -2)#(1,64,1)
36 | x2 = self.fc(x).squeeze(-1).transpose(-1, -2)#(1,1,64)
37 | out1 = torch.sum(torch.matmul(x1,x2),dim=1).unsqueeze(-1).unsqueeze(-1)#(1,64,1,1)
38 | #x1 = x1.transpose(-1, -2).unsqueeze(-1)
39 | out1 = self.sigmoid(out1)
40 | out2 = torch.sum(torch.matmul(x2.transpose(-1, -2),x1.transpose(-1, -2)),dim=1).unsqueeze(-1).unsqueeze(-1)
41 |
42 | #out2 = self.fc(x)
43 | out2 = self.sigmoid(out2)
44 | out = self.mix(out1,out2)
45 | out = self.conv1(out.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
46 | out = self.sigmoid(out)
47 |
48 | return input*out
--------------------------------------------------------------------------------
/特征提取or融合or对齐模块/GLSA.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from mmcv.cnn import constant_init, kaiming_init
5 |
6 |
7 | class BasicConv2d(nn.Module):
8 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
9 | super(BasicConv2d, self).__init__()
10 |
11 | self.conv = nn.Conv2d(in_planes, out_planes,
12 | kernel_size=kernel_size, stride=stride,
13 | padding=padding, dilation=dilation, bias=False)
14 | self.bn = nn.BatchNorm2d(out_planes)
15 | self.relu = nn.ReLU(inplace=True)
16 |
17 | def forward(self, x):
18 | x = self.conv(x)
19 | x = self.bn(x)
20 | x = self.relu(x)
21 | return x
22 |
23 |
24 | class Block(nn.Sequential):
25 | def __init__(self, input_num, num1, num2, dilation_rate, drop_out, bn_start=True, norm_layer=nn.BatchNorm2d):
26 | super(Block, self).__init__()
27 | if bn_start:
28 | self.add_module('norm1', norm_layer(input_num)),
29 |
30 | self.add_module('relu1', nn.ReLU(inplace=True)),
31 | self.add_module('conv1', nn.Conv2d(in_channels=input_num, out_channels=num1, kernel_size=1)),
32 |
33 | self.add_module('norm2', norm_layer(num1)),
34 | self.add_module('relu2', nn.ReLU(inplace=True)),
35 | self.add_module('conv2', nn.Conv2d(in_channels=num1, out_channels=num2, kernel_size=3,
36 | dilation=dilation_rate, padding=dilation_rate)),
37 | self.drop_rate = drop_out
38 |
39 | def forward(self, _input):
40 | feature = super(Block, self).forward(_input)
41 | if self.drop_rate > 0:
42 | feature = F.dropout2d(feature, p=self.drop_rate, training=self.training)
43 | return feature
44 |
45 |
46 | def Upsample(x, size, align_corners=False):
47 | """
48 | Wrapper Around the Upsample Call
49 | """
50 | return nn.functional.interpolate(x, size=size, mode='bilinear', align_corners=align_corners)
51 |
52 |
53 | def last_zero_init(m):
54 | if isinstance(m, nn.Sequential):
55 | constant_init(m[-1], val=0)
56 | else:
57 | constant_init(m, val=0)
58 |
59 |
60 | class ContextBlock(nn.Module):
61 |
62 | def __init__(self,
63 | inplanes,
64 | ratio,
65 | pooling_type='att',
66 | fusion_types=('channel_mul',)):
67 | super(ContextBlock, self).__init__()
68 | assert pooling_type in ['avg', 'att']
69 | assert isinstance(fusion_types, (list, tuple))
70 | valid_fusion_types = ['channel_add', 'channel_mul']
71 | assert all([f in valid_fusion_types for f in fusion_types])
72 | assert len(fusion_types) > 0, 'at least one fusion should be used'
73 | self.inplanes = inplanes
74 | self.ratio = ratio
75 | self.planes = int(inplanes * ratio)
76 | self.pooling_type = pooling_type
77 | self.fusion_types = fusion_types
78 | if pooling_type == 'att':
79 | self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
80 | self.softmax = nn.Softmax(dim=2)
81 | else:
82 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
83 | if 'channel_add' in fusion_types:
84 | self.channel_add_conv = nn.Sequential(
85 | nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
86 | nn.LayerNorm([self.planes, 1, 1]),
87 | nn.ReLU(inplace=True), # yapf: disable
88 | nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
89 | else:
90 | self.channel_add_conv = None
91 | if 'channel_mul' in fusion_types:
92 | self.channel_mul_conv = nn.Sequential(
93 | nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
94 | nn.LayerNorm([self.planes, 1, 1]),
95 | nn.ReLU(inplace=True), # yapf: disable
96 | nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
97 | else:
98 | self.channel_mul_conv = None
99 | self.reset_parameters()
100 |
101 | def reset_parameters(self):
102 | if self.pooling_type == 'att':
103 | kaiming_init(self.conv_mask, mode='fan_in')
104 | self.conv_mask.inited = True
105 |
106 | if self.channel_add_conv is not None:
107 | last_zero_init(self.channel_add_conv)
108 | if self.channel_mul_conv is not None:
109 | last_zero_init(self.channel_mul_conv)
110 |
111 | def spatial_pool(self, x):
112 | batch, channel, height, width = x.size()
113 | if self.pooling_type == 'att':
114 | input_x = x
115 | # [N, C, H * W]
116 | input_x = input_x.view(batch, channel, height * width)
117 | # [N, 1, C, H * W]
118 | input_x = input_x.unsqueeze(1)
119 | # [N, 1, H, W]
120 | context_mask = self.conv_mask(x)
121 | # [N, 1, H * W]
122 | context_mask = context_mask.view(batch, 1, height * width)
123 | # [N, 1, H * W]
124 | context_mask = self.softmax(context_mask)
125 | # [N, 1, H * W, 1]
126 | context_mask = context_mask.unsqueeze(-1)
127 | # [N, 1, C, 1]
128 | context = torch.matmul(input_x, context_mask)
129 | # [N, C, 1, 1]
130 | context = context.view(batch, channel, 1, 1)
131 | else:
132 | # [N, C, 1, 1]
133 | context = self.avg_pool(x)
134 |
135 | return context
136 |
137 | def forward(self, x):
138 | # [N, C, 1, 1]
139 | context = self.spatial_pool(x)
140 |
141 | out = x
142 | if self.channel_mul_conv is not None:
143 | # [N, C, 1, 1]
144 | channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
145 | out = out + out * channel_mul_term
146 | if self.channel_add_conv is not None:
147 | # [N, C, 1, 1]
148 | channel_add_term = self.channel_add_conv(context)
149 | out = out + channel_add_term
150 |
151 | return out
152 |
153 |
154 | class ChannelAttention(nn.Module):
155 | def __init__(self, in_planes, ratio=16):
156 | super(ChannelAttention, self).__init__()
157 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
158 | self.max_pool = nn.AdaptiveMaxPool2d(1)
159 |
160 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
161 | self.relu1 = nn.ReLU()
162 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
163 |
164 | self.sigmoid = nn.Sigmoid()
165 |
166 | def forward(self, x):
167 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
168 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
169 | out = avg_out + max_out
170 | return self.sigmoid(out)
171 |
172 |
173 | class SpatialAttention(nn.Module):
174 | def __init__(self, kernel_size=7):
175 | super(SpatialAttention, self).__init__()
176 |
177 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
178 | padding = 3 if kernel_size == 7 else 1
179 |
180 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
181 | self.sigmoid = nn.Sigmoid()
182 |
183 | def forward(self, x):
184 | avg_out = torch.mean(x, dim=1, keepdim=True)
185 | max_out, _ = torch.max(x, dim=1, keepdim=True)
186 | x = torch.cat([avg_out, max_out], dim=1)
187 | x = self.conv1(x)
188 | return self.sigmoid(x)
189 |
190 |
191 | class ConvBranch(nn.Module):
192 | def __init__(self, in_features, hidden_features=None, out_features=None):
193 | super().__init__()
194 | hidden_features = hidden_features or in_features
195 | out_features = out_features or in_features
196 | self.conv1 = nn.Sequential(
197 | nn.Conv2d(in_features, hidden_features, 1, bias=False),
198 | nn.BatchNorm2d(hidden_features),
199 | nn.ReLU(inplace=True)
200 | )
201 | self.conv2 = nn.Sequential(
202 | nn.Conv2d(hidden_features, hidden_features, 3, padding=1, groups=hidden_features, bias=False),
203 | nn.BatchNorm2d(hidden_features),
204 | nn.ReLU(inplace=True)
205 | )
206 | self.conv3 = nn.Sequential(
207 | nn.Conv2d(hidden_features, hidden_features, 1, bias=False),
208 | nn.BatchNorm2d(hidden_features),
209 | nn.ReLU(inplace=True)
210 | )
211 | self.conv4 = nn.Sequential(
212 | nn.Conv2d(hidden_features, hidden_features, 3, padding=1, groups=hidden_features, bias=False),
213 | nn.BatchNorm2d(hidden_features),
214 | nn.ReLU(inplace=True)
215 | )
216 | self.conv5 = nn.Sequential(
217 | nn.Conv2d(hidden_features, hidden_features, 1, bias=False),
218 | nn.BatchNorm2d(hidden_features),
219 | nn.SiLU(inplace=True)
220 | )
221 | self.conv6 = nn.Sequential(
222 | nn.Conv2d(hidden_features, hidden_features, 3, padding=1, groups=hidden_features, bias=False),
223 | nn.BatchNorm2d(hidden_features),
224 | nn.ReLU(inplace=True)
225 | )
226 | self.conv7 = nn.Sequential(
227 | nn.Conv2d(hidden_features, out_features, 1, bias=False),
228 | nn.ReLU(inplace=True)
229 | )
230 | self.ca = ChannelAttention(64)
231 | self.sa = SpatialAttention()
232 | self.sigmoid_spatial = nn.Sigmoid()
233 |
234 | def forward(self, x):
235 | res1 = x
236 | res2 = x
237 | x = self.conv1(x)
238 | x = x + self.conv2(x)
239 | x = self.conv3(x)
240 | x = x + self.conv4(x)
241 | x = self.conv5(x)
242 | x = x + self.conv6(x)
243 | x = self.conv7(x)
244 | x_mask = self.sigmoid_spatial(x)
245 | res1 = res1 * x_mask
246 | return res2 + res1
247 |
248 |
249 | class GLSA(nn.Module):
250 |
251 | def __init__(self, input_dim=512, embed_dim=32, k_s=3):
252 | super().__init__()
253 |
254 | self.conv1_1 = BasicConv2d(embed_dim * 2, embed_dim, 1)
255 | self.conv1_1_1 = BasicConv2d(input_dim // 2, embed_dim, 1)
256 | self.local_11conv = nn.Conv2d(input_dim // 2, embed_dim, 1)
257 | self.global_11conv = nn.Conv2d(input_dim // 2, embed_dim, 1)
258 | self.GlobelBlock = ContextBlock(inplanes=embed_dim, ratio=2)
259 | self.local = ConvBranch(in_features=embed_dim, hidden_features=embed_dim, out_features=embed_dim)
260 |
261 | def forward(self, x):
262 | b, c, h, w = x.size()
263 | x_0, x_1 = x.chunk(2, dim=1)
264 |
265 | # local block
266 | local = self.local(self.local_11conv(x_0))
267 |
268 | # Globel block
269 | Globel = self.GlobelBlock(self.global_11conv(x_1))
270 |
271 | # concat Globel + local
272 | x = torch.cat([local, Globel], dim=1)
273 | x = self.conv1_1(x)
274 |
275 | return x
--------------------------------------------------------------------------------
/特征提取or融合or对齐模块/LSK.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class LSKblock(nn.Module):
5 | def __init__(self, dim):
6 | super().__init__()
7 | self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
8 | self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)
9 | self.conv1 = nn.Conv2d(dim, dim//2, 1)
10 | self.conv2 = nn.Conv2d(dim, dim//2, 1)
11 | self.conv_squeeze = nn.Conv2d(2, 2, 7, padding=3)
12 | self.conv = nn.Conv2d(dim//2, dim, 1)
13 |
14 | def forward(self, x):
15 | attn1 = self.conv0(x)
16 | attn2 = self.conv_spatial(attn1)
17 |
18 | attn1 = self.conv1(attn1)
19 | attn2 = self.conv2(attn2)
20 |
21 | attn = torch.cat([attn1, attn2], dim=1)
22 | avg_attn = torch.mean(attn, dim=1, keepdim=True)
23 | max_attn, _ = torch.max(attn, dim=1, keepdim=True)
24 | agg = torch.cat([avg_attn, max_attn], dim=1)
25 | sig = self.conv_squeeze(agg).sigmoid()
26 | attn = attn1 * sig[: ,0 ,: ,:].unsqueeze(1) + attn2 * sig[: ,1 ,: ,:].unsqueeze(1)
27 | attn = self.conv(attn)
28 | return x * attn
--------------------------------------------------------------------------------
/特征提取or融合or对齐模块/MFII.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class eca_layer(nn.Module):
6 | """Constructs a ECA module.
7 | Args:
8 | channel: Number of channels of the input feature map
9 | k_size: Adaptive selection of kernel size
10 | source: https://github.com/BangguWu/ECANet
11 | """
12 | def __init__(self, channel, k_size=3):
13 | super(eca_layer, self).__init__()
14 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
15 | self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
16 | self.sigmoid = nn.Sigmoid()
17 |
18 | def forward(self, x):
19 | # x: input features with shape [b, c, h, w]
20 | b, c, h, w = x.size()
21 |
22 | # feature descriptor on the global spatial information
23 | y = self.avg_pool(x)
24 |
25 | # Two different branches of ECA module
26 | y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
27 |
28 | # Multi-scale information fusion
29 | y = self.sigmoid(y)
30 |
31 | return x * y.expand_as(x)
32 |
33 |
34 | class SELayer(nn.Module):
35 | def __init__(self, channel, reduction=16):
36 | super(SELayer, self).__init__()
37 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
38 | self.fc = nn.Sequential(
39 | nn.Linear(channel, channel // reduction, bias=False),
40 | nn.ReLU(inplace=True),
41 | nn.Linear(channel // reduction, channel, bias=False),
42 | nn.Sigmoid()
43 | )
44 |
45 | def forward(self, x):
46 | b, c, _, _ = x.size()
47 | y = self.avg_pool(x).view(b, c)
48 | y = self.fc(y).view(b, c, 1, 1)
49 | return x * y.expand_as(x)
50 |
51 |
52 |
53 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
54 | """3x3 convolution with padding"""
55 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
56 | padding=dilation, groups=groups, bias=False, dilation=dilation)
57 |
58 | def conv1x1(in_planes, out_planes, stride=1):
59 | """1x1 convolution"""
60 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
61 |
62 | class MFII_BasicBlock(nn.Module):
63 | expansion = 1
64 |
65 | def __init__(self, inplanes, planes, stride=1, downsample=None,
66 | rla_channel=32, SE=False, ECA_size=None, groups=1,
67 | base_width=64, dilation=1, norm_layer=None, reduction=16):
68 | super(MFII_BasicBlock, self).__init__()
69 | if norm_layer is None:
70 | norm_layer = nn.BatchNorm2d
71 |
72 | # if groups != 1 or base_width != 64:
73 | # raise ValueError('BasicBlock only supports groups=1 and base_width=64')
74 | if dilation > 1:
75 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
76 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
77 | self.conv1 = conv3x3(inplanes + rla_channel, planes, stride)
78 | self.bn1 = norm_layer(planes)
79 | self.relu = nn.ReLU(inplace=True)
80 | self.conv2 = conv3x3(planes, planes)
81 | self.bn2 = norm_layer(planes)
82 | self.downsample = downsample
83 | self.stride = stride
84 |
85 | self.averagePooling = None
86 | if downsample is not None and stride != 1:
87 | self.averagePooling = nn.AvgPool2d((2, 2), stride=(2, 2))
88 |
89 | self.se = None
90 | if SE:
91 | self.se = SELayer(planes * self.expansion, reduction)
92 |
93 | self.eca = None
94 | if ECA_size != None:
95 | self.eca = eca_layer(planes * self.expansion, int(ECA_size))
96 |
97 | def forward(self, x, h):
98 | identity = x
99 |
100 | x = torch.cat((x, h), dim=1) # [8, 96, 56, 56]
101 |
102 | out = self.conv1(x) # [8, 64, 56, 56]
103 | out = self.bn1(out) # [8, 64, 56, 56]
104 | out = self.relu(out)
105 |
106 | out = self.conv2(out) # [8, 64, 56, 56]
107 | out = self.bn2(out)
108 |
109 | if self.se != None:
110 | out = self.se(out)
111 |
112 | if self.eca != None:
113 | out = self.eca(out)
114 |
115 | y = out
116 |
117 | if self.downsample is not None:
118 | identity = self.downsample(identity)
119 | if self.averagePooling is not None:
120 | h = self.averagePooling(h)
121 |
122 | out += identity
123 | out = self.relu(out)
124 |
125 | return out, y, h
126 |
127 |
128 | class MFII_BasicBlock_half(nn.Module):
129 | expansion = 1
130 |
131 | def __init__(self, inplanes, planes, stride=1, downsample=None,
132 | rla_channel=16, SE=False, ECA_size=None, groups=1,
133 | base_width=64, dilation=1, norm_layer=None, reduction=16):
134 | super(MFII_BasicBlock_half, self).__init__()
135 | if norm_layer is None:
136 | norm_layer = nn.BatchNorm2d
137 |
138 | # if groups != 1 or base_width != 64:
139 | # raise ValueError('BasicBlock only supports groups=1 and base_width=64')
140 | if dilation > 1:
141 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
142 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
143 | self.conv1 = conv3x3(inplanes + rla_channel, planes, stride)
144 | self.bn1 = norm_layer(planes)
145 | self.relu = nn.ReLU(inplace=True)
146 | self.conv2 = conv3x3(planes, planes)
147 | self.bn2 = norm_layer(planes)
148 | self.downsample = downsample
149 | self.stride = stride
150 |
151 | self.averagePooling = None
152 | if downsample is not None and stride != 1:
153 | self.averagePooling = nn.AvgPool2d((2, 2), stride=(2, 2))
154 |
155 | self.se = None
156 | if SE:
157 | self.se = SELayer(planes * self.expansion, reduction)
158 |
159 | self.eca = None
160 | if ECA_size != None:
161 | self.eca = eca_layer(planes * self.expansion, int(ECA_size))
162 |
163 | def forward(self, x, h):
164 | identity = x
165 |
166 | x = torch.cat((x, h), dim=1) # [8, 96, 56, 56]
167 |
168 | out = self.conv1(x) # [8, 64, 56, 56]
169 | out = self.bn1(out) # [8, 64, 56, 56]
170 | out = self.relu(out)
171 |
172 | out = self.conv2(out) # [8, 64, 56, 56]
173 | out = self.bn2(out)
174 |
175 | if self.se != None:
176 | out = self.se(out)
177 |
178 | if self.eca != None:
179 | out = self.eca(out)
180 |
181 | y = out
182 |
183 | if self.downsample is not None:
184 | identity = self.downsample(identity)
185 | if self.averagePooling is not None:
186 | h = self.averagePooling(h)
187 |
188 | out += identity
189 | out = self.relu(out)
190 |
191 | return out, y, h
--------------------------------------------------------------------------------
/特征提取or融合or对齐模块/PSFM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class BBasicConv2d(nn.Module):
6 | def __init__(
7 | self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False,
8 | ):
9 | super(BBasicConv2d, self).__init__()
10 |
11 | self.basicconv = nn.Sequential(
12 | nn.Conv2d(
13 | in_planes,
14 | out_planes,
15 | kernel_size=kernel_size,
16 | stride=stride,
17 | padding=padding,
18 | dilation=dilation,
19 | groups=groups,
20 | bias=bias,
21 | ),
22 | nn.BatchNorm2d(out_planes),
23 | nn.ReLU(inplace=True),
24 | )
25 |
26 | def forward(self, x):
27 | return self.basicconv(x)
28 |
29 |
30 | class DenseLayer(nn.Module):
31 | def __init__(self, in_C, out_C, down_factor=4, k=4):
32 | """
33 | 更像是DenseNet的Block,从而构造特征内的密集连接
34 | """
35 | super(DenseLayer, self).__init__()
36 | self.k = k
37 | self.down_factor = down_factor
38 | mid_C = out_C // self.down_factor
39 |
40 | self.down = nn.Conv2d(in_C, mid_C, 1)
41 |
42 | self.denseblock = nn.ModuleList()
43 | for i in range(1, self.k + 1):
44 | self.denseblock.append(BBasicConv2d(mid_C * i, mid_C, 3, 1, 1))
45 |
46 | self.fuse = BBasicConv2d(in_C + mid_C, out_C, kernel_size=3, stride=1, padding=1)
47 |
48 | def forward(self, in_feat):
49 | down_feats = self.down(in_feat)
50 | # print(down_feats.shape)
51 | # print(self.denseblock)
52 | out_feats = []
53 | for i in self.denseblock:
54 | # print(self.denseblock)
55 | feats = i(torch.cat((*out_feats, down_feats), dim=1))
56 | # print(feats.shape)
57 | out_feats.append(feats)
58 |
59 | feats = torch.cat((in_feat, feats), dim=1)
60 | return self.fuse(feats)
61 |
62 |
63 | class GEFM(nn.Module):
64 | def __init__(self, in_C, out_C):
65 | super(GEFM, self).__init__()
66 | self.RGB_K = BBasicConv2d(out_C, out_C, 3, 1, 1)
67 | self.RGB_V = BBasicConv2d(out_C, out_C, 3, 1, 1)
68 | self.Q = BBasicConv2d(in_C, out_C, 3, 1, 1)
69 | self.INF_K = BBasicConv2d(out_C, out_C, 3, 1, 1)
70 | self.INF_V = BBasicConv2d(out_C, out_C, 3, 1, 1)
71 | self.Second_reduce = BBasicConv2d(in_C, out_C, 3, 1, 1)
72 | self.gamma1 = nn.Parameter(torch.zeros(1))
73 | self.gamma2 = nn.Parameter(torch.zeros(1))
74 | self.softmax = nn.Softmax(dim=-1)
75 |
76 | def forward(self, x, y):
77 | Q = self.Q(torch.cat([x, y], dim=1))
78 | RGB_K = self.RGB_K(x)
79 | RGB_V = self.RGB_V(x)
80 | m_batchsize, C, height, width = RGB_V.size()
81 | RGB_V = RGB_V.view(m_batchsize, -1, width * height)
82 | RGB_K = RGB_K.view(m_batchsize, -1, width * height).permute(0, 2, 1)
83 | RGB_Q = Q.view(m_batchsize, -1, width * height)
84 | RGB_mask = torch.bmm(RGB_K, RGB_Q)
85 | RGB_mask = self.softmax(RGB_mask)
86 | RGB_refine = torch.bmm(RGB_V, RGB_mask.permute(0, 2, 1))
87 | RGB_refine = RGB_refine.view(m_batchsize, -1, height, width)
88 | RGB_refine = self.gamma1 * RGB_refine + y
89 |
90 | INF_K = self.INF_K(y)
91 | INF_V = self.INF_V(y)
92 | INF_V = INF_V.view(m_batchsize, -1, width * height)
93 | INF_K = INF_K.view(m_batchsize, -1, width * height).permute(0, 2, 1)
94 | INF_Q = Q.view(m_batchsize, -1, width * height)
95 | INF_mask = torch.bmm(INF_K, INF_Q)
96 | INF_mask = self.softmax(INF_mask)
97 | INF_refine = torch.bmm(INF_V, INF_mask.permute(0, 2, 1))
98 | INF_refine = INF_refine.view(m_batchsize, -1, height, width)
99 | INF_refine = self.gamma2 * INF_refine + x
100 |
101 | out = self.Second_reduce(torch.cat([RGB_refine, INF_refine], dim=1))
102 | return out
103 |
104 |
105 | class PSFM(nn.Module):
106 | def __init__(self, in_C, out_C, cat_C):
107 | super(PSFM, self).__init__()
108 | self.RGBobj = DenseLayer(in_C, out_C)
109 | self.Infobj = DenseLayer(in_C, out_C)
110 | self.obj_fuse = GEFM(cat_C, out_C)
111 |
112 | def forward(self, rgb, depth):
113 | rgb_sum = self.RGBobj(rgb)
114 | Inf_sum = self.Infobj(depth)
115 | out = self.obj_fuse(rgb_sum, Inf_sum)
116 | return out
--------------------------------------------------------------------------------
/特征提取or融合or对齐模块/SMFA.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from feature_show import feature_show
5 |
6 |
7 | class DMlp(nn.Module):
8 | def __init__(self, dim, growth_rate=2.0):
9 | super().__init__()
10 | hidden_dim = int(dim * growth_rate)
11 | self.conv_0 = nn.Sequential(
12 | nn.Conv2d(dim,hidden_dim,3,1,1,groups=dim),
13 | nn.Conv2d(hidden_dim,hidden_dim,1,1,0)
14 | )
15 | self.act = nn.GELU()
16 | self.conv_1 = nn.Conv2d(hidden_dim, dim, 1, 1, 0)
17 |
18 | def forward(self, x):
19 | x = self.conv_0(x)
20 | x = self.act(x)
21 | x = self.conv_1(x)
22 | return x
23 |
24 | class SMFA(nn.Module):
25 | def __init__(self, dim=36, id = 0):
26 | super(SMFA, self).__init__()
27 |
28 | self.id = id
29 |
30 | self.linear_0 = nn.Conv2d(dim,dim*2,1,1,0)
31 | self.linear_1 = nn.Conv2d(dim,dim,1,1,0)
32 | self.linear_2 = nn.Conv2d(dim,dim,1,1,0)
33 |
34 | self.lde = DMlp(dim,2)
35 |
36 | self.dw_conv = nn.Conv2d(dim,dim,3,1,1,groups=dim)
37 |
38 | self.gelu = nn.GELU()
39 | self.down_scale = 8
40 |
41 | self.alpha = nn.Parameter(torch.ones((1,dim,1,1)))
42 | self.belt = nn.Parameter(torch.zeros((1,dim,1,1)))
43 |
44 | def forward(self, f):
45 | feature_show(f, f"{self.id}_input")
46 | _,_,h,w = f.shape
47 |
48 | y, x = self.linear_0(f).chunk(2, dim=1)
49 | x_s = self.dw_conv(F.adaptive_max_pool2d(x, (h // self.down_scale, w // self.down_scale)))
50 | x_v = torch.var(x, dim=(-2,-1), keepdim=True)
51 | x_l = x * F.interpolate(self.gelu(self.linear_1(x_s * self.alpha + x_v * self.belt)), size=(h,w), mode='nearest')
52 | y_d = self.lde(y)
53 | out = self.linear_2(x_l + y_d)
54 |
55 | feature_show(x_l, f"{self.id}_easa")
56 | feature_show(y_d, f"{self.id}_lde")
57 | feature_show(out, f"{self.id}_output")
58 |
59 | return out
--------------------------------------------------------------------------------
/特征提取or融合or对齐模块/SSFF.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def autopad(k, p=None, d=1): # kernel, padding, dilation
7 | # Pad to 'same' shape outputs
8 | if d > 1:
9 | k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
10 | if p is None:
11 | p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
12 | return p
13 |
14 | class Conv(nn.Module):
15 | # Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)
16 | default_act = nn.SiLU() # default activation
17 |
18 | def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
19 | super().__init__()
20 | self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
21 | self.bn = nn.BatchNorm2d(c2)
22 | self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
23 |
24 | def forward(self, x):
25 | return self.act(self.bn(self.conv(x)))
26 |
27 | def forward_fuse(self, x):
28 | return self.act(self.conv(x))
29 |
30 | class Zoom_cat(nn.Module):
31 | def __init__(self, in_dim):
32 | super().__init__()
33 | #self.conv_l_post_down = Conv(in_dim, 2*in_dim, 3, 1, 1)
34 |
35 | def forward(self, x):
36 | """l,m,s表示大中小三个尺度,最终会被整合到m这个尺度上"""
37 | l, m, s = x[0], x[1], x[2]
38 | tgt_size = m.shape[2:]
39 | l = F.adaptive_max_pool2d(l, tgt_size) + F.adaptive_avg_pool2d(l, tgt_size)
40 | #l = self.conv_l_post_down(l)
41 | # m = self.conv_m(m)
42 | # s = self.conv_s_pre_up(s)
43 | s = F.interpolate(s, m.shape[2:], mode='nearest')
44 | # s = self.conv_s_post_up(s)
45 | lms = torch.cat([l, m, s], dim=1)
46 | return lms
47 |
48 |
49 | class SSFF(nn.Module):
50 | def __init__(self, channel):
51 | super(SSFF, self).__init__()
52 | self.conv1 = Conv(512, channel,1)
53 | self.conv2 = Conv(1024, channel,1)
54 | self.conv3d = nn.Conv3d(channel,channel,kernel_size=(1,1,1))
55 | self.bn = nn.BatchNorm3d(channel)
56 | self.act = nn.LeakyReLU(0.1)
57 | self.pool_3d = nn.MaxPool3d(kernel_size=(3,1,1))
58 |
59 | def forward(self, x):
60 | p3, p4, p5 = x[0],x[1],x[2]
61 | p4_2 = self.conv1(p4)
62 | p4_2 = F.interpolate(p4_2, p3.size()[2:], mode='nearest')
63 | p5_2 = self.conv2(p5)
64 | p5_2 = F.interpolate(p5_2, p3.size()[2:], mode='nearest')
65 | p3_3d = torch.unsqueeze(p3, -3)
66 | p4_3d = torch.unsqueeze(p4_2, -3)
67 | p5_3d = torch.unsqueeze(p5_2, -3)
68 | combine = torch.cat([p3_3d,p4_3d,p5_3d],dim = 2)
69 | conv_3d = self.conv3d(combine)
70 | bn = self.bn(conv_3d)
71 | act = self.act(bn)
72 | x = self.pool_3d(act)
73 | x = torch.squeeze(x, 2)
74 | return x
--------------------------------------------------------------------------------