├── Conv Modules for improve YOLO ├── YOLOV8-SPDConv.py ├── YOLOv7-ScConv.py └── YOLOv8-DySnakeConv.py ├── Attention Modules for improve YOLO ├── SimAM.py ├── ECA.py ├── SE.py ├── EMA.py ├── MHSA.py ├── CAM.py ├── CBAM.py ├── Focused Linear Attention.py └── BiFormer.py ├── README.md ├── SPP空间金字塔池化系列模块 ├── YOLOV8-SPPFCSPC.py ├── YOLOV8-ASPP.py └── SimSPPF.py ├── 更换YOLO系列中的IOU损失 └── MPDIoU.py └── Some modules for improve YOLO ├── YOLOV7-AIFI.py └── YOLOV8-AFPN.py /Conv Modules for improve YOLO/YOLOV8-SPDConv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class SPDConv(nn.Module): 5 | # Changing the dimension of the Tensor 6 | def __init__(self, dimension=1): 7 | super().__init__() 8 | self.d = dimension 9 | def forward(self, x): 10 | return torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1) 11 | 12 | 13 | if __name__ == '__main__': 14 | 15 | input = torch.randn(1, 128, 8, 8) 16 | dsconv = SPDConv() 17 | output = dsconv(input) 18 | print(output.shape) 19 | -------------------------------------------------------------------------------- /Attention Modules for improve YOLO/SimAM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SimAM(torch.nn.Module): 6 | def __init__(self, e_lambda=1e-4): 7 | super(SimAM, self).__init__() 8 | 9 | self.activaton = nn.Sigmoid() 10 | self.e_lambda = e_lambda 11 | 12 | def __repr__(self): 13 | s = self.__class__.__name__ + '(' 14 | s += ('lambda=%f)' % self.e_lambda) 15 | return s 16 | 17 | @staticmethod 18 | def get_module_name(): 19 | return "simam" 20 | 21 | def forward(self, x): 22 | b, c, h, w = x.size() 23 | 24 | n = w * h - 1 25 | 26 | x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2) 27 | y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5 28 | 29 | return x * self.activaton(y) 30 | 31 | if __name__ == '__main__': 32 | input = torch.randn(1, 128, 16, 16) 33 | att = SimAM() 34 | outputs = att(input) 35 | print(outputs.shape) 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # object-detect 2 | 本仓库存放的是目标检测YOLO系列的一些代码以及改进模块的代码实现,需要的小伙伴自取就可以啦~ 3 | 4 | 文件夹Attention Modules for improve YOLO中存放了不同的注意力机制 5 | 文件夹Conv Modules for improve YOLO中存放了不同的卷积模块 6 | 文件夹Different IoU for improve YOLO中存放了不同的IoU损失函数 7 | 8 | YOLOV8模型改进之AFPN:https://www.bilibili.com/video/BV1Ea4y1Q7qc/?spm_id_from=333.999.0.0 9 | 10 | YOLOV8模型改进之SPDConv:https://www.bilibili.com/video/BV1iu4y1777u/?spm_id_from=333.999.0.0&vd_source=14cd0464a0d319ab6d156ba89adc03dd 11 | 12 | YOLOV8模型代码详解:https://www.bilibili.com/video/BV1uM411X7n4/?spm_id_from=333.999.0.0 13 | 14 | YOLOV7模型代码详解:https://www.bilibili.com/video/BV1jw411X79C/?spm_id_from=333.999.0.0 15 | 16 | YOLOV5模型代码详解:https://www.bilibili.com/video/BV1T94y1j7dr/?spm_id_from=333.999.0.0 17 | 18 | YOLOV8模型改进之Focused Linear Attention:https://www.bilibili.com/video/BV1cC4y1J7VL/?spm_id_from=333.999.0.0&vd_source=14cd0464a0d319ab6d156ba89adc03dd 19 | 20 | YOLOV7模型改进之Focused Linear Attention:https://www.bilibili.com/video/BV1nG411X7bS/?spm_id_from=333.999.0.0&vd_source=14cd0464a0d319ab6d156ba89adc03dd 21 | 22 | YOLOV7模型改进之AIFI:https://www.bilibili.com/video/BV1BH4y1z7it/?spm_id_from=333.999.0.0&vd_source=14cd0464a0d319ab6d156ba89adc03dd 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /SPP空间金字塔池化系列模块/YOLOV8-SPPFCSPC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | ####### SPPFCSPC ##### 6 | class SPPFCSPC(nn.Module): 7 | 8 | def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=5): 9 | super(SPPFCSPC, self).__init__() 10 | c_ = int(2 * c2 * e) # hidden channels 11 | self.cv1 = Conv(c1, c_, 1, 1) 12 | self.cv2 = Conv(c1, c_, 1, 1) 13 | self.cv3 = Conv(c_, c_, 3, 1) 14 | self.cv4 = Conv(c_, c_, 1, 1) 15 | self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) 16 | self.cv5 = Conv(4 * c_, c_, 1, 1) 17 | self.cv6 = Conv(c_, c_, 3, 1) 18 | self.cv7 = Conv(2 * c_, c2, 1, 1) 19 | 20 | def forward(self, x): 21 | x1 = self.cv4(self.cv3(self.cv1(x))) 22 | x2 = self.m(x1) 23 | x3 = self.m(x2) 24 | y1 = self.cv6(self.cv5(torch.cat((x1, x2, x3, self.m(x3)), 1))) 25 | y2 = self.cv2(x) 26 | return self.cv7(torch.cat((y1, y2), dim=1)) 27 | ####### end of SPPFCSPC ##### 28 | 29 | if __name__ == '__main__': 30 | x = torch.randn(1, 256, 16, 16) 31 | model = SPPFCSPC(256, 256) 32 | print(model(x).shape) 33 | -------------------------------------------------------------------------------- /Attention Modules for improve YOLO/ECA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import init 4 | from collections import OrderedDict 5 | 6 | 7 | class ECA(nn.Module): 8 | 9 | def __init__(self, kernel_size=3): 10 | super().__init__() 11 | self.gap = nn.AdaptiveAvgPool2d(1) 12 | self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2) 13 | self.sigmoid = nn.Sigmoid() 14 | 15 | def init_weights(self): 16 | for m in self.modules(): 17 | if isinstance(m, nn.Conv2d): 18 | init.kaiming_normal_(m.weight, mode='fan_out') 19 | if m.bias is not None: 20 | init.constant_(m.bias, 0) 21 | elif isinstance(m, nn.BatchNorm2d): 22 | init.constant_(m.weight, 1) 23 | init.constant_(m.bias, 0) 24 | elif isinstance(m, nn.Linear): 25 | init.normal_(m.weight, std=0.001) 26 | if m.bias is not None: 27 | init.constant_(m.bias, 0) 28 | 29 | def forward(self, x): 30 | y = self.gap(x) # bs,c,1,1 31 | y = y.squeeze(-1).permute(0, 2, 1) # bs,1,c 32 | y = self.conv(y) # bs,1,c 33 | y = self.sigmoid(y) # bs,1,c 34 | y = y.permute(0, 2, 1).unsqueeze(-1) # bs,c,1,1 35 | return x * y.expand_as(x) 36 | 37 | if __name__ == '__main__': 38 | input = torch.randn(1, 128, 16, 16) 39 | att = ECA(kernel_size=3) 40 | output = att(input) 41 | print(output.shape) 42 | -------------------------------------------------------------------------------- /Attention Modules for improve YOLO/SE.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | 7 | 8 | class SEAttention(nn.Module): 9 | 10 | def __init__(self, channel=512,reduction=16): 11 | super().__init__() 12 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 13 | self.fc = nn.Sequential( 14 | nn.Linear(channel, channel // reduction, bias=False), 15 | nn.ReLU(inplace=True), 16 | nn.Linear(channel // reduction, channel, bias=False), 17 | nn.Sigmoid() 18 | ) 19 | 20 | 21 | def init_weights(self): 22 | for m in self.modules(): 23 | if isinstance(m, nn.Conv2d): 24 | init.kaiming_normal_(m.weight, mode='fan_out') 25 | if m.bias is not None: 26 | init.constant_(m.bias, 0) 27 | elif isinstance(m, nn.BatchNorm2d): 28 | init.constant_(m.weight, 1) 29 | init.constant_(m.bias, 0) 30 | elif isinstance(m, nn.Linear): 31 | init.normal_(m.weight, std=0.001) 32 | if m.bias is not None: 33 | init.constant_(m.bias, 0) 34 | 35 | def forward(self, x): 36 | b, c, _, _ = x.size() 37 | y = self.avg_pool(x).view(b, c) 38 | y = self.fc(y).view(b, c, 1, 1) 39 | return x * y.expand_as(x) 40 | 41 | if __name__ == '__main__': 42 | input=torch.randn(1,128,7,7) 43 | att = SEAttention(channel=128,reduction=8) 44 | output=att(input) 45 | print(output.shape) 46 | -------------------------------------------------------------------------------- /SPP空间金字塔池化系列模块/YOLOV8-ASPP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | class ASPP(nn.Module): 6 | def __init__(self, in_channel=512, out_channel=256): 7 | super(ASPP, self).__init__() 8 | self.mean = nn.AdaptiveAvgPool2d((1, 1)) # (1,1)means ouput_dim 9 | self.conv = nn.Conv2d(in_channel, out_channel, 1, 1) 10 | self.atrous_block1 = nn.Conv2d(in_channel, out_channel, 1, 1) 11 | self.atrous_block6 = nn.Conv2d(in_channel, out_channel, 3, 1, padding=6, dilation=6) 12 | self.atrous_block12 = nn.Conv2d(in_channel, out_channel, 3, 1, padding=12, dilation=12) 13 | self.atrous_block18 = nn.Conv2d(in_channel, out_channel, 3, 1, padding=18, dilation=18) 14 | self.conv_1x1_output = nn.Conv2d(out_channel * 5, out_channel, 1, 1) 15 | 16 | def forward(self, x): 17 | size = x.shape[2:] 18 | 19 | image_features = self.mean(x) 20 | image_features = self.conv(image_features) 21 | image_features = F.upsample(image_features, size=size, mode='bilinear') 22 | 23 | atrous_block1 = self.atrous_block1(x) 24 | atrous_block6 = self.atrous_block6(x) 25 | atrous_block12 = self.atrous_block12(x) 26 | atrous_block18 = self.atrous_block18(x) 27 | 28 | net = self.conv_1x1_output(torch.cat([image_features, atrous_block1, atrous_block6, 29 | atrous_block12, atrous_block18], dim=1)) 30 | return net 31 | 32 | 33 | if __name__ == '__main__': 34 | x = torch.randn(1, 256, 16, 16) 35 | model = ASPP(256, 256) 36 | print(model(x).shape) 37 | -------------------------------------------------------------------------------- /SPP空间金字塔池化系列模块/SimSPPF.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | 6 | class SimConv(nn.Module): 7 | '''Normal Conv with ReLU activation''' 8 | def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1, bias=False): 9 | super().__init__() 10 | padding = kernel_size // 2 11 | self.conv = nn.Conv2d( 12 | in_channels, 13 | out_channels, 14 | kernel_size=kernel_size, 15 | stride=stride, 16 | padding=padding, 17 | groups=groups, 18 | bias=bias, 19 | ) 20 | self.bn = nn.BatchNorm2d(out_channels) 21 | self.act = nn.ReLU() 22 | 23 | def forward(self, x): 24 | return self.act(self.bn(self.conv(x))) 25 | 26 | def forward_fuse(self, x): 27 | return self.act(self.conv(x)) 28 | 29 | class SimSPPF(nn.Module): 30 | '''Simplified SPPF with ReLU activation''' 31 | def __init__(self, in_channels, out_channels, kernel_size=5): 32 | super().__init__() 33 | c_ = in_channels // 2 # hidden channels 34 | self.cv1 = SimConv(in_channels, c_, 1, 1) 35 | self.cv2 = SimConv(c_ * 4, out_channels, 1, 1) 36 | self.m = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=kernel_size // 2) 37 | 38 | def forward(self, x): 39 | x = self.cv1(x) 40 | with warnings.catch_warnings(): 41 | warnings.simplefilter('ignore') 42 | y1 = self.m(x) 43 | y2 = self.m(y1) 44 | return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1)) 45 | 46 | if __name__ == '__main__': 47 | x = torch.randn(1, 256, 16, 16) 48 | model = SimSPPF(256, 256) 49 | print(model(x).shape) 50 | -------------------------------------------------------------------------------- /更换YOLO系列中的IOU损失/MPDIoU.py: -------------------------------------------------------------------------------- 1 | ################ MPDIoU ################ 2 | def bbox_mpdiou(box1, box2, x1y1x2y2=True, mpdiou_hw=None, grid=None, eps=1e-7): 3 | # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4 4 | box2 = box2.T 5 | box1[:2] += grid 6 | box2[:2] += grid 7 | 8 | # Get the coordinates of bounding boxes 9 | if x1y1x2y2: # x1, y1, x2, y2 = box1 10 | b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] 11 | b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] 12 | else: # transform from xywh to xyxy 13 | b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2 14 | b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2 15 | b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2 16 | b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2 17 | 18 | # Intersection area 19 | inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \ 20 | (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0) 21 | 22 | # Union Area 23 | w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps 24 | w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps 25 | union = w1 * h1 + w2 * h2 - inter + eps 26 | 27 | iou = inter / union 28 | d1 = (b2_x1 - b1_x1) ** 2 + (b2_y1 - b1_y1) ** 2 29 | d2 = (b2_x2 - b1_x2) ** 2 + (b2_y2 - b1_y2) ** 2 30 | return iou - d1 / mpdiou_hw - d2 / mpdiou_hw 31 | ################ the end of MPDIoU ################ 32 | 33 | # ComputeLoss (0) 34 | iou = bbox_mpdiou(pbox.T, tbox[i], x1y1x2y2=False, mpdiou_hw=pi.size(2) ** 2 + pi.size(3) ** 2, grid=torch.stack([gj, gi])) # iou(prediction, target) 35 | 36 | # ComputeLossOTA (1) 37 | iou = bbox_mpdiou(pbox.T, selected_tbox, x1y1x2y2=False, mpdiou_hw=pi.size(2) ** 2 + pi.size(3) ** 2, grid=torch.stack([gj, gi])) # iou(prediction, target) 38 | -------------------------------------------------------------------------------- /Attention Modules for improve YOLO/EMA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class EMA(nn.Module): 5 | def __init__(self, channels, factor=8): 6 | super(EMA, self).__init__() 7 | self.groups = factor 8 | assert channels // self.groups > 0 9 | self.softmax = nn.Softmax(-1) 10 | self.agp = nn.AdaptiveAvgPool2d((1, 1)) 11 | self.pool_h = nn.AdaptiveAvgPool2d((None, 1)) 12 | self.pool_w = nn.AdaptiveAvgPool2d((1, None)) 13 | self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups) 14 | self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0) 15 | self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1) 16 | 17 | def forward(self, x): 18 | b, c, h, w = x.size() 19 | group_x = x.reshape(b * self.groups, -1, h, w) # b*g,c//g,h,w 20 | x_h = self.pool_h(group_x) 21 | x_w = self.pool_w(group_x).permute(0, 1, 3, 2) 22 | hw = self.conv1x1(torch.cat([x_h, x_w], dim=2)) 23 | x_h, x_w = torch.split(hw, [h, w], dim=2) 24 | x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid()) 25 | x2 = self.conv3x3(group_x) 26 | x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1)) 27 | x12 = x2.reshape(b * self.groups, c // self.groups, -1) # b*g, c//g, hw 28 | x21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1)) 29 | x22 = x1.reshape(b * self.groups, c // self.groups, -1) # b*g, c//g, hw 30 | weights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w) 31 | return (group_x * weights.sigmoid()).reshape(b, c, h, w) 32 | 33 | if __name__ == '__main__': 34 | input = torch.randn(1, 128, 64, 64) 35 | ema = EMA(128) 36 | output = ema(input) 37 | print(output.shape) 38 | -------------------------------------------------------------------------------- /Attention Modules for improve YOLO/MHSA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class MHSA(nn.Module): 5 | def __init__(self, n_dims, width=14, height=14, heads=1, pos_emb=False): 6 | super(MHSA, self).__init__() 7 | 8 | self.heads = heads 9 | self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1) 10 | self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1) 11 | self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1) 12 | self.pos = pos_emb 13 | if self.pos: 14 | self.rel_h_weight = nn.Parameter(torch.randn([1, heads, (n_dims) // heads, 1, int(height)]), 15 | requires_grad=True) 16 | self.rel_w_weight = nn.Parameter(torch.randn([1, heads, (n_dims) // heads, int(width), 1]), 17 | requires_grad=True) 18 | self.softmax = nn.Softmax(dim=-1) 19 | 20 | def forward(self, x): 21 | n_batch, C, width, height = x.size() 22 | q = self.query(x).view(n_batch, self.heads, C // self.heads, -1) 23 | k = self.key(x).view(n_batch, self.heads, C // self.heads, -1) 24 | v = self.value(x).view(n_batch, self.heads, C // self.heads, -1) 25 | content_content = torch.matmul(q.permute(0, 1, 3, 2), k) # 1,C,h*w,h*w 26 | c1, c2, c3, c4 = content_content.size() 27 | if self.pos: 28 | content_position = (self.rel_h_weight + self.rel_w_weight).view(1, self.heads, C // self.heads, -1).permute( 29 | 0, 1, 3, 2) # 1,4,1024,64 30 | 31 | content_position = torch.matmul(content_position, q) # ([1, 4, 1024, 256]) 32 | content_position = content_position if ( 33 | content_content.shape == content_position.shape) else content_position[:, :, :c3, ] 34 | assert (content_content.shape == content_position.shape) 35 | energy = content_content + content_position 36 | else: 37 | energy = content_content 38 | attention = self.softmax(energy) 39 | out = torch.matmul(v, attention.permute(0, 1, 3, 2)) # 1,4,256,64 40 | out = out.view(n_batch, C, width, height) 41 | return out 42 | 43 | if __name__ == '__main__': 44 | input = torch.randn(50, 128, 7, 7) 45 | mhsa = MHSA(n_dims=128) 46 | output = mhsa(input) 47 | print(output.shape) 48 | -------------------------------------------------------------------------------- /Attention Modules for improve YOLO/CAM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def autopad(k, p=None): # kernel, padding 6 | # Pad to 'same' 7 | if p is None: 8 | p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad 9 | return p 10 | 11 | 12 | class Conv(nn.Module): 13 | # Standard convolution 14 | def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups 15 | super(Conv, self).__init__() 16 | self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) 17 | self.bn = nn.BatchNorm2d(c2) 18 | self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) 19 | 20 | def forward(self, x): 21 | return self.act(self.bn(self.conv(x))) 22 | 23 | def fuseforward(self, x): 24 | return self.act(self.conv(x)) 25 | 26 | class CAM(nn.Module): 27 | def __init__(self, inc, fusion='weight'): 28 | super().__init__() 29 | 30 | assert fusion in ['weight', 'adaptive', 'concat'] 31 | self.fusion = fusion 32 | 33 | self.conv1 = Conv(inc, inc, 3, 1, None, 1, 1) 34 | self.conv2 = Conv(inc, inc, 3, 1, None, 1, 3) 35 | self.conv3 = Conv(inc, inc, 3, 1, None, 1, 5) 36 | 37 | self.fusion_1 = Conv(inc, inc, 1) 38 | self.fusion_2 = Conv(inc, inc, 1) 39 | self.fusion_3 = Conv(inc, inc, 1) 40 | 41 | if self.fusion == 'adaptive': 42 | self.fusion_4 = Conv(inc * 3, 3, 1) 43 | 44 | def forward(self, x): 45 | x1 = self.conv1(x) 46 | x2 = self.conv2(x) 47 | x3 = self.conv3(x) 48 | 49 | if self.fusion == 'weight': 50 | return self.fusion_1(x1) + self.fusion_2(x2) + self.fusion_3(x3) 51 | elif self.fusion == 'adaptive': 52 | fusion = torch.softmax(self.fusion_4(torch.cat([self.fusion_1(x1), self.fusion_2(x2), self.fusion_3(x3)], dim=1)), dim=1) 53 | x1_weight, x2_weight, x3_weight = torch.split(fusion, [1, 1, 1], dim=1) 54 | return x1 * x1_weight + x2 * x2_weight + x3 * x3_weight 55 | else: 56 | return torch.cat([self.fusion_1(x1), self.fusion_2(x2), self.fusion_3(x3)], dim=1) 57 | 58 | 59 | if __name__ == '__main__': 60 | input = torch.randn(1, 128, 64, 64) 61 | cam = CAM(128) 62 | output = cam(input) 63 | print(output.shape) 64 | -------------------------------------------------------------------------------- /Attention Modules for improve YOLO/CBAM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import init 4 | 5 | 6 | class ChannelAttention(nn.Module): 7 | def __init__(self, channel, reduction=16): 8 | super().__init__() 9 | self.maxpool = nn.AdaptiveMaxPool2d(1) 10 | self.avgpool = nn.AdaptiveAvgPool2d(1) 11 | self.se = nn.Sequential( 12 | nn.Conv2d(channel, channel // reduction, 1, bias=False), 13 | nn.ReLU(), 14 | nn.Conv2d(channel // reduction, channel, 1, bias=False) 15 | ) 16 | self.sigmoid = nn.Sigmoid() 17 | 18 | def forward(self, x): 19 | max_result = self.maxpool(x) 20 | avg_result = self.avgpool(x) 21 | max_out = self.se(max_result) 22 | avg_out = self.se(avg_result) 23 | output = self.sigmoid(max_out + avg_out) 24 | return output 25 | 26 | 27 | class SpatialAttention(nn.Module): 28 | def __init__(self, kernel_size=7): 29 | super().__init__() 30 | self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2) 31 | self.sigmoid = nn.Sigmoid() 32 | 33 | def forward(self, x): 34 | max_result, _ = torch.max(x, dim=1, keepdim=True) 35 | avg_result = torch.mean(x, dim=1, keepdim=True) 36 | result = torch.cat([max_result, avg_result], 1) 37 | output = self.conv(result) 38 | output = self.sigmoid(output) 39 | return output 40 | 41 | 42 | class CBAMBlock(nn.Module): 43 | 44 | def __init__(self, channel=512, reduction=16, kernel_size=7): 45 | super().__init__() 46 | self.ca = ChannelAttention(channel=channel, reduction=reduction) 47 | self.sa = SpatialAttention(kernel_size=kernel_size) 48 | 49 | def init_weights(self): 50 | for m in self.modules(): 51 | if isinstance(m, nn.Conv2d): 52 | init.kaiming_normal_(m.weight, mode='fan_out') 53 | if m.bias is not None: 54 | init.constant_(m.bias, 0) 55 | elif isinstance(m, nn.BatchNorm2d): 56 | init.constant_(m.weight, 1) 57 | init.constant_(m.bias, 0) 58 | elif isinstance(m, nn.Linear): 59 | init.normal_(m.weight, std=0.001) 60 | if m.bias is not None: 61 | init.constant_(m.bias, 0) 62 | 63 | def forward(self, x): 64 | b, c, _, _ = x.size() 65 | residual = x 66 | out = x * self.ca(x) 67 | out = out * self.sa(out) 68 | out = out + residual 69 | return out 70 | 71 | if __name__ == '__main__': 72 | input = torch.randn(1, 128, 16, 16) 73 | cbam = CBAMBlock(channel=128, reduction=16) 74 | output = cbam(input) 75 | print(output.shape) 76 | -------------------------------------------------------------------------------- /Some modules for improve YOLO/YOLOV7-AIFI.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class TransformerEncoderLayer(nn.Module): 6 | """Defines a single layer of the transformer encoder.""" 7 | 8 | def __init__(self, c1, cm=2048, num_heads=8, dropout=0.0, act=nn.GELU(), normalize_before=False): 9 | """Initialize the TransformerEncoderLayer with specified parameters.""" 10 | super().__init__() 11 | # from ...utils.torch_utils import TORCH_1_9 12 | # if not TORCH_1_9: 13 | # raise ModuleNotFoundError( 14 | # 'TransformerEncoderLayer() requires torch>=1.9 to use nn.MultiheadAttention(batch_first=True).') 15 | self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True) 16 | # Implementation of Feedforward model 17 | self.fc1 = nn.Linear(c1, cm) 18 | self.fc2 = nn.Linear(cm, c1) 19 | 20 | self.norm1 = nn.LayerNorm(c1) 21 | self.norm2 = nn.LayerNorm(c1) 22 | self.dropout = nn.Dropout(dropout) 23 | self.dropout1 = nn.Dropout(dropout) 24 | self.dropout2 = nn.Dropout(dropout) 25 | 26 | self.act = act 27 | self.normalize_before = normalize_before 28 | 29 | @staticmethod 30 | def with_pos_embed(tensor, pos=None): 31 | """Add position embeddings to the tensor if provided.""" 32 | return tensor if pos is None else tensor + pos 33 | 34 | def forward_post(self, src, src_mask=None, src_key_padding_mask=None, pos=None): 35 | """Performs forward pass with post-normalization.""" 36 | q = k = self.with_pos_embed(src, pos) 37 | src2 = self.ma(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] 38 | src = src + self.dropout1(src2) 39 | src = self.norm1(src) 40 | src2 = self.fc2(self.dropout(self.act(self.fc1(src)))) 41 | src = src + self.dropout2(src2) 42 | return self.norm2(src) 43 | 44 | def forward_pre(self, src, src_mask=None, src_key_padding_mask=None, pos=None): 45 | """Performs forward pass with pre-normalization.""" 46 | src2 = self.norm1(src) 47 | q = k = self.with_pos_embed(src2, pos) 48 | src2 = self.ma(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] 49 | src = src + self.dropout1(src2) 50 | src2 = self.norm2(src) 51 | src2 = self.fc2(self.dropout(self.act(self.fc1(src2)))) 52 | return src + self.dropout2(src2) 53 | 54 | def forward(self, src, src_mask=None, src_key_padding_mask=None, pos=None): 55 | """Forward propagates the input through the encoder module.""" 56 | if self.normalize_before: 57 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 58 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 59 | 60 | 61 | class AIFI(TransformerEncoderLayer): 62 | """Defines the AIFI transformer layer.""" 63 | 64 | def __init__(self, c1, cm=2048, num_heads=8, dropout=0, act=nn.GELU(), normalize_before=False): 65 | """Initialize the AIFI instance with specified parameters.""" 66 | super().__init__(c1, cm, num_heads, dropout, act, normalize_before) 67 | 68 | def forward(self, x): 69 | """Forward pass for the AIFI transformer layer.""" 70 | c, h, w = x.shape[1:] 71 | pos_embed = self.build_2d_sincos_position_embedding(w, h, c) 72 | # Flatten [B, C, H, W] to [B, HxW, C] 73 | x = super().forward(x.flatten(2).permute(0, 2, 1), pos=pos_embed.to(device=x.device, dtype=x.dtype)) 74 | return x.permute(0, 2, 1).view([-1, c, h, w]).contiguous() 75 | 76 | @staticmethod 77 | def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.0): 78 | """Builds 2D sine-cosine position embedding.""" 79 | grid_w = torch.arange(int(w), dtype=torch.float32) 80 | grid_h = torch.arange(int(h), dtype=torch.float32) 81 | grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing='ij') 82 | assert embed_dim % 4 == 0, \ 83 | 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' 84 | pos_dim = embed_dim // 4 85 | omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim 86 | omega = 1. / (temperature ** omega) 87 | 88 | out_w = grid_w.flatten()[..., None] @ omega[None] 89 | out_h = grid_h.flatten()[..., None] @ omega[None] 90 | 91 | return torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], 1)[None] 92 | 93 | if __name__ == '__main__': 94 | x = torch.randn(1, 256, 8, 8) 95 | model = AIFI(256, 256) 96 | output = model(x) 97 | print(output.shape) 98 | -------------------------------------------------------------------------------- /Attention Modules for improve YOLO/Focused Linear Attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | 5 | 6 | class FocusedLinearAttention(nn.Module): 7 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 8 | It supports both of shifted and non-shifted window. 9 | 10 | Args: 11 | dim (int): Number of input channels. 12 | window_size (tuple[int]): The height and width of the window. 13 | num_heads (int): Number of attention heads. 14 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 15 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 16 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 17 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 18 | """ 19 | 20 | def __init__(self, dim, window_size=[20, 20], num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., 21 | focusing_factor=3, kernel_size=5): 22 | 23 | super().__init__() 24 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 25 | self.dim = dim 26 | self.num_heads = num_heads 27 | head_dim = dim // num_heads 28 | self.focusing_factor = focusing_factor 29 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 30 | self.attn_drop = nn.Dropout(attn_drop) 31 | self.proj = nn.Linear(dim, dim) 32 | self.proj_drop = nn.Dropout(proj_drop) 33 | self.window_size = window_size 34 | self.positional_encoding = nn.Parameter(torch.zeros(size=(1, window_size[0] * window_size[1], dim))) 35 | 36 | self.softmax = nn.Softmax(dim=-1) 37 | 38 | self.dwc = nn.Conv2d(in_channels=head_dim, out_channels=head_dim, kernel_size=kernel_size, 39 | groups=head_dim, padding=kernel_size // 2) 40 | self.scale = nn.Parameter(torch.zeros(size=(1, 1, dim))) 41 | 42 | def forward(self, x, mask=None): 43 | """ 44 | Args: 45 | x: input features with shape of (num_windows*B, N, C) 46 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 47 | """ 48 | # flatten: [B, C, H, W] -> [B, C, HW] 49 | # transpose: [B, C, HW] -> [B, HW, C] 50 | x = x.flatten(2).transpose(1, 2) 51 | B, N, C = x.shape 52 | qkv = self.qkv(x).reshape(B, N, 3, C).permute(2, 0, 1, 3) 53 | q, k, v = qkv.unbind(0) 54 | k = k + self.positional_encoding[:, :k.shape[1], :] 55 | focusing_factor = self.focusing_factor 56 | kernel_function = nn.ReLU() 57 | q = kernel_function(q) + 1e-6 58 | k = kernel_function(k) + 1e-6 59 | scale = nn.Softplus()(self.scale) 60 | q = q / scale 61 | k = k / scale 62 | q_norm = q.norm(dim=-1, keepdim=True) 63 | k_norm = k.norm(dim=-1, keepdim=True) 64 | if float(focusing_factor) <= 6: 65 | q = q ** focusing_factor 66 | k = k ** focusing_factor 67 | else: 68 | q = (q / q.max(dim=-1, keepdim=True)[0]) ** focusing_factor 69 | k = (k / k.max(dim=-1, keepdim=True)[0]) ** focusing_factor 70 | q = (q / q.norm(dim=-1, keepdim=True)) * q_norm 71 | k = (k / k.norm(dim=-1, keepdim=True)) * k_norm 72 | q, k, v = (rearrange(x, "b n (h c) -> (b h) n c", h=self.num_heads) for x in [q, k, v]) 73 | i, j, c, d = q.shape[-2], k.shape[-2], k.shape[-1], v.shape[-1] 74 | 75 | z = 1 / (torch.einsum("b i c, b c -> b i", q, k.sum(dim=1)) + 1e-6) 76 | if i * j * (c + d) > c * d * (i + j): 77 | kv = torch.einsum("b j c, b j d -> b c d", k, v) 78 | x = torch.einsum("b i c, b c d, b i -> b i d", q, kv, z) 79 | else: 80 | qk = torch.einsum("b i c, b j c -> b i j", q, k) 81 | x = torch.einsum("b i j, b j d, b i -> b i d", qk, v, z) 82 | 83 | num = int(v.shape[1] ** 0.5) 84 | feature_map = rearrange(v, "b (w h) c -> b c w h", w=num, h=num) 85 | feature_map = rearrange(self.dwc(feature_map), "b c w h -> b (w h) c") 86 | x = x + feature_map 87 | x = rearrange(x, "(b h) n c -> b n (h c)", h=self.num_heads) 88 | x = self.proj(x) 89 | x = self.proj_drop(x) 90 | x = rearrange(x, "b (w h) c -> b c w h", b=B, c=self.dim, w=num, h=num) 91 | return x 92 | 93 | if __name__ == '__main__': 94 | 95 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 96 | input = torch.randn(1, 128, 20, 20) 97 | siatt = FocusedLinearAttention(dim=128).to(device) 98 | output = siatt(input.to(device)) 99 | print(output.shape) 100 | -------------------------------------------------------------------------------- /Conv Modules for improve YOLO/YOLOv7-ScConv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | 6 | class GroupBatchnorm2d(nn.Module): 7 | def __init__(self, c_num: int, 8 | group_num: int = 16, 9 | eps: float = 1e-10 10 | ): 11 | super(GroupBatchnorm2d, self).__init__() 12 | assert c_num >= group_num 13 | self.group_num = group_num 14 | self.gamma = nn.Parameter(torch.randn(c_num, 1, 1)) 15 | self.beta = nn.Parameter(torch.zeros(c_num, 1, 1)) 16 | self.eps = eps 17 | 18 | def forward(self, x): 19 | N, C, H, W = x.size() 20 | x = x.view(N, self.group_num, -1) 21 | mean = x.mean(dim=2, keepdim=True) 22 | std = x.std(dim=2, keepdim=True) 23 | x = (x - mean) / (std + self.eps) 24 | x = x.view(N, C, H, W) 25 | return x * self.gamma + self.beta 26 | 27 | 28 | class SRU(nn.Module): 29 | def __init__(self, 30 | oup_channels: int, 31 | group_num: int = 16, 32 | gate_treshold: float = 0.5 33 | ): 34 | super().__init__() 35 | 36 | self.gn = GroupBatchnorm2d(oup_channels, group_num=group_num) 37 | self.gate_treshold = gate_treshold 38 | self.sigomid = nn.Sigmoid() 39 | 40 | def forward(self, x): 41 | gn_x = self.gn(x) 42 | w_gamma = F.softmax(self.gn.gamma, dim=0) 43 | reweigts = self.sigomid(gn_x * w_gamma) 44 | # Gate 45 | info_mask = reweigts > self.gate_treshold 46 | noninfo_mask = reweigts <= self.gate_treshold 47 | x_1 = info_mask * x 48 | x_2 = noninfo_mask * x 49 | x = self.reconstruct(x_1, x_2) 50 | return x 51 | 52 | def reconstruct(self, x_1, x_2): 53 | x_11, x_12 = torch.split(x_1, x_1.size(1) // 2, dim=1) 54 | x_21, x_22 = torch.split(x_2, x_2.size(1) // 2, dim=1) 55 | return torch.cat([x_11 + x_22, x_12 + x_21], dim=1) 56 | 57 | 58 | class CRU(nn.Module): 59 | ''' 60 | alpha: 0 None: 28 | super().__init__() 29 | 30 | self.conv_0 = Conv(inc, ouc, k, act=act) 31 | self.conv_x = DSConv(inc, ouc, 0, k, act=True) 32 | self.conv_y = DSConv(inc, ouc, 1, k, act=True) 33 | self.conv_1x1 = Conv(ouc * 3, ouc, 1, act=act) 34 | 35 | def forward(self, x): 36 | return self.conv_1x1(torch.cat([self.conv_0(x), self.conv_x(x), self.conv_y(x)], dim=1)) 37 | 38 | 39 | class DSConv(nn.Module): 40 | def __init__(self, in_ch, out_ch, morph, kernel_size=3, if_offset=True, extend_scope=1, act=True): 41 | """ 42 | The Dynamic Snake Convolution 43 | :param in_ch: input channel 44 | :param out_ch: output channel 45 | :param kernel_size: the size of kernel 46 | :param extend_scope: the range to expand (default 1 for this method) 47 | :param morph: the morphology of the convolution kernel is mainly divided into two types 48 | along the x-axis (0) and the y-axis (1) (see the paper for details) 49 | :param if_offset: whether deformation is required, if it is False, it is the standard convolution kernel 50 | """ 51 | super(DSConv, self).__init__() 52 | # use the to learn the deformable offset 53 | self.offset_conv = nn.Conv2d(in_ch, 2 * kernel_size, 3, padding=1) 54 | self.bn = nn.BatchNorm2d(2 * kernel_size) 55 | self.kernel_size = kernel_size 56 | 57 | # two types of the DSConv (along x-axis and y-axis) 58 | self.dsc_conv_x = nn.Conv2d( 59 | in_ch, 60 | out_ch, 61 | kernel_size=(kernel_size, 1), 62 | stride=(kernel_size, 1), 63 | padding=0, 64 | ) 65 | self.dsc_conv_y = nn.Conv2d( 66 | in_ch, 67 | out_ch, 68 | kernel_size=(1, kernel_size), 69 | stride=(1, kernel_size), 70 | padding=0, 71 | ) 72 | 73 | self.gn = nn.GroupNorm(out_ch // 4, out_ch) 74 | self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) 75 | 76 | self.extend_scope = extend_scope 77 | self.morph = morph 78 | self.if_offset = if_offset 79 | 80 | def forward(self, f): 81 | offset = self.offset_conv(f) 82 | offset = self.bn(offset) 83 | # We need a range of deformation between -1 and 1 to mimic the snake's swing 84 | offset = torch.tanh(offset) 85 | input_shape = f.shape 86 | dsc = DSC(input_shape, self.kernel_size, self.extend_scope, self.morph) 87 | deformed_feature = dsc.deform_conv(f, offset, self.if_offset) 88 | if self.morph == 0: 89 | x = self.dsc_conv_x(deformed_feature.type(f.dtype)) 90 | x = self.gn(x) 91 | x = self.act(x) 92 | return x 93 | else: 94 | x = self.dsc_conv_y(deformed_feature.type(f.dtype)) 95 | x = self.gn(x) 96 | x = self.act(x) 97 | return x 98 | 99 | 100 | # Core code, for ease of understanding, we mark the dimensions of input and output next to the code 101 | class DSC(object): 102 | def __init__(self, input_shape, kernel_size, extend_scope, morph): 103 | self.num_points = kernel_size 104 | self.width = input_shape[2] 105 | self.height = input_shape[3] 106 | self.morph = morph 107 | self.extend_scope = extend_scope # offset (-1 ~ 1) * extend_scope 108 | 109 | # define feature map shape 110 | """ 111 | B: Batch size C: Channel W: Width H: Height 112 | """ 113 | self.num_batch = input_shape[0] 114 | self.num_channels = input_shape[1] 115 | 116 | """ 117 | input: offset [B,2*K,W,H] K: Kernel size (2*K: 2D image, deformation contains and ) 118 | output_x: [B,1,W,K*H] coordinate map 119 | output_y: [B,1,K*W,H] coordinate map 120 | """ 121 | 122 | def _coordinate_map_3D(self, offset, if_offset): 123 | device = offset.device 124 | # offset 125 | y_offset, x_offset = torch.split(offset, self.num_points, dim=1) 126 | 127 | y_center = torch.arange(0, self.width).repeat([self.height]) 128 | y_center = y_center.reshape(self.height, self.width) 129 | y_center = y_center.permute(1, 0) 130 | y_center = y_center.reshape([-1, self.width, self.height]) 131 | y_center = y_center.repeat([self.num_points, 1, 1]).float() 132 | y_center = y_center.unsqueeze(0) 133 | 134 | x_center = torch.arange(0, self.height).repeat([self.width]) 135 | x_center = x_center.reshape(self.width, self.height) 136 | x_center = x_center.permute(0, 1) 137 | x_center = x_center.reshape([-1, self.width, self.height]) 138 | x_center = x_center.repeat([self.num_points, 1, 1]).float() 139 | x_center = x_center.unsqueeze(0) 140 | 141 | if self.morph == 0: 142 | """ 143 | Initialize the kernel and flatten the kernel 144 | y: only need 0 145 | x: -num_points//2 ~ num_points//2 (Determined by the kernel size) 146 | !!! The related PPT will be submitted later, and the PPT will contain the whole changes of each step 147 | """ 148 | y = torch.linspace(0, 0, 1) 149 | x = torch.linspace( 150 | -int(self.num_points // 2), 151 | int(self.num_points // 2), 152 | int(self.num_points), 153 | ) 154 | 155 | y, x = torch.meshgrid(y, x) 156 | y_spread = y.reshape(-1, 1) 157 | x_spread = x.reshape(-1, 1) 158 | 159 | y_grid = y_spread.repeat([1, self.width * self.height]) 160 | y_grid = y_grid.reshape([self.num_points, self.width, self.height]) 161 | y_grid = y_grid.unsqueeze(0) # [B*K*K, W,H] 162 | 163 | x_grid = x_spread.repeat([1, self.width * self.height]) 164 | x_grid = x_grid.reshape([self.num_points, self.width, self.height]) 165 | x_grid = x_grid.unsqueeze(0) # [B*K*K, W,H] 166 | 167 | y_new = y_center + y_grid 168 | x_new = x_center + x_grid 169 | 170 | y_new = y_new.repeat(self.num_batch, 1, 1, 1).to(device) 171 | x_new = x_new.repeat(self.num_batch, 1, 1, 1).to(device) 172 | 173 | y_offset_new = y_offset.detach().clone() 174 | 175 | if if_offset: 176 | y_offset = y_offset.permute(1, 0, 2, 3) 177 | y_offset_new = y_offset_new.permute(1, 0, 2, 3) 178 | center = int(self.num_points // 2) 179 | 180 | # The center position remains unchanged and the rest of the positions begin to swing 181 | # This part is quite simple. The main idea is that "offset is an iterative process" 182 | y_offset_new[center] = 0 183 | for index in range(1, center): 184 | y_offset_new[center + index] = (y_offset_new[center + index - 1] + y_offset[center + index]) 185 | y_offset_new[center - index] = (y_offset_new[center - index + 1] + y_offset[center - index]) 186 | y_offset_new = y_offset_new.permute(1, 0, 2, 3).to(device) 187 | y_new = y_new.add(y_offset_new.mul(self.extend_scope)) 188 | 189 | y_new = y_new.reshape( 190 | [self.num_batch, self.num_points, 1, self.width, self.height]) 191 | y_new = y_new.permute(0, 3, 1, 4, 2) 192 | y_new = y_new.reshape([ 193 | self.num_batch, self.num_points * self.width, 1 * self.height 194 | ]) 195 | x_new = x_new.reshape( 196 | [self.num_batch, self.num_points, 1, self.width, self.height]) 197 | x_new = x_new.permute(0, 3, 1, 4, 2) 198 | x_new = x_new.reshape([ 199 | self.num_batch, self.num_points * self.width, 1 * self.height 200 | ]) 201 | return y_new, x_new 202 | 203 | else: 204 | """ 205 | Initialize the kernel and flatten the kernel 206 | y: -num_points//2 ~ num_points//2 (Determined by the kernel size) 207 | x: only need 0 208 | """ 209 | y = torch.linspace( 210 | -int(self.num_points // 2), 211 | int(self.num_points // 2), 212 | int(self.num_points), 213 | ) 214 | x = torch.linspace(0, 0, 1) 215 | 216 | y, x = torch.meshgrid(y, x) 217 | y_spread = y.reshape(-1, 1) 218 | x_spread = x.reshape(-1, 1) 219 | 220 | y_grid = y_spread.repeat([1, self.width * self.height]) 221 | y_grid = y_grid.reshape([self.num_points, self.width, self.height]) 222 | y_grid = y_grid.unsqueeze(0) 223 | 224 | x_grid = x_spread.repeat([1, self.width * self.height]) 225 | x_grid = x_grid.reshape([self.num_points, self.width, self.height]) 226 | x_grid = x_grid.unsqueeze(0) 227 | 228 | y_new = y_center + y_grid 229 | x_new = x_center + x_grid 230 | 231 | y_new = y_new.repeat(self.num_batch, 1, 1, 1) 232 | x_new = x_new.repeat(self.num_batch, 1, 1, 1) 233 | 234 | y_new = y_new.to(device) 235 | x_new = x_new.to(device) 236 | x_offset_new = x_offset.detach().clone() 237 | 238 | if if_offset: 239 | x_offset = x_offset.permute(1, 0, 2, 3) 240 | x_offset_new = x_offset_new.permute(1, 0, 2, 3) 241 | center = int(self.num_points // 2) 242 | x_offset_new[center] = 0 243 | for index in range(1, center): 244 | x_offset_new[center + index] = (x_offset_new[center + index - 1] + x_offset[center + index]) 245 | x_offset_new[center - index] = (x_offset_new[center - index + 1] + x_offset[center - index]) 246 | x_offset_new = x_offset_new.permute(1, 0, 2, 3).to(device) 247 | x_new = x_new.add(x_offset_new.mul(self.extend_scope)) 248 | 249 | y_new = y_new.reshape( 250 | [self.num_batch, 1, self.num_points, self.width, self.height]) 251 | y_new = y_new.permute(0, 3, 1, 4, 2) 252 | y_new = y_new.reshape([ 253 | self.num_batch, 1 * self.width, self.num_points * self.height 254 | ]) 255 | x_new = x_new.reshape( 256 | [self.num_batch, 1, self.num_points, self.width, self.height]) 257 | x_new = x_new.permute(0, 3, 1, 4, 2) 258 | x_new = x_new.reshape([ 259 | self.num_batch, 1 * self.width, self.num_points * self.height 260 | ]) 261 | return y_new, x_new 262 | 263 | """ 264 | input: input feature map [N,C,D,W,H];coordinate map [N,K*D,K*W,K*H] 265 | output: [N,1,K*D,K*W,K*H] deformed feature map 266 | """ 267 | 268 | def _bilinear_interpolate_3D(self, input_feature, y, x): 269 | device = input_feature.device 270 | y = y.reshape([-1]).float() 271 | x = x.reshape([-1]).float() 272 | 273 | zero = torch.zeros([]).int() 274 | max_y = self.width - 1 275 | max_x = self.height - 1 276 | 277 | # find 8 grid locations 278 | y0 = torch.floor(y).int() 279 | y1 = y0 + 1 280 | x0 = torch.floor(x).int() 281 | x1 = x0 + 1 282 | 283 | # clip out coordinates exceeding feature map volume 284 | y0 = torch.clamp(y0, zero, max_y) 285 | y1 = torch.clamp(y1, zero, max_y) 286 | x0 = torch.clamp(x0, zero, max_x) 287 | x1 = torch.clamp(x1, zero, max_x) 288 | 289 | input_feature_flat = input_feature.flatten() 290 | input_feature_flat = input_feature_flat.reshape( 291 | self.num_batch, self.num_channels, self.width, self.height) 292 | input_feature_flat = input_feature_flat.permute(0, 2, 3, 1) 293 | input_feature_flat = input_feature_flat.reshape(-1, self.num_channels) 294 | dimension = self.height * self.width 295 | 296 | base = torch.arange(self.num_batch) * dimension 297 | base = base.reshape([-1, 1]).float() 298 | 299 | repeat = torch.ones([self.num_points * self.width * self.height 300 | ]).unsqueeze(0) 301 | repeat = repeat.float() 302 | 303 | base = torch.matmul(base, repeat) 304 | base = base.reshape([-1]) 305 | 306 | base = base.to(device) 307 | 308 | base_y0 = base + y0 * self.height 309 | base_y1 = base + y1 * self.height 310 | 311 | # top rectangle of the neighbourhood volume 312 | index_a0 = base_y0 - base + x0 313 | index_c0 = base_y0 - base + x1 314 | 315 | # bottom rectangle of the neighbourhood volume 316 | index_a1 = base_y1 - base + x0 317 | index_c1 = base_y1 - base + x1 318 | 319 | # get 8 grid values 320 | value_a0 = input_feature_flat[index_a0.type(torch.int64)].to(device) 321 | value_c0 = input_feature_flat[index_c0.type(torch.int64)].to(device) 322 | value_a1 = input_feature_flat[index_a1.type(torch.int64)].to(device) 323 | value_c1 = input_feature_flat[index_c1.type(torch.int64)].to(device) 324 | 325 | # find 8 grid locations 326 | y0 = torch.floor(y).int() 327 | y1 = y0 + 1 328 | x0 = torch.floor(x).int() 329 | x1 = x0 + 1 330 | 331 | # clip out coordinates exceeding feature map volume 332 | y0 = torch.clamp(y0, zero, max_y + 1) 333 | y1 = torch.clamp(y1, zero, max_y + 1) 334 | x0 = torch.clamp(x0, zero, max_x + 1) 335 | x1 = torch.clamp(x1, zero, max_x + 1) 336 | 337 | x0_float = x0.float() 338 | x1_float = x1.float() 339 | y0_float = y0.float() 340 | y1_float = y1.float() 341 | 342 | vol_a0 = ((y1_float - y) * (x1_float - x)).unsqueeze(-1).to(device) 343 | vol_c0 = ((y1_float - y) * (x - x0_float)).unsqueeze(-1).to(device) 344 | vol_a1 = ((y - y0_float) * (x1_float - x)).unsqueeze(-1).to(device) 345 | vol_c1 = ((y - y0_float) * (x - x0_float)).unsqueeze(-1).to(device) 346 | 347 | outputs = (value_a0 * vol_a0 + value_c0 * vol_c0 + value_a1 * vol_a1 + 348 | value_c1 * vol_c1) 349 | 350 | if self.morph == 0: 351 | outputs = outputs.reshape([ 352 | self.num_batch, 353 | self.num_points * self.width, 354 | 1 * self.height, 355 | self.num_channels, 356 | ]) 357 | outputs = outputs.permute(0, 3, 1, 2) 358 | else: 359 | outputs = outputs.reshape([ 360 | self.num_batch, 361 | 1 * self.width, 362 | self.num_points * self.height, 363 | self.num_channels, 364 | ]) 365 | outputs = outputs.permute(0, 3, 1, 2) 366 | return outputs 367 | 368 | def deform_conv(self, input, offset, if_offset): 369 | y, x = self._coordinate_map_3D(offset, if_offset) 370 | deformed_feature = self._bilinear_interpolate_3D(input, y, x) 371 | return deformed_feature 372 | 373 | if __name__ == '__main__': 374 | 375 | input = torch.randn(1, 128, 8, 8) 376 | dsconv = DySnakeConv(128, 128) 377 | output = dsconv(input) 378 | print(output.shape) 379 | -------------------------------------------------------------------------------- /Attention Modules for improve YOLO/BiFormer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Core of BiFormer, Bi-Level Routing Attention. 3 | 4 | To be refactored. 5 | 6 | author: ZHU Lei 7 | github: https://github.com/rayleizhu 8 | email: ray.leizhu@outlook.com 9 | 10 | This source code is licensed under the license found in the 11 | LICENSE file in the root directory of this source tree. 12 | """ 13 | from typing import Tuple 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | from einops import rearrange 19 | from torch import Tensor 20 | 21 | 22 | class TopkRouting(nn.Module): 23 | """ 24 | differentiable topk routing with scaling 25 | Args: 26 | qk_dim: int, feature dimension of query and key 27 | topk: int, the 'topk' 28 | qk_scale: int or None, temperature (multiply) of softmax activation 29 | with_param: bool, wether inorporate learnable params in routing unit 30 | diff_routing: bool, wether make routing differentiable 31 | soft_routing: bool, wether make output value multiplied by routing weights 32 | """ 33 | def __init__(self, qk_dim, topk=4, qk_scale=None, param_routing=False, diff_routing=False): 34 | super().__init__() 35 | self.topk = topk 36 | self.qk_dim = qk_dim 37 | self.scale = qk_scale or qk_dim ** -0.5 38 | self.diff_routing = diff_routing 39 | # TODO: norm layer before/after linear? 40 | self.emb = nn.Linear(qk_dim, qk_dim) if param_routing else nn.Identity() 41 | # routing activation 42 | self.routing_act = nn.Softmax(dim=-1) 43 | 44 | def forward(self, query:Tensor, key:Tensor)->Tuple[Tensor]: 45 | """ 46 | Args: 47 | q, k: (n, p^2, c) tensor 48 | Return: 49 | r_weight, topk_index: (n, p^2, topk) tensor 50 | """ 51 | if not self.diff_routing: 52 | query, key = query.detach(), key.detach() 53 | query_hat, key_hat = self.emb(query), self.emb(key) # per-window pooling -> (n, p^2, c) 54 | attn_logit = (query_hat*self.scale) @ key_hat.transpose(-2, -1) # (n, p^2, p^2) 55 | topk_attn_logit, topk_index = torch.topk(attn_logit, k=self.topk, dim=-1) # (n, p^2, k), (n, p^2, k) 56 | r_weight = self.routing_act(topk_attn_logit) # (n, p^2, k) 57 | 58 | return r_weight, topk_index 59 | 60 | 61 | class KVGather(nn.Module): 62 | def __init__(self, mul_weight='none'): 63 | super().__init__() 64 | assert mul_weight in ['none', 'soft', 'hard'] 65 | self.mul_weight = mul_weight 66 | 67 | def forward(self, r_idx:Tensor, r_weight:Tensor, kv:Tensor): 68 | """ 69 | r_idx: (n, p^2, topk) tensor 70 | r_weight: (n, p^2, topk) tensor 71 | kv: (n, p^2, w^2, c_kq+c_v) 72 | 73 | Return: 74 | (n, p^2, topk, w^2, c_kq+c_v) tensor 75 | """ 76 | # select kv according to routing index 77 | n, p2, w2, c_kv = kv.size() 78 | topk = r_idx.size(-1) 79 | # print(r_idx.size(), r_weight.size()) 80 | # FIXME: gather consumes much memory (topk times redundancy), write cuda kernel? 81 | topk_kv = torch.gather(kv.view(n, 1, p2, w2, c_kv).expand(-1, p2, -1, -1, -1), # (n, p^2, p^2, w^2, c_kv) without mem cpy 82 | dim=2, 83 | index=r_idx.view(n, p2, topk, 1, 1).expand(-1, -1, -1, w2, c_kv) # (n, p^2, k, w^2, c_kv) 84 | ) 85 | 86 | if self.mul_weight == 'soft': 87 | topk_kv = r_weight.view(n, p2, topk, 1, 1) * topk_kv # (n, p^2, k, w^2, c_kv) 88 | elif self.mul_weight == 'hard': 89 | raise NotImplementedError('differentiable hard routing TBA') 90 | # else: #'none' 91 | # topk_kv = topk_kv # do nothing 92 | 93 | return topk_kv 94 | 95 | class QKVLinear(nn.Module): 96 | def __init__(self, dim, qk_dim, bias=True): 97 | super().__init__() 98 | self.dim = dim 99 | self.qk_dim = qk_dim 100 | self.qkv = nn.Linear(dim, qk_dim + qk_dim + dim, bias=bias) 101 | 102 | def forward(self, x): 103 | q, kv = self.qkv(x).split([self.qk_dim, self.qk_dim+self.dim], dim=-1) 104 | return q, kv 105 | # q, k, v = self.qkv(x).split([self.qk_dim, self.qk_dim, self.dim], dim=-1) 106 | # return q, k, v 107 | 108 | class BiLevelRoutingAttention(nn.Module): 109 | """ 110 | n_win: number of windows in one side (so the actual number of windows is n_win*n_win) 111 | kv_per_win: for kv_downsample_mode='ada_xxxpool' only, number of key/values per window. Similar to n_win, the actual number is kv_per_win*kv_per_win. 112 | topk: topk for window filtering 113 | param_attention: 'qkvo'-linear for q,k,v and o, 'none': param free attention 114 | param_routing: extra linear for routing 115 | diff_routing: wether to set routing differentiable 116 | soft_routing: wether to multiply soft routing weights 117 | """ 118 | def __init__(self, dim, n_win=7, num_heads=8, qk_dim=None, qk_scale=None, 119 | kv_per_win=4, kv_downsample_ratio=4, kv_downsample_kernel=None, kv_downsample_mode='identity', 120 | topk=4, param_attention="qkvo", param_routing=False, diff_routing=False, soft_routing=False, side_dwconv=3, 121 | auto_pad=True): 122 | super().__init__() 123 | # local attention setting 124 | self.dim = dim 125 | self.n_win = n_win # Wh, Ww 126 | self.num_heads = num_heads 127 | self.qk_dim = qk_dim or dim 128 | assert self.qk_dim % num_heads == 0 and self.dim % num_heads==0, 'qk_dim and dim must be divisible by num_heads!' 129 | self.scale = qk_scale or self.qk_dim ** -0.5 130 | 131 | 132 | ################side_dwconv (i.e. LCE in ShuntedTransformer)########### 133 | self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \ 134 | lambda x: torch.zeros_like(x) 135 | 136 | ################ global routing setting ################# 137 | self.topk = topk 138 | self.param_routing = param_routing 139 | self.diff_routing = diff_routing 140 | self.soft_routing = soft_routing 141 | # router 142 | assert not (self.param_routing and not self.diff_routing) # cannot be with_param=True and diff_routing=False 143 | self.router = TopkRouting(qk_dim=self.qk_dim, 144 | qk_scale=self.scale, 145 | topk=self.topk, 146 | diff_routing=self.diff_routing, 147 | param_routing=self.param_routing) 148 | if self.soft_routing: # soft routing, always diffrentiable (if no detach) 149 | mul_weight = 'soft' 150 | elif self.diff_routing: # hard differentiable routing 151 | mul_weight = 'hard' 152 | else: # hard non-differentiable routing 153 | mul_weight = 'none' 154 | self.kv_gather = KVGather(mul_weight=mul_weight) 155 | 156 | # qkv mapping (shared by both global routing and local attention) 157 | self.param_attention = param_attention 158 | if self.param_attention == 'qkvo': 159 | self.qkv = QKVLinear(self.dim, self.qk_dim) 160 | self.wo = nn.Linear(dim, dim) 161 | elif self.param_attention == 'qkv': 162 | self.qkv = QKVLinear(self.dim, self.qk_dim) 163 | self.wo = nn.Identity() 164 | else: 165 | raise ValueError(f'param_attention mode {self.param_attention} is not surpported!') 166 | 167 | self.kv_downsample_mode = kv_downsample_mode 168 | self.kv_per_win = kv_per_win 169 | self.kv_downsample_ratio = kv_downsample_ratio 170 | self.kv_downsample_kenel = kv_downsample_kernel 171 | if self.kv_downsample_mode == 'ada_avgpool': 172 | assert self.kv_per_win is not None 173 | self.kv_down = nn.AdaptiveAvgPool2d(self.kv_per_win) 174 | elif self.kv_downsample_mode == 'ada_maxpool': 175 | assert self.kv_per_win is not None 176 | self.kv_down = nn.AdaptiveMaxPool2d(self.kv_per_win) 177 | elif self.kv_downsample_mode == 'maxpool': 178 | assert self.kv_downsample_ratio is not None 179 | self.kv_down = nn.MaxPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity() 180 | elif self.kv_downsample_mode == 'avgpool': 181 | assert self.kv_downsample_ratio is not None 182 | self.kv_down = nn.AvgPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity() 183 | elif self.kv_downsample_mode == 'identity': # no kv downsampling 184 | self.kv_down = nn.Identity() 185 | elif self.kv_downsample_mode == 'fracpool': 186 | # assert self.kv_downsample_ratio is not None 187 | # assert self.kv_downsample_kenel is not None 188 | # TODO: fracpool 189 | # 1. kernel size should be input size dependent 190 | # 2. there is a random factor, need to avoid independent sampling for k and v 191 | raise NotImplementedError('fracpool policy is not implemented yet!') 192 | elif kv_downsample_mode == 'conv': 193 | # TODO: need to consider the case where k != v so that need two downsample modules 194 | raise NotImplementedError('conv policy is not implemented yet!') 195 | else: 196 | raise ValueError(f'kv_down_sample_mode {self.kv_downsaple_mode} is not surpported!') 197 | 198 | # softmax for local attention 199 | self.attn_act = nn.Softmax(dim=-1) 200 | 201 | self.auto_pad=auto_pad 202 | 203 | def forward(self, x, ret_attn_mask=False): 204 | """ 205 | x: NHWC tensor 206 | 207 | Return: 208 | NHWC tensor 209 | """ 210 | x = rearrange(x, "n c h w -> n h w c") 211 | # NOTE: use padding for semantic segmentation 212 | ################################################### 213 | if self.auto_pad: 214 | N, H_in, W_in, C = x.size() 215 | 216 | pad_l = pad_t = 0 217 | pad_r = (self.n_win - W_in % self.n_win) % self.n_win 218 | pad_b = (self.n_win - H_in % self.n_win) % self.n_win 219 | x = F.pad(x, (0, 0, # dim=-1 220 | pad_l, pad_r, # dim=-2 221 | pad_t, pad_b)) # dim=-3 222 | _, H, W, _ = x.size() # padded size 223 | else: 224 | N, H, W, C = x.size() 225 | assert H%self.n_win == 0 and W%self.n_win == 0 # 226 | ################################################### 227 | 228 | 229 | # patchify, (n, p^2, w, w, c), keep 2d window as we need 2d pooling to reduce kv size 230 | x = rearrange(x, "n (j h) (i w) c -> n (j i) h w c", j=self.n_win, i=self.n_win) 231 | 232 | #################qkv projection################### 233 | # q: (n, p^2, w, w, c_qk) 234 | # kv: (n, p^2, w, w, c_qk+c_v) 235 | # NOTE: separte kv if there were memory leak issue caused by gather 236 | q, kv = self.qkv(x) 237 | 238 | # pixel-wise qkv 239 | # q_pix: (n, p^2, w^2, c_qk) 240 | # kv_pix: (n, p^2, h_kv*w_kv, c_qk+c_v) 241 | q_pix = rearrange(q, 'n p2 h w c -> n p2 (h w) c') 242 | kv_pix = self.kv_down(rearrange(kv, 'n p2 h w c -> (n p2) c h w')) 243 | kv_pix = rearrange(kv_pix, '(n j i) c h w -> n (j i) (h w) c', j=self.n_win, i=self.n_win) 244 | 245 | q_win, k_win = q.mean([2, 3]), kv[..., 0:self.qk_dim].mean([2, 3]) # window-wise qk, (n, p^2, c_qk), (n, p^2, c_qk) 246 | 247 | ##################side_dwconv(lepe)################## 248 | # NOTE: call contiguous to avoid gradient warning when using ddp 249 | lepe = self.lepe(rearrange(kv[..., self.qk_dim:], 'n (j i) h w c -> n c (j h) (i w)', j=self.n_win, i=self.n_win).contiguous()) 250 | lepe = rearrange(lepe, 'n c (j h) (i w) -> n (j h) (i w) c', j=self.n_win, i=self.n_win) 251 | 252 | ############ gather q dependent k/v ################# 253 | 254 | r_weight, r_idx = self.router(q_win, k_win) # both are (n, p^2, topk) tensors 255 | 256 | kv_pix_sel = self.kv_gather(r_idx=r_idx, r_weight=r_weight, kv=kv_pix) #(n, p^2, topk, h_kv*w_kv, c_qk+c_v) 257 | k_pix_sel, v_pix_sel = kv_pix_sel.split([self.qk_dim, self.dim], dim=-1) 258 | # kv_pix_sel: (n, p^2, topk, h_kv*w_kv, c_qk) 259 | # v_pix_sel: (n, p^2, topk, h_kv*w_kv, c_v) 260 | 261 | ######### do attention as normal #################### 262 | k_pix_sel = rearrange(k_pix_sel, 'n p2 k w2 (m c) -> (n p2) m c (k w2)', m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_kq//m) transpose here? 263 | v_pix_sel = rearrange(v_pix_sel, 'n p2 k w2 (m c) -> (n p2) m (k w2) c', m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_v//m) 264 | q_pix = rearrange(q_pix, 'n p2 w2 (m c) -> (n p2) m w2 c', m=self.num_heads) # to BMLC tensor (n*p^2, m, w^2, c_qk//m) 265 | 266 | # param-free multihead attention 267 | attn_weight = (q_pix * self.scale) @ k_pix_sel # (n*p^2, m, w^2, c) @ (n*p^2, m, c, topk*h_kv*w_kv) -> (n*p^2, m, w^2, topk*h_kv*w_kv) 268 | attn_weight = self.attn_act(attn_weight) 269 | out = attn_weight @ v_pix_sel # (n*p^2, m, w^2, topk*h_kv*w_kv) @ (n*p^2, m, topk*h_kv*w_kv, c) -> (n*p^2, m, w^2, c) 270 | out = rearrange(out, '(n j i) m (h w) c -> n (j h) (i w) (m c)', j=self.n_win, i=self.n_win, 271 | h=H//self.n_win, w=W//self.n_win) 272 | 273 | out = out + lepe 274 | # output linear 275 | out = self.wo(out) 276 | 277 | # NOTE: use padding for semantic segmentation 278 | # crop padded region 279 | if self.auto_pad and (pad_r > 0 or pad_b > 0): 280 | out = out[:, :H_in, :W_in, :].contiguous() 281 | 282 | if ret_attn_mask: 283 | return out, r_weight, r_idx, attn_weight 284 | else: 285 | return rearrange(out, "n h w c -> n c h w") 286 | 287 | class Attention(nn.Module): 288 | """ 289 | vanilla attention 290 | """ 291 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 292 | super().__init__() 293 | self.num_heads = num_heads 294 | head_dim = dim // num_heads 295 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 296 | self.scale = qk_scale or head_dim ** -0.5 297 | 298 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 299 | self.attn_drop = nn.Dropout(attn_drop) 300 | self.proj = nn.Linear(dim, dim) 301 | self.proj_drop = nn.Dropout(proj_drop) 302 | 303 | def forward(self, x): 304 | """ 305 | args: 306 | x: NCHW tensor 307 | return: 308 | NCHW tensor 309 | """ 310 | _, _, H, W = x.size() 311 | x = rearrange(x, 'n c h w -> n (h w) c') 312 | 313 | ####################################### 314 | B, N, C = x.shape 315 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 316 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 317 | 318 | attn = (q @ k.transpose(-2, -1)) * self.scale 319 | attn = attn.softmax(dim=-1) 320 | attn = self.attn_drop(attn) 321 | 322 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 323 | x = self.proj(x) 324 | x = self.proj_drop(x) 325 | ####################################### 326 | 327 | x = rearrange(x, 'n (h w) c -> n c h w', h=H, w=W) 328 | return x 329 | 330 | class AttentionLePE(nn.Module): 331 | """ 332 | vanilla attention 333 | """ 334 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., side_dwconv=5): 335 | super().__init__() 336 | self.num_heads = num_heads 337 | head_dim = dim // num_heads 338 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 339 | self.scale = qk_scale or head_dim ** -0.5 340 | 341 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 342 | self.attn_drop = nn.Dropout(attn_drop) 343 | self.proj = nn.Linear(dim, dim) 344 | self.proj_drop = nn.Dropout(proj_drop) 345 | self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \ 346 | lambda x: torch.zeros_like(x) 347 | 348 | def forward(self, x): 349 | """ 350 | args: 351 | x: NCHW tensor 352 | return: 353 | NCHW tensor 354 | """ 355 | _, _, H, W = x.size() 356 | x = rearrange(x, 'n c h w -> n (h w) c') 357 | 358 | ####################################### 359 | B, N, C = x.shape 360 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 361 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 362 | 363 | lepe = self.lepe(rearrange(x, 'n (h w) c -> n c h w', h=H, w=W)) 364 | lepe = rearrange(lepe, 'n c h w -> n (h w) c') 365 | 366 | attn = (q @ k.transpose(-2, -1)) * self.scale 367 | attn = attn.softmax(dim=-1) 368 | attn = self.attn_drop(attn) 369 | 370 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 371 | x = x + lepe 372 | 373 | x = self.proj(x) 374 | x = self.proj_drop(x) 375 | ####################################### 376 | 377 | x = rearrange(x, 'n (h w) c -> n c h w', h=H, w=W) 378 | return x 379 | 380 | 381 | if __name__ == '__main__': 382 | input = torch.randn(1, 128, 16, 16) 383 | att = BiLevelRoutingAttention(128) 384 | output = att(input) 385 | print(output.shape) 386 | --------------------------------------------------------------------------------