├── README.md ├── simam_module.py ├── ParNetAttention.py ├── SpatialGroupEnhance.py ├── LSKblock.py ├── lstm_model.py ├── MLP_Communicator(多模态适用).py ├── Partial_conv3.py ├── kNNAttention.py ├── MobileViTv2Attention.py ├── ContraNorm(对比归一化层).py ├── SKNet.py ├── CAN(人群计数,CV2维任务通用).py ├── EfficientAdditiveAttnetion.py ├── VisionPermutator.py ├── S2Attention.py ├── TripletAttention.py ├── LFA.py ├── OutlookAtt.py ├── Pacoloss(参数对比损失,用于对比学习).py ├── UFOAttention.py ├── Deepfake(深度伪造检测).py ├── Free_UNetModel(扩散模型).py ├── LinAngularAttention.py ├── SSPCAB(图像和视频异常检测,CV2维任务通用).py ├── MUSEAttention.py ├── DynamicFilter(频域模块动态滤波器用于CV2维图像).py ├── Wave-pooling(轨迹预测,CV2维图像通用).py ├── ISL(用于点云任务).py ├── ScConv.py ├── Crossnorm-Selfnorm(领域泛化).py ├── SDM(3D任务).py ├── SViT.py ├── PFNet(点云).py ├── EGA(边缘检测,CV2维图像通用).py ├── F_Block(频域模块用于时间序列).py ├── efficient kan.py ├── GKONet((三维人体姿态估计).py ├── SAM-单目深度估计-特征融合-CV2维度通用.py ├── sLSTM&mLSTM(NLP和时序任务).py └── MambaIR(CV二维图像).py /README.md: -------------------------------------------------------------------------------- 1 | # DeepLearning 2 | 深度学习领域模块化代码 3 | 代码附有详细注释 4 | -------------------------------------------------------------------------------- /simam_module.py: -------------------------------------------------------------------------------- 1 | # https://github.com/ZjjConan/SimAM 2 | 3 | """ 4 | 该模块的目的是增强图像特征之间的关系,以提高模型的表现。 5 | 6 | 以下是模块的主要组件和功能: 7 | 8 | 初始化:在初始化过程中,模块接受一个参数 e_lambda,它是一个小的正数(默认为1e-4)。e_lambda 用于避免分母为零的情况,以确保数值稳定性。此外,模块还创建了一个 Sigmoid 激活函数 act。 9 | 10 | 前向传播:在前向传播中,模块执行以下步骤: 11 | 12 | 计算输入张量 x 的形状信息,包括批量大小 b、通道数 c、高度 h 和宽度 w。 13 | 计算像素点的数量 n,即图像的高度和宽度的乘积减去1(减1是因为在计算方差时要排除一个像素的均值)。 14 | 计算每个像素点与均值的差的平方,即 (x - x.mean(dim=[2, 3], keepdim=True)).pow(2),这样可以得到差的平方矩阵。 15 | 计算分母部分,即 (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda),并加上小的正数 e_lambda 以确保分母不为零。 16 | 计算 y,通过将差的平方矩阵除以分母部分,然后加上0.5。这个操作应用了 Sigmoid 函数,将结果限制在0到1之间。 17 | 最后,将输入张量 x 与 y 经过 Sigmoid 激活后的结果相乘,以产生最终的输出。 18 | SIMAM 模块的关键思想是计算每个像素点的特征值与均值之间的关系,并通过 Sigmoid 激活函数来调整这种关系,从而增强特征之间的互动性。这有助于捕获图像中不同位置之间的关系,有助于提高模型性能。 19 | """ 20 | 21 | import torch 22 | import torch.nn as nn 23 | from thop import profile 24 | 25 | from einops import rearrange 26 | 27 | def to_3d(x): 28 | return rearrange(x, 'b c h w -> b (h w) c') 29 | 30 | def to_4d(x,h,w): 31 | return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 32 | 33 | class Simam_module(torch.nn.Module): 34 | def __init__(self, e_lambda=1e-4): 35 | super(Simam_module, self).__init__() 36 | self.act = nn.Sigmoid() 37 | self.e_lambda = e_lambda 38 | 39 | def forward(self, x): 40 | b, c, h, w = x.size() 41 | n = w * h - 1 42 | x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2) 43 | y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5 44 | 45 | return x * self.act(y) 46 | 47 | 48 | # 输入 N C H W, 输出 N C H W 49 | if __name__ == '__main__': 50 | model = Simam_module().cuda() 51 | # x = torch.randn(1, 3, 64, 64).cuda() 52 | x = torch.randn(32, 784, 128).cuda() 53 | x = to_4d(x,h=28,w=28) 54 | y = model(x) 55 | y = to_3d(y) 56 | print(y.shape) 57 | flops, params = profile(model, (x,)) 58 | print(flops / 1e9) 59 | print(params) 60 | -------------------------------------------------------------------------------- /ParNetAttention.py: -------------------------------------------------------------------------------- 1 | # https://github.com/imankgoyal/NonDeepNetworks 2 | 3 | """ 4 | 模块包括以下组件: 5 | 6 | sse(Squeeze-and-Excitation)模块: 7 | 8 | 通过自适应平均池化将输入张量池化到大小为 1x1。 9 | 然后使用一个具有相同通道数的卷积层,产生一组注意力权重,这些权重通过 Sigmoid 激活函数进行缩放。 10 | 这些注意力权重用于对输入特征进行加权,以突出重要的特征。 11 | conv1x1 和 conv3x3 模块: 12 | 13 | conv1x1 是一个1x1卷积层,用于捕捉输入的全局信息。 14 | conv3x3 是一个3x3卷积层,用于捕捉局部信息。 15 | 两者都后跟批归一化层以稳定训练。 16 | silu 激活函数: 17 | 18 | Silu(或Swish)激活函数是一种非线性激活函数,它将输入映射到一个非线性范围内。 19 | 在前向传播中,输入张量 x 通过这些组件,最终输出特征张量 y。这个模块旨在提高神经网络的特征表示能力,通过不同尺度的特征融合和注意力加权来捕获全局和局部信息。 20 | """ 21 | 22 | import numpy as np 23 | import torch 24 | from torch import nn 25 | from torch.nn import init 26 | 27 | 28 | from einops import rearrange 29 | 30 | def to_3d(x): 31 | return rearrange(x, 'b c h w -> b (h w) c') 32 | 33 | def to_4d(x,h,w): 34 | return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) 35 | 36 | 37 | 38 | class ParNetAttention(nn.Module): 39 | 40 | def __init__(self, channel=512): 41 | super().__init__() 42 | self.sse = nn.Sequential( 43 | nn.AdaptiveAvgPool2d(1), 44 | nn.Conv2d(channel, channel, kernel_size=1), 45 | nn.Sigmoid() 46 | ) 47 | 48 | self.conv1x1 = nn.Sequential( 49 | nn.Conv2d(channel, channel, kernel_size=1), 50 | nn.BatchNorm2d(channel) 51 | ) 52 | self.conv3x3 = nn.Sequential( 53 | nn.Conv2d(channel, channel, kernel_size=3, padding=1), 54 | nn.BatchNorm2d(channel) 55 | ) 56 | self.silu = nn.SiLU() 57 | 58 | def forward(self, x): 59 | b, c, _, _ = x.size() 60 | x1 = self.conv1x1(x) 61 | x2 = self.conv3x3(x) 62 | x3 = self.sse(x) * x 63 | y = self.silu(x1 + x2 + x3) 64 | return y 65 | 66 | 67 | # 输入 N C H W, 输出 N C H W 68 | if __name__ == '__main__': 69 | # input = torch.randn(3, 512, 7, 7).cuda() 70 | input = torch.randn(1, 128, 256, 256).cuda() 71 | pna = ParNetAttention(channel=128).cuda() 72 | output = pna(input) 73 | print(output.shape) 74 | -------------------------------------------------------------------------------- /SpatialGroupEnhance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | 7 | class SpatialGroupEnhance(nn.Module): 8 | 9 | def __init__(self, groups): 10 | super().__init__() 11 | self.groups = groups 12 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 13 | self.weight = nn.Parameter(torch.zeros(1, groups, 1, 1)) 14 | self.bias = nn.Parameter(torch.zeros(1, groups, 1, 1)) 15 | self.sig = nn.Sigmoid() 16 | self.init_weights() 17 | 18 | def init_weights(self): 19 | for m in self.modules(): 20 | if isinstance(m, nn.Conv2d): 21 | init.kaiming_normal_(m.weight, mode='fan_out') 22 | if m.bias is not None: 23 | init.constant_(m.bias, 0) 24 | elif isinstance(m, nn.BatchNorm2d): 25 | init.constant_(m.weight, 1) 26 | init.constant_(m.bias, 0) 27 | elif isinstance(m, nn.Linear): 28 | init.normal_(m.weight, std=0.001) 29 | if m.bias is not None: 30 | init.constant_(m.bias, 0) 31 | 32 | def forward(self, x): 33 | b, c, h, w = x.shape 34 | x = x.view(b * self.groups, -1, h, w) # bs*g,dim//g,h,w 35 | xn = x * self.avg_pool(x) # bs*g,dim//g,h,w 36 | xn = xn.sum(dim=1, keepdim=True) # bs*g,1,h,w 37 | t = xn.view(b * self.groups, -1) # bs*g,h*w 38 | 39 | t = t - t.mean(dim=1, keepdim=True) # bs*g,h*w 40 | std = t.std(dim=1, keepdim=True) + 1e-5 41 | t = t / std # bs*g,h*w 42 | t = t.view(b, self.groups, h, w) # bs,g,h*w 43 | 44 | t = t * self.weight + self.bias # bs,g,h*w 45 | t = t.view(b * self.groups, 1, h, w) # bs*g,1,h*w 46 | x = x * self.sig(t) 47 | x = x.view(b, c, h, w) 48 | 49 | return x 50 | 51 | 52 | # 输入 N C H W, 输出 N C H W 53 | if __name__ == '__main__': 54 | input = torch.randn(50, 512, 7, 7) 55 | sge = SpatialGroupEnhance(groups=4) 56 | output = sge(input) 57 | print(output.shape) 58 | -------------------------------------------------------------------------------- /LSKblock.py: -------------------------------------------------------------------------------- 1 | # https://github.com/zcablii/Large-Selective-Kernel-Network 2 | 3 | """ 4 | 以下是该模块的主要组件和操作: 5 | 6 | conv0:这是一个深度可分离卷积层,使用 5x5 的卷积核进行卷积操作,groups=dim 意味着将输入的每个通道分为一组进行卷积操作。这一步旨在捕获输入中的空间特征。 7 | 8 | conv_spatial:这是另一个深度可分离卷积层,使用 7x7 的卷积核进行卷积操作,stride=1 表示步幅为 1,padding=9 用于零填充操作,groups=dim 表示将输入的每个通道分为一组进行卷积操作,并且通过 dilation=3 进行扩张卷积。这一步旨在捕获输入中的更大范围的空间特征。 9 | 10 | conv1 和 conv2:这是两个 1x1 的卷积层,用于降低通道数,将输入的通道数减少到 dim // 2。这两个卷积层分别应用于 conv0 和 conv_spatial 的输出。 11 | 12 | conv_squeeze:这是一个 7x7 的卷积层,用于进行通道维度的压缩,将输入通道的数量从 2 降低到 2,通过 sigmoid 函数将输出的值缩放到 (0, 1) 范围内。 13 | 14 | conv:这是一个 1x1 的卷积层,用于将通道数从 dim // 2 恢复到 dim,最终的输出通道数与输入的通道数相同。 15 | 16 | 在前向传播过程中,该模块通过一系列卷积操作将输入的特征图进行加权,其中使用了 sigmoid 权重来调整不同部分的注意力。最终输出的特征图是输入特征图乘以注意力加权的结果。 17 | 18 | 这个 LSKblock 模块的目的是引入空间和通道注意力,以更好地捕获输入特征图中的重要信息。 19 | """ 20 | 21 | import torch 22 | import torch.nn as nn 23 | 24 | 25 | class LSKblock(nn.Module): 26 | def __init__(self, dim): 27 | super().__init__() 28 | self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) 29 | self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3) 30 | self.conv1 = nn.Conv2d(dim, dim // 2, 1) 31 | self.conv2 = nn.Conv2d(dim, dim // 2, 1) 32 | self.conv_squeeze = nn.Conv2d(2, 2, 7, padding=3) 33 | self.conv = nn.Conv2d(dim // 2, dim, 1) 34 | 35 | def forward(self, x): 36 | attn1 = self.conv0(x) 37 | attn2 = self.conv_spatial(attn1) 38 | 39 | attn1 = self.conv1(attn1) 40 | attn2 = self.conv2(attn2) 41 | 42 | attn = torch.cat([attn1, attn2], dim=1) 43 | avg_attn = torch.mean(attn, dim=1, keepdim=True) 44 | max_attn, _ = torch.max(attn, dim=1, keepdim=True) 45 | agg = torch.cat([avg_attn, max_attn], dim=1) 46 | sig = self.conv_squeeze(agg).sigmoid() 47 | attn = attn1 * sig[:, 0, :, :].unsqueeze(1) + attn2 * sig[:, 1, :, :].unsqueeze(1) 48 | attn = self.conv(attn) 49 | return x * attn 50 | 51 | 52 | # 输入 N C H W, 输出 N C H W 53 | if __name__ == '__main__': 54 | block = LSKblock(64).cuda() 55 | input = torch.rand(1, 64, 64, 64).cuda() 56 | output = block(input) 57 | print(input.size(), output.size()) 58 | -------------------------------------------------------------------------------- /lstm_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | class LSTMModel(nn.Module): 4 | def __init__(self, window_size, input_size, 5 | hidden_dim, pred_len, num_layers, batch_size, device) -> None: 6 | super().__init__() 7 | self.pred_len = pred_len 8 | self.batch_size = batch_size 9 | self.input_size = input_size 10 | self.device = device 11 | self.lstm_encoder = nn.LSTM(input_size=input_size, hidden_size=hidden_dim, num_layers=num_layers, 12 | batch_first=True).to(self.device) 13 | self.lstm_decoder = nn.LSTM(input_size=input_size, hidden_size=hidden_dim, num_layers=num_layers, 14 | batch_first=True).to(self.device) 15 | self.relu = nn.GELU() 16 | self.fc = nn.Linear(hidden_dim, input_size) 17 | def forward(self, src): 18 | src = torch.unsqueeze(src, -1)#展平 19 | _, decoder_hidden = self.lstm_encoder(src) 20 | cur_batch = src.shape[0] 21 | decoder_input = torch.zeros(cur_batch, 1, self.input_size).to(self.device) 22 | outputs = torch.zeros(self.pred_len, cur_batch, self.input_size).to(self.device) 23 | for t in range(self.pred_len): 24 | decoder_output, decoder_hidden = self.lstm_decoder(decoder_input, decoder_hidden) 25 | decoder_output = self.relu(decoder_output) 26 | decoder_input = self.fc(decoder_output) 27 | outputs[t] = torch.squeeze(decoder_input, dim=-2) 28 | return outputs 29 | if __name__ == '__main__':#模型调用示例 30 | # if torch.cuda.is_available(): 31 | # device = torch.device("cuda") 32 | # else: 33 | # device = torch.device("cpu") 34 | device = torch.device("cpu") 35 | feature = 2#每个时间戳特征数 36 | timestep = 3#时间步长 37 | batch_size = 1#批次 38 | inputseq = torch.randn(timestep, feature).to(device) # 模拟输入,生成batch批次的,序列长度为timestep,序列每个时间戳特征数为feature的随机序列 39 | hidden_dim = 5#lstm隐藏层 40 | num_layers = 1 #lstm层数 41 | window_size = 5 42 | input_size = 2#输入序列每个时间戳的特征数 43 | pred_len = 1#重构序列的个数 44 | 45 | model = LSTMModel(window_size, input_size, hidden_dim, pred_len, num_layers, batch_size, device) 46 | output = model(inputseq)#模型输出与输入相同,用于重构x'误差的计算 47 | ou = 1 48 | "output.size()=[1,3,2],inputseq.size=[3,2],故output.squezze" 49 | -------------------------------------------------------------------------------- /MLP_Communicator(多模态适用).py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops.layers.torch import Rearrange 4 | 5 | 6 | """ 7 | 多模态情感分析旨在判断互联网用户在各种社交媒体平台上上传的多模态数据的情感。一方面,现有研究关注文本、音频和视觉等多模态数据的融合机制,而忽视了文本与音频、文本与视觉的相似性以及音频与视觉的异质性,导致情感分析出现偏差。 8 | 另一方面,多模态数据带来与情感分析无关的噪声,影响融合效果。在本文中,我们提出了一种称为 PS-Mixer 的极向量和强度向量混合模型,它基于 MLP-Mixer,以实现不同模态数据之间更好的通信,以进行多模态情感分析。 9 | 具体来说,我们设计了一个极向量(PV)和一个强度向量(SV)来分别判断情绪的极性和强度。 PV是从文本和视觉特征的交流中获得的,以决定情感是积极的、消极的还是中性的。 10 | SV是从文本和音频特征之间的通信中获得的,以分析0到3范围内的情感强度。 11 | 此外,我们设计了一个由多个全连接层和激活函数组成的MLP通信模块(MLP-C) ,以使得不同模态特征在水平和垂直方向上充分交互,是利用MLP进行多模态信息通信的新颖尝试。 12 | """ 13 | 14 | 15 | class MLP_block(nn.Module): 16 | def __init__(self, input_size, hidden_size, dropout=0.5): 17 | super().__init__() 18 | self.net = nn.Sequential( 19 | nn.Linear(input_size, hidden_size), 20 | nn.GELU(), 21 | nn.Dropout(dropout), 22 | nn.Linear(hidden_size, input_size), 23 | nn.Dropout(dropout) 24 | ) 25 | 26 | def forward(self, x): 27 | x = self.net(x) 28 | return x 29 | 30 | 31 | class MLP_Communicator(nn.Module): 32 | def __init__(self, token, channel, hidden_size, depth=1): 33 | super(MLP_Communicator, self).__init__() 34 | self.depth = depth 35 | self.token_mixer = nn.Sequential( 36 | Rearrange('b n d -> b d n'), 37 | MLP_block(input_size=channel, hidden_size=hidden_size), 38 | Rearrange('b n d -> b d n') 39 | ) 40 | self.channel_mixer = nn.Sequential( 41 | MLP_block(input_size=token, hidden_size=hidden_size) 42 | ) 43 | 44 | def forward(self, x): 45 | for _ in range(self.depth): 46 | x = x + self.token_mixer(x) 47 | x = x + self.channel_mixer(x) 48 | return x 49 | 50 | 51 | if __name__ == '__main__': 52 | # 创建模型实例 53 | block = MLP_Communicator( 54 | token=32, # token 的大小 55 | channel=128, # 通道的大小 56 | hidden_size=64, # 隐藏层的大小 57 | depth=1 # 深度 58 | ) 59 | 60 | # 准备输入张量 61 | input_tensor = torch.randn(8, 128, 32) # 32与token对应 128与channel对应 62 | 63 | # 执行前向传播 64 | output_tensor = block(input_tensor) 65 | 66 | # 打印输入张量和输出张量的形状 67 | print("Input Tensor Shape:", input_tensor.size()) 68 | print("Output Tensor Shape:", output_tensor.size()) 69 | -------------------------------------------------------------------------------- /Partial_conv3.py: -------------------------------------------------------------------------------- 1 | # https://github.com/JierunChen/FasterNet 2 | 3 | """ 4 | 这个代码实现了一个名为Partial_conv3的自定义卷积模块,它根据参数的不同执行不同的操作。这个模块的主要特点如下: 5 | 6 | 部分卷积操作:这个模块使用了一个nn.Conv2d的部分卷积操作,其中dim_conv3表示卷积操作的输出通道数,通常是输入通道数dim的一部分。这部分卷积操作在输入图像的特定通道上执行。 7 | 8 | 前向传播策略:这个模块可以采用两种不同的前向传播策略,具体取决于forward参数的设置: 9 | 10 | 'slicing':在前向传播时,仅对输入张量的部分通道进行部分卷积操作。这对应于仅在推理时使用部分卷积。 11 | 'split_cat':在前向传播时,将输入张量分为两部分,其中一部分进行部分卷积操作,然后将两部分重新连接。这对应于在训练和推理过程中都使用部分卷积。 12 | 部分卷积操作的应用:部分卷积操作被用于输入张量的部分通道上,而保持其他通道不变。这有助于模型有选择性地应用卷积操作到特定通道上,从而可以灵活地控制特征的提取和传播。 13 | 14 | 残差连接:在部分卷积操作之后,模块保留了未经处理的部分通道,然后将两部分连接起来,以保持输入和输出的通道数一致,以便与其他模块连接。 15 | 16 | 总的来说,Partial_conv3模块提供了一种自定义卷积策略,可以根据应用的需要选择性地应用卷积操作到输入图像的特定通道上。这种模块可以用于特征选择、通道交互等任务,增加了神经网络的灵活性。 17 | """ 18 | 19 | 20 | from torch import nn 21 | import torch 22 | 23 | from einops.einops import rearrange 24 | 25 | def to_3d(x): 26 | return rearrange(x, 'b c h w -> b (h w) c') 27 | 28 | def to_4d(x, h, w): 29 | return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 30 | 31 | 32 | class Partial_conv3(nn.Module): 33 | 34 | def __init__(self, dim, n_div, forward): 35 | super().__init__() 36 | self.dim_conv3 = dim // n_div 37 | self.dim_untouched = dim - self.dim_conv3 38 | self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False) 39 | 40 | if forward == 'slicing': 41 | self.forward = self.forward_slicing 42 | elif forward == 'split_cat': 43 | self.forward = self.forward_split_cat 44 | else: 45 | raise NotImplementedError 46 | 47 | def forward_slicing(self, x): 48 | # only for inference 49 | x = x.clone() # !!! Keep the original input intact for the residual connection later 50 | x[:, :self.dim_conv3, :, :] = self.partial_conv3(x[:, :self.dim_conv3, :, :]) 51 | 52 | return x 53 | 54 | def forward_split_cat(self, x): 55 | # for training/inference 56 | x1, x2 = torch.split(x, [self.dim_conv3, self.dim_untouched], dim=1) 57 | x1 = self.partial_conv3(x1) 58 | x = torch.cat((x1, x2), 1) 59 | 60 | return x 61 | 62 | 63 | if __name__ == '__main__': 64 | block = Partial_conv3(128, 2, 'split_cat') 65 | input = torch.rand(32, 784, 128) 66 | input = to_4d(input, 28, 28) 67 | output = block(input) 68 | output = to_3d(input) 69 | print(input.size()) 70 | print(output.size()) 71 | -------------------------------------------------------------------------------- /kNNAttention.py: -------------------------------------------------------------------------------- 1 | # https://github.com/damo-cv/KVT 2 | 3 | """ 4 | 以下是该模块的主要组件和操作: 5 | 6 | qkv:这是一个线性层,将输入特征 x 映射到三个不同的线性变换,分别对应查询 (query),键 (key),和值 (value)。这三个变换将输入特征的通道划分成多个头 (heads)。 7 | 8 | attn_drop 和 proj_drop:这是用于进行注意力矩阵和输出特征的丢弃操作的 Dropout 层。 9 | 10 | topk:这是一个超参数,表示要选择每个查询的前 k 个最相关的键。它控制了 k-最近邻注意力机制的行为。 11 | 12 | 在前向传播过程中,该模块首先将输入特征 x 映射为查询、键和值。然后,通过矩阵乘法操作计算注意力矩阵,但注意力矩阵的计算在这里进行了修改。具体来说,它使用 torch.topk 函数来选择每个查询的前 k 个最相关的键,然后将其余的注意力权重设为负无穷大,以实现 k-最近邻注意力机制。之后,应用 softmax 归一化得到最终的注意力矩阵。最后,利用注意力矩阵对值进行加权平均,得到最终的输出特征。 13 | 14 | 这个模块的核心思想是在计算注意力时仅考虑与每个查询最相关的 k 个键,从而减少计算复杂度并提高效率。这对于处理大规模数据或具有长序列的模型特别有用。 15 | """ 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | 21 | class kNNAttention(nn.Module): 22 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,topk=100): 23 | super().__init__() 24 | self.num_heads = num_heads 25 | head_dim = dim // num_heads 26 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 27 | self.scale = qk_scale or head_dim ** -0.5 28 | 29 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 30 | self.attn_drop = nn.Dropout(attn_drop) 31 | self.proj = nn.Linear(dim, dim) 32 | self.proj_drop = nn.Dropout(proj_drop) 33 | self.topk = topk 34 | 35 | def forward(self, x): 36 | B, N, C = x.shape 37 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 38 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 39 | attn = (q @ k.transpose(-2, -1)) * self.scale 40 | # the core code block 41 | mask=torch.zeros(B,self.num_heads,N,N,device=x.device,requires_grad=False) 42 | index=torch.topk(attn,k=self.topk,dim=-1,largest=True)[1] 43 | mask.scatter_(-1,index,1.) 44 | attn=torch.where(mask>0, attn,torch.full_like(attn, float('-inf'))) 45 | # end of the core code block 46 | 47 | attn = attn.softmax(dim=-1) 48 | attn = self.attn_drop(attn) 49 | 50 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 51 | x = self.proj(x) 52 | x = self.proj_drop(x) 53 | return x 54 | 55 | 56 | if __name__ == '__main__': 57 | block = kNNAttention(dim=128) 58 | input = torch.rand(32,784,128) 59 | output = block(input) 60 | print(input.size()) 61 | print(output.size()) -------------------------------------------------------------------------------- /MobileViTv2Attention.py: -------------------------------------------------------------------------------- 1 | # https://github.com/apple/ml-cvnets 2 | 3 | """ 4 | 以下是该模块的主要组件和操作: 5 | 6 | 自注意力计算:使用线性变换(fc_i, fc_k, fc_v和fc_o)将输入映射到不同的子空间,并计算权重(weight_i)来为每个查询分配注意力权重。注意力权重通过对fc_i的输出进行softmax操作得到,然后用于加权fc_k(input)的输出,得到context_score。接下来,通过对context_score进行求和,以获得一个上下文向量(context_vector),该向量用于加权fc_v(input)的输出。最后,对v进行线性变换(fc_o)以获得最终的输出。 7 | 8 | 初始化权重:通过init_weights方法来初始化模块中的权重。 9 | 10 | 前向传播:根据输入执行自注意力计算,返回计算得到的注意力输出。 11 | """ 12 | 13 | import numpy as np 14 | import torch 15 | from torch import nn 16 | from torch.nn import init 17 | 18 | 19 | class MobileViTv2Attention(nn.Module): 20 | ''' 21 | Scaled dot-product attention 22 | ''' 23 | 24 | def __init__(self, d_model): 25 | ''' 26 | :param d_model: Output dimensionality of the model 27 | :param d_k: Dimensionality of queries and keys 28 | :param d_v: Dimensionality of values 29 | :param h: Number of heads 30 | ''' 31 | super(MobileViTv2Attention, self).__init__() 32 | self.fc_i = nn.Linear(d_model, 1) 33 | self.fc_k = nn.Linear(d_model, d_model) 34 | self.fc_v = nn.Linear(d_model, d_model) 35 | self.fc_o = nn.Linear(d_model, d_model) 36 | 37 | self.d_model = d_model 38 | self.init_weights() 39 | 40 | def init_weights(self): 41 | for m in self.modules(): 42 | if isinstance(m, nn.Conv2d): 43 | init.kaiming_normal_(m.weight, mode='fan_out') 44 | if m.bias is not None: 45 | init.constant_(m.bias, 0) 46 | elif isinstance(m, nn.BatchNorm2d): 47 | init.constant_(m.weight, 1) 48 | init.constant_(m.bias, 0) 49 | elif isinstance(m, nn.Linear): 50 | init.normal_(m.weight, std=0.001) 51 | if m.bias is not None: 52 | init.constant_(m.bias, 0) 53 | 54 | def forward(self, input): 55 | ''' 56 | Computes 57 | :param queries: Queries (b_s, nq, d_model) 58 | :return: 59 | ''' 60 | i = self.fc_i(input) # (bs,nq,1) 61 | weight_i = torch.softmax(i, dim=1) # bs,nq,1 62 | context_score = weight_i * self.fc_k(input) # bs,nq,d_model 63 | context_vector = torch.sum(context_score, dim=1, keepdim=True) # bs,1,d_model 64 | v = self.fc_v(input) * context_vector # bs,nq,d_model 65 | out = self.fc_o(v) # bs,nq,d_model 66 | 67 | return out 68 | 69 | 70 | if __name__ == '__main__': 71 | block = MobileViTv2Attention(d_model=256) 72 | # input = torch.rand(64, 64, 512).cuda() 73 | input = torch.rand(1, 128, 256, 256) 74 | output = block(input) 75 | print(input.size(), output.size()) 76 | -------------------------------------------------------------------------------- /ContraNorm(对比归一化层).py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | """ContraNorm:对比学习视角下的过度平滑及其超越 5 | 过度平滑是各种图神经网络 (GNN) 和 Transformer 中的常见现象,其性能会随着层数的增加而下降。我们不是从表示收敛到单个点的完全崩溃的角度来描述过度平滑,而是深入研究维度崩溃的更一般视角,其中表示位于一个狭窄的锥体中。 6 | 因此,受到对比学习在防止维度崩溃方面的有效性的启发,我们提出了一种称为 ContraNorm 的新型规范化层。直观地说,ContraNorm 隐式地破坏了嵌入空间中的表示,从而导致更均匀的分布和更轻微的维度崩溃。 7 | 在理论分析中,我们证明了 ContraNorm 在某些条件下可以缓解完全崩溃和维度崩溃。我们提出的规范化层可以轻松集成到 GNN 和 Transformer 中,并且参数开销可以忽略不计。 8 | 在各种真实数据集上的实验证明了我们提出的 ContraNorm 的有效性。 9 | """ 10 | 11 | class ContraNorm(nn.Module): 12 | def __init__(self, dim, scale=0.1, dual_norm=False, pre_norm=False, temp=1.0, learnable=False, positive=False, identity=False): 13 | super().__init__() 14 | if learnable and scale > 0: 15 | import math 16 | if positive: 17 | scale_init = math.log(scale) 18 | else: 19 | scale_init = scale 20 | self.scale_param = nn.Parameter(torch.empty(dim).fill_(scale_init)) 21 | self.dual_norm = dual_norm 22 | self.scale = scale 23 | self.pre_norm = pre_norm 24 | self.temp = temp 25 | self.learnable = learnable 26 | self.positive = positive 27 | self.identity = identity 28 | 29 | self.layernorm = nn.LayerNorm(dim, eps=1e-6) 30 | 31 | def forward(self, x): 32 | if self.scale > 0.0: 33 | xn = nn.functional.normalize(x, dim=2) 34 | if self.pre_norm: 35 | x = xn 36 | sim = torch.bmm(xn, xn.transpose(1,2)) / self.temp 37 | if self.dual_norm: 38 | sim = nn.functional.softmax(sim, dim=2) + nn.functional.softmax(sim, dim=1) 39 | else: 40 | sim = nn.functional.softmax(sim, dim=2) 41 | x_neg = torch.bmm(sim, x) 42 | if not self.learnable: 43 | if self.identity: 44 | x = (1+self.scale) * x - self.scale * x_neg 45 | else: 46 | x = x - self.scale * x_neg 47 | else: 48 | scale = torch.exp(self.scale_param) if self.positive else self.scale_param 49 | scale = scale.view(1, 1, -1) 50 | if self.identity: 51 | x = scale * x - scale * x_neg 52 | else: 53 | x = x - scale * x_neg 54 | x = self.layernorm(x) 55 | return x 56 | 57 | 58 | if __name__ == '__main__': 59 | block = ContraNorm(dim=128, scale=0.1, dual_norm=False, pre_norm=False, temp=1.0, learnable=False, positive=False, identity=False) 60 | input = torch.rand(32, 784, 128) 61 | output = block(input) 62 | print("Input size:", input.size()) 63 | print("Output size:", output.size()) 64 | -------------------------------------------------------------------------------- /SKNet.py: -------------------------------------------------------------------------------- 1 | # https://github.com/implus/SKNet 2 | 3 | """ 4 | 该模块的主要功能是对输入张量进行一系列卷积操作,然后计算不同卷积核的注意力权重,并将它们应用于输入的不同部分以生成最终的输出。以下是该模块的主要组件和步骤: 5 | 6 | 初始化:在初始化中,模块接受以下参数: 7 | 8 | channel:输入通道数。 9 | kernels:用于卷积操作的核大小列表。 10 | reduction:通道减少比例,用于降低通道数。 11 | group:卷积操作的分组数。 12 | L:指定的参数,用于确定最大通道数的值。 13 | 在初始化过程中,模块创建了一系列卷积层、线性层和 Softmax 操作,以用于后续的计算。 14 | 15 | 前向传播:在前向传播过程中,模块执行以下步骤: 16 | 17 | 针对每个核大小,使用相应的卷积操作对输入进行卷积,并将卷积结果存储在列表 conv_outs 中。 18 | 将所有卷积结果叠加起来以生成 U,它代表了输入的融合表示。 19 | 对 U 进行平均池化,然后通过线性层将通道数减少到 d。 20 | 使用线性层计算不同卷积核的注意力权重,并将它们存储在列表 weights 中。 21 | 使用 Softmax 函数将注意力权重归一化。 22 | 将注意力权重应用于不同卷积核的特征表示,并对它们进行加权叠加,生成最终的输出张量 V。 23 | 最终,模块返回张量 V 作为输出。 24 | 25 | 这个模块的核心思想是在不同尺度的卷积核上计算注意力权重,以捕获输入的多尺度信息,然后将不同尺度的特征进行加权叠加以生成最终的输出。这可以增强模型对不同尺度物体的感知能力。 26 | """ 27 | 28 | import torch 29 | from torch import nn 30 | from collections import OrderedDict 31 | 32 | 33 | class SKAttention(nn.Module): 34 | 35 | def __init__(self, channel=512, kernels=[1, 3, 5, 7], reduction=16, group=1, L=32): 36 | super().__init__() 37 | self.d = max(L, channel // reduction) 38 | self.convs = nn.ModuleList([]) 39 | for k in kernels: 40 | self.convs.append( 41 | nn.Sequential(OrderedDict([ 42 | ('conv', nn.Conv2d(channel, channel, kernel_size=k, padding=k // 2, groups=group)), 43 | ('bn', nn.BatchNorm2d(channel)), 44 | ('relu', nn.ReLU()) 45 | ])) 46 | ) 47 | self.fc = nn.Linear(channel, self.d) 48 | self.fcs = nn.ModuleList([]) 49 | for i in range(len(kernels)): 50 | self.fcs.append(nn.Linear(self.d, channel)) 51 | self.softmax = nn.Softmax(dim=0) 52 | 53 | def forward(self, x): 54 | bs, c, _, _ = x.size() 55 | conv_outs = [] 56 | ### split 57 | for conv in self.convs: 58 | conv_outs.append(conv(x)) 59 | feats = torch.stack(conv_outs, 0) # k,bs,channel,h,w 60 | 61 | ### fuse 62 | U = sum(conv_outs) # bs,c,h,w 63 | 64 | ### reduction channel 65 | S = U.mean(-1).mean(-1) # bs,c 66 | Z = self.fc(S) # bs,d 67 | 68 | ### calculate attention weight 69 | weights = [] 70 | for fc in self.fcs: 71 | weight = fc(Z) 72 | weights.append(weight.view(bs, c, 1, 1)) # bs,channel 73 | attention_weughts = torch.stack(weights, 0) # k,bs,channel,1,1 74 | attention_weughts = self.softmax(attention_weughts) # k,bs,channel,1,1 75 | 76 | ### fuse 77 | V = (attention_weughts * feats).sum(0) 78 | return V 79 | 80 | # 输入 N C H W, 输出 N C H W 81 | if __name__ == '__main__': 82 | input = torch.randn(50, 512, 7, 7) 83 | se = SKAttention(channel=512, reduction=8) 84 | output = se(input) 85 | print(output.shape) 86 | -------------------------------------------------------------------------------- /CAN(人群计数,CV2维任务通用).py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore") 3 | import torch.nn as nn 4 | import torch 5 | from torch.nn import functional as F 6 | from torchvision import models 7 | 8 | """ 9 | 在本文中,我们提出了两种基于双路径多尺度融合网络(SFANet)和SegNet的改进神经网络,以实现准确高效的人群计数。 10 | 受 SFANet 的启发,第一个模型被命名为 M-SFANet,附加了多孔空间金字塔池(ASPP)和上下文感知模块(CAN)。 11 | M-SFANet 的编码器通过包含具有不同采样率的并行空洞卷积层的 ASPP 进行了增强,因此能够提取目标对象的多尺度特征并合并更大的上下文。 12 | 为了进一步处理整个输入图像的尺度变化,我们利用 CAN 模块对上下文信息的尺度进行自适应编码。该组合产生了在密集和稀疏人群场景中进行计数的有效模型。 13 | 基于SFANet解码器结构,M-SFANet的解码器具有双路径,用于密度图和注意力图生成。 14 | """ 15 | 16 | 17 | class ContextualModule(nn.Module): 18 | def __init__(self, features, out_features=512, sizes=(1, 2, 3, 6)): 19 | super(ContextualModule, self).__init__() 20 | self.scales = [] 21 | self.scales = nn.ModuleList([self._make_scale(features, size) for size in sizes]) 22 | self.bottleneck = nn.Conv2d(features * 2, out_features, kernel_size=1) 23 | self.relu = nn.ReLU() 24 | self.weight_net = nn.Conv2d(features, features, kernel_size=1) 25 | self._initialize_weights() 26 | 27 | def __make_weight(self, feature, scale_feature): 28 | weight_feature = feature - scale_feature 29 | return F.sigmoid(self.weight_net(weight_feature)) 30 | 31 | def _make_scale(self, features, size): 32 | prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) 33 | conv = nn.Conv2d(features, features, kernel_size=1, bias=False) 34 | return nn.Sequential(prior, conv) 35 | 36 | def forward(self, feats): 37 | h, w = feats.size(2), feats.size(3) 38 | multi_scales = [F.upsample(input=stage(feats), size=(h, w), mode='bilinear') for stage in self.scales] 39 | weights = [self.__make_weight(feats, scale_feature) for scale_feature in multi_scales] 40 | overall_features = [(multi_scales[0] * weights[0] + multi_scales[1] * weights[1] + multi_scales[2] * weights[ 41 | 2] + multi_scales[3] * weights[3]) / (weights[0] + weights[1] + weights[2] + weights[3])] + [feats] 42 | bottle = self.bottleneck(torch.cat(overall_features, 1)) 43 | return self.relu(bottle) 44 | 45 | def _initialize_weights(self): 46 | for m in self.modules(): 47 | if isinstance(m, nn.Conv2d): 48 | nn.init.normal_(m.weight, std=0.01) 49 | if m.bias is not None: 50 | nn.init.constant_(m.bias, 0) 51 | elif isinstance(m, nn.BatchNorm2d): 52 | nn.init.constant_(m.weight, 1) 53 | nn.init.constant_(m.bias, 0) 54 | 55 | 56 | 57 | 58 | if __name__ == '__main__': 59 | block = ContextualModule(features=64, out_features=64) 60 | input_tensor = torch.rand(1, 64, 128, 128) 61 | output = block(input_tensor) 62 | print("Input size:", input_tensor.size()) 63 | print("Output size:", output.size()) -------------------------------------------------------------------------------- /EfficientAdditiveAttnetion.py: -------------------------------------------------------------------------------- 1 | # https:// tinyurl.com/ 5ft8v46w 2 | """ 3 | 以下是这个模块的主要特点和作用: 4 | 5 | 线性变换:模块中包括两个线性层 to_query 和 to_key,分别用于将输入的特征进行线性变换,将特征维度从 in_dims 映射到 token_dim * num_heads。这两个线性层的输出用于计算查询(Query)和键(Key)。 6 | 7 | 可学习的权重:模块中包括一个可学习的权重向量 w_g,用于计算加性注意力的权重。这个权重向量的形状是 (token_dim * num_heads, 1)。 8 | 9 | 归一化:通过 torch.nn.functional.normalize 函数对查询(Query)和键(Key)进行 L2 归一化,以确保它们具有单位长度。 10 | 11 | 权重计算:计算查询(Query)与权重向量 w_g 的点积,并乘以缩放因子 scale_factor(通常是 token_dim 的倒数的平方根),以得到加性注意力的权重 A。 12 | 13 | 归一化:对权重 A 进行归一化,以确保它们在序列长度维度上的和为 1。 14 | 15 | 加权求和:通过将注意力权重 A 与查询(Query)相乘,然后在序列长度维度上求和,得到全局上下文向量 G。 16 | 17 | 扩展 G:通过 einops.repeat 操作,将全局上下文向量 G 扩展为与键(Key)相同形状的张量。 18 | 19 | 注意力计算:通过将扩展后的 G 与键(Key)相乘,然后加上原始查询(Query),得到注意力加权的输出。 20 | 21 | 投影层:通过线性层 Proj 对注意力加权的输出进行投影,将特征维度从 token_dim * num_heads 投影回 token_dim * num_heads。 22 | 23 | 最终投影:通过线性层 final 对投影后的输出进行最终的线性变换,将特征维度从 token_dim * num_heads 投影回 token_dim,并得到最终的输出。 24 | 25 | 总的来说,这个模块实现了一种高效的加性注意力机制,用于学习输入序列的全局上下文信息,并将加权后的全局上下文信息与原始特征进行融合,生成最终的输出特征。这种模块通常用于自注意力机制的一部分,可以用于处理序列数据,如自然语言处理中的 Transformer 模型。 26 | """ 27 | import torch 28 | import torch.nn as nn 29 | import einops 30 | 31 | 32 | class EfficientAdditiveAttnetion(nn.Module): 33 | """ 34 | Efficient Additive Attention module for SwiftFormer. 35 | Input: tensor in shape [B, N, D] 36 | Output: tensor in shape [B, N, D] 37 | """ 38 | 39 | def __init__(self, in_dims=512, token_dim=256, num_heads=2): 40 | super().__init__() 41 | 42 | self.to_query = nn.Linear(in_dims, token_dim * num_heads) 43 | self.to_key = nn.Linear(in_dims, token_dim * num_heads) 44 | 45 | self.w_g = nn.Parameter(torch.randn(token_dim * num_heads, 1)) 46 | self.scale_factor = token_dim ** -0.5 47 | self.Proj = nn.Linear(token_dim * num_heads, token_dim * num_heads) 48 | self.final = nn.Linear(token_dim * num_heads, token_dim) 49 | 50 | def forward(self, x): 51 | query = self.to_query(x) 52 | key = self.to_key(x) 53 | 54 | query = torch.nn.functional.normalize(query, dim=-1) # BxNxD 55 | key = torch.nn.functional.normalize(key, dim=-1) # BxNxD 56 | 57 | query_weight = query @ self.w_g # BxNx1 (BxNxD @ Dx1) 58 | A = query_weight * self.scale_factor # BxNx1 59 | 60 | A = torch.nn.functional.normalize(A, dim=1) # BxNx1 61 | 62 | G = torch.sum(A * query, dim=1) # BxD 63 | 64 | G = einops.repeat( 65 | G, "b d -> b repeat d", repeat=key.shape[1] 66 | ) # BxNxD 67 | 68 | out = self.Proj(G * key) + query # BxNxD 69 | 70 | out = self.final(out) # BxNxD 71 | 72 | return out 73 | 74 | 75 | # 输入 B N C , 输出 B N C 76 | if __name__ == '__main__': 77 | block = EfficientAdditiveAttnetion(64, 32).cuda() 78 | input = torch.rand(3, 64 * 64, 64).cuda() 79 | output = block(input) 80 | print(input.size(), output.size()) 81 | -------------------------------------------------------------------------------- /VisionPermutator.py: -------------------------------------------------------------------------------- 1 | # https://github.com/Andrew-Qibin/VisionPermutator 2 | 3 | """ 4 | MLP (Multi-Layer Perceptron) 模块: 5 | 6 | MLP 是一个多层感知器(MLP)模块,用于将输入数据进行线性变换和激活函数操作,以学习和提取特征。 7 | 8 | 构造函数 (__init__) 接受以下参数: 9 | 10 | in_features:输入特征的维度。 11 | hidden_features:中间隐藏层的特征维度。 12 | out_features:输出层的特征维度。 13 | act_layer:激活函数,默认为 GELU。 14 | drop:Dropout 概率,默认为 0.1。 15 | MLP 模块包括两个线性层(fc1 和 fc2),一个激活函数(act_layer)和一个 Dropout 层(drop)。 16 | 17 | forward 方法接受输入 x,首先将输入经过第一个线性层和激活函数,然后应用 Dropout,最后通过第二个线性层得到输出。 18 | 19 | WeightedPermuteMLP 模块: 20 | 21 | WeightedPermuteMLP 是一个自注意力模块,它用于对输入张量进行特征变换和加权重组。 22 | 23 | 构造函数 (__init__) 接受以下参数: 24 | 25 | dim:输入特征的维度。 26 | seg_dim:分段维度,默认为 8。 27 | qkv_bias:Q、K 和 V 投影是否包括偏差,默认为 False。 28 | proj_drop:投影层后的 Dropout 概率,默认为 0。 29 | WeightedPermuteMLP 模块首先将输入张量通过三个线性层(mlp_c、mlp_h 和 mlp_w)进行特征变换,分别用于通道、高度和宽度方向。 30 | 31 | 输入张量被分成多个段,并在通道维度上进行重组,然后经过线性层进行特征变换。 32 | 33 | 每个变换后的段都会计算一个权重,然后通过加权平均的方式将这些段组合在一起,以获得最终的输出。 34 | 35 | 最终输出通过投影层和 Dropout 进行后处理。 36 | 37 | 这两个模块通常用于神经网络的不同部分,用于特征提取和建模。MLP 主要用于局部特征的提取,而 WeightedPermuteMLP 主要用于加权重组特征以增强全局特征表示。 38 | """ 39 | 40 | import torch 41 | from torch import nn 42 | 43 | 44 | class MLP(nn.Module): 45 | def __init__(self,in_features,hidden_features,out_features,act_layer=nn.GELU,drop=0.1): 46 | super().__init__() 47 | self.fc1=nn.Linear(in_features,hidden_features) 48 | self.act=act_layer() 49 | self.fc2=nn.Linear(hidden_features,out_features) 50 | self.drop=nn.Dropout(drop) 51 | 52 | def forward(self, x) : 53 | return self.drop(self.fc2(self.drop(self.act(self.fc1(x))))) 54 | 55 | class WeightedPermuteMLP(nn.Module): 56 | def __init__(self,dim,seg_dim=8, qkv_bias=False, proj_drop=0.): 57 | super().__init__() 58 | self.seg_dim=seg_dim 59 | 60 | self.mlp_c=nn.Linear(dim,dim,bias=qkv_bias) 61 | self.mlp_h=nn.Linear(dim,dim,bias=qkv_bias) 62 | self.mlp_w=nn.Linear(dim,dim,bias=qkv_bias) 63 | 64 | self.reweighting=MLP(dim,dim//4,dim*3) 65 | 66 | self.proj=nn.Linear(dim,dim) 67 | self.proj_drop=nn.Dropout(proj_drop) 68 | 69 | def forward(self,x) : 70 | B,H,W,C=x.shape 71 | 72 | c_embed=self.mlp_c(x) 73 | 74 | S=C//self.seg_dim 75 | h_embed=x.reshape(B,H,W,self.seg_dim,S).permute(0,3,2,1,4).reshape(B,self.seg_dim,W,H*S) 76 | h_embed=self.mlp_h(h_embed).reshape(B,self.seg_dim,W,H,S).permute(0,3,2,1,4).reshape(B,H,W,C) 77 | 78 | w_embed=x.reshape(B,H,W,self.seg_dim,S).permute(0,3,1,2,4).reshape(B,self.seg_dim,H,W*S) 79 | w_embed=self.mlp_w(w_embed).reshape(B,self.seg_dim,H,W,S).permute(0,2,3,1,4).reshape(B,H,W,C) 80 | 81 | weight=(c_embed+h_embed+w_embed).permute(0,3,1,2).flatten(2).mean(2) 82 | weight=self.reweighting(weight).reshape(B,C,3).permute(2,0,1).softmax(0).unsqueeze(2).unsqueeze(2) 83 | 84 | x=c_embed*weight[0]+w_embed*weight[1]+h_embed*weight[2] 85 | 86 | x=self.proj_drop(self.proj(x)) 87 | 88 | return x 89 | 90 | 91 | 92 | if __name__ == '__main__': 93 | input=torch.randn(64,8,8,512) 94 | seg_dim=8 95 | vip=WeightedPermuteMLP(512,seg_dim) 96 | out=vip(input) 97 | print(out.shape) 98 | -------------------------------------------------------------------------------- /S2Attention.py: -------------------------------------------------------------------------------- 1 | # https://paperswithcode.com/paper/s-2-mlpv2-improved-spatial-shift-mlp 2 | 3 | """ 4 | SplitAttention: 5 | 6 | 这是一个分离式注意力(Split Attention)模块,用于增强神经网络的特征表示。 7 | 参数包括 channel(通道数)和 k(分离的注意力头数)。 8 | 在前向传播中,输入张量 x_all 被重塑为形状 (b, k, h*w, c),其中 b 是批次大小,k 是头数,h 和 w 是高度和宽度,c 是通道数。 9 | 然后,计算注意力的权重,通过 MLP 网络计算 hat_a,然后应用 softmax 函数得到 bar_a。 10 | 最后,将 bar_a 与输入张量 x_all 相乘,并对所有头的结果进行求和以获得最终的输出。 11 | S2Attention: 12 | 13 | 这是一个基于Split Attention的注意力模块,用于处理输入张量。 14 | 参数包括 channels(通道数)。 15 | 在前向传播中,首先对输入张量进行线性变换,然后将结果分为三部分(x1、x2 和 x3)。 16 | 接下来,这三部分被传递给 SplitAttention 模块,以计算注意力权重并增强特征表示。 17 | 最后,通过另一个线性变换将注意力增强后的特征表示进行合并并返回。 18 | 这些模块可以用于构建神经网络中的不同层,以提高特征表示的性能和泛化能力。 19 | """ 20 | 21 | import numpy as np 22 | import torch 23 | from torch import nn 24 | from torch.nn import init 25 | 26 | 27 | def spatial_shift1(x): 28 | b, w, h, c = x.size() 29 | x[:, 1:, :, :c // 4] = x[:, :w - 1, :, :c // 4] 30 | x[:, :w - 1, :, c // 4:c // 2] = x[:, 1:, :, c // 4:c // 2] 31 | x[:, :, 1:, c // 2:c * 3 // 4] = x[:, :, :h - 1, c // 2:c * 3 // 4] 32 | x[:, :, :h - 1, 3 * c // 4:] = x[:, :, 1:, 3 * c // 4:] 33 | return x 34 | 35 | 36 | def spatial_shift2(x): 37 | b, w, h, c = x.size() 38 | x[:, :, 1:, :c // 4] = x[:, :, :h - 1, :c // 4] 39 | x[:, :, :h - 1, c // 4:c // 2] = x[:, :, 1:, c // 4:c // 2] 40 | x[:, 1:, :, c // 2:c * 3 // 4] = x[:, :w - 1, :, c // 2:c * 3 // 4] 41 | x[:, :w - 1, :, 3 * c // 4:] = x[:, 1:, :, 3 * c // 4:] 42 | return x 43 | 44 | 45 | class SplitAttention(nn.Module): 46 | def __init__(self, channel=32, k=3): 47 | super().__init__() 48 | self.channel = channel 49 | self.k = k 50 | self.mlp1 = nn.Linear(channel, channel, bias=False) 51 | self.gelu = nn.GELU() 52 | self.mlp2 = nn.Linear(channel, channel * k, bias=False) 53 | self.softmax = nn.Softmax(1) 54 | 55 | def forward(self, x_all): 56 | b, k, h, w, c = x_all.shape 57 | x_all = x_all.reshape(b, k, -1, c) # bs,k,n,c 58 | a = torch.sum(torch.sum(x_all, 1), 1) # bs,c 59 | hat_a = self.mlp2(self.gelu(self.mlp1(a))) # bs,kc 60 | hat_a = hat_a.reshape(b, self.k, c) # bs,k,c 61 | bar_a = self.softmax(hat_a) # bs,k,c 62 | attention = bar_a.unsqueeze(-2) # #bs,k,1,c 63 | out = attention * x_all # #bs,k,n,c 64 | out = torch.sum(out, 1).reshape(b, h, w, c) 65 | return out 66 | 67 | 68 | class S2Attention(nn.Module): 69 | 70 | def __init__(self, channels=32): 71 | super().__init__() 72 | self.mlp1 = nn.Linear(channels, channels * 3) 73 | self.mlp2 = nn.Linear(channels, channels) 74 | self.split_attention = SplitAttention() 75 | 76 | def forward(self, x): 77 | b, c, w, h = x.size() 78 | x = x.permute(0, 2, 3, 1) 79 | x = self.mlp1(x) 80 | x1 = spatial_shift1(x[:, :, :, :c]) 81 | x2 = spatial_shift2(x[:, :, :, c:c * 2]) 82 | x3 = x[:, :, :, c * 2:] 83 | x_all = torch.stack([x1, x2, x3], 1) 84 | a = self.split_attention(x_all) 85 | x = self.mlp2(a) 86 | x = x.permute(0, 3, 1, 2) 87 | return x 88 | 89 | 90 | # 输入 N C H W, 输出 N C H W 91 | if __name__ == '__main__': 92 | input = torch.randn(64, 32, 7, 7) 93 | s2att = S2Attention(channels=32) 94 | output = s2att(input) 95 | print(output.shape) 96 | -------------------------------------------------------------------------------- /TripletAttention.py: -------------------------------------------------------------------------------- 1 | # https://github.com/mindspore-courses/External-Attention-MindSpore/blob/main/model/attention/TripletAttention.py 2 | 3 | """ 4 | 以下是这些模块的主要特点和作用: 5 | 6 | BasicConv 模块: 7 | 8 | 这是一个基本的卷积模块,用于进行卷积操作,包括卷积、批归一化(可选)、ReLU 激活函数(可选)。 9 | 可以通过参数来控制是否使用批归一化和ReLU激活函数。 10 | ZPool 模块: 11 | 12 | 这是一个自定义的池化操作,将输入的特征图进行最大池化和平均池化,然后将它们拼接在一起。 13 | AttentionGate 模块: 14 | 15 | 这个模块实现了一个注意力门控机制,用于学习特征图的注意力权重。 16 | 首先通过 ZPool 操作将输入的特征图进行池化。 17 | 然后应用一个卷积层,该卷积层输出一个注意力权重,通过 Sigmoid 激活函数将其归一化。 18 | 最后,将输入特征图与注意力权重相乘,以得到加权的特征图。 19 | TripletAttention 模块: 20 | 21 | 这个模块实现了一种三重注意力机制,用于学习特征图的全局和局部信息。 22 | 该模块包括三个 AttentionGate 模块,分别用于通道维度(c)、高度维度(h)和宽度维度(w)的注意力权重学习。 23 | 可以通过参数 no_spatial 来控制是否忽略空间维度。 24 | 最终,将三个注意力权重加权平均,以得到最终的特征图。 25 | """ 26 | 27 | import torch 28 | import torch.nn as nn 29 | 30 | 31 | class BasicConv(nn.Module): 32 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, 33 | bn=True, bias=False): 34 | super(BasicConv, self).__init__() 35 | self.out_channels = out_planes 36 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, 37 | dilation=dilation, groups=groups, bias=bias) 38 | self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None 39 | self.relu = nn.ReLU() if relu else None 40 | 41 | def forward(self, x): 42 | x = self.conv(x) 43 | if self.bn is not None: 44 | x = self.bn(x) 45 | if self.relu is not None: 46 | x = self.relu(x) 47 | return x 48 | 49 | 50 | class ZPool(nn.Module): 51 | def forward(self, x): 52 | return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1) 53 | 54 | 55 | class AttentionGate(nn.Module): 56 | def __init__(self): 57 | super(AttentionGate, self).__init__() 58 | kernel_size = 7 59 | self.compress = ZPool() 60 | self.conv = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False) 61 | 62 | def forward(self, x): 63 | x_compress = self.compress(x) 64 | x_out = self.conv(x_compress) 65 | scale = torch.sigmoid_(x_out) 66 | return x * scale 67 | 68 | 69 | class TripletAttention(nn.Module): 70 | def __init__(self, no_spatial=False): 71 | super(TripletAttention, self).__init__() 72 | self.cw = AttentionGate() 73 | self.hc = AttentionGate() 74 | self.no_spatial = no_spatial 75 | if not no_spatial: 76 | self.hw = AttentionGate() 77 | 78 | def forward(self, x): 79 | x_perm1 = x.permute(0, 2, 1, 3).contiguous() 80 | x_out1 = self.cw(x_perm1) 81 | x_out11 = x_out1.permute(0, 2, 1, 3).contiguous() 82 | x_perm2 = x.permute(0, 3, 2, 1).contiguous() 83 | x_out2 = self.hc(x_perm2) 84 | x_out21 = x_out2.permute(0, 3, 2, 1).contiguous() 85 | if not self.no_spatial: 86 | x_out = self.hw(x) 87 | x_out = 1 / 3 * (x_out + x_out11 + x_out21) 88 | else: 89 | x_out = 1 / 2 * (x_out11 + x_out21) 90 | return x_out 91 | 92 | 93 | if __name__ == '__main__': 94 | input = torch.randn(50, 512, 7, 7) 95 | triplet = TripletAttention() 96 | output = triplet(input) 97 | print(output.shape) 98 | -------------------------------------------------------------------------------- /LFA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import os 5 | import torch.nn.functional as F 6 | 7 | """ 8 | 随着智慧城市的发展,交通流预测(TFP)越来越受到人们的关注。在过去的几年中,基于神经网络的方法在 TFP 方面表现出了令人印象深刻的性能。 9 | 然而,之前的大多数研究未能明确有效地模拟流入和流出之间的关系。因此,这些方法通常无法解释且不准确。 10 | 在本文中,我们提出了一种用于 TFP 的可解释的局部流注意(LFA)机制,它具有三个优点。 11 | (1) LFA 具有流量感知能力。与现有的在通道维度上混合流入和流出的作品不同,我们通过一种新颖的注意力机制明确地利用了流量之间的相关性。 12 | (2) LFA是可解释的。它是根据交通流的真理制定的,学习到的注意力权重可以很好地解释流量相关性。 13 | (3) LFA高效。 LFA没有像之前的研究那样使用全局空间注意力,而是利用局部模式。注意力查询仅在局部相关区域上执行。这不仅降低了计算成本,还避免了错误关注。 14 | """ 15 | 16 | class LFA(nn.Module): 17 | def __init__(self, hidden_channel): 18 | super(LFA, self).__init__() 19 | self.proj_hq = nn.Conv2d(in_channels=hidden_channel, out_channels=hidden_channel, kernel_size=1, stride=1,bias=False) 20 | self.proj_mk = nn.Conv2d(in_channels=hidden_channel, out_channels=hidden_channel, kernel_size=1, stride=1,bias=False) 21 | self.proj_mv = nn.Conv2d(in_channels=hidden_channel, out_channels=hidden_channel, kernel_size=1, stride=1,bias=False) 22 | 23 | self.kernel_size=7 24 | self.pad=3 25 | 26 | self.dis=self.init_distance() 27 | 28 | def init_distance(self): 29 | dis=torch.zeros(self.kernel_size,self.kernel_size).cuda() 30 | certer_x=int((self.kernel_size-1)/2) 31 | certer_y = int((self.kernel_size - 1) / 2) 32 | for i in range(self.kernel_size): 33 | for j in range(self.kernel_size): 34 | ii=i-certer_x 35 | jj=j-certer_y 36 | tmp=(self.kernel_size-1)*(self.kernel_size-1) 37 | tmp=(ii*ii+jj*jj)/tmp+dis[i,j] 38 | dis[i,j]=torch.exp(-tmp) 39 | dis[certer_x,certer_y]=0 40 | return dis 41 | 42 | 43 | def forward(self, H,M): 44 | b,c, h, w = H.shape 45 | pad_M=F.pad(M,[self.pad,self.pad,self.pad,self.pad]) 46 | 47 | Q_h = self.proj_hq(H) # b,c,h,w 48 | K_m = self.proj_mk(pad_M) # b,c,h+2,w+2 49 | V_m = self.proj_mv(pad_M) # b,c,h+2,w+2 50 | 51 | K_m=K_m.unfold(2,self.kernel_size,1).unfold(3,self.kernel_size,1) # b,c,h,w,k,k 52 | V_m=V_m.unfold(2,self.kernel_size,1).unfold(3,self.kernel_size,1) # b,c,h,w,k,k 53 | 54 | Q_h=Q_h.permute(0,2,3,1) # b,h,w,c 55 | K_m=K_m.permute(0,2,3,4,5,1) # b,h,w,k,k,c 56 | K_m=K_m.contiguous().view(b,h,w,-1,c) # b,h,w,(k*k),c 57 | alpha=torch.einsum('bhwik,bhwkj->bhwij',K_m,Q_h.unsqueeze(-1)) # b,h,w,(k*k),1 58 | dis_alpha=self.dis.view(-1,1) # (k*k),1 59 | alpha=alpha*dis_alpha 60 | alpha = F.softmax(alpha.squeeze(dim=-1), dim=-1) # b,h,w,(k*k) 61 | V_m=V_m.permute(0,2,3,4,5,1).contiguous().view(b,h,w,-1,c) # b,h,w,(k*k),c 62 | res=torch.einsum('bhwik,bhwkj->bhwij',alpha.unsqueeze(dim=-2),V_m) # b,h,w,1,c 63 | res=res.permute(0,4,1,2,3).squeeze(-1) # b,c,h,w 64 | return res 65 | 66 | if __name__ == '__main__': 67 | hidden_channel = 64 # 隐藏通道数 68 | block = LFA(hidden_channel).to(device=0) # 创建 LFA 实例 69 | input_H = torch.rand(1, hidden_channel, 32, 32).to(device=0) # 输入H 70 | input_M = torch.rand(1, hidden_channel, 32, 32).to(device=0) # 输入M 71 | output = block(input_H, input_M) # 模型前向传播 72 | print("Input shape (H):", input_H.size()) 73 | print("Input shape (M):", input_M.size()) 74 | print("Output shape: ", output.size()) 75 | -------------------------------------------------------------------------------- /OutlookAtt.py: -------------------------------------------------------------------------------- 1 | # https://github.com/sail-sg/volo 2 | 3 | """ 4 | 以下是该模块的主要组件和操作: 5 | 6 | v_pj:通过线性变换将输入特征映射到新的特征空间,以产生 v。 7 | 8 | attn:通过线性变换将输入图像的局部区域映射到注意力得分的空间。这个得分表示局部区域的重要性。 9 | 10 | attn_drop:一个用于应用注意力得分的丢弃层,以防止过度拟合。 11 | 12 | proj 和 proj_drop:用于最终输出的线性变换和丢弃层。 13 | 14 | unflod:一个用于手动卷积的操作,将 v 特征张量按指定的 kernel_size、padding 和 stride 进行展开。 15 | 16 | pool:用于在输入图像上执行平均池化,以减小图像尺寸。 17 | 18 | 在前向传播中,模块首先将输入图像的局部区域映射到 v 特征空间,然后计算注意力得分。注意力得分被应用于 v 特征以获得加权特征表示。最后,通过线性变换和丢弃层来进一步处理特征表示,以产生最终的输出。 19 | 20 | 这个模块的主要用途是捕获输入图像的局部信息,并根据局部区域的重要性来加权特征表示。这对于各种计算机视觉任务,如图像分类和分割,可能都会有所帮助。 21 | """ 22 | 23 | import numpy as np 24 | import torch 25 | from torch import nn 26 | from torch.nn import init 27 | import math 28 | from torch.nn import functional as F 29 | 30 | 31 | class OutlookAttention(nn.Module): 32 | 33 | def __init__(self, dim, num_heads=1, kernel_size=3, padding=1, stride=1, qkv_bias=False, 34 | attn_drop=0.1): 35 | super().__init__() 36 | self.dim = dim 37 | self.num_heads = num_heads 38 | self.head_dim = dim // num_heads 39 | self.kernel_size = kernel_size 40 | self.padding = padding 41 | self.stride = stride 42 | self.scale = self.head_dim ** (-0.5) 43 | 44 | self.v_pj = nn.Linear(dim, dim, bias=qkv_bias) 45 | self.attn = nn.Linear(dim, kernel_size ** 4 * num_heads) 46 | 47 | self.attn_drop = nn.Dropout(attn_drop) 48 | self.proj = nn.Linear(dim, dim) 49 | self.proj_drop = nn.Dropout(attn_drop) 50 | 51 | self.unflod = nn.Unfold(kernel_size, padding, stride) # 手动卷积 52 | self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True) 53 | 54 | def forward(self, x): 55 | B, H, W, C = x.shape 56 | 57 | # 映射到新的特征v 58 | v = self.v_pj(x).permute(0, 3, 1, 2) # B,C,H,W 59 | h, w = math.ceil(H / self.stride), math.ceil(W / self.stride) 60 | v = self.unflod(v).reshape(B, self.num_heads, self.head_dim, self.kernel_size * self.kernel_size, 61 | h * w).permute(0, 1, 4, 3, 2) # B,num_head,H*W,kxk,head_dim 62 | 63 | # 生成Attention Map 64 | attn = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) # B,H,W,C 65 | attn = self.attn(attn).reshape(B, h * w, self.num_heads, self.kernel_size * self.kernel_size \ 66 | , self.kernel_size * self.kernel_size).permute(0, 2, 1, 3, 67 | 4) # B,num_head,H*W,kxk,kxk 68 | attn = self.scale * attn 69 | attn = attn.softmax(-1) 70 | attn = self.attn_drop(attn) 71 | 72 | # 获取weighted特征 73 | out = (attn @ v).permute(0, 1, 4, 3, 2).reshape(B, C * self.kernel_size * self.kernel_size, 74 | h * w) # B,dimxkxk,H*W 75 | out = F.fold(out, output_size=(H, W), kernel_size=self.kernel_size, 76 | padding=self.padding, stride=self.stride) # B,C,H,W 77 | out = self.proj(out.permute(0, 2, 3, 1)) # B,H,W,C 78 | out = self.proj_drop(out) 79 | 80 | return out 81 | 82 | 83 | # 输入 B, H, W, C, 输出 B, H, W, C 84 | if __name__ == '__main__': 85 | block = OutlookAttention(dim=256).cuda() 86 | # input = torch.rand(1, 64, 64, 512).cuda() 87 | input = torch.rand(1, 128, 256, 256).cuda() 88 | output = block(input) 89 | print(input.size(), output.size()) 90 | -------------------------------------------------------------------------------- /Pacoloss(参数对比损失,用于对比学习).py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | """ 5 | 在本文中,我们提出了参数对比学习 (PaCo) 来解决长尾识别问题。基于理论分析,我们观察到监督对比损失倾向于偏向高频类别,从而增加了不平衡学习的难度。我们引入了一组参数化的类可学习中心,从优化的角度重新平衡。 6 | 此外,我们在平衡设置下分析了我们的 PaCo 损失。我们的分析表明,随着更多样本与其相应的中心被拉到一起,PaCo 可以自适应地增强将同一类别的样本推近的强度,并有利于硬示例学习。 7 | 在长尾 CIFAR、ImageNet、Places 和 iNaturalist 2018 上的实验展现了长尾识别的最新进展。 8 | """ 9 | 10 | 11 | class PaCoLoss(nn.Module): 12 | def __init__(self, alpha=1.0, beta=1.0, gamma=0.0, supt=1.0, temperature=1.0, base_temperature=None, K=128, 13 | num_classes=1000): 14 | super(PaCoLoss, self).__init__() 15 | self.temperature = temperature 16 | self.base_temperature = temperature if base_temperature is None else base_temperature 17 | self.K = K 18 | self.alpha = alpha 19 | self.beta = beta 20 | self.gamma = gamma 21 | self.supt = supt 22 | self.num_classes = num_classes 23 | 24 | def forward(self, features, labels=None, sup_logits=None): 25 | device = torch.device('cuda' if features.is_cuda else 'cpu') 26 | 27 | batch_size = features.shape[0] 28 | 29 | labels = labels.contiguous().view(-1, 1) 30 | mask = torch.eq(labels[:batch_size], labels.T).float().to(device) 31 | 32 | # compute logits using complete features tensor 33 | anchor_dot_contrast = torch.div( 34 | torch.matmul(features, features.T), 35 | self.temperature) 36 | 37 | # add supervised logits 38 | anchor_dot_contrast = torch.cat(((sup_logits) / self.supt, anchor_dot_contrast), dim=1) 39 | 40 | # for numerical stability 41 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 42 | logits = anchor_dot_contrast - logits_max.detach() 43 | 44 | # mask-out self-contrast cases 45 | logits_mask = torch.scatter( 46 | torch.ones_like(mask), 47 | 1, 48 | torch.arange(batch_size).view(-1, 1).to(device), 49 | 0 50 | ) 51 | 52 | mask = mask * logits_mask 53 | 54 | # add ground truth 55 | one_hot_label = torch.nn.functional.one_hot(labels[:batch_size, ].view(-1, ), num_classes=self.num_classes).to( 56 | torch.float32) 57 | mask = torch.cat((one_hot_label * self.beta, mask * self.alpha), dim=1) 58 | 59 | # compute log_prob 60 | logits_mask = torch.cat((torch.ones(batch_size, self.num_classes).to(device), self.gamma * logits_mask), dim=1) 61 | exp_logits = torch.exp(logits) * logits_mask 62 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12) 63 | 64 | # compute mean of log-likelihood over positive 65 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 66 | 67 | # loss 68 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 69 | loss = loss.mean() 70 | 71 | return loss 72 | 73 | 74 | if __name__ == '__main__': 75 | # 初始化 PaCoLoss 模型 76 | block = PaCoLoss() 77 | 78 | # 随机生成输入特征、标签和监督logits 79 | input_features = torch.rand(64, 64) # 例如,64 个样本的特征 80 | labels = torch.randint(0, 10, (64,)) # 例如,64 个样本的标签 81 | sup_logits = torch.rand(64, 1000) # 例如,64 个样本的监督 logits 82 | 83 | print("Supervised logits shape:", sup_logits.size()) 84 | 85 | # 使用输入数据计算损失 86 | loss = block(input_features, labels=labels, sup_logits=sup_logits) 87 | 88 | # 输出输入特征的形状和计算得到的损失值 89 | print("Input features shape:", input_features.size()) 90 | print("Output loss:", loss.item()) 91 | 92 | -------------------------------------------------------------------------------- /UFOAttention.py: -------------------------------------------------------------------------------- 1 | # https://github.com/mindspore-courses/External-Attention-MindSpore/blob/main/model/attention/UFOAttention.py 2 | 3 | """ 4 | 以下是这个模块的主要特点和作用: 5 | 6 | 多头自注意力:这个模块使用了多头自注意力机制,通过将输入进行不同线性变换,分为多个头来计算注意力。h 参数表示注意力头的数量。 7 | 8 | 线性变换:模块中的线性层(fc_q、fc_k、fc_v 和 fc_o)用于将输入进行线性变换,以生成查询(Q)、键(K)和值(V)的向量。 9 | 10 | 权重初始化:模块中的线性层的权重被初始化,以确保良好的训练收敛性。这些初始化方法包括卷积层的 He 初始化和线性层的正态分布初始化。 11 | 12 | 注意力计算:通过计算 Q 和 K 的点积,然后应用归一化函数,得到注意力矩阵。在这个模块中,注意力矩阵经过了一些自定义的归一化(XNorm 函数)。 13 | 14 | 多头特征整合:多个注意力头的输出被整合在一起,然后通过线性层进行进一步的处理,以生成最终的输出。 15 | 16 | Dropout 正则化:模块中使用了 Dropout 操作,以减少过拟合的风险。 17 | 18 | 参数化的缩放因子:模块中包括一个可学习的缩放因子 gamma,用于调整注意力计算的缩放。 19 | 20 | 总的来说,UFOAttention模块提供了一种用于神经网络中的自注意力机制,它可以根据输入数据生成不同的查询、键和值,并计算注意力矩阵,然后整合多个头的输出以生成最终的特征表示。这种模块通常用于处理序列数据,如自然语言处理中的 Transformer 模型中的注意力层。 21 | """ 22 | 23 | import numpy as np 24 | import torch 25 | from torch import nn 26 | from torch.functional import norm 27 | from torch.nn import init 28 | 29 | 30 | def XNorm(x, gamma): 31 | norm_tensor = torch.norm(x, 2, -1, True) 32 | return x * gamma / norm_tensor 33 | 34 | 35 | class UFOAttention(nn.Module): 36 | ''' 37 | Scaled dot-product attention 38 | ''' 39 | 40 | def __init__(self, d_model, d_k, d_v, h, dropout=.1): 41 | ''' 42 | :param d_model: Output dimensionality of the model 43 | :param d_k: Dimensionality of queries and keys 44 | :param d_v: Dimensionality of values 45 | :param h: Number of heads 46 | ''' 47 | super(UFOAttention, self).__init__() 48 | self.fc_q = nn.Linear(d_model, h * d_k) 49 | self.fc_k = nn.Linear(d_model, h * d_k) 50 | self.fc_v = nn.Linear(d_model, h * d_v) 51 | self.fc_o = nn.Linear(h * d_v, d_model) 52 | self.dropout = nn.Dropout(dropout) 53 | self.gamma = nn.Parameter(torch.randn((1, h, 1, 1))) 54 | 55 | self.d_model = d_model 56 | self.d_k = d_k 57 | self.d_v = d_v 58 | self.h = h 59 | 60 | self.init_weights() 61 | 62 | def init_weights(self): 63 | for m in self.modules(): 64 | if isinstance(m, nn.Conv2d): 65 | init.kaiming_normal_(m.weight, mode='fan_out') 66 | if m.bias is not None: 67 | init.constant_(m.bias, 0) 68 | elif isinstance(m, nn.BatchNorm2d): 69 | init.constant_(m.weight, 1) 70 | init.constant_(m.bias, 0) 71 | elif isinstance(m, nn.Linear): 72 | init.normal_(m.weight, std=0.001) 73 | if m.bias is not None: 74 | init.constant_(m.bias, 0) 75 | 76 | def forward(self, queries, keys, values): 77 | b_s, nq = queries.shape[:2] 78 | nk = keys.shape[1] 79 | 80 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) 81 | k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) 82 | v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) 83 | 84 | kv = torch.matmul(k, v) # bs,h,c,c 85 | kv_norm = XNorm(kv, self.gamma) # bs,h,c,c 86 | q_norm = XNorm(q, self.gamma) # bs,h,n,c 87 | out = torch.matmul(q_norm, kv_norm).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) 88 | out = self.fc_o(out) # (b_s, nq, d_model) 89 | 90 | return out 91 | 92 | 93 | if __name__ == '__main__': 94 | block = UFOAttention(d_model=512, d_k=512, d_v=512, h=8).cuda() 95 | input = torch.rand(64, 64, 512).cuda() 96 | output = block(input, input, input) 97 | print(input.size(), output.size()) 98 | -------------------------------------------------------------------------------- /Deepfake(深度伪造检测).py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | """ 7 | Deepfake 席卷全球,引发信任危机。当前的深度伪造检测方法通常普遍性不足,容易过度拟合背景等图像内容,这种情况在训练数据集中经常出现但相对不重要。 8 | 此外,当前的方法严重依赖于一些主要的伪造区域,并且可能忽略其他同等重要的区域,导致伪造线索的发现不充分。 9 | 设计了三个功能模块来处理协作学习方案中的多流和多尺度特征 10 | """ 11 | 12 | 13 | 14 | class CMCE(nn.Module): # Contrastive Multimodal Contrastive Enhancement 增强模型对特征的关注度,提高模型的性能 15 | def __init__(self, in_channel=3): 16 | super(CMCE, self).__init__() 17 | self.relu = nn.ReLU() 18 | self.bn = nn.BatchNorm2d(in_channel) 19 | self.stage1 = nn.Sequential( 20 | nn.Conv2d(in_channel, in_channel, 3, 1, bias=False), 21 | nn.BatchNorm2d(in_channel), 22 | nn.ReLU() 23 | ) 24 | self.stage2 = nn.Sequential( 25 | nn.Conv2d(in_channel, in_channel, 3, 1, bias=False), 26 | nn.BatchNorm2d(in_channel), 27 | nn.ReLU() 28 | ) 29 | 30 | def forward(self, fa, fb): 31 | (b1, c1, h1, w1), (b2, c2, h2, w2) = fa.size(), fb.size() 32 | assert c1 == c2 33 | cos_sim = F.cosine_similarity(fa, fb, dim=1) 34 | cos_sim = cos_sim.unsqueeze(1) 35 | fa = fa + fb * cos_sim 36 | fb = fb + fa * cos_sim 37 | fa = self.relu(fa) 38 | fb = self.relu(fb) 39 | 40 | return fa, fb 41 | 42 | if __name__ == '__main__': 43 | block = CMCE() 44 | fa = torch.rand(16, 3, 32, 32) 45 | fb = torch.rand(16, 3, 32, 32) 46 | 47 | fa1, fb1 = block(fa, fb) 48 | print(fa.size()) 49 | print(fb.size()) 50 | print(fa1.size()) 51 | print(fb1.size()) 52 | 53 | 54 | class LFGA(nn.Module): # Local Feature Guidance Attention 旨在引导特征图的注意力以更好地聚焦在局部特征上 55 | def __init__(self, in_channel=3, out_channel=None, ratio=4): 56 | super(LFGA, self).__init__() 57 | self.chanel_in = in_channel 58 | 59 | if out_channel is None: 60 | out_channel = in_channel // ratio if in_channel // ratio > 0 else 1 61 | 62 | self.query_conv = nn.Conv2d( 63 | in_channels=in_channel, out_channels=out_channel, kernel_size=1) 64 | self.key_conv = nn.Conv2d( 65 | in_channels=in_channel, out_channels=out_channel, kernel_size=1) 66 | self.value_conv = nn.Conv2d( 67 | in_channels=in_channel, out_channels=in_channel, kernel_size=1) 68 | self.gamma = nn.Parameter(torch.zeros(1)) 69 | 70 | self.softmax = nn.Softmax(dim=-1) 71 | self.relu = nn.ReLU() 72 | self.bn = nn.BatchNorm2d(self.chanel_in) 73 | 74 | def forward(self, fa, fb): 75 | B, C, H, W = fa.size() 76 | proj_query = self.query_conv(fb).view( 77 | B, -1, H * W).permute(0, 2, 1) # B , HW, C 78 | proj_key = self.key_conv(fb).view( 79 | B, -1, H * W) # B X C x (*W*H) 80 | energy = torch.bmm(proj_query, proj_key) # B, HW, HW 81 | attention = self.softmax(energy) # BX (N) X (N) 82 | # attention = F.normalize(energy, dim=-1) 83 | 84 | proj_value = self.value_conv(fa).view( 85 | B, -1, H * W) # B , C , HW 86 | 87 | out = torch.bmm(proj_value, attention.permute(0, 2, 1)) 88 | out = out.view(B, C, H, W) 89 | 90 | out = self.gamma * out + fa 91 | 92 | return self.relu(out) 93 | 94 | 95 | if __name__ == '__main__': 96 | block = LFGA(in_channel=3, ratio=4) 97 | fa = torch.rand(16, 3, 32, 32) 98 | fb = torch.rand(16, 3, 32, 32) 99 | 100 | output = block(fa, fb) 101 | print(fa.size()) 102 | print(fb.size()) 103 | print(output.size()) 104 | -------------------------------------------------------------------------------- /Free_UNetModel(扩散模型).py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.fft as fft 4 | 5 | """ 6 | 在这篇论文中,我们揭示了扩散 U-Net 的潜力,它被视为一种“免费午餐”,可以在生成过程中大幅提高质量。 7 | 我们首先研究了 U-Net 架构对去噪过程的关键贡献,并确定其主要骨干主要贡献于去噪,而其跳跃连接主要将高频特征引入解码器模块,导致网络忽略了骨干语义。 8 | 基于这一发现,我们提出了一种简单而有效的方法——称为“FreeU”——它可以提高生成质量,而无需额外的训练或微调。我们的关键见解是,战略性地重新加权源自 U-Net 跳跃连接和骨干特征图的贡献 9 | 以利用 U-Net 架构的两个组成部分的优势。在图像和视频生成任务上的有希望的结果表明,我们的 FreeU 可以轻松集成到现有的扩散模型中. 10 | """ 11 | 12 | def Fourier_filter(x, threshold, scale): 13 | # FFT 14 | x_freq = fft.fftn(x, dim=(-2, -1)) 15 | x_freq = fft.fftshift(x_freq, dim=(-2, -1)) 16 | 17 | B, C, H, W = x_freq.shape 18 | mask = torch.ones((B, C, H, W), device=x.device) 19 | 20 | crow, ccol = H // 2, W // 2 21 | mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale 22 | x_freq = x_freq * mask 23 | 24 | # IFFT 25 | x_freq = fft.ifftshift(x_freq, dim=(-2, -1)) 26 | x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real 27 | 28 | return x_filtered 29 | 30 | 31 | class UNetModel(nn.Module): 32 | def __init__(self, model_channels, num_classes=None): 33 | super().__init__() 34 | self.model_channels = model_channels 35 | self.num_classes = num_classes 36 | self.input_block = nn.Conv2d(3, model_channels, 3, padding=1) 37 | self.middle_block = nn.Conv2d(model_channels, model_channels, 3, padding=1) 38 | self.output_block = nn.Conv2d(model_channels, model_channels, 3, padding=1) 39 | self.final = nn.Conv2d(model_channels, 3, 3, padding=1) # Ensure output has 3 channels 40 | 41 | 42 | def timestep_embedding(timesteps, dim, repeat_only=False): 43 | return torch.randn((timesteps.shape[0], dim)) 44 | 45 | 46 | class Free_UNetModel(UNetModel): 47 | def __init__( 48 | self, 49 | b1, 50 | b2, 51 | s1, 52 | s2, 53 | *args, 54 | **kwargs 55 | ): 56 | super().__init__(*args, **kwargs) 57 | self.b1 = b1 58 | self.b2 = b2 59 | self.s1 = s1 60 | self.s2 = s2 61 | # Define the time embedding layer 62 | self.time_embed = nn.Linear(self.model_channels, self.model_channels) 63 | 64 | if self.num_classes is not None: 65 | self.label_emb = nn.Embedding(self.num_classes, self.model_channels) 66 | 67 | 68 | def forward(self, x, timesteps=None, context=None, y=None, **kwargs): 69 | assert (y is not None) == ( 70 | self.num_classes is not None), "must specify y if and only if the model is class-conditional" 71 | hs = [] 72 | t_emb = timestep_embedding(timesteps, self.model_channels) 73 | emb = self.time_embed(t_emb) 74 | 75 | if self.num_classes is not None: 76 | emb = emb + self.label_emb(y) 77 | 78 | h = x 79 | h = self.input_block(h) # First convolution 80 | hs.append(h) 81 | h = self.middle_block(h) # Middle convolution 82 | for module in [self.output_block, self.final]: # Output convolutions 83 | h = module(h) 84 | 85 | return h 86 | 87 | 88 | if __name__ == '__main__': 89 | block = Free_UNetModel(1.5, 1.2, 0.8, 0.5, model_channels=64, num_classes=10) 90 | 91 | input = torch.rand(32, 3, 256, 256) 92 | timesteps = torch.tensor([1]) 93 | y = torch.tensor([1]) 94 | 95 | # 调用模型进行前向传播,并保存输出到 output 变量中 96 | output = block(input, timesteps=timesteps, y=y) 97 | 98 | print("Input size:", input.size()) 99 | print("Output size:", output.size()) 100 | 101 | 102 | 103 | # 1.5:b1,用于指定 FreeU 中的第一个模块的参数。它控制截断位置,以调整骨干特征的权重。 104 | # 1.2:b2,用于指定 FreeU 中的第二个模块的参数。它控制截断位置,以调整跳跃连接特征的权重。 105 | # 0.8:s1,用于指定 FreeU 中的第一个模块的参数。它控制截断的比例因子,以调整骨干特征的缩放。 106 | # 0.5:s2,用于指定 FreeU 中的第二个模块的参数。它控制截断的比例因子,以调整跳跃连接特征的缩放。 -------------------------------------------------------------------------------- /LinAngularAttention.py: -------------------------------------------------------------------------------- 1 | # https://www.haoranyou.com/castling-vit/ 2 | 3 | """ 4 | 以下是该模块的主要组件和操作: 5 | 6 | qkv:这是一个线性层,将输入特征 x 映射到三个不同的线性变换,分别对应查询 (query),键 (key),和值 (value)。这三个变换将输入特征的通道划分成多个头 (heads)。 7 | 8 | attn_drop 和 proj_drop:这是用于进行注意力矩阵和输出特征的丢弃操作的 Dropout 层。 9 | 10 | kq_matmul、kqv_matmul 和 qk_matmul:这些是自定义的矩阵乘法操作,用于计算注意力矩阵中的各个部分。kq_matmul 用于计算键和查询的点积,kqv_matmul 用于计算键和值的点积,qk_matmul 用于计算查询和键的点积。 11 | 12 | dconv:这是一个深度卷积层,用于对值进行深度卷积操作。 13 | 14 | 在前向传播过程中,该模块首先将输入特征 x 映射为查询、键和值。然后,通过上述矩阵乘法操作,计算注意力矩阵的各个部分。接下来,对查询和键进行标准化处理,并计算值的深度卷积。最后,根据注意力矩阵和深度卷积的结果,计算最终的输出特征。 15 | 16 | 此模块实现了线性角注意力机制,可用于处理序列或图像数据中的信息交互和特征提取任务。该模块的参数配置如 num_heads、qkv_bias、attn_drop 等可以根据具体任务进行调整。 17 | """ 18 | 19 | import torch 20 | import torch.nn as nn 21 | import math 22 | 23 | 24 | class MatMul(nn.Module): 25 | def __init__(self): 26 | super(MatMul, self).__init__() 27 | 28 | def forward(self, x, y): 29 | return torch.matmul(x, y) 30 | 31 | class LinAngularAttention(nn.Module): 32 | def __init__( 33 | self, 34 | in_channels, 35 | num_heads=8, 36 | qkv_bias=False, 37 | attn_drop=0.0, 38 | proj_drop=0.0, 39 | res_kernel_size=9, 40 | sparse_reg=False, 41 | ): 42 | super().__init__() 43 | assert in_channels % num_heads == 0, "dim should be divisible by num_heads" 44 | self.num_heads = num_heads 45 | head_dim = in_channels // num_heads 46 | self.scale = head_dim**-0.5 47 | self.sparse_reg = sparse_reg 48 | 49 | self.qkv = nn.Linear(in_channels, in_channels * 3, bias=qkv_bias) 50 | self.attn_drop = nn.Dropout(attn_drop) 51 | self.proj = nn.Linear(in_channels, in_channels) 52 | self.proj_drop = nn.Dropout(proj_drop) 53 | 54 | self.kq_matmul = MatMul() 55 | self.kqv_matmul = MatMul() 56 | if self.sparse_reg: 57 | self.qk_matmul = MatMul() 58 | self.sv_matmul = MatMul() 59 | 60 | self.dconv = nn.Conv2d( 61 | in_channels=self.num_heads, 62 | out_channels=self.num_heads, 63 | kernel_size=(res_kernel_size, 1), 64 | padding=(res_kernel_size // 2, 0), 65 | bias=False, 66 | groups=self.num_heads, 67 | ) 68 | 69 | def forward(self, x): 70 | N, L, C = x.shape 71 | qkv = ( 72 | self.qkv(x) 73 | .reshape(N, L, 3, self.num_heads, C // self.num_heads) 74 | .permute(2, 0, 3, 1, 4) 75 | ) 76 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 77 | 78 | if self.sparse_reg: 79 | attn = self.qk_matmul(q * self.scale, k.transpose(-2, -1)) 80 | attn = attn.softmax(dim=-1) 81 | mask = attn > 0.02 # note that the threshold could be different; adapt to your codebases. 82 | sparse = mask * attn 83 | 84 | q = q / q.norm(dim=-1, keepdim=True) 85 | k = k / k.norm(dim=-1, keepdim=True) 86 | dconv_v = self.dconv(v) 87 | 88 | attn = self.kq_matmul(k.transpose(-2, -1), v) 89 | 90 | if self.sparse_reg: 91 | x = ( 92 | self.sv_matmul(sparse, v) 93 | + 0.5 * v 94 | + 1.0 / math.pi * self.kqv_matmul(q, attn) 95 | ) 96 | else: 97 | x = 0.5 * v + 1.0 / math.pi * self.kqv_matmul(q, attn) 98 | x = x / x.norm(dim=-1, keepdim=True) 99 | x += dconv_v 100 | x = x.transpose(1, 2).reshape(N, L, C) 101 | x = self.proj(x) 102 | x = self.proj_drop(x) 103 | return x 104 | 105 | 106 | if __name__ == '__main__': 107 | block = LinAngularAttention(in_channels=128) 108 | input = torch.rand(32,784,128) 109 | output = block(input) 110 | print(input.size(), output.size()) 111 | -------------------------------------------------------------------------------- /SSPCAB(图像和视频异常检测,CV2维任务通用).py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | """ 7 | 异常检测通常被视为一类分类问题,其中模型只能从正常训练样本中学习,同时在正常和异常测试样本上进行评估。 8 | 在成功的异常检测方法中,一类独特的方法依赖于预测屏蔽信息(例如补丁、未来帧等)并利用相对于屏蔽信息的重建误差作为异常分数。与相关方法不同,我们建议将基于重建的功能集成到一种新颖的自监督预测架构构建块中。 9 | 所提出的自监督块是通用的,可以很容易地合并到各种最先进的异常检测方法中。我们的块从带有扩张滤波器的卷积层开始,其中感受野的中心区域被屏蔽。生成的激活图通过通道注意模块传递。 10 | 我们的块配备了一个损失,可以最小化相对于感受野中的掩模区域的重建误差。我们通过将我们的模块集成到几个最先进的图像和视频异常检测框架中来展示该模块的通用性。 11 | """ 12 | 13 | 14 | class SELayer(nn.Module): 15 | def __init__(self, num_channels, reduction_ratio=8): 16 | ''' 17 | num_channels: The number of input channels 18 | reduction_ratio: The reduction ratio 'r' from the paper 19 | ''' 20 | super(SELayer, self).__init__() 21 | num_channels_reduced = num_channels // reduction_ratio 22 | self.reduction_ratio = reduction_ratio 23 | self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True) 24 | self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True) 25 | self.relu = nn.ReLU() 26 | self.sigmoid = nn.Sigmoid() 27 | 28 | def forward(self, input_tensor): 29 | batch_size, num_channels, H, W = input_tensor.size() 30 | 31 | squeeze_tensor = input_tensor.view(batch_size, num_channels, -1).mean(dim=2) 32 | 33 | # channel excitation 34 | fc_out_1 = self.relu(self.fc1(squeeze_tensor)) 35 | fc_out_2 = self.sigmoid(self.fc2(fc_out_1)) 36 | 37 | a, b = squeeze_tensor.size() 38 | output_tensor = torch.mul(input_tensor, fc_out_2.view(a, b, 1, 1)) 39 | return output_tensor 40 | 41 | 42 | # SSPCAB implementation 43 | class SSPCAB(nn.Module): 44 | def __init__(self, channels, kernel_dim=1, dilation=1, reduction_ratio=8): 45 | ''' 46 | channels: The number of filter at the output (usually the same with the number of filter from the input) 47 | kernel_dim: The dimension of the sub-kernels ' k' ' from the paper 48 | dilation: The dilation dimension 'd' from the paper 49 | reduction_ratio: The reduction ratio for the SE block ('r' from the paper) 50 | ''' 51 | super(SSPCAB, self).__init__() 52 | self.pad = kernel_dim + dilation 53 | self.border_input = kernel_dim + 2*dilation + 1 54 | 55 | self.relu = nn.ReLU() 56 | self.se = SELayer(channels, reduction_ratio=reduction_ratio) 57 | 58 | self.conv1 = nn.Conv2d(in_channels=channels, 59 | out_channels=channels, 60 | kernel_size=kernel_dim) 61 | self.conv2 = nn.Conv2d(in_channels=channels, 62 | out_channels=channels, 63 | kernel_size=kernel_dim) 64 | self.conv3 = nn.Conv2d(in_channels=channels, 65 | out_channels=channels, 66 | kernel_size=kernel_dim) 67 | self.conv4 = nn.Conv2d(in_channels=channels, 68 | out_channels=channels, 69 | kernel_size=kernel_dim) 70 | 71 | def forward(self, x): 72 | x = F.pad(x, (self.pad, self.pad, self.pad, self.pad), "constant", 0) 73 | 74 | x1 = self.conv1(x[:, :, :-self.border_input, :-self.border_input]) 75 | x2 = self.conv2(x[:, :, self.border_input:, :-self.border_input]) 76 | x3 = self.conv3(x[:, :, :-self.border_input, self.border_input:]) 77 | x4 = self.conv4(x[:, :, self.border_input:, self.border_input:]) 78 | x = self.relu(x1 + x2 + x3 + x4) 79 | 80 | x = self.se(x) 81 | return x 82 | 83 | if __name__ == '__main__': 84 | block = SSPCAB(channels=3) 85 | input = torch.rand(16,3,32,32) 86 | output = block(input) 87 | print(input.size()) 88 | print(output.size()) 89 | 90 | 91 | # Example of how our block should be updated 92 | # mse_loss = nn.MSELoss() 93 | # cost_sspcab = mse_loss(input_sspcab, output_sspcab) -------------------------------------------------------------------------------- /MUSEAttention.py: -------------------------------------------------------------------------------- 1 | # https://github.com/lancopku/MUSE 2 | 3 | """ 4 | 以下是该模块的主要组件和操作: 5 | 6 | 多头自注意力:通过输入的queries、keys和values,首先使用线性变换(fc_q, fc_k和fc_v)将它们映射到不同的子空间,然后计算多头自注意力得分,并使用softmax函数进行归一化。最后,使用这些得分加权values以获得最终的输出。 7 | 8 | 动态参数的卷积融合:在多头自注意力的输出上应用卷积操作,这些卷积操作具有不同的kernel_size(1、3和5),并使用动态参数(dy_paras)来决定它们的权重。这样,可以通过调整这些参数来动态控制不同kernel_size的卷积操作的贡献。 9 | 10 | 初始化权重:通过init_weights方法来初始化模块中的权重。 11 | 12 | 前向传播:根据输入的queries、keys、values以及可选的注意力掩码(attention_mask)和注意力权重(attention_weights),计算多头自注意力的输出,并与动态参数的卷积融合的结果相加以获得最终输出。 13 | """ 14 | 15 | import numpy as np 16 | import torch 17 | from torch import nn 18 | from torch.nn import init 19 | 20 | 21 | class Depth_Pointwise_Conv1d(nn.Module): 22 | def __init__(self, in_ch, out_ch, k): 23 | super().__init__() 24 | if (k == 1): 25 | self.depth_conv = nn.Identity() 26 | else: 27 | self.depth_conv = nn.Conv1d( 28 | in_channels=in_ch, 29 | out_channels=in_ch, 30 | kernel_size=k, 31 | groups=in_ch, 32 | padding=k // 2 33 | ) 34 | self.pointwise_conv = nn.Conv1d( 35 | in_channels=in_ch, 36 | out_channels=out_ch, 37 | kernel_size=1, 38 | groups=1 39 | ) 40 | 41 | def forward(self, x): 42 | out = self.pointwise_conv(self.depth_conv(x)) 43 | return out 44 | 45 | 46 | class MUSEAttention(nn.Module): 47 | 48 | def __init__(self, d_model, d_k, d_v, h, dropout=.1): 49 | 50 | super(MUSEAttention, self).__init__() 51 | self.fc_q = nn.Linear(d_model, h * d_k) 52 | self.fc_k = nn.Linear(d_model, h * d_k) 53 | self.fc_v = nn.Linear(d_model, h * d_v) 54 | self.fc_o = nn.Linear(h * d_v, d_model) 55 | self.dropout = nn.Dropout(dropout) 56 | 57 | self.conv1 = Depth_Pointwise_Conv1d(h * d_v, d_model, 1) 58 | self.conv3 = Depth_Pointwise_Conv1d(h * d_v, d_model, 3) 59 | self.conv5 = Depth_Pointwise_Conv1d(h * d_v, d_model, 5) 60 | self.dy_paras = nn.Parameter(torch.ones(3)) 61 | self.softmax = nn.Softmax(-1) 62 | 63 | self.d_model = d_model 64 | self.d_k = d_k 65 | self.d_v = d_v 66 | self.h = h 67 | 68 | self.init_weights() 69 | 70 | def init_weights(self): 71 | for m in self.modules(): 72 | if isinstance(m, nn.Conv2d): 73 | init.kaiming_normal_(m.weight, mode='fan_out') 74 | if m.bias is not None: 75 | init.constant_(m.bias, 0) 76 | elif isinstance(m, nn.BatchNorm2d): 77 | init.constant_(m.weight, 1) 78 | init.constant_(m.bias, 0) 79 | elif isinstance(m, nn.Linear): 80 | init.normal_(m.weight, std=0.001) 81 | if m.bias is not None: 82 | init.constant_(m.bias, 0) 83 | 84 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 85 | 86 | # Self Attention 87 | b_s, nq = queries.shape[:2] 88 | nk = keys.shape[1] 89 | 90 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) 91 | k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) 92 | v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) 93 | 94 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk) 95 | if attention_weights is not None: 96 | att = att * attention_weights 97 | if attention_mask is not None: 98 | att = att.masked_fill(attention_mask, -np.inf) 99 | att = torch.softmax(att, -1) 100 | att = self.dropout(att) 101 | 102 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) 103 | out = self.fc_o(out) # (b_s, nq, d_model) 104 | 105 | v2 = v.permute(0, 1, 3, 2).contiguous().view(b_s, -1, nk) # bs,dim,n 106 | self.dy_paras = nn.Parameter(self.softmax(self.dy_paras)) 107 | out2 = self.dy_paras[0] * self.conv1(v2) + self.dy_paras[1] * self.conv3(v2) + self.dy_paras[2] * self.conv5(v2) 108 | out2 = out2.permute(0, 2, 1) # bs.n.dim 109 | 110 | out = out + out2 111 | return out 112 | 113 | 114 | if __name__ == '__main__': 115 | block = MUSEAttention(d_model=256, d_k=256, d_v=256, h=256).cuda() 116 | # input = torch.rand(64, 64, 512).cuda() 117 | input = torch.rand(1, 128, 256, 256).cuda() 118 | output = block(input, input, input) 119 | print(input.size(), output.size()) 120 | -------------------------------------------------------------------------------- /DynamicFilter(频域模块动态滤波器用于CV2维图像).py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from timm.layers.helpers import to_2tuple 4 | 5 | """ 6 | 配备多头自注意力 (MHSA) 的模型在计算机视觉方面取得了显着的性能。它们的计算复杂度与输入特征图中的二次像素数成正比,导致处理速度缓慢,尤其是在处理高分辨率图像时。 7 | 为了规避这个问题,提出了一种新型的代币混合器作为MHSA的替代方案:基于FFT的代币混合器涉及类似于MHSA的全局操作,但计算复杂度较低。 8 | 在这里,我们提出了一种名为动态过滤器的新型令牌混合器以缩小上述差距。 9 | DynamicFilter 模块通过频域滤波和动态调整滤波器权重,能够对图像进行复杂的增强和处理。 10 | """ 11 | 12 | class StarReLU(nn.Module): 13 | """ 14 | StarReLU: s * relu(x) ** 2 + b 15 | """ 16 | 17 | def __init__(self, scale_value=1.0, bias_value=0.0, 18 | scale_learnable=True, bias_learnable=True, 19 | mode=None, inplace=False): 20 | super().__init__() 21 | self.inplace = inplace 22 | self.relu = nn.ReLU(inplace=inplace) 23 | self.scale = nn.Parameter(scale_value * torch.ones(1), 24 | requires_grad=scale_learnable) 25 | self.bias = nn.Parameter(bias_value * torch.ones(1), 26 | requires_grad=bias_learnable) 27 | 28 | def forward(self, x): 29 | return self.scale * self.relu(x) ** 2 + self.bias 30 | 31 | class Mlp(nn.Module): 32 | """ MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks. 33 | Mostly copied from timm. 34 | """ 35 | 36 | def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0., 37 | bias=False, **kwargs): 38 | super().__init__() 39 | in_features = dim 40 | out_features = out_features or in_features 41 | hidden_features = int(mlp_ratio * in_features) 42 | drop_probs = to_2tuple(drop) 43 | 44 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 45 | self.act = act_layer() 46 | self.drop1 = nn.Dropout(drop_probs[0]) 47 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 48 | self.drop2 = nn.Dropout(drop_probs[1]) 49 | 50 | def forward(self, x): 51 | x = self.fc1(x) 52 | x = self.act(x) 53 | x = self.drop1(x) 54 | x = self.fc2(x) 55 | x = self.drop2(x) 56 | return x 57 | 58 | 59 | class DynamicFilter(nn.Module): 60 | def __init__(self, dim, expansion_ratio=2, reweight_expansion_ratio=.25, 61 | act1_layer=StarReLU, act2_layer=nn.Identity, 62 | bias=False, num_filters=4, size=14, weight_resize=False, 63 | **kwargs): 64 | super().__init__() 65 | size = to_2tuple(size) 66 | self.size = size[0] 67 | self.filter_size = size[1] // 2 + 1 68 | self.num_filters = num_filters 69 | self.dim = dim 70 | self.med_channels = int(expansion_ratio * dim) 71 | self.weight_resize = weight_resize 72 | self.pwconv1 = nn.Linear(dim, self.med_channels, bias=bias) 73 | self.act1 = act1_layer() 74 | self.reweight = Mlp(dim, reweight_expansion_ratio, num_filters * self.med_channels) 75 | self.complex_weights = nn.Parameter( 76 | torch.randn(self.size, self.filter_size, num_filters, 2, 77 | dtype=torch.float32) * 0.02) 78 | self.act2 = act2_layer() 79 | self.pwconv2 = nn.Linear(self.med_channels, dim, bias=bias) 80 | 81 | def forward(self, x): 82 | B, H, W, _ = x.shape 83 | 84 | routeing = self.reweight(x.mean(dim=(1, 2))).view(B, self.num_filters, 85 | -1).softmax(dim=1) 86 | x = self.pwconv1(x) 87 | x = self.act1(x) 88 | x = x.to(torch.float32) 89 | x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho') 90 | 91 | if self.weight_resize: 92 | complex_weights = resize_complex_weight(self.complex_weights, x.shape[1], 93 | x.shape[2]) 94 | complex_weights = torch.view_as_complex(complex_weights.contiguous()) 95 | else: 96 | complex_weights = torch.view_as_complex(self.complex_weights) 97 | routeing = routeing.to(torch.complex64) 98 | weight = torch.einsum('bfc,hwf->bhwc', routeing, complex_weights) 99 | if self.weight_resize: 100 | weight = weight.view(-1, x.shape[1], x.shape[2], self.med_channels) 101 | else: 102 | weight = weight.view(-1, self.size, self.filter_size, self.med_channels) 103 | x = x * weight 104 | x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho') 105 | 106 | x = self.act2(x) 107 | x = self.pwconv2(x) 108 | return x 109 | 110 | 111 | if __name__ == '__main__': 112 | block = DynamicFilter(32, size=64) # size==H,W 113 | input = torch.rand(3, 64, 64, 32) 114 | output = block(input) 115 | print(input.size()) 116 | print(output.size()) -------------------------------------------------------------------------------- /Wave-pooling(轨迹预测,CV2维图像通用).py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from timm.models.layers import DropPath 6 | 7 | 8 | """ 9 | 预测周围车辆的运动对于帮助自动驾驶系统规划安全路径并避免碰撞至关重要。 10 | 尽管最近基于LSTM模型通过考虑彼此靠近的车辆之间的运动交互而取得了显着的性能提升,但由于实际复杂驾驶场景中的动态和高阶交互,车辆轨迹预测仍然是一个具有挑战性的研究问题。 11 | 为此,我们提出了一种受波叠加启发的社交池(简称波池)方法,用于动态聚合来自本地和全局邻居车辆的高阶交互。 12 | 通过将每个车辆建模为具有振幅和相位的波,波池可以更有效地表示车辆的动态运动状态,并通过波叠加捕获它们的高阶动态相互作用。 13 | 通过集成Wave-pooling,还提出了一种名为WSiP的基于编码器-解码器的学习框架。 14 | 在两个公共高速公路数据集 NGSIM 和 highD 上进行的大量实验通过与当前最先进的基线进行比较来验证 WSiP 的有效性。 15 | 更重要的是,WSiP的结果更具可解释性,因为车辆之间的相互作用强度可以通过它们的相位差直观地反映出来。 16 | """ 17 | 18 | 19 | 20 | class Mlp(nn.Module): 21 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 22 | super().__init__() 23 | out_features = out_features or in_features 24 | hidden_features = hidden_features or in_features 25 | self.act = act_layer() 26 | self.drop = nn.Dropout(drop) 27 | self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1) 28 | self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1) 29 | 30 | def forward(self, x): 31 | x = self.fc1(x) 32 | x = self.act(x) 33 | x = self.drop(x) 34 | x = self.fc2(x) 35 | x = self.drop(x) 36 | return x 37 | 38 | class PATM(nn.Module): 39 | def __init__(self, dim, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0. ,mode='fc'): 40 | super().__init__() 41 | 42 | self.fc_h = nn.Conv2d(dim, dim, 1, 1 ,bias=qkv_bias) 43 | self.fc_w = nn.Conv2d(dim, dim, 1, 1 ,bias=qkv_bias) 44 | self.fc_c = nn.Conv2d(dim, dim, 1, 1 ,bias=qkv_bias) 45 | 46 | self.tfc_h = nn.Conv2d( 2 *dim, dim, (1 ,7), stride=1, padding=(0 , 7//2), groups=dim, bias=False) 47 | self.tfc_w = nn.Conv2d( 2 *dim, dim, (7 ,1), stride=1, padding=( 7//2 ,0), groups=dim, bias=False) 48 | self.reweight = Mlp(dim, dim // 4, dim * 3) 49 | self.proj = nn.Conv2d(dim, dim, 1, 1 ,bias=True) 50 | self.proj_drop = nn.Dropout(proj_drop) 51 | self.mode =mode 52 | # 对h和w都学出相位 53 | if mode=='fc': 54 | self.theta_h_conv =nn.Sequential(nn.Conv2d(dim, dim, 1, 1 ,bias=True) ,nn.BatchNorm2d(dim) ,nn.ReLU()) 55 | self.theta_w_conv =nn.Sequential(nn.Conv2d(dim, dim, 1, 1 ,bias=True) ,nn.BatchNorm2d(dim) ,nn.ReLU()) 56 | else: 57 | self.theta_h_conv =nn.Sequential(nn.Conv2d(dim, dim, 3, stride=1, padding=1, groups=dim, bias=False) 58 | ,nn.BatchNorm2d(dim) ,nn.ReLU()) 59 | self.theta_w_conv =nn.Sequential(nn.Conv2d(dim, dim, 3, stride=1, padding=1, groups=dim, bias=False) 60 | ,nn.BatchNorm2d(dim) ,nn.ReLU()) 61 | 62 | 63 | 64 | def forward(self, x): 65 | B, C, H, W = x.shape 66 | # C, H, W = x.shape 67 | # 相位 68 | theta_h =self.theta_h_conv(x) 69 | theta_w =self.theta_w_conv(x) 70 | # Channel-FC提取振幅 71 | x_h =self.fc_h(x) 72 | x_w =self.fc_w(x) 73 | # 用欧拉公式对特征进行展开 74 | x_h =torch.cat([x_h *torch.cos(theta_h) ,x_h *torch.sin(theta_h)] ,dim=1) 75 | x_w =torch.cat([x_w *torch.cos(theta_w) ,x_w *torch.sin(theta_w)] ,dim=1) 76 | # Token-FC 77 | h = self.tfc_h(x_h) 78 | w = self.tfc_w(x_w) 79 | c = self.fc_c(x) 80 | a = F.adaptive_avg_pool2d(h + w + c ,output_size=1) 81 | a = self.reweight(a).reshape(B, C, 3).permute(2, 0, 1).softmax(dim=0).unsqueeze(-1).unsqueeze(-1) 82 | x = h * a[0] + w * a[1] + c * a[2] 83 | x = self.proj(x) 84 | x = self.proj_drop(x) 85 | return x 86 | 87 | class WaveBlock(nn.Module): 88 | def __init__(self, dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 89 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.BatchNorm2d, mode='fc'): 90 | super().__init__() 91 | self.norm1 = norm_layer(dim) 92 | self.attn = PATM(dim, qkv_bias=qkv_bias, qk_scale=None, attn_drop=attn_drop ,mode=mode) 93 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 94 | self.norm2 = norm_layer(dim) 95 | mlp_hidden_dim = int(dim * mlp_ratio) 96 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) 97 | 98 | def forward(self, x): 99 | x = x + self.drop_path(self.attn(self.norm1(x))) 100 | x = x + self.drop_path(self.mlp(self.norm2(x))) 101 | return x 102 | 103 | 104 | if __name__ == '__main__': 105 | # 实例化模块并定义输入 106 | block = WaveBlock(dim=64) # 假设输入特征的通道数为 64 107 | input = torch.rand(2, 64, 32, 32) # 假设输入大小为 (batch_size=2, channels=64, height=32, width=32) 108 | 109 | # 运行前向传播 110 | output = block(input) 111 | 112 | # 打印输入和输出的大小 113 | print("输入大小:", input.size()) 114 | print("输出大小:", output.size()) -------------------------------------------------------------------------------- /ISL(用于点云任务).py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | """ 5 | 学习区域内上下文和区域间关系是加强点云分析特征表示的两种有效策略。然而,现有方法并未充分强调统一两种点云表示策略。 6 | 为此,我们提出了一种名为点关系感知网络(PRA-Net)的新颖框架,它由区域内结构学习(ISL)模块和区域间关系学习(IRL)模块组成。 7 | ISL模块可以动态地将局部结构信息集成到点特征中,而IRL模块通过可微分区域划分方案和基于代表性点的策略自适应且有效地捕获区域间关系。 8 | 在形状分类、关键点估计和零件分割等多个 3D 基准上进行的大量实验验证了 PRA-Net 的有效性和泛化能力。 9 | """ 10 | 11 | def knn(x, k): 12 | """ 13 | :param x: (B,3,N) 14 | :param k: int 15 | :return: (B,N,k_hat) 16 | """ 17 | inner = -2 * torch.matmul(x.transpose(2, 1), x) 18 | xx = torch.sum(x ** 2, dim = 1, keepdim = True) 19 | pairwise_distance = -xx - inner - xx.transpose(2, 1) 20 | 21 | idx = pairwise_distance.topk(k = k, dim = -1)[1] # (batch_size, num_points, k_hat) 22 | return idx 23 | 24 | 25 | class DFA(nn.Module): 26 | def __init__(self, features, M=2, r=1, L=32): 27 | """ Constructor 28 | Args: 29 | features: input channel dimensionality. 30 | M: the number of branchs. 31 | r: the radio for compute d, the length of z. 32 | stride: stride, default 1. 33 | L: the minimum dim of the vector z in paper, default 32. 34 | """ 35 | super(DFA, self).__init__() 36 | 37 | self.M = M 38 | self.features = features 39 | d = max(int(self.features / r), L) 40 | 41 | self.fc = nn.Sequential(nn.Conv1d(self.features, d, kernel_size=1), 42 | nn.BatchNorm1d(d)) 43 | 44 | self.fc_out = nn.Sequential(nn.Conv1d(d, self.features, kernel_size=1), 45 | nn.BatchNorm1d(self.features)) 46 | 47 | def forward(self, x): 48 | """ 49 | :param x: [x1,x2] (B,C,N) 50 | :return: 51 | """ 52 | 53 | shape = x[0].shape 54 | if len(shape) > 3: 55 | assert NotImplemented('Don not support len(shape)>=3.') 56 | 57 | # (B,MC,N) 58 | fea_U = x[0] + x[1] 59 | 60 | fea_z = self.fc(fea_U) 61 | # B,C,N 62 | fea_cat = self.fc_out(fea_z) 63 | 64 | attention_vectors = torch.sigmoid(fea_cat) 65 | fea_v = attention_vectors * x[0] + (1 - attention_vectors) * x[1] 66 | 67 | return fea_v 68 | 69 | 70 | def get_graph_feature(x, xyz=None, idx=None, k_hat=20): 71 | """ 72 | Get graph features by minus the k_hat nearest neighbors' feature. 73 | :param x: (B,C,N) 74 | input features 75 | :param xyz: (B,3,N) or None 76 | xyz coordinate 77 | :param idx: (B,N,k_hat) 78 | kNN graph index 79 | :param k_hat: (int) 80 | the neighbor number 81 | :return: graph feature (B,C,N,k_hat) 82 | """ 83 | batch_size = x.size(0) 84 | num_points = x.size(2) 85 | x = x.view(batch_size, -1, num_points) 86 | if idx is None: 87 | idx = knn(xyz, k=k_hat) # (batch_size, num_points, k_hat) 88 | 89 | idx_base = torch.arange(0, batch_size, device=x.device).view(-1, 1, 1) * num_points 90 | idx = idx + idx_base 91 | idx = idx.view(-1) 92 | 93 | _, num_dims, _ = x.size() 94 | 95 | x = x.transpose(2, 1).contiguous() 96 | feature = x.view(batch_size * num_points, -1)[idx, :] 97 | feature = feature.view(batch_size, num_points, k_hat, num_dims) 98 | x = x.view(batch_size, num_points, 1, num_dims) 99 | feature = feature - x 100 | feature = feature.permute(0, 3, 1, 2) 101 | return feature 102 | 103 | 104 | class ISL(nn.Module): 105 | 106 | def __init__(self, in_channel, out_channel_list, k_hat=20, bias=False, ): 107 | """ 108 | :param in_channel: 109 | input feature channel type:int 110 | :param out_channel_list: int or list of int 111 | out channel of MLPs 112 | :param k_hat: int 113 | k_hat in ISL 114 | :param bias: bool 115 | use bias or not 116 | """ 117 | super(ISL, self).__init__() 118 | 119 | out_channel = out_channel_list[0] 120 | 121 | self.self_feature_learning = nn.Conv1d(in_channel // 2, out_channel, kernel_size=1, bias=bias) 122 | self.neighbor_feature_learning = nn.Conv2d(in_channel // 2, out_channel, kernel_size=1, bias=bias) 123 | self.k = k_hat 124 | 125 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 126 | 127 | last_layer_list = [] 128 | 129 | for i in range(len(out_channel_list) - 1): 130 | in_channel = out_channel_list[i] 131 | out_channel = out_channel_list[i + 1] 132 | last_layer_list.append(nn.Conv2d(in_channel, out_channel, kernel_size=1, bias=bias)) 133 | last_layer_list.append(nn.BatchNorm2d(out_channel)) 134 | last_layer_list.append(nn.LeakyReLU(negative_slope=0.2, inplace=True)) 135 | self.last_layers = nn.Sequential(*last_layer_list) 136 | 137 | self.bn = nn.BatchNorm2d(out_channel) 138 | 139 | self.bn2 = nn.BatchNorm1d(out_channel) 140 | self.bn = nn.BatchNorm2d(out_channel) 141 | 142 | self.DFA_layer = DFA(features=out_channel, M=2, r=1) 143 | 144 | def forward(self, x, idx_): 145 | """ 146 | :param x: (B,3,N) 147 | Input point cloud 148 | :param idx_: (B,N,k_hat) 149 | kNN graph index 150 | :return: graph feature: (B,C,N,k_hat) 151 | """ 152 | 153 | x_minus = get_graph_feature(x, idx=idx_, k_hat=self.k) 154 | # (B,C,N,K) 155 | a1 = self.neighbor_feature_learning(x_minus) 156 | # (B,C,N) 157 | a2 = self.self_feature_learning(x) 158 | 159 | a1 = self.leaky_relu(self.bn(a1)) 160 | # (B,C,N) 161 | a1 = a1.max(dim=-1, keepdim=False)[0] 162 | a2 = self.leaky_relu(self.bn2(a2)) 163 | res = self.DFA_layer([a1, a2]) 164 | 165 | res = self.last_layers(res) 166 | 167 | return res 168 | 169 | 170 | 171 | if __name__ == '__main__': 172 | block = ISL(in_channel=6, out_channel_list=[3], k_hat=20, bias=False) 173 | input = torch.rand((2, 3, 100)) 174 | idx = knn(input, k=20) 175 | output = block(input, idx) 176 | print(input.size()) 177 | print(output.size()) -------------------------------------------------------------------------------- /ScConv.py: -------------------------------------------------------------------------------- 1 | # https://github.com/cheng-haha/ScConv 2 | 3 | """ 4 | GroupBatchnorm2d: 5 | 6 | 这是一个自定义的批量归一化(Batch Normalization)模块。 7 | 它支持将通道分组,即将通道分成多个组,每个组共享统计信息。 8 | 参数包括 c_num(通道数),group_num(分组数),和 eps(防止除以零的小值)。 9 | 在前向传播中,它首先将输入张量按组进行划分,并在每个组内计算均值和标准差,然后使用这些统计信息来对输入进行标准化。 10 | SRU(Self-Reconstruction Unit): 11 | 12 | 这是一个自定义的模块,用于增强神经网络的特征表示。 13 | 参数包括 oup_channels(输出通道数),group_num(分组数),gate_treshold(门控阈值),和 torch_gn(是否使用PyTorch的GroupNorm)。 14 | 在前向传播中,它首先应用分组归一化(Group Normalization),然后通过门控机制(Gate)重新构造输入特征。 15 | 门控机制根据输入特征的分布和权重来决定哪些信息被保留,哪些信息被舍弃。 16 | CRU(Channel Reorganization Unit): 17 | 18 | 这是一个自定义的通道重组模块,用于重新组织神经网络的通道。 19 | 参数包括 op_channel(输出通道数),alpha(通道划分比例),squeeze_radio(压缩比例),group_size(分组大小),和 group_kernel_size(分组卷积核大小)。 20 | 在前向传播中,它首先将输入通道分成两部分,然后对每部分进行压缩(squeeze)操作和分组卷积(Group Convolution)操作,最后将结果进行融合。 21 | ScConv(Scale and Channel Convolution): 22 | 23 | 这是一个结合了SRU和CRU的模块,用于增强特征表示并进行通道重组。 24 | 参数包括 SRU 和 CRU 模块的参数。 25 | 在前向传播中,它首先应用SRU模块,然后应用CRU模块,以改善特征表示并重新组织通道。 26 | 这些自定义模块可以用于构建更复杂的神经网络,以满足特定的任务和需求。模块中的操作和机制可以帮助提高神经网络的性能和泛化能力。 27 | """ 28 | 29 | import torch 30 | import torch.nn.functional as F 31 | import torch.nn as nn 32 | 33 | 34 | class GroupBatchnorm2d(nn.Module): 35 | def __init__(self, c_num: int, 36 | group_num: int = 16, 37 | eps: float = 1e-10 38 | ): 39 | super(GroupBatchnorm2d, self).__init__() 40 | assert c_num >= group_num 41 | self.group_num = group_num 42 | self.weight = nn.Parameter(torch.randn(c_num, 1, 1)) 43 | self.bias = nn.Parameter(torch.zeros(c_num, 1, 1)) 44 | self.eps = eps 45 | 46 | def forward(self, x): 47 | N, C, H, W = x.size() 48 | x = x.view(N, self.group_num, -1) 49 | mean = x.mean(dim=2, keepdim=True) 50 | std = x.std(dim=2, keepdim=True) 51 | x = (x - mean) / (std + self.eps) 52 | x = x.view(N, C, H, W) 53 | return x * self.weight + self.bias 54 | 55 | 56 | class SRU(nn.Module): 57 | def __init__(self, 58 | oup_channels: int, 59 | group_num: int = 16, 60 | gate_treshold: float = 0.5, 61 | torch_gn: bool = False 62 | ): 63 | super().__init__() 64 | 65 | self.gn = nn.GroupNorm(num_channels=oup_channels, num_groups=group_num) if torch_gn else GroupBatchnorm2d( 66 | c_num=oup_channels, group_num=group_num) 67 | self.gate_treshold = gate_treshold 68 | self.sigomid = nn.Sigmoid() 69 | 70 | def forward(self, x): 71 | gn_x = self.gn(x) 72 | w_gamma = self.gn.weight / torch.sum(self.gn.weight) 73 | w_gamma = w_gamma.view(1, -1, 1, 1) 74 | reweigts = self.sigomid(gn_x * w_gamma) 75 | # Gate 76 | info_mask = reweigts >= self.gate_treshold 77 | noninfo_mask = reweigts < self.gate_treshold 78 | x_1 = info_mask * gn_x 79 | x_2 = noninfo_mask * gn_x 80 | x = self.reconstruct(x_1, x_2) 81 | return x 82 | 83 | def reconstruct(self, x_1, x_2): 84 | x_11, x_12 = torch.split(x_1, x_1.size(1) // 2, dim=1) 85 | x_21, x_22 = torch.split(x_2, x_2.size(1) // 2, dim=1) 86 | return torch.cat([x_11 + x_22, x_12 + x_21], dim=1) 87 | 88 | 89 | class CRU(nn.Module): 90 | ''' 91 | alpha: 0 bbx_thres: 59 | break 60 | 61 | return bbx1, bby1, bbx2, bby2 62 | 63 | 64 | def cn_op_2ins_space_chan(x, crop='neither', beta=1, bbx_thres=0.1, lam=None, chan=False): 65 | """2-instance crossnorm with cropping.""" 66 | 67 | assert crop in ['neither', 'style', 'content', 'both'] 68 | ins_idxs = torch.randperm(x.size()[0]).to(x.device) 69 | 70 | if crop in ['style', 'both']: 71 | bbx3, bby3, bbx4, bby4 = cn_rand_bbox(x.size(), beta=beta, bbx_thres=bbx_thres) 72 | x2 = x[ins_idxs, :, bbx3:bbx4, bby3:bby4] 73 | else: 74 | x2 = x[ins_idxs] 75 | 76 | if chan: 77 | chan_idxs = torch.randperm(x.size()[1]).to(x.device) 78 | x2 = x2[:, chan_idxs, :, :] 79 | 80 | if crop in ['content', 'both']: 81 | x_aug = torch.zeros_like(x) 82 | bbx1, bby1, bbx2, bby2 = cn_rand_bbox(x.size(), beta=beta, bbx_thres=bbx_thres) 83 | x_aug[:, :, bbx1:bbx2, bby1:bby2] = instance_norm_mix(content_feat=x[:, :, bbx1:bbx2, bby1:bby2], 84 | style_feat=x2) 85 | 86 | mask = torch.ones_like(x, requires_grad=False) 87 | mask[:, :, bbx1:bbx2, bby1:bby2] = 0. 88 | x_aug = x * mask + x_aug 89 | else: 90 | x_aug = instance_norm_mix(content_feat=x, style_feat=x2) 91 | 92 | if lam is not None: 93 | x = x * lam + x_aug * (1-lam) 94 | else: 95 | x = x_aug 96 | 97 | return x 98 | 99 | 100 | class CrossNorm(nn.Module): 101 | """CrossNorm module""" 102 | def __init__(self, crop=None, beta=None): 103 | super(CrossNorm, self).__init__() 104 | 105 | self.active = False 106 | self.cn_op = functools.partial(cn_op_2ins_space_chan, 107 | crop=crop, beta=beta) 108 | 109 | def forward(self, x): 110 | if self.training and self.active: 111 | 112 | x = self.cn_op(x) 113 | 114 | self.active = False 115 | 116 | return x 117 | 118 | 119 | class SelfNorm(nn.Module): 120 | """SelfNorm module""" 121 | def __init__(self, chan_num, is_two=False): 122 | super(SelfNorm, self).__init__() 123 | 124 | # channel-wise fully connected layer 125 | self.g_fc = nn.Conv1d(chan_num, chan_num, kernel_size=2, 126 | bias=False, groups=chan_num) 127 | self.g_bn = nn.BatchNorm1d(chan_num) 128 | 129 | if is_two is True: 130 | self.f_fc = nn.Conv1d(chan_num, chan_num, kernel_size=2, 131 | bias=False, groups=chan_num) 132 | self.f_bn = nn.BatchNorm1d(chan_num) 133 | else: 134 | self.f_fc = None 135 | 136 | def forward(self, x): 137 | b, c, _, _ = x.size() 138 | 139 | mean, std = calc_ins_mean_std(x, eps=1e-12) 140 | 141 | statistics = torch.cat((mean.squeeze(3), std.squeeze(3)), -1) 142 | 143 | g_y = self.g_fc(statistics) 144 | g_y = self.g_bn(g_y) 145 | g_y = torch.sigmoid(g_y) 146 | g_y = g_y.view(b, c, 1, 1) 147 | 148 | if self.f_fc is not None: 149 | f_y = self.f_fc(statistics) 150 | f_y = self.f_bn(f_y) 151 | f_y = torch.sigmoid(f_y) 152 | f_y = f_y.view(b, c, 1, 1) 153 | 154 | return x * g_y.expand_as(x) + mean.expand_as(x) * (f_y.expand_as(x)-g_y.expand_as(x)) 155 | else: 156 | return x * g_y.expand_as(x) 157 | 158 | class CNSN(nn.Module): 159 | """A module to combine CrossNorm and SelfNorm""" 160 | def __init__(self, crossnorm, selfnorm): 161 | super(CNSN, self).__init__() 162 | self.crossnorm = crossnorm 163 | self.selfnorm = selfnorm 164 | 165 | def forward(self, x): 166 | if self.crossnorm and self.crossnorm.active: 167 | x = self.crossnorm(x) 168 | if self.selfnorm: 169 | x = self.selfnorm(x) 170 | return x 171 | 172 | if __name__ == '__main__': 173 | # block = CrossNorm() 174 | 175 | # block = SelfNorm(chan_num=3) 176 | 177 | # 创建 CrossNorm 和 SelfNorm 的实例 178 | crossnorm = CrossNorm() 179 | selfnorm = SelfNorm(chan_num=3) 180 | block = CNSN(crossnorm, selfnorm) 181 | 182 | 183 | input = torch.rand(32, 3, 224, 224) 184 | output = block(input) 185 | print(input.size()) 186 | print(output.size()) -------------------------------------------------------------------------------- /SDM(3D任务).py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | """ 6 | 体积图像的精确边界分割是图像引导诊断和计算机辅助干预的关键任务,特别是对于临床实践中的边界混乱。然而,由于缺乏边界形状约束,U型网络无法有效解决这一挑战。 7 | 此外,现有的细化边界的方法过分强调细长结构,由于网络对微小物体建模的能力有限,导致过拟合现象。在本文中,我们通过包含与相邻区域的相互作用动态来重新概念化边界生成机制。 8 | 此外,我们提出了一个称为 PnPNet 的统一网络来模拟混淆边界区域的形状特征。 PnPNet 的核心成分包括推分支和拉分支。具体来说,基于扩散理论,我们从推动分支设计了语义差异模块(SDM)来挤压边界区域。 9 | SDM 内的显式和隐式差异信息显着提高了类间边界的表示能力。 10 | 此外,在 K-means 算法的推动下,引入了拉分支的类聚类模块(CCM)来拉伸相交的边界区域。 11 | 因此,推分支和拉分支将分别缩小和放大边界不确定性。他们提供了两种对抗力量来促进模型输出更精确的边界划分。 12 | 我们对三个具有挑战性的公共数据集和一个内部数据集进行了实验,其中包含模型预测中的三种类型的边界混淆。 13 | 实验结果证明了 PnPNet 相对于其他分割网络的优越性,特别是在 HD 和 ASSD 的评估指标上。此外,推拉分支可以作为即插即用模块来增强经典的U形基线模型。 14 | """ 15 | 16 | class SDC(nn.Module): 17 | def __init__(self, in_channels, guidance_channels, kernel_size=3, stride=1, 18 | padding=1, dilation=1, groups=1, bias=False, theta=0.7): 19 | super(SDC, self).__init__() 20 | self.conv = nn.Conv3d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, 21 | dilation=dilation, groups=groups, bias=bias) 22 | self.conv1 = Conv3dbn(guidance_channels, in_channels, kernel_size=3, padding=1) 23 | # self.conv1 = Conv3dGN(guidance_channels, in_channels, kernel_size=3, padding=1) 24 | self.theta = theta 25 | self.guidance_channels = guidance_channels 26 | self.in_channels = in_channels 27 | self.kernel_size = kernel_size 28 | 29 | # initialize 30 | x_initial = torch.randn(in_channels, 1, kernel_size, kernel_size, kernel_size) 31 | x_initial = self.kernel_initialize(x_initial) 32 | 33 | self.x_kernel_diff = nn.Parameter(x_initial) 34 | self.x_kernel_diff[:, :, 0, 0, 0].detach() 35 | self.x_kernel_diff[:, :, 0, 0, 2].detach() 36 | self.x_kernel_diff[:, :, 0, 2, 0].detach() 37 | self.x_kernel_diff[:, :, 2, 0, 0].detach() 38 | self.x_kernel_diff[:, :, 0, 2, 2].detach() 39 | self.x_kernel_diff[:, :, 2, 0, 2].detach() 40 | self.x_kernel_diff[:, :, 2, 2, 0].detach() 41 | self.x_kernel_diff[:, :, 2, 2, 2].detach() 42 | 43 | guidance_initial = torch.randn(in_channels, 1, kernel_size, kernel_size, kernel_size) 44 | guidance_initial = self.kernel_initialize(guidance_initial) 45 | 46 | self.guidance_kernel_diff = nn.Parameter(guidance_initial) 47 | self.guidance_kernel_diff[:, :, 0, 0, 0].detach() 48 | self.guidance_kernel_diff[:, :, 0, 0, 2].detach() 49 | self.guidance_kernel_diff[:, :, 0, 2, 0].detach() 50 | self.guidance_kernel_diff[:, :, 2, 0, 0].detach() 51 | self.guidance_kernel_diff[:, :, 0, 2, 2].detach() 52 | self.guidance_kernel_diff[:, :, 2, 0, 2].detach() 53 | self.guidance_kernel_diff[:, :, 2, 2, 0].detach() 54 | self.guidance_kernel_diff[:, :, 2, 2, 2].detach() 55 | 56 | def kernel_initialize(self, kernel): 57 | kernel[:, :, 0, 0, 0] = -1 58 | 59 | kernel[:, :, 0, 0, 2] = 1 60 | kernel[:, :, 0, 2, 0] = 1 61 | kernel[:, :, 2, 0, 0] = 1 62 | 63 | kernel[:, :, 0, 2, 2] = -1 64 | kernel[:, :, 2, 0, 2] = -1 65 | kernel[:, :, 2, 2, 0] = -1 66 | 67 | kernel[:, :, 2, 2, 2] = 1 68 | 69 | return kernel 70 | 71 | def forward(self, x, guidance): 72 | guidance_channels = self.guidance_channels 73 | in_channels = self.in_channels 74 | kernel_size = self.kernel_size 75 | 76 | guidance = self.conv1(guidance) 77 | 78 | x_diff = F.conv3d(input=x, weight=self.x_kernel_diff, bias=self.conv.bias, stride=self.conv.stride, padding=1, 79 | groups=in_channels) 80 | 81 | guidance_diff = F.conv3d(input=guidance, weight=self.guidance_kernel_diff, bias=self.conv.bias, 82 | stride=self.conv.stride, padding=1, groups=in_channels) 83 | out = self.conv(x_diff * guidance_diff * guidance_diff) 84 | return out 85 | 86 | 87 | class SDM(nn.Module): 88 | def __init__(self, in_channel=3, guidance_channels=2): 89 | super(SDM, self).__init__() 90 | self.sdc1 = SDC(in_channel, guidance_channels) 91 | self.relu = nn.ReLU(inplace=True) 92 | self.bn = nn.BatchNorm3d(in_channel) 93 | 94 | def forward(self, feature, guidance): 95 | boundary_enhanced = self.sdc1(feature, guidance) 96 | boundary = self.relu(self.bn(boundary_enhanced)) 97 | boundary_enhanced = boundary + feature 98 | 99 | return boundary_enhanced 100 | 101 | 102 | 103 | class Conv3dbn(nn.Sequential): 104 | def __init__( 105 | self, 106 | in_channels, 107 | out_channels, 108 | kernel_size, 109 | padding=0, 110 | stride=1, 111 | use_batchnorm=True, 112 | ): 113 | conv = nn.Conv3d( 114 | in_channels, 115 | out_channels, 116 | kernel_size, 117 | stride=stride, 118 | padding=padding, 119 | bias=not (use_batchnorm), 120 | ) 121 | 122 | bn = nn.BatchNorm3d(out_channels) 123 | 124 | super(Conv3dbn, self).__init__(conv, bn) 125 | 126 | 127 | class Conv3dGNReLU(nn.Sequential): 128 | def __init__( 129 | self, 130 | in_channels, 131 | out_channels, 132 | kernel_size, 133 | padding=0, 134 | stride=1, 135 | use_batchnorm=True, 136 | ): 137 | conv = nn.Conv3d( 138 | in_channels, 139 | out_channels, 140 | kernel_size, 141 | stride=stride, 142 | padding=padding, 143 | bias=not (use_batchnorm), 144 | ) 145 | gelu = nn.GELU() 146 | 147 | gn = nn.GroupNorm(4, out_channels) 148 | 149 | super(Conv3dGNReLU, self).__init__(conv, gn, gelu) 150 | 151 | 152 | class Conv3dGN(nn.Sequential): 153 | def __init__( 154 | self, 155 | in_channels, 156 | out_channels, 157 | kernel_size, 158 | padding=0, 159 | stride=1, 160 | use_batchnorm=True, 161 | ): 162 | conv = nn.Conv3d( 163 | in_channels, 164 | out_channels, 165 | kernel_size, 166 | stride=stride, 167 | padding=padding, 168 | bias=not (use_batchnorm), 169 | ) 170 | 171 | gn = nn.GroupNorm(4, out_channels) 172 | 173 | super(Conv3dGN, self).__init__(conv, gn) 174 | 175 | if __name__ == '__main__': 176 | block = SDM(in_channel=3, guidance_channels=3) 177 | input = torch.rand(32, 3, 64, 32, 32) 178 | guidance = torch.randn((32, 3, 64, 32, 32)) 179 | 180 | output = block(input,guidance) 181 | print(input.size()) 182 | print(output.size()) -------------------------------------------------------------------------------- /SViT.py: -------------------------------------------------------------------------------- 1 | # https://github.com/hhb072/SViT 2 | 3 | """" 4 | 以下是该模块的主要组件和功能: 5 | 6 | Unfold 操作:Unfold 类定义了一个卷积操作,用于将输入图像进行解展开(unfolding)。具体来说,它将输入图像划分成不重叠的局部块,并将这些块展平成向量。这有助于在局部区域之间建立联系。 7 | 8 | Fold 操作:Fold 类定义了一个卷积转置操作,用于将展开的局部块还原为原始的图像形状。这有助于将局部特征重新组合成图像。 9 | 10 | Attention 操作:Attention 类定义了一个加性注意力机制,用于计算局部块之间的关联权重。通过对展开的局部块执行注意力操作,可以确定不同块之间的相关性,从而更好地捕获局部特征。 11 | 12 | Stoken 操作:StokenAttention 类将图像划分为多个小块,并在这些小块之间执行加性注意力操作。它还包括对块之间的关系进行迭代更新的逻辑,以更好地捕获图像中的局部特征。 13 | 14 | 直接传递操作:direct_forward 方法用于直接传递输入图像,而不进行块划分和注意力操作。这对于某些情况下不需要局部特征建模的情况很有用。 15 | 16 | Stoken 操作和直接传递操作的选择:根据 self.stoken_size 参数的设置,模块可以选择执行 Stoken 操作或直接传递操作。如果 self.stoken_size 的值大于 1,则执行 Stoken 操作,否则执行直接传递操作。 17 | 18 | 总的来说,这个模块提供了一种有效的方式来处理图像数据,并在图像的不同局部区域之间建立关联,以捕获局部特征。这对于许多计算机视觉任务,如目标检测和图像分割,都具有重要意义。 19 | """ 20 | 21 | 22 | import torch 23 | import torch.nn as nn 24 | import torch.nn.functional as F 25 | 26 | 27 | class Unfold(nn.Module): 28 | def __init__(self, kernel_size=3): 29 | super().__init__() 30 | 31 | self.kernel_size = kernel_size 32 | 33 | weights = torch.eye(kernel_size ** 2) 34 | weights = weights.reshape(kernel_size ** 2, 1, kernel_size, kernel_size) 35 | self.weights = nn.Parameter(weights, requires_grad=False) 36 | 37 | def forward(self, x): 38 | b, c, h, w = x.shape 39 | x = F.conv2d(x.reshape(b * c, 1, h, w), self.weights, stride=1, padding=self.kernel_size // 2) 40 | return x.reshape(b, c * 9, h * w) 41 | 42 | 43 | class Fold(nn.Module): 44 | def __init__(self, kernel_size=3): 45 | super().__init__() 46 | 47 | self.kernel_size = kernel_size 48 | 49 | weights = torch.eye(kernel_size ** 2) 50 | weights = weights.reshape(kernel_size ** 2, 1, kernel_size, kernel_size) 51 | self.weights = nn.Parameter(weights, requires_grad=False) 52 | 53 | def forward(self, x): 54 | b, _, h, w = x.shape 55 | x = F.conv_transpose2d(x, self.weights, stride=1, padding=self.kernel_size // 2) 56 | return x 57 | 58 | 59 | class Attention(nn.Module): 60 | def __init__(self, dim, window_size=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 61 | super().__init__() 62 | 63 | self.dim = dim 64 | self.num_heads = num_heads 65 | head_dim = dim // num_heads 66 | 67 | self.window_size = window_size 68 | 69 | self.scale = qk_scale or head_dim ** -0.5 70 | 71 | self.qkv = nn.Conv2d(dim, dim * 3, 1, bias=qkv_bias) 72 | self.attn_drop = nn.Dropout(attn_drop) 73 | self.proj = nn.Conv2d(dim, dim, 1) 74 | self.proj_drop = nn.Dropout(proj_drop) 75 | 76 | def forward(self, x): 77 | B, C, H, W = x.shape 78 | N = H * W 79 | 80 | q, k, v = self.qkv(x).reshape(B, self.num_heads, C // self.num_heads * 3, N).chunk(3, 81 | dim=2) # (B, num_heads, head_dim, N) 82 | 83 | attn = (k.transpose(-1, -2) @ q) * self.scale 84 | 85 | attn = attn.softmax(dim=-2) # (B, h, N, N) 86 | attn = self.attn_drop(attn) 87 | 88 | x = (v @ attn).reshape(B, C, H, W) 89 | 90 | x = self.proj(x) 91 | x = self.proj_drop(x) 92 | return x 93 | 94 | 95 | class StokenAttention(nn.Module): 96 | def __init__(self, dim, stoken_size, n_iter=1, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., 97 | proj_drop=0.): 98 | super().__init__() 99 | 100 | self.n_iter = n_iter 101 | self.stoken_size = stoken_size 102 | 103 | self.scale = dim ** - 0.5 104 | 105 | self.unfold = Unfold(3) 106 | self.fold = Fold(3) 107 | 108 | self.stoken_refine = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 109 | attn_drop=attn_drop, proj_drop=proj_drop) 110 | 111 | def stoken_forward(self, x): 112 | ''' 113 | x: (B, C, H, W) 114 | ''' 115 | B, C, H0, W0 = x.shape 116 | h, w = self.stoken_size 117 | 118 | pad_l = pad_t = 0 119 | pad_r = (w - W0 % w) % w 120 | pad_b = (h - H0 % h) % h 121 | if pad_r > 0 or pad_b > 0: 122 | x = F.pad(x, (pad_l, pad_r, pad_t, pad_b)) 123 | 124 | _, _, H, W = x.shape 125 | 126 | hh, ww = H // h, W // w 127 | 128 | stoken_features = F.adaptive_avg_pool2d(x, (hh, ww)) # (B, C, hh, ww) 129 | 130 | pixel_features = x.reshape(B, C, hh, h, ww, w).permute(0, 2, 4, 3, 5, 1).reshape(B, hh * ww, h * w, C) 131 | 132 | with torch.no_grad(): 133 | for idx in range(self.n_iter): 134 | stoken_features = self.unfold(stoken_features) # (B, C*9, hh*ww) 135 | stoken_features = stoken_features.transpose(1, 2).reshape(B, hh * ww, C, 9) 136 | affinity_matrix = pixel_features @ stoken_features * self.scale # (B, hh*ww, h*w, 9) 137 | 138 | affinity_matrix = affinity_matrix.softmax(-1) # (B, hh*ww, h*w, 9) 139 | 140 | affinity_matrix_sum = affinity_matrix.sum(2).transpose(1, 2).reshape(B, 9, hh, ww) 141 | 142 | affinity_matrix_sum = self.fold(affinity_matrix_sum) 143 | if idx < self.n_iter - 1: 144 | stoken_features = pixel_features.transpose(-1, -2) @ affinity_matrix # (B, hh*ww, C, 9) 145 | 146 | stoken_features = self.fold(stoken_features.permute(0, 2, 3, 1).reshape(B * C, 9, hh, ww)).reshape( 147 | B, C, hh, ww) 148 | 149 | stoken_features = stoken_features / (affinity_matrix_sum + 1e-12) # (B, C, hh, ww) 150 | 151 | stoken_features = pixel_features.transpose(-1, -2) @ affinity_matrix # (B, hh*ww, C, 9) 152 | 153 | stoken_features = self.fold(stoken_features.permute(0, 2, 3, 1).reshape(B * C, 9, hh, ww)).reshape(B, C, hh, ww) 154 | 155 | stoken_features = stoken_features / (affinity_matrix_sum.detach() + 1e-12) # (B, C, hh, ww) 156 | 157 | stoken_features = self.stoken_refine(stoken_features) 158 | 159 | stoken_features = self.unfold(stoken_features) # (B, C*9, hh*ww) 160 | stoken_features = stoken_features.transpose(1, 2).reshape(B, hh * ww, C, 9) # (B, hh*ww, C, 9) 161 | 162 | pixel_features = stoken_features @ affinity_matrix.transpose(-1, -2) # (B, hh*ww, C, h*w) 163 | 164 | pixel_features = pixel_features.reshape(B, hh, ww, C, h, w).permute(0, 3, 1, 4, 2, 5).reshape(B, C, H, W) 165 | 166 | if pad_r > 0 or pad_b > 0: 167 | pixel_features = pixel_features[:, :, :H0, :W0] 168 | 169 | return pixel_features 170 | 171 | def direct_forward(self, x): 172 | B, C, H, W = x.shape 173 | stoken_features = x 174 | stoken_features = self.stoken_refine(stoken_features) 175 | return stoken_features 176 | 177 | def forward(self, x): 178 | if self.stoken_size[0] > 1 or self.stoken_size[1] > 1: 179 | return self.stoken_forward(x) 180 | else: 181 | return self.direct_forward(x) 182 | 183 | 184 | # 输入 N C H W, 输出 N C H W 185 | if __name__ == '__main__': 186 | input = torch.randn(3, 64, 32, 64).cuda() 187 | se = StokenAttention(64, stoken_size=[8,8]).cuda() 188 | output = se(input) 189 | print(output.shape) 190 | -------------------------------------------------------------------------------- /PFNet(点云).py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.data 5 | 6 | """ 7 | 在本文中,我们提出了一种点分形网络(PF-Net),这是一种基于学习的新型方法,用于精确和高保真度的点云完成。 8 | 与现有的点云完成网络不同,现有的点云完成网络从不完整的点云生成点云的整体形状,并且总是改变现有点并遇到噪声和几何损失,PF-Net保留了不完全点云的空间排列,并可以计算出预测中缺失区域的详细几何结构。 9 | 为了成功完成这项任务,PF-Net利用基于特征点的多尺度发电网络,对缺失的点云进行分层估计。此外,我们将多阶段完成损失和对抗性损失相加,以生成更真实的缺失区域。 10 | 对抗性损失可以更好地处理预测中的多种模式。我们的实验证明了我们的方法在几个具有挑战性的点云完成任务中的有效性。 11 | """ 12 | 13 | 14 | class Convlayer(nn.Module): 15 | def __init__(self, point_scales): 16 | super(Convlayer, self).__init__() 17 | self.point_scales = point_scales 18 | self.conv1 = torch.nn.Conv2d(1, 64, (1, 3)) 19 | self.conv2 = torch.nn.Conv2d(64, 64, 1) 20 | self.conv3 = torch.nn.Conv2d(64, 128, 1) 21 | self.conv4 = torch.nn.Conv2d(128, 256, 1) 22 | self.conv5 = torch.nn.Conv2d(256, 512, 1) 23 | self.conv6 = torch.nn.Conv2d(512, 1024, 1) 24 | self.maxpool = torch.nn.MaxPool2d((self.point_scales, 1), 1) 25 | self.bn1 = nn.BatchNorm2d(64) 26 | self.bn2 = nn.BatchNorm2d(64) 27 | self.bn3 = nn.BatchNorm2d(128) 28 | self.bn4 = nn.BatchNorm2d(256) 29 | self.bn5 = nn.BatchNorm2d(512) 30 | self.bn6 = nn.BatchNorm2d(1024) 31 | 32 | def forward(self, x): 33 | x = torch.unsqueeze(x, 1) 34 | x = F.relu(self.bn1(self.conv1(x))) 35 | x = F.relu(self.bn2(self.conv2(x))) 36 | x_128 = F.relu(self.bn3(self.conv3(x))) 37 | x_256 = F.relu(self.bn4(self.conv4(x_128))) 38 | x_512 = F.relu(self.bn5(self.conv5(x_256))) 39 | x_1024 = F.relu(self.bn6(self.conv6(x_512))) 40 | x_128 = torch.squeeze(self.maxpool(x_128), 2) 41 | x_256 = torch.squeeze(self.maxpool(x_256), 2) 42 | x_512 = torch.squeeze(self.maxpool(x_512), 2) 43 | x_1024 = torch.squeeze(self.maxpool(x_1024), 2) 44 | L = [x_1024, x_512, x_256, x_128] 45 | x = torch.cat(L, 1) 46 | return x 47 | 48 | 49 | class Latentfeature(nn.Module): 50 | def __init__(self, num_scales, each_scales_size, point_scales_list): 51 | super(Latentfeature, self).__init__() 52 | self.num_scales = num_scales 53 | self.each_scales_size = each_scales_size 54 | self.point_scales_list = point_scales_list 55 | self.Convlayers1 = nn.ModuleList( 56 | [Convlayer(point_scales=self.point_scales_list[0]) for i in range(self.each_scales_size)]) 57 | self.Convlayers2 = nn.ModuleList( 58 | [Convlayer(point_scales=self.point_scales_list[1]) for i in range(self.each_scales_size)]) 59 | self.Convlayers3 = nn.ModuleList( 60 | [Convlayer(point_scales=self.point_scales_list[2]) for i in range(self.each_scales_size)]) 61 | self.conv1 = torch.nn.Conv1d(3, 1, 1) 62 | self.bn1 = nn.BatchNorm1d(1) 63 | 64 | def forward(self, x): 65 | outs = [] 66 | for i in range(self.each_scales_size): 67 | outs.append(self.Convlayers1[i](x[0])) 68 | for j in range(self.each_scales_size): 69 | outs.append(self.Convlayers2[j](x[1])) 70 | for k in range(self.each_scales_size): 71 | outs.append(self.Convlayers3[k](x[2])) 72 | latentfeature = torch.cat(outs, 2) 73 | latentfeature = latentfeature.transpose(1, 2) 74 | latentfeature = F.relu(self.bn1(self.conv1(latentfeature))) 75 | latentfeature = torch.squeeze(latentfeature, 1) 76 | return latentfeature 77 | 78 | 79 | class PointcloudCls(nn.Module): 80 | def __init__(self, num_scales, each_scales_size, point_scales_list, k=40): 81 | super(PointcloudCls, self).__init__() 82 | self.latentfeature = Latentfeature(num_scales, each_scales_size, point_scales_list) 83 | self.fc1 = nn.Linear(1920, 1024) 84 | self.fc2 = nn.Linear(1024, 512) 85 | self.fc3 = nn.Linear(512, 256) 86 | self.fc4 = nn.Linear(256, k) 87 | self.dropout = nn.Dropout(p=0.3) 88 | self.bn1 = nn.BatchNorm1d(1024) 89 | self.bn2 = nn.BatchNorm1d(512) 90 | self.bn3 = nn.BatchNorm1d(256) 91 | self.relu = nn.ReLU() 92 | 93 | def forward(self, x): 94 | x = self.latentfeature(x) 95 | x = F.relu(self.bn1(self.fc1(x))) 96 | x = F.relu(self.bn2(self.dropout(self.fc2(x)))) 97 | x = F.relu(self.bn3(self.dropout(self.fc3(x)))) 98 | x = self.fc4(x) 99 | return F.log_softmax(x, dim=1) 100 | 101 | 102 | class _netG(nn.Module): 103 | def __init__(self, num_scales, each_scales_size, point_scales_list, point_num): 104 | super(_netG, self).__init__() 105 | self.point_num = point_num # 保存输入的点数 106 | self.latentfeature = Latentfeature(num_scales, each_scales_size, point_scales_list) 107 | self.fc1 = nn.Linear(1920, 1024) 108 | self.fc2 = nn.Linear(1024, 512) 109 | self.fc3 = nn.Linear(512, 256) 110 | self.fc_final = nn.Linear(256, point_num * 3) # 最后一个全连接层输出维度为 point_num * 3 111 | 112 | def forward(self, x): 113 | x = self.latentfeature(x) 114 | x = F.relu(self.fc1(x)) 115 | x = F.relu(self.fc2(x)) 116 | x = F.relu(self.fc3(x)) 117 | x = self.fc_final(x) # 输出维度为 [batch_size, point_num * 3] 118 | x = x.reshape(-1, self.point_num, 3) # 重塑为 [batch_size, point_num, 3] 119 | return x 120 | 121 | 122 | 123 | class _netlocalD(nn.Module): 124 | def __init__(self, crop_point_num): 125 | super(_netlocalD, self).__init__() 126 | self.crop_point_num = crop_point_num 127 | self.conv1 = torch.nn.Conv2d(1, 64, (1, 3)) 128 | self.conv2 = torch.nn.Conv2d(64, 64, 1) 129 | self.conv3 = torch.nn.Conv2d(64, 128, 1) 130 | self.conv4 = torch.nn.Conv2d(128, 256, 1) 131 | self.maxpool = torch.nn.MaxPool2d((self.crop_point_num, 1), 1) 132 | self.bn1 = nn.BatchNorm2d(64) 133 | self.bn2 = nn.BatchNorm2d(64) 134 | self.bn3 = nn.BatchNorm2d(128) 135 | self.bn4 = nn.BatchNorm2d(256) 136 | self.fc1 = nn.Linear(448, 256) 137 | self.fc2 = nn.Linear(256, 128) 138 | self.fc3 = nn.Linear(128, 16) 139 | self.fc4 = nn.Linear(16, 1) 140 | self.bn_1 = nn.BatchNorm1d(256) 141 | self.bn_2 = nn.BatchNorm1d(128) 142 | self.bn_3 = nn.BatchNorm1d(16) 143 | 144 | def forward(self, x): 145 | x = F.relu(self.bn1(self.conv1(x))) 146 | x_64 = F.relu(self.bn2(self.conv2(x))) 147 | x_128 = F.relu(self.bn3(self.conv3(x_64))) 148 | x_256 = F.relu(self.bn4(self.conv4(x_128))) 149 | x_64 = torch.squeeze(self.maxpool(x_64)) 150 | x_128 = torch.squeeze(self.maxpool(x_128)) 151 | x_256 = torch.squeeze(self.maxpool(x_256)) 152 | Layers = [x_256, x_128, x_64] 153 | x = torch.cat(Layers, 1) 154 | x = F.relu(self.bn_1(self.fc1(x))) 155 | x = F.relu(self.bn_2(self.fc2(x))) 156 | x = F.relu(self.bn_3(self.fc3(x))) 157 | x = self.fc4(x) 158 | return x 159 | 160 | 161 | 162 | if __name__ == '__main__': 163 | # 假设你有三个不同尺度的输入,尺度分别是2048, 512, 256 164 | input1 = torch.randn(64, 2048, 3) # 第一个尺度的输入 165 | input2 = torch.randn(64, 512, 3) # 第二个尺度的输入 166 | input3 = torch.randn(64, 256, 3) # 第三个尺度的输入 167 | 168 | # 将这三个输入作为一个列表传递给模型 169 | input_ = [input1, input2, input3] 170 | 171 | # 初始化模型 172 | netG = _netG(num_scales=3, each_scales_size=1, point_scales_list=[2048, 512, 256], point_num=2048) 173 | 174 | # 调用模型并获取输出 175 | output = netG(input_) 176 | print(output.shape) # 输出的形状应该与 input1 相同 177 | 178 | 179 | -------------------------------------------------------------------------------- /EGA(边缘检测,CV2维图像通用).py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore") 3 | import torch 4 | if torch.cuda.is_available(): 5 | torch.cuda.init() 6 | import torch.nn.functional as F 7 | import torch.nn as nn 8 | 9 | 10 | """ Edge-Guided Attention module (EGA) 11 | 这个模块实现的是一个基于边缘引导的注意力模块 (Edge-Guided Attention Module, EGA),主要用于计算机视觉中的图像处理任务,特别是关注于边缘信息的场景。下面是对代码各个部分及其用途的详细解释: 12 | 13 | 高斯模糊和金字塔 14 | gauss_kernel(channels=3, cuda=True): 创建一个用于高斯模糊的卷积核。 15 | downsample(x): 实现下采样,将图像尺寸减半。 16 | conv_gauss(img, kernel): 对图像进行高斯卷积模糊处理。 17 | upsample(x, channels): 实现上采样,将图像尺寸加倍。 18 | make_laplace(img, channels): 生成拉普拉斯金字塔的一层,捕获高频信息。 19 | make_laplace_pyramid(img, level, channels): 生成整个拉普拉斯金字塔,包含多层高频信息。 20 | 这些函数的组合用于提取和处理图像的多尺度特征,特别是高频边缘信息,这对于图像分割和检测等任务非常有用。 21 | 22 | CBAM (Convolutional Block Attention Module) 23 | ChannelGate: 用于生成通道注意力权重,基于全局平均池化和最大池化。 24 | SpatialGate: 用于生成空间注意力权重,基于最大池化和平均池化的特征融合。 25 | CBAM: 组合ChannelGate和SpatialGate,提供通道和空间的联合注意力机制。 26 | CBAM模块用于增强重要特征,抑制不重要特征,从而提升模型的性能。 27 | 28 | EGA (Edge-Guided Attention Module) 29 | EGA类: 主要模块,整合边缘特征、输入特征和预测特征,通过注意力机制融合这些特征,从而增强重要特征的表达。 30 | EGA模块的处理流程如下: 31 | 32 | Reverse Attention: 基于预测结果计算背景注意力,并得到背景特征。 33 | Boundary Attention: 通过拉普拉斯金字塔提取预测边缘特征,并生成预测特征。 34 | High-Frequency Feature: 利用高频边缘特征,生成输入特征。 35 | Feature Fusion: 将以上三种特征融合,经过卷积和注意力机制生成融合特征。 36 | Output: 融合特征加上残差,经过CBAM模块,得到最终输出。 37 | 主要用途 38 | 这个模块主要用于增强图像处理任务中的边缘和高频特征,适用于以下场景: 39 | 图像分割: 提高边界精度。 40 | 目标检测: 强化边缘特征,提升检测效果。 41 | 图像超分辨率: 更好地恢复高频细节。 42 | 通过EGA模块,可以更好地捕捉和利用图像中的边缘和细节信息,提升模型在各种图像处理任务中的表现。 43 | """ 44 | 45 | def gauss_kernel(channels=3, cuda=True): 46 | kernel = torch.tensor([[1., 4., 6., 4., 1], 47 | [4., 16., 24., 16., 4.], 48 | [6., 24., 36., 24., 6.], 49 | [4., 16., 24., 16., 4.], 50 | [1., 4., 6., 4., 1.]]) 51 | kernel /= 256. 52 | kernel = kernel.repeat(channels, 1, 1, 1) 53 | if cuda: 54 | kernel = kernel.cuda() 55 | return kernel 56 | 57 | 58 | def downsample(x): 59 | return x[:, :, ::2, ::2] 60 | 61 | 62 | def conv_gauss(img, kernel): 63 | img = F.pad(img, (2, 2, 2, 2), mode='reflect') 64 | out = F.conv2d(img, kernel, groups=img.shape[1]) 65 | return out 66 | 67 | 68 | def upsample(x, channels): 69 | cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3], device=x.device)], dim=3) 70 | cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3]) 71 | cc = cc.permute(0, 1, 3, 2) 72 | cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2, device=x.device)], dim=3) 73 | cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2) 74 | x_up = cc.permute(0, 1, 3, 2) 75 | return conv_gauss(x_up, 4 * gauss_kernel(channels)) 76 | 77 | 78 | def make_laplace(img, channels): 79 | filtered = conv_gauss(img, gauss_kernel(channels)) 80 | down = downsample(filtered) 81 | up = upsample(down, channels) 82 | if up.shape[2] != img.shape[2] or up.shape[3] != img.shape[3]: 83 | up = nn.functional.interpolate(up, size=(img.shape[2], img.shape[3])) 84 | diff = img - up 85 | return diff 86 | 87 | 88 | def make_laplace_pyramid(img, level, channels): 89 | current = img 90 | pyr = [] 91 | for _ in range(level): 92 | filtered = conv_gauss(current, gauss_kernel(channels)) 93 | down = downsample(filtered) 94 | up = upsample(down, channels) 95 | if up.shape[2] != current.shape[2] or up.shape[3] != current.shape[3]: 96 | up = nn.functional.interpolate(up, size=(current.shape[2], current.shape[3])) 97 | diff = current - up 98 | pyr.append(diff) 99 | current = down 100 | pyr.append(current) 101 | return pyr 102 | 103 | 104 | class ChannelGate(nn.Module): 105 | def __init__(self, gate_channels, reduction_ratio=16): 106 | super(ChannelGate, self).__init__() 107 | self.gate_channels = gate_channels 108 | self.mlp = nn.Sequential( 109 | nn.Flatten(), 110 | nn.Linear(gate_channels, gate_channels // reduction_ratio), 111 | nn.ReLU(), 112 | nn.Linear(gate_channels // reduction_ratio, gate_channels) 113 | ) 114 | 115 | def forward(self, x): 116 | avg_out = self.mlp(F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))).unsqueeze(-1).unsqueeze(-1) 117 | max_out = self.mlp(F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))).unsqueeze(-1).unsqueeze(-1) 118 | channel_att_sum = avg_out + max_out 119 | 120 | scale = torch.sigmoid(channel_att_sum).expand_as(x) 121 | return x * scale 122 | 123 | 124 | 125 | class SpatialGate(nn.Module): 126 | def __init__(self): 127 | super(SpatialGate, self).__init__() 128 | kernel_size = 7 129 | self.spatial = nn.Conv2d(2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2) 130 | 131 | def forward(self, x): 132 | x_compress = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1) 133 | x_out = self.spatial(x_compress) 134 | scale = torch.sigmoid(x_out) # broadcasting 135 | return x * scale 136 | 137 | 138 | class CBAM(nn.Module): 139 | def __init__(self, gate_channels, reduction_ratio=16): 140 | super(CBAM, self).__init__() 141 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio) 142 | self.SpatialGate = SpatialGate() 143 | 144 | def forward(self, x): 145 | x_out = self.ChannelGate(x) 146 | x_out = self.SpatialGate(x_out) 147 | return x_out 148 | 149 | 150 | # Edge-Guided Attention Module 151 | class EGA(nn.Module): 152 | def __init__(self, in_channels): 153 | super(EGA, self).__init__() 154 | 155 | self.fusion_conv = nn.Sequential( 156 | nn.Conv2d(in_channels * 3, in_channels, 3, 1, 1), 157 | nn.BatchNorm2d(in_channels), 158 | nn.ReLU(inplace=True)) 159 | 160 | self.attention = nn.Sequential( 161 | nn.Conv2d(in_channels, 1, 3, 1, 1), 162 | nn.BatchNorm2d(1), 163 | nn.Sigmoid()) 164 | 165 | self.cbam = CBAM(in_channels) 166 | 167 | def forward(self, edge_feature, x, pred): 168 | residual = x 169 | xsize = x.size()[2:] 170 | 171 | pred = torch.sigmoid(pred) 172 | 173 | # reverse attention 174 | background_att = 1 - pred 175 | background_x = x * background_att 176 | 177 | # boudary attention 178 | edge_pred = make_laplace(pred, 1) 179 | pred_feature = x * edge_pred 180 | 181 | # high-frequency feature 182 | edge_input = F.interpolate(edge_feature, size=xsize, mode='bilinear', align_corners=True) 183 | input_feature = x * edge_input 184 | 185 | fusion_feature = torch.cat([background_x, pred_feature, input_feature], dim=1) 186 | fusion_feature = self.fusion_conv(fusion_feature) 187 | 188 | attention_map = self.attention(fusion_feature) 189 | fusion_feature = fusion_feature * attention_map 190 | 191 | out = fusion_feature + residual 192 | out = self.cbam(out) 193 | return out 194 | 195 | if __name__ == '__main__': 196 | in_channels = 3 197 | height, width = 224, 224 198 | edge_feature = torch.rand(1, in_channels, height, width).cuda() 199 | x = torch.rand(1, in_channels, height, width).cuda() 200 | pred = torch.rand(1, 1, height, width).cuda() 201 | 202 | block = EGA(in_channels).cuda() 203 | output = block(edge_feature, x, pred) 204 | 205 | print("Edge feature size:", edge_feature.size()) # torch.Size([1, 3, 224, 224]) 206 | print("Input feature size:", x.size()) # torch.Size([1, 3, 224, 224]) 207 | print("Prediction size:", pred.size()) # torch.Size([1, 1, 224, 224]) 208 | print("Output size:", output.size()) # torch.Size([1, 3, 224, 224]) 209 | 210 | -------------------------------------------------------------------------------- /F_Block(频域模块用于时间序列).py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from math import sqrt 4 | from torch.nn.functional import relu 5 | 6 | """ 7 | 时序数据分析的复杂性极大地受益于时域和频域表示提供的独特优势。虽然时域在表示局部依赖性方面更胜一筹,特别是在非周期序列中,但频域在捕获全局依赖性方面表现出色,使其成为具有明显周期模式的序列的理想选择。 8 | 为了利用这两个优势,我们提出了ATFNet,这是一个创新的框架,它结合了时域模块和频域模块,以同时捕获时间序列数据中的局部和全局依赖关系。 9 | 具体来说,我们引入了主谐波级数能量加权,这是一种基于输入时间序列的周期性动态调整两个模块之间权重的新机制。 10 | 在频域模块中,我们通过扩展DFT增强了传统的离散傅里叶变换(DFT),旨在解决离散频率错位的挑战。 11 | """ 12 | 13 | def complex_mse(pred, label): 14 | delta = pred - label 15 | return torch.mean(torch.abs(delta) ** 2) 16 | 17 | def complex_relu(input): 18 | return relu(input.real).type(torch.complex64) + 1j * relu(input.imag).type(torch.complex64) 19 | 20 | def complex_dropout(input, dropout): 21 | mask_r = torch.ones(input.shape, dtype=torch.float32) 22 | mask_r = dropout(mask_r) 23 | mask_i = torch.zeros_like(mask_r) 24 | mask = torch.complex(mask_r, mask_i).to(input.device) 25 | return mask * input 26 | 27 | def apply_complex(fr, fi, input, dtype=torch.complex64): 28 | return (fr(input.real) - fi(input.imag)).type(dtype) \ 29 | + 1j * (fr(input.imag) + fi(input.real)).type(dtype) 30 | 31 | def complex_mul(order, mat1, mat2): 32 | return (torch.einsum(order, mat1.real, mat2.real) - torch.einsum(order, mat1.imag, mat2.imag)) \ 33 | + 1j * (torch.einsum(order, mat1.real, mat2.imag) - torch.einsum(order, mat1.imag, mat2.real)) 34 | 35 | class ComplexLN(nn.Module): 36 | def __init__(self, C): 37 | super(ComplexLN, self).__init__() 38 | z = torch.zeros((C)) 39 | self.w_r = nn.Parameter(torch.tensor(sqrt(2)/2) + z) 40 | self.w_i = nn.Parameter(torch.tensor(sqrt(2)/2) + z) 41 | self.b_r = nn.Parameter(torch.tensor(0) + z) 42 | self.b_i = nn.Parameter(torch.tensor(0) + z) 43 | def forward(self, x): 44 | means = torch.mean(x, dim=-1, keepdim=True) 45 | std = torch.sqrt(torch.var(x, dim=-1, keepdim=True)) 46 | x = (x - means) / std 47 | w = torch.complex(self.w_r, self.w_i) 48 | b = torch.complex(self.b_r, self.b_i) 49 | return x * w + b 50 | 51 | class ComplexLinear(nn.Module): 52 | def __init__(self, in_features, out_features): 53 | super(ComplexLinear, self).__init__() 54 | self.fc_r = nn.Linear(in_features, out_features) 55 | self.fc_i = nn.Linear(in_features, out_features) 56 | def forward(self, input): 57 | return apply_complex(self.fc_r, self.fc_i, input) 58 | 59 | class ComplexAttention(nn.Module): 60 | def __init__(self): 61 | super(ComplexAttention, self).__init__() 62 | self.dropout = nn.Dropout(0.1) 63 | self.is_activate = True 64 | def forward(self, q, k, v): 65 | scores = complex_mul("bnlhe,bnshe->bnhls", q, k) 66 | if self.is_activate: 67 | scores = self.dropout(torch.softmax(torch.abs(scores), dim=-1)) 68 | scores = torch.complex(scores, torch.zeros_like(scores)) 69 | V = complex_mul("bnhls,bnshd->bnlhd", scores, v) 70 | return V.contiguous() 71 | 72 | class ComplexAttentionLayer(nn.Module): 73 | def __init__(self, d_model, n_heads, attention=ComplexAttention(), d_keys=None, 74 | d_values=None): 75 | super(ComplexAttentionLayer, self).__init__() 76 | 77 | d_keys = d_keys or (d_model // n_heads) 78 | d_values = d_values or (d_model // n_heads) 79 | 80 | self.inner_attention = attention 81 | self.query_projection = ComplexLinear(d_model, d_keys * n_heads) 82 | self.key_projection = ComplexLinear(d_model, d_keys * n_heads) 83 | self.value_projection = ComplexLinear(d_model, d_values * n_heads) 84 | self.out_projection = ComplexLinear(d_values * n_heads, d_model) 85 | self.n_heads = n_heads 86 | 87 | def forward(self, queries, keys, values): # queries.shape = [batch_size, n_vars, d_model] 88 | B, L, _ = queries.shape 89 | _, S, _ = keys.shape 90 | H = self.n_heads 91 | 92 | queries = self.query_projection(queries).view(B, L, -1, H, 1) 93 | keys = self.key_projection(keys).view(B, S, -1, H, 1) 94 | values = self.value_projection(values).view(B, S, -1, H, 1) 95 | 96 | out= self.inner_attention( 97 | queries, 98 | keys, 99 | values, 100 | ) 101 | out = out.view(B, L, -1) 102 | 103 | return self.out_projection(out) 104 | 105 | class ComplexEncoderLayer(nn.Module): 106 | def __init__(self, attention, d_model, d_ff, dropout=0.2): 107 | super(ComplexEncoderLayer, self).__init__() 108 | self.norm1 = ComplexLN(d_model) 109 | self.norm2 = ComplexLN(d_model) 110 | self.attention = attention 111 | self.Linear1 = ComplexLinear(d_model, d_ff) 112 | self.Linear2 = ComplexLinear(d_ff, d_model) 113 | self.dropout = nn.Dropout(dropout) 114 | 115 | def forward(self, x): 116 | new_x = self.attention( 117 | x, x, x 118 | ) 119 | new_x = complex_dropout(new_x, self.dropout) 120 | x = x + new_x 121 | 122 | x = self.norm1(x) 123 | y = x 124 | y = complex_relu(self.Linear1(y)) 125 | y = complex_dropout(y, self.dropout) 126 | y = self.Linear2(y) 127 | y = complex_dropout(y, self.dropout) 128 | 129 | return self.norm2(x + y) 130 | 131 | class ComplexEncoder(nn.Module): 132 | def __init__(self, attn_layers): 133 | super(ComplexEncoder, self).__init__() 134 | self.attn_layers = nn.ModuleList(attn_layers) 135 | def forward(self, x): 136 | for attn_layer in self.attn_layers: 137 | x = attn_layer(x) 138 | return x 139 | 140 | class CompEncoderBlock(nn.Module): 141 | # Input: Extended_DFT(input sequence), shape: [B, L:(seq_len+pred_len)//2+1, n_vars] dtype: torch.cfloat 142 | # Output: DFT(input sequence + pred sequence), shape: [B, L:(seq_len+pred_len)//2+1, n_vars] dtype: torch.cfloat 143 | def __init__(self, configs, extended=True): 144 | super(CompEncoderBlock, self).__init__() 145 | self.configs = configs 146 | self.device = configs.device 147 | self.is_emb = True 148 | self.ori_len = int((configs.seq_len) / 2) + 1 149 | self.tar_len = int((configs.seq_len + configs.pred_len) / 2) + 1 150 | if extended: 151 | self.ori_len = int((configs.seq_len + configs.pred_len) / 2) + 1 152 | self.d_ff = configs.fnet_d_ff 153 | self.d_model = configs.fnet_d_model 154 | if not self.is_emb: 155 | self.d_model = self.tar_len 156 | self.emb = ComplexLinear(self.ori_len, self.d_model) 157 | self.dropout = nn.Dropout(configs.complex_dropout) 158 | self.projection = ComplexLinear(self.d_model, self.tar_len) 159 | self.encoder = ComplexEncoder( 160 | attn_layers=[ 161 | ComplexEncoderLayer( 162 | attention=ComplexAttentionLayer( 163 | d_model=self.d_model, n_heads=configs.n_heads 164 | ), 165 | d_model=self.d_model, 166 | d_ff=self.d_ff, 167 | dropout = configs.complex_dropout 168 | ) for _ in range(configs.fnet_layers) 169 | ] 170 | ) 171 | def forward(self, x): 172 | x = x.permute(0, 2, 1) 173 | if self.is_emb: 174 | x = self.emb(x) 175 | x = complex_dropout(x, self.dropout) 176 | x = self.encoder(x) 177 | x = self.projection(x) 178 | return x.permute(0, 2, 1) 179 | 180 | 181 | 182 | class F_Block(nn.Module): 183 | def __init__(self, configs): 184 | super(F_Block, self).__init__() 185 | self.configs = configs 186 | self.seq_len = configs.seq_len 187 | self.pred_len = configs.pred_len 188 | self.model = CompEncoderBlock(configs) 189 | 190 | 191 | # def forward(self, x_enc, x_enc_mark, x_dec, x_dec_mark): 192 | def forward(self, x_enc): 193 | paddings = torch.zeros((x_enc.shape[0], self.pred_len ,x_enc.shape[2])).to(x_enc.device) 194 | x_enc = torch.concatenate((x_enc, paddings), dim=1) 195 | freq = torch.fft.rfft(x_enc, dim=1) # [B, L, n_vars], dtype=torch.complex 196 | freq = freq / x_enc.shape[1] 197 | 198 | # Frequency Normalization 199 | means = torch.mean(freq, dim=1) 200 | freq_abs = torch.abs(freq) 201 | stdev = torch.sqrt(torch.var(freq_abs, dim=1, keepdim=True)) 202 | freq = (freq - means.unsqueeze(1).detach()) / stdev 203 | 204 | freq_pred = self.model(freq) 205 | 206 | # Frequency De-Normalization 207 | freq_pred = freq_pred * stdev 208 | freq_pred = freq_pred + means.unsqueeze(1).detach() 209 | 210 | 211 | freq_pred = freq_pred * freq_pred.shape[1] 212 | pred_seq = torch.fft.irfft(freq_pred, dim=1)[:, -self.configs.pred_len:] 213 | return pred_seq 214 | 215 | 216 | if __name__ == '__main__': 217 | # 定义一个示例配置对象 218 | class Configs: 219 | def __init__(self): 220 | self.seq_len = 96 # 输入序列长度 221 | self.pred_len = 96 # 预测序列长度 222 | self.d_model = 512 # 模型的维度 223 | self.factor = 5 # 用于缩放注意力机制的因子 224 | self.n_heads = 8 # 注意力头的数量 225 | self.e_layers = 3 # 编码器的层数 226 | self.d_ff = 2048 # 前馈神经网络的维度 227 | self.dropout = 0.1 # dropout的概率 228 | self.activation = 'gelu' # 激活函数 229 | self.enc_in = 7 # 编码器输入的特征数量 230 | self.dec_in = 7 # 解码器输入的特征数量 231 | self.c_out = 1 # 输出的特征数量 232 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' # 使用GPU还是CPU 233 | self.fnet_d_ff = 1024 # 频域前馈神经网络的维度 234 | self.fnet_d_model = 512 # 频域模型的维度 235 | self.complex_dropout = 0.1 # 复数dropout的概率 236 | self.fnet_layers = 2 # 频域网络的层数 237 | self.is_emb = False # 是否使用嵌入层 238 | 239 | configs = Configs() 240 | 241 | block = F_Block(configs).to(configs.device) # 初始化并将模型移动到指定设备 242 | 243 | x_enc = torch.rand(2, configs.seq_len, configs.enc_in).to(configs.device) # (batch_size, seq_len, n_vars) 244 | 245 | 246 | output = block(x_enc) 247 | 248 | # 打印输入和输出张量的尺寸 249 | print("x_enc size: ", x_enc.size()) 250 | 251 | print("Output size: ", output.size()) 252 | 253 | -------------------------------------------------------------------------------- /efficient kan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import math 4 | 5 | 6 | class KANLinear(torch.nn.Module): 7 | def __init__( 8 | self, 9 | in_features, 10 | out_features, 11 | grid_size=5, 12 | spline_order=3, 13 | scale_noise=0.1, 14 | scale_base=1.0, 15 | scale_spline=1.0, 16 | enable_standalone_scale_spline=True, 17 | base_activation=torch.nn.SiLU, 18 | grid_eps=0.02, 19 | grid_range=[-1, 1], 20 | ): 21 | super(KANLinear, self).__init__() 22 | self.in_features = in_features 23 | self.out_features = out_features 24 | self.grid_size = grid_size 25 | self.spline_order = spline_order 26 | 27 | h = (grid_range[1] - grid_range[0]) / grid_size 28 | grid = ( 29 | ( 30 | torch.arange(-spline_order, grid_size + spline_order + 1) * h 31 | + grid_range[0] 32 | ) 33 | .expand(in_features, -1) 34 | .contiguous() 35 | ) 36 | self.register_buffer("grid", grid) 37 | 38 | self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) 39 | self.spline_weight = torch.nn.Parameter( 40 | torch.Tensor(out_features, in_features, grid_size + spline_order) 41 | ) 42 | if enable_standalone_scale_spline: 43 | self.spline_scaler = torch.nn.Parameter( 44 | torch.Tensor(out_features, in_features) 45 | ) 46 | 47 | self.scale_noise = scale_noise 48 | self.scale_base = scale_base 49 | self.scale_spline = scale_spline 50 | self.enable_standalone_scale_spline = enable_standalone_scale_spline 51 | self.base_activation = base_activation() 52 | self.grid_eps = grid_eps 53 | 54 | self.reset_parameters() 55 | 56 | def reset_parameters(self): 57 | torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base) 58 | with torch.no_grad(): 59 | noise = ( 60 | ( 61 | torch.rand(self.grid_size + 1, self.in_features, self.out_features) 62 | - 1 / 2 63 | ) 64 | * self.scale_noise 65 | / self.grid_size 66 | ) 67 | self.spline_weight.data.copy_( 68 | (self.scale_spline if not self.enable_standalone_scale_spline else 1.0) 69 | * self.curve2coeff( 70 | self.grid.T[self.spline_order : -self.spline_order], 71 | noise, 72 | ) 73 | ) 74 | if self.enable_standalone_scale_spline: 75 | # torch.nn.init.constant_(self.spline_scaler, self.scale_spline) 76 | torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline) 77 | 78 | def b_splines(self, x: torch.Tensor): 79 | """ 80 | Compute the B-spline bases for the given input tensor. 81 | 82 | Args: 83 | x (torch.Tensor): Input tensor of shape (batch_size, in_features). 84 | 85 | Returns: 86 | torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order). 87 | """ 88 | assert x.dim() == 2 and x.size(1) == self.in_features 89 | 90 | grid: torch.Tensor = ( 91 | self.grid 92 | ) # (in_features, grid_size + 2 * spline_order + 1) 93 | x = x.unsqueeze(-1) 94 | bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype) 95 | for k in range(1, self.spline_order + 1): 96 | bases = ( 97 | (x - grid[:, : -(k + 1)]) 98 | / (grid[:, k:-1] - grid[:, : -(k + 1)]) 99 | * bases[:, :, :-1] 100 | ) + ( 101 | (grid[:, k + 1 :] - x) 102 | / (grid[:, k + 1 :] - grid[:, 1:(-k)]) 103 | * bases[:, :, 1:] 104 | ) 105 | 106 | assert bases.size() == ( 107 | x.size(0), 108 | self.in_features, 109 | self.grid_size + self.spline_order, 110 | ) 111 | return bases.contiguous() 112 | 113 | def curve2coeff(self, x: torch.Tensor, y: torch.Tensor): 114 | """ 115 | Compute the coefficients of the curve that interpolates the given points. 116 | 117 | Args: 118 | x (torch.Tensor): Input tensor of shape (batch_size, in_features). 119 | y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features). 120 | 121 | Returns: 122 | torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order). 123 | """ 124 | assert x.dim() == 2 and x.size(1) == self.in_features 125 | assert y.size() == (x.size(0), self.in_features, self.out_features) 126 | 127 | A = self.b_splines(x).transpose( 128 | 0, 1 129 | ) # (in_features, batch_size, grid_size + spline_order) 130 | B = y.transpose(0, 1) # (in_features, batch_size, out_features) 131 | solution = torch.linalg.lstsq( 132 | A, B 133 | ).solution # (in_features, grid_size + spline_order, out_features) 134 | result = solution.permute( 135 | 2, 0, 1 136 | ) # (out_features, in_features, grid_size + spline_order) 137 | 138 | assert result.size() == ( 139 | self.out_features, 140 | self.in_features, 141 | self.grid_size + self.spline_order, 142 | ) 143 | return result.contiguous() 144 | 145 | @property 146 | def scaled_spline_weight(self): 147 | return self.spline_weight * ( 148 | self.spline_scaler.unsqueeze(-1) 149 | if self.enable_standalone_scale_spline 150 | else 1.0 151 | ) 152 | 153 | def forward(self, x: torch.Tensor): 154 | assert x.dim() == 2 and x.size(1) == self.in_features 155 | 156 | base_output = F.linear(self.base_activation(x), self.base_weight) 157 | spline_output = F.linear( 158 | self.b_splines(x).view(x.size(0), -1), 159 | self.scaled_spline_weight.view(self.out_features, -1), 160 | ) 161 | return base_output + spline_output 162 | 163 | @torch.no_grad() 164 | def update_grid(self, x: torch.Tensor, margin=0.01): 165 | assert x.dim() == 2 and x.size(1) == self.in_features 166 | batch = x.size(0) 167 | 168 | splines = self.b_splines(x) # (batch, in, coeff) 169 | splines = splines.permute(1, 0, 2) # (in, batch, coeff) 170 | orig_coeff = self.scaled_spline_weight # (out, in, coeff) 171 | orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out) 172 | unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out) 173 | unreduced_spline_output = unreduced_spline_output.permute( 174 | 1, 0, 2 175 | ) # (batch, in, out) 176 | 177 | # sort each channel individually to collect data distribution 178 | x_sorted = torch.sort(x, dim=0)[0] 179 | grid_adaptive = x_sorted[ 180 | torch.linspace( 181 | 0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device 182 | ) 183 | ] 184 | 185 | uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size 186 | grid_uniform = ( 187 | torch.arange( 188 | self.grid_size + 1, dtype=torch.float32, device=x.device 189 | ).unsqueeze(1) 190 | * uniform_step 191 | + x_sorted[0] 192 | - margin 193 | ) 194 | 195 | grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive 196 | grid = torch.concatenate( 197 | [ 198 | grid[:1] 199 | - uniform_step 200 | * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1), 201 | grid, 202 | grid[-1:] 203 | + uniform_step 204 | * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1), 205 | ], 206 | dim=0, 207 | ) 208 | 209 | self.grid.copy_(grid.T) 210 | self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output)) 211 | 212 | def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): 213 | """ 214 | Compute the regularization loss. 215 | 216 | This is a dumb simulation of the original L1 regularization as stated in the 217 | paper, since the original one requires computing absolutes and entropy from the 218 | expanded (batch, in_features, out_features) intermediate tensor, which is hidden 219 | behind the F.linear function if we want an memory efficient implementation. 220 | 221 | The L1 regularization is now computed as mean absolute value of the spline 222 | weights. The authors implementation also includes this term in addition to the 223 | sample-based regularization. 224 | """ 225 | l1_fake = self.spline_weight.abs().mean(-1) 226 | regularization_loss_activation = l1_fake.sum() 227 | p = l1_fake / regularization_loss_activation 228 | regularization_loss_entropy = -torch.sum(p * p.log()) 229 | return ( 230 | regularize_activation * regularization_loss_activation 231 | + regularize_entropy * regularization_loss_entropy 232 | ) 233 | 234 | 235 | class KAN(torch.nn.Module): 236 | def __init__( 237 | self, 238 | layers_hidden, 239 | grid_size=5, 240 | spline_order=3, 241 | scale_noise=0.1, 242 | scale_base=1.0, 243 | scale_spline=1.0, 244 | base_activation=torch.nn.SiLU, 245 | grid_eps=0.02, 246 | grid_range=[-1, 1], 247 | ): 248 | super(KAN, self).__init__() 249 | self.grid_size = grid_size 250 | self.spline_order = spline_order 251 | 252 | self.layers = torch.nn.ModuleList() 253 | for in_features, out_features in zip(layers_hidden, layers_hidden[1:]): 254 | self.layers.append( 255 | KANLinear( 256 | in_features, 257 | out_features, 258 | grid_size=grid_size, 259 | spline_order=spline_order, 260 | scale_noise=scale_noise, 261 | scale_base=scale_base, 262 | scale_spline=scale_spline, 263 | base_activation=base_activation, 264 | grid_eps=grid_eps, 265 | grid_range=grid_range, 266 | ) 267 | ) 268 | 269 | def forward(self, x: torch.Tensor, update_grid=False): 270 | for layer in self.layers: 271 | if update_grid: 272 | layer.update_grid(x) 273 | x = layer(x) 274 | return x 275 | 276 | def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): 277 | return sum( 278 | layer.regularization_loss(regularize_activation, regularize_entropy) 279 | for layer in self.layers 280 | ) 281 | -------------------------------------------------------------------------------- /GKONet((三维人体姿态估计).py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import torch 3 | import torch.nn as nn 4 | from timm.models.layers import DropPath 5 | 6 | """ 7 | 作为 3D 人体姿态估计 (HPE) 的关键部分,建立 2D 到 3D 提升映射受到深度模糊性的限制。目前大多数工作普遍缺乏对提升映射中相对深度表达和深度模糊误差表达的定量分析,导致预测效率低、可解释性差。 8 | 为此,本文基于针孔成像原理挖掘和利用这些表达式的先验几何知识,解耦2D到3D的提升映射并简化模型训练。 9 | 具体来说,本文提出了一种具有两分支变压器架构的面向先验几何知识的姿态估计模型,明确引入高维先验几何特征以提高模型效率和可解释性。 10 | 它将空间坐标的回归转化为关节之间空间方向向量的预测,以生成多个可行解,进一步减轻深度模糊性。 11 | 此外,本文首次提出了一种基于先验几何关系与相对深度表达式解耦的基于非学习的绝对深度估计算法。 12 | 它建立从非根节点到根节点的多个独立深度映射来计算绝对深度候选,无参数、即插即用、可解释。 13 | 实验表明,所提出的姿态估计模型以更低的参数和更快的推理速度在 Human 3.6M 和 MPI-INF-3DHP 基准上实现了最先进的性能,并且所提出的绝对深度估计算法实现了与传统方法相似的性能,而无需任何网络参数。 14 | """ 15 | 16 | class Mlp(nn.Module): 17 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 18 | super().__init__() 19 | out_features = out_features or in_features 20 | hidden_features = hidden_features or in_features 21 | self.fc1 = nn.Linear(in_features, hidden_features) 22 | self.act = act_layer() 23 | self.fc2 = nn.Linear(hidden_features, out_features) 24 | self.drop = nn.Dropout(drop) 25 | 26 | def forward(self, x): 27 | x = self.fc1(x) 28 | x = self.act(x) 29 | x = self.drop(x) 30 | x = self.fc2(x) 31 | x = self.drop(x) 32 | return x 33 | 34 | 35 | class Attention(nn.Module): 36 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 37 | super().__init__() 38 | self.num_heads = num_heads 39 | self.dim = dim 40 | head_dim = dim // num_heads 41 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 42 | self.scale = qk_scale or head_dim ** -0.5 43 | 44 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 45 | self.attn_drop = nn.Dropout(attn_drop) 46 | self.proj = nn.Linear(dim, dim) 47 | self.proj_drop = nn.Dropout(proj_drop) 48 | 49 | def forward(self, x): 50 | B, N, C = x.shape 51 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 52 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 53 | 54 | attn = (q @ k.transpose(-2, -1)) * self.scale 55 | attn = attn.softmax(dim=-1) 56 | attn = self.attn_drop(attn) 57 | 58 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 59 | x = self.proj(x) 60 | x = self.proj_drop(x) 61 | return x 62 | 63 | 64 | class Block(nn.Module): 65 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 66 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 67 | super().__init__() 68 | self.norm1 = norm_layer(dim) 69 | self.attn = Attention( 70 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 71 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 72 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 73 | self.norm2 = norm_layer(dim) 74 | mlp_hidden_dim = int(dim * mlp_ratio) 75 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 76 | 77 | self.dim = dim 78 | self.mlp_hidden_dim = mlp_hidden_dim 79 | self.num_heads = num_heads 80 | 81 | def forward(self, x): 82 | x = x + self.drop_path(self.attn(self.norm1(x))) 83 | x = x + self.drop_path(self.mlp(self.norm2(x))) 84 | return x 85 | 86 | 87 | class GKONet(nn.Module): 88 | def __init__(self, 89 | num_joints=17, 90 | in_chans=(2, 5), 91 | embed_dim_pose=32, 92 | embed_dim_joint=128, 93 | depth=4, 94 | num_heads=8, 95 | mlp_ratio=2., 96 | qkv_bias=True, 97 | qk_scale=None, 98 | mlp_drop_pose=0., 99 | attn_drop_pose=0., 100 | mlp_drop_joint=0., 101 | attn_drop_joint=0., 102 | drop_path_rate=0.1, 103 | norm_layer=None): 104 | super().__init__() 105 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 106 | self.embed_dim_pose = embed_dim_pose 107 | self.embed_dim_joint = embed_dim_joint 108 | 109 | # embedding 110 | self.Pose_embedding = nn.Linear(in_chans[0], embed_dim_pose) 111 | self.Pose_embedding_position = nn.Parameter(torch.zeros(1, num_joints, embed_dim_pose)) 112 | self.Joint_pose_embedding = nn.Linear(in_chans[1] * num_joints, embed_dim_joint) 113 | self.Joint_embedding_position = nn.Parameter(torch.zeros(1, num_joints, embed_dim_joint)) 114 | self.pos_drop = nn.Dropout(p=0.) 115 | 116 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 117 | self.Pose_blocks = nn.ModuleList([ 118 | Block( 119 | dim=embed_dim_pose, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 120 | drop=mlp_drop_pose, attn_drop=attn_drop_pose, drop_path=dpr[i], norm_layer=norm_layer) 121 | for i in range(depth)]) 122 | 123 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 124 | self.Joint_blocks = nn.ModuleList([ 125 | Block( 126 | dim=embed_dim_joint, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 127 | drop=mlp_drop_joint, attn_drop=attn_drop_joint, drop_path=dpr[i], norm_layer=norm_layer) 128 | for i in range(depth)]) 129 | 130 | self.merge_pose2joint = nn.ModuleList([ 131 | nn.Linear(embed_dim_pose, embed_dim_joint) 132 | for i in range(depth)]) 133 | 134 | self.merge_joint2pose = nn.ModuleList([ 135 | nn.Linear(embed_dim_joint, embed_dim_pose) 136 | for i in range(depth)]) 137 | 138 | self.Joint_norm = norm_layer(embed_dim_joint + embed_dim_pose) 139 | self.head_joint = nn.Sequential( 140 | nn.Linear(embed_dim_joint + embed_dim_pose, 3 * num_joints), 141 | ) 142 | 143 | def forward(self, pose_2d, joint_vector_2d): 144 | # initiation 145 | b, j, _ = pose_2d.shape 146 | joint_vector_2d = joint_vector_2d.reshape(b, j, -1) 147 | 148 | # embedding 149 | pose_embedding = self.Pose_embedding(pose_2d) 150 | pose_embedding += self.Pose_embedding_position 151 | pose_embedding = self.pos_drop(pose_embedding) 152 | 153 | joint_embedding = self.Joint_pose_embedding(joint_vector_2d) 154 | joint_embedding += self.Joint_embedding_position 155 | joint_embedding = self.pos_drop(joint_embedding) 156 | 157 | # feature 158 | for blk_pose, blk_joint, blk_joint2pose, blk_pose2joint in zip(self.Pose_blocks, self.Joint_blocks, 159 | self.merge_joint2pose, self.merge_pose2joint): 160 | pose_embedding_merge = pose_embedding + blk_joint2pose(joint_embedding) 161 | joint_embedding_merge = joint_embedding + blk_pose2joint(pose_embedding) 162 | pose_embedding = blk_pose(pose_embedding_merge) 163 | joint_embedding = blk_joint(joint_embedding_merge) 164 | 165 | # head 166 | joint_embedding = torch.cat((joint_embedding, pose_embedding), dim=-1) 167 | joint_embedding = self.Joint_norm(joint_embedding) 168 | joint_embedding = self.head_joint(joint_embedding) 169 | joint_embedding = joint_embedding.view(b, j, j, -1) 170 | 171 | joint_vector_filp = -torch.transpose(joint_embedding, 1, 2) 172 | joint_vector = (joint_embedding + joint_vector_filp) / 2 173 | feasible_solution = joint_vector[:, :, :1, :] - joint_vector[:, :, :, :] 174 | final_pose_3d = torch.mean(feasible_solution, dim=1) 175 | 176 | return -joint_vector[:, 0], joint_embedding, final_pose_3d 177 | 178 | def count_flops(self): 179 | flops = 0 180 | # embedding 181 | flops += self.Pose_embedding.in_features * self.Pose_embedding.out_features * 17 182 | flops += self.Joint_pose_embedding.in_features * self.Joint_pose_embedding.out_features * 17 183 | 184 | # transformer 185 | for blk_pose in self.Pose_blocks: 186 | # qkv 187 | flops += 17 * blk_pose.dim * 3 * blk_pose.dim 188 | # attn = (q @ k.tanspose(-2, -1)) 189 | flops += blk_pose.num_heads * 17 * (blk_pose.dim // blk_pose.num_heads) * 17 190 | # x = (attn @ v) 191 | flops += blk_pose.num_heads * 17 * 17 * (blk_pose.dim // blk_pose.num_heads) 192 | # proj 193 | flops += 17 * blk_pose.dim * blk_pose.dim 194 | 195 | # mlp 196 | flops += blk_pose.mlp.fc1.in_features * blk_pose.mlp.fc1.out_features * 17 197 | flops += blk_pose.mlp.fc2.in_features * blk_pose.mlp.fc2.out_features * 17 198 | 199 | # norm 200 | flops += 17 * blk_pose.dim 201 | 202 | for blk_pose in self.Joint_blocks: 203 | # qkv 204 | flops += 17 * blk_pose.dim * 3 * blk_pose.dim 205 | # attn = (q @ k.tanspose(-2, -1)) 206 | flops += blk_pose.num_heads * 17 * (blk_pose.dim // blk_pose.num_heads) * 17 207 | # x = (attn @ v) 208 | flops += blk_pose.num_heads * 17 * 17 * (blk_pose.dim // blk_pose.num_heads) 209 | # proj 210 | flops += 17 * blk_pose.dim * blk_pose.dim 211 | 212 | # mlp 213 | flops += blk_pose.mlp.fc1.in_features * blk_pose.mlp.fc1.out_features * 17 214 | flops += blk_pose.mlp.fc2.in_features * blk_pose.mlp.fc2.out_features * 17 215 | 216 | # norm 217 | flops += 17 * blk_pose.dim 218 | 219 | # norm 220 | flops += 17 * self.embed_dim_pose 221 | flops += 17 * self.embed_dim_joint 222 | 223 | for blk_joint2pose in self.merge_joint2pose: 224 | flops += blk_joint2pose.in_features * blk_joint2pose.out_features * 17 225 | 226 | for blk_pose2joint in self.merge_pose2joint: 227 | flops += blk_pose2joint.in_features * blk_pose2joint.out_features * 17 228 | 229 | # head 230 | flops += self.head_joint[0].in_features * self.head_joint[0].out_features * 17 231 | flops += 17 * self.embed_dim_pose 232 | return flops 233 | 234 | if __name__ == '__main__': 235 | block = GKONet() 236 | 237 | # 创建一个随机输入张量 238 | # 假设每个关节有2个输入通道,共有17个关节 239 | pose_2d = torch.rand(1, 17, 2) # (batch_size, num_joints, in_chans[0]) 240 | joint_vector_2d = torch.rand(1, 17, 85) # (batch_size, num_joints, in_chans[1]) 241 | 242 | # 运行模型 243 | joint_vector, joint_embedding, final_pose_3d = block(pose_2d, joint_vector_2d) 244 | 245 | # 打印输出张量的尺寸 246 | print("joint_vector size:", joint_vector.size()) 247 | print("joint_embedding size:", joint_embedding.size()) 248 | print("final_pose_3d size:", final_pose_3d.size()) 249 | 250 | -------------------------------------------------------------------------------- /SAM-单目深度估计-特征融合-CV2维度通用.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 5 | 6 | """ 7 | 单目深度估计 (MDE) 旨在预测给定单个 RGB 图像的像素级深度。对于卷积模型和最近的基于注意力的模型,由于同时需要全局上下文和像素级分辨率,基于编码器-解码器的架构被发现非常有用。 8 | 通常,跳跃连接模块用于融合编码器和解码器特征,其中包括特征图串联和卷积运算。受到注意力在众多计算机视觉问题中所展示的好处的启发,我们提出了一种基于注意力的编码器和解码器特征融合。 9 | 我们将 MDE 视为像素查询细化问题,其中最粗级编码器特征用于初始化像素级查询,然后通过提出的跳过注意模块(SAM)将其细化为更高分辨率。 10 | """ 11 | 12 | 13 | class Mlp(nn.Module): 14 | """ Multilayer perceptron.""" 15 | 16 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 17 | super().__init__() 18 | out_features = out_features or in_features 19 | hidden_features = hidden_features or in_features 20 | self.fc1 = nn.Linear(in_features, hidden_features) 21 | self.act = act_layer() 22 | self.fc2 = nn.Linear(hidden_features, out_features) 23 | self.drop = nn.Dropout(drop) 24 | 25 | def forward(self, x): 26 | x = self.fc1(x) 27 | x = self.act(x) 28 | x = self.drop(x) 29 | x = self.fc2(x) 30 | x = self.drop(x) 31 | return x 32 | 33 | 34 | def window_partition(x, window_size): 35 | """ 36 | Args: 37 | x: (B, H, W, C) 38 | window_size (int): window size 39 | 40 | Returns: 41 | windows: (num_windows*B, window_size, window_size, C) 42 | """ 43 | B, H, W, C = x.shape 44 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 45 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 46 | return windows 47 | 48 | 49 | def window_reverse(windows, window_size, H, W): 50 | """ 51 | Args: 52 | windows: (num_windows*B, window_size, window_size, C) 53 | window_size (int): Window size 54 | H (int): Height of image 55 | W (int): Width of image 56 | 57 | Returns: 58 | x: (B, H, W, C) 59 | """ 60 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 61 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 62 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 63 | return x 64 | 65 | 66 | class WindowAttention(nn.Module): 67 | """ Window based multi-head self attention (W-MSA) module with relative position bias. 68 | It supports both of shifted and non-shifted window. 69 | 70 | Args: 71 | dim (int): Number of input channels. 72 | window_size (tuple[int]): The height and width of the window. 73 | num_heads (int): Number of attention heads. 74 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 75 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 76 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 77 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 78 | """ 79 | 80 | def __init__(self, dim, window_size, num_heads, v_dim, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 81 | 82 | super().__init__() 83 | self.dim = dim 84 | self.window_size = window_size # Wh, Ww 85 | self.num_heads = num_heads 86 | head_dim = dim // num_heads 87 | self.scale = qk_scale or head_dim ** -0.5 88 | 89 | # define a parameter table of relative position bias 90 | self.relative_position_bias_table = nn.Parameter( 91 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 92 | 93 | # get pair-wise relative position index for each token inside the window 94 | coords_h = torch.arange(self.window_size[0]) 95 | coords_w = torch.arange(self.window_size[1]) 96 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 97 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 98 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 99 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 100 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 101 | relative_coords[:, :, 1] += self.window_size[1] - 1 102 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 103 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 104 | self.register_buffer("relative_position_index", relative_position_index) 105 | 106 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 107 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 108 | self.attn_drop = nn.Dropout(attn_drop) 109 | self.proj = nn.Linear(v_dim, v_dim) 110 | self.proj_drop = nn.Dropout(proj_drop) 111 | 112 | trunc_normal_(self.relative_position_bias_table, std=.02) 113 | self.softmax = nn.Softmax(dim=-1) 114 | 115 | def forward(self, x, v, mask=None): 116 | """ Forward function. 117 | 118 | Args: 119 | x: input features with shape of (num_windows*B, N, C) 120 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 121 | """ 122 | B_, N, C = x.shape 123 | q = self.q(x).view(B_, N, self.num_heads, -1).transpose(1, 2) 124 | kv = self.kv(v).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 125 | k, v = kv[0], kv[1] # make torchscript happy (cannot use tensor as tuple) 126 | 127 | q = q * self.scale 128 | attn = (q @ k.transpose(-2, -1)) 129 | 130 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 131 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 132 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 133 | attn = attn + relative_position_bias.unsqueeze(0) 134 | 135 | if mask is not None: 136 | nW = mask.shape[0] 137 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 138 | attn = attn.view(-1, self.num_heads, N, N) 139 | attn = self.softmax(attn) 140 | else: 141 | attn = self.softmax(attn) 142 | 143 | attn = self.attn_drop(attn) 144 | 145 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 146 | x = self.proj(x) 147 | x = self.proj_drop(x) 148 | return x 149 | 150 | 151 | class SAMBLOCK(nn.Module): 152 | """ 153 | Args: 154 | dim (int): Number of feature channels 155 | num_heads (int): Number of attention head. 156 | window_size (int): Local window size. Default: 7. 157 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. 158 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 159 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 160 | drop (float, optional): Dropout rate. Default: 0.0 161 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 162 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 163 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 164 | """ 165 | 166 | def __init__(self, 167 | dim, 168 | num_heads, 169 | v_dim, 170 | window_size=7, 171 | mlp_ratio=4., 172 | qkv_bias=True, 173 | qk_scale=None, 174 | drop=0., 175 | attn_drop=0., 176 | drop_path=0., 177 | norm_layer=nn.LayerNorm, 178 | ): 179 | super().__init__() 180 | self.window_size = window_size 181 | self.dim = dim 182 | self.num_heads = num_heads 183 | self.v_dim = v_dim 184 | self.window_size = window_size 185 | self.mlp_ratio = mlp_ratio 186 | act_layer = nn.GELU 187 | norm_layer = nn.LayerNorm 188 | 189 | self.norm1 = norm_layer(dim) 190 | self.normv = norm_layer(dim) 191 | self.attn = WindowAttention( 192 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, v_dim=v_dim, 193 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 194 | 195 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 196 | self.norm2 = norm_layer(v_dim) 197 | mlp_hidden_dim = int(v_dim * mlp_ratio) 198 | self.mlp = Mlp(in_features=v_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 199 | 200 | def forward(self, x, v, H, W): 201 | """ Forward function. 202 | 203 | Args: 204 | x: Input feature, tensor size (B, H*W, C). 205 | H, W: Spatial resolution of the input feature. 206 | """ 207 | 208 | B, L, C = x.shape 209 | assert L == H * W, "input feature has wrong size" 210 | 211 | shortcut = x 212 | x = self.norm1(x) 213 | x = x.view(B, H, W, C) 214 | 215 | shortcut_v = v 216 | v = self.normv(v) 217 | v = v.view(B, H, W, C) 218 | 219 | # pad feature maps to multiples of window size 220 | pad_l = pad_t = 0 221 | pad_r = (self.window_size - W % self.window_size) % self.window_size 222 | pad_b = (self.window_size - H % self.window_size) % self.window_size 223 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) 224 | v = F.pad(v, (0, 0, pad_l, pad_r, pad_t, pad_b)) 225 | _, Hp, Wp, _ = x.shape 226 | 227 | # partition windows 228 | x_windows = window_partition(x, self.window_size) # nW*B, window_size, window_size, C 229 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 230 | v_windows = window_partition(v, self.window_size) # nW*B, window_size, window_size, C 231 | v_windows = v_windows.view(-1, self.window_size * self.window_size, 232 | v_windows.shape[-1]) # nW*B, window_size*window_size, C 233 | 234 | # W-MSA/SW-MSA 235 | attn_windows = self.attn(x_windows, v_windows, mask=None) # nW*B, window_size*window_size, C 236 | 237 | # merge windows 238 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.v_dim) 239 | x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C 240 | 241 | if pad_r > 0 or pad_b > 0: 242 | x = x[:, :H, :W, :].contiguous() 243 | 244 | x = x.view(B, H * W, self.v_dim) 245 | 246 | # FFN 247 | x = self.drop_path(x) + shortcut 248 | x = x + self.drop_path(self.mlp(self.norm2(x))) 249 | 250 | return x, H, W 251 | 252 | 253 | class SAM(nn.Module): 254 | def __init__(self, 255 | input_dim=96, 256 | embed_dim=96, 257 | v_dim=64, 258 | window_size=7, 259 | num_heads=4, 260 | patch_size=4, 261 | in_chans=3, 262 | norm_layer=nn.LayerNorm, 263 | patch_norm=True): 264 | super().__init__() 265 | 266 | self.embed_dim = embed_dim 267 | 268 | if input_dim != embed_dim: 269 | self.proj_e = nn.Conv2d(input_dim, embed_dim, 3, padding=1) 270 | else: 271 | self.proj_e = None 272 | 273 | if v_dim != embed_dim: 274 | self.proj_q = nn.Conv2d(v_dim, embed_dim, 3, padding=1) 275 | elif embed_dim % v_dim == 0: 276 | self.proj_q = None 277 | self.proj = nn.Conv2d(embed_dim, embed_dim, 3, padding=1) 278 | 279 | v_dim = embed_dim 280 | self.sam_block = SAMBLOCK( 281 | dim=embed_dim, 282 | num_heads=num_heads, 283 | v_dim=v_dim, 284 | window_size=window_size, 285 | mlp_ratio=4., 286 | qkv_bias=True, 287 | qk_scale=None, 288 | drop=0., 289 | attn_drop=0., 290 | drop_path=0., 291 | norm_layer=norm_layer) 292 | 293 | layer = norm_layer(embed_dim) 294 | layer_name = 'norm_sam' 295 | self.add_module(layer_name, layer) 296 | 297 | def forward(self, e, q): 298 | if self.proj_q is not None: 299 | q = self.proj_q(q) 300 | if self.proj_e is not None: 301 | e = self.proj_e(e) 302 | e_proj = e 303 | q_proj = q 304 | 305 | Wh, Ww = q.size(2), q.size(3) 306 | q = q.flatten(2).transpose(1, 2) 307 | e = e.flatten(2).transpose(1, 2) 308 | 309 | q_out, H, W = self.sam_block(q, e, Wh, Ww) 310 | norm_layer = getattr(self, f'norm_sam') 311 | q_out = norm_layer(q_out) 312 | q_out = q_out.view(-1, H, W, self.embed_dim).permute(0, 3, 1, 2).contiguous() 313 | 314 | return q_out + e_proj + q_proj 315 | 316 | 317 | 318 | if __name__ == '__main__': 319 | # SAM 模块的参数 320 | input_dim = 96 # `e` 张量的输入通道 321 | embed_dim = 96 # 内部处理的嵌入维度 322 | v_dim = 96 # `q` 张量的维度 323 | window_size = 7 # 自注意力块的窗口大小 324 | num_heads = 4 # 注意力头的数量 325 | patch_size = 4 # 并未在 SAM 类中直接使用 326 | in_chans = 3 # 图像片段的输入通道,这里没有直接使用 327 | 328 | # 创建一个 SAM 模型 329 | model = SAM(input_dim=input_dim, embed_dim=embed_dim, v_dim=v_dim, window_size=window_size, num_heads=num_heads, patch_size=patch_size, in_chans=in_chans) 330 | 331 | # 模拟输入张量 332 | B, H, W = 2, 128, 128 # 批量大小和空间维度 333 | e = torch.rand(B, input_dim, H, W) # 'e' 输入张量 334 | q = torch.rand(B, v_dim, H, W) # 'q' 输入张量 335 | 336 | # 前向传播 337 | output = model(e, q) 338 | 339 | # 打印输出大小 340 | print("Input 'e' size:", e.size()) 341 | print("Input 'q' size:", q.size()) 342 | print("Output size: ", output.size()) -------------------------------------------------------------------------------- /sLSTM&mLSTM(NLP和时序任务).py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Tuple, Optional, List 4 | 5 | """xLSTM:扩展长短期记忆 6 | 20 世纪 90 年代,恒定误差轮播和门控被引入作为长短期记忆 (LSTM) 的核心思想。从那时起,LSTM 经受住了时间的考验,并为众多深度学习成功案例做出了贡献,特别是它们构成了第一个大型语言模型 (LLM)。 7 | 然而,以可并行自注意力为核心的 Transformer 技术的出现标志着一个新时代的到来,其规模超过了 LSTM。 8 | 我们现在提出一个简单的问题:当将 LSTM 扩展到数十亿个参数,利用现代 LLM 的最新技术,但减轻 LSTM 已知的局限性时,我们在语言建模方面能走多远? 9 | 首先,我们引入指数门控,并采用适当的规范化和稳定技术。其次,我们修改 LSTM 内存结构,获得:(i) 具有标量内存、标量更新和新内存混合的 sLSTM,(ii) 具有矩阵内存和协方差更新规则的完全可并行的 mLSTM。 10 | 将这些 LSTM 扩展集成到残差块主干中会产生 xLSTM 块,然后将其残差堆叠到 xLSTM 架构中。 11 | 指数门控和修改后的内存结构增强了 xLSTM 功能,与最先进的 Transformers 和状态空间模型相比,无论是在性能还是扩展方面,都表现出色。 12 | """ 13 | 14 | 15 | class sLSTMCell(nn.Module): 16 | def __init__(self, input_size: int, hidden_size: int, bias: bool = True) -> None: 17 | super().__init__() 18 | 19 | # Store the input and hidden size 20 | self.input_size = input_size 21 | self.hidden_size = hidden_size 22 | self.bias = bias 23 | 24 | # Combine the Weights and Recurrent weights into a single matrix 25 | self.W = nn.Parameter( 26 | nn.init.xavier_uniform_( 27 | torch.randn(self.input_size + self.hidden_size, 4 * self.hidden_size) 28 | ), 29 | requires_grad=True, 30 | ) 31 | # Combine the Bias into a single matrix 32 | if self.bias: 33 | self.B = nn.Parameter( 34 | (torch.zeros(4 * self.hidden_size)), requires_grad=True 35 | ) 36 | 37 | def forward( 38 | self, 39 | x: torch.Tensor, 40 | internal_state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], 41 | ) -> Tuple[ 42 | torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] 43 | ]: 44 | # Unpack the internal state 45 | h, c, n, m = internal_state # (batch_size, hidden_size) 46 | 47 | # Combine the weights and the input 48 | combined = torch.cat((x, h), dim=1) # (batch_size, input_size + hidden_size) 49 | # Calculate the linear transformation 50 | gates = torch.matmul(combined, self.W) # (batch_size, 4 * hidden_size) 51 | 52 | # Add the bias if included 53 | if self.bias: 54 | gates += self.B 55 | 56 | # Split the gates into the input, forget, output and stabilization gates 57 | z_tilda, i_tilda, f_tilda, o_tilda = torch.split(gates, self.hidden_size, dim=1) 58 | 59 | # Calculate the activation of the states 60 | z_t = torch.tanh(z_tilda) # (batch_size, hidden_size) 61 | # Exponential activation of the input gate 62 | i_t = torch.exp(i_tilda) # (batch_size, hidden_size) 63 | # Exponential activation of the forget gate 64 | f_t = torch.sigmoid(f_tilda) # (batch_size, hidden_size) 65 | 66 | # Sigmoid activation of the output gate 67 | o_t = torch.sigmoid(o_tilda) # (batch_size, input_size) 68 | # Calculate the stabilization state 69 | m_t = torch.max(torch.log(f_t) + m, torch.log(i_t)) # (batch_size, hidden_size) 70 | # Calculate the input stabilization state 71 | i_prime = torch.exp(i_tilda - m_t) # (batch_size, hidden_size) 72 | 73 | # Calculate the new internal states 74 | c_t = f_t * c + i_prime * z_t # (batch_size, hidden_size) 75 | n_t = f_t * n + i_prime # (batch_size, hidden_size) 76 | 77 | # Calculate the stabilized hidden state 78 | h_tilda = c_t / n_t # (batch_size, hidden_size) 79 | 80 | # Calculate the new hidden state 81 | h_t = o_t * h_tilda # (batch_size, hidden_size) 82 | return h_t, ( 83 | h_t, 84 | c_t, 85 | n_t, 86 | m_t, 87 | ) # (batch_size, hidden_size), (batch_size, hidden_size), (batch_size, hidden_size), (batch_size, hidden_size) 88 | 89 | def init_hidden( 90 | self, batch_size: int, **kwargs 91 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 92 | return ( 93 | torch.zeros(batch_size, self.hidden_size, **kwargs), 94 | torch.zeros(batch_size, self.hidden_size, **kwargs), 95 | torch.zeros(batch_size, self.hidden_size, **kwargs), 96 | torch.zeros(batch_size, self.hidden_size, **kwargs), 97 | ) 98 | 99 | 100 | class sLSTM(nn.Module): 101 | def __init__( 102 | self, 103 | input_size: int, 104 | hidden_size: int, 105 | num_layers: int, 106 | bias: bool = True, 107 | batch_first: bool = False, 108 | ) -> None: 109 | super().__init__() 110 | self.input_size = input_size 111 | self.hidden_size = hidden_size 112 | self.num_layers = num_layers 113 | self.bias = bias 114 | self.batch_first = batch_first 115 | 116 | self.cells = nn.ModuleList( 117 | [ 118 | sLSTMCell(input_size if layer == 0 else hidden_size, hidden_size, bias) 119 | for layer in range(num_layers) 120 | ] 121 | ) 122 | 123 | def forward( 124 | self, 125 | x: torch.Tensor, 126 | hidden_states: Optional[ 127 | List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] 128 | ] = None, 129 | ) -> Tuple[ 130 | torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] 131 | ]: 132 | # Permute the input tensor if batch_first is True 133 | if self.batch_first: 134 | x = x.permute(1, 0, 2) 135 | 136 | # Initialize the hidden states if not provided 137 | if hidden_states is None: 138 | hidden_states = self.init_hidden(x.size(1), device=x.device, dtype=x.dtype) 139 | else: 140 | # Check if the hidden states are of the correct length 141 | if len(hidden_states) != self.num_layers: 142 | raise ValueError( 143 | f"Expected hidden states of length {self.num_layers}, but got {len(hidden_states)}" 144 | ) 145 | if any(state[0].size(0) != x.size(1) for state in hidden_states): 146 | raise ValueError( 147 | f"Expected hidden states of batch size {x.size(1)}, but got {hidden_states[0][0].size(0)}" 148 | ) 149 | 150 | H, C, N, M = [], [], [], [] 151 | 152 | for layer, cell in enumerate(self.cells): 153 | lh, lc, ln, lm = [], [], [], [] 154 | for t in range(x.size(0)): 155 | h_t, hidden_states[layer] = ( 156 | cell(x[t], hidden_states[layer]) 157 | if layer == 0 158 | else cell(H[layer - 1][t], hidden_states[layer]) 159 | ) 160 | lh.append(h_t) 161 | lc.append(hidden_states[layer][0]) 162 | ln.append(hidden_states[layer][1]) 163 | lm.append(hidden_states[layer][2]) 164 | 165 | H.append(torch.stack(lh, dim=0)) 166 | C.append(torch.stack(lc, dim=0)) 167 | N.append(torch.stack(ln, dim=0)) 168 | M.append(torch.stack(lm, dim=0)) 169 | 170 | H = torch.stack(H, dim=0) 171 | C = torch.stack(C, dim=0) 172 | N = torch.stack(N, dim=0) 173 | M = torch.stack(M, dim=0) 174 | 175 | return H[-1], (H, C, N, M) 176 | 177 | def init_hidden( 178 | self, batch_size: int, **kwargs 179 | ) -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: 180 | 181 | return [cell.init_hidden(batch_size, **kwargs) for cell in self.cells] 182 | 183 | 184 | 185 | if __name__ == '__main__': 186 | # 定义输入张量的参数 187 | input_size = 128 188 | hidden_size = 128 189 | num_layers = 2 190 | seq_length = 10 191 | batch_size = 32 192 | dropout = 0.1 193 | 194 | # 初始化 mLSTM 模块 195 | block = sLSTM(input_size, hidden_size, num_layers) 196 | 197 | # 随机生成输入张量 198 | input_seq = torch.rand(batch_size, seq_length, input_size) 199 | 200 | # 运行前向传递 201 | output, hidden_state = block(input_seq) 202 | 203 | # 输出输入张量和输出张量的形状 204 | print(" sLSTM.Input size:", input_seq.size()) 205 | print("sLSTM.Output size:", output.size()) 206 | 207 | 208 | class mLSTMCell(nn.Module): 209 | def __init__(self, input_size: int, hidden_size: int, bias: bool = True) -> None: 210 | 211 | super().__init__() 212 | 213 | self.input_size = input_size 214 | self.hidden_size = hidden_size 215 | self.bias = bias 216 | 217 | # Initialize weights and biases 218 | self.W_i = nn.Parameter( 219 | nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)), 220 | requires_grad=True, 221 | ) 222 | self.W_f = nn.Parameter( 223 | nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)), 224 | requires_grad=True, 225 | ) 226 | self.W_o = nn.Parameter( 227 | nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)), 228 | requires_grad=True, 229 | ) 230 | self.W_q = nn.Parameter( 231 | nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)), 232 | requires_grad=True, 233 | ) 234 | self.W_k = nn.Parameter( 235 | nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)), 236 | requires_grad=True, 237 | ) 238 | self.W_v = nn.Parameter( 239 | nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)), 240 | requires_grad=True, 241 | ) 242 | 243 | if self.bias: 244 | self.B_i = nn.Parameter(torch.zeros(hidden_size), requires_grad=True) 245 | self.B_f = nn.Parameter(torch.zeros(hidden_size), requires_grad=True) 246 | self.B_o = nn.Parameter(torch.zeros(hidden_size), requires_grad=True) 247 | self.B_q = nn.Parameter(torch.zeros(hidden_size), requires_grad=True) 248 | self.B_k = nn.Parameter(torch.zeros(hidden_size), requires_grad=True) 249 | self.B_v = nn.Parameter(torch.zeros(hidden_size), requires_grad=True) 250 | 251 | def forward( 252 | self, 253 | x: torch.Tensor, 254 | internal_state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], 255 | ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: 256 | # Get the internal state 257 | C, n, m = internal_state 258 | 259 | # Calculate the input, forget, output, query, key and value gates 260 | i_tilda = ( 261 | torch.matmul(x, self.W_i) + self.B_i 262 | if self.bias 263 | else torch.matmul(x, self.W_i) 264 | ) 265 | f_tilda = ( 266 | torch.matmul(x, self.W_f) + self.B_f 267 | if self.bias 268 | else torch.matmul(x, self.W_f) 269 | ) 270 | o_tilda = ( 271 | torch.matmul(x, self.W_o) + self.B_o 272 | if self.bias 273 | else torch.matmul(x, self.W_o) 274 | ) 275 | q_t = ( 276 | torch.matmul(x, self.W_q) + self.B_q 277 | if self.bias 278 | else torch.matmul(x, self.W_q) 279 | ) 280 | k_t = ( 281 | torch.matmul(x, self.W_k) / torch.sqrt(torch.tensor(self.hidden_size)) 282 | + self.B_k 283 | if self.bias 284 | else torch.matmul(x, self.W_k) / torch.sqrt(torch.tensor(self.hidden_size)) 285 | ) 286 | v_t = ( 287 | torch.matmul(x, self.W_v) + self.B_v 288 | if self.bias 289 | else torch.matmul(x, self.W_v) 290 | ) 291 | 292 | # Exponential activation of the input gate 293 | i_t = torch.exp(i_tilda) 294 | f_t = torch.sigmoid(f_tilda) 295 | o_t = torch.sigmoid(o_tilda) 296 | 297 | # Stabilization state 298 | m_t = torch.max(torch.log(f_t) + m, torch.log(i_t)) 299 | i_prime = torch.exp(i_tilda - m_t) 300 | 301 | # Covarieance matrix and normalization state 302 | C_t = f_t.unsqueeze(-1) * C + i_prime.unsqueeze(-1) * torch.einsum( 303 | "bi, bk -> bik", v_t, k_t 304 | ) 305 | n_t = f_t * n + i_prime * k_t 306 | 307 | normalize_inner = torch.diagonal(torch.matmul(n_t, q_t.T)) 308 | divisor = torch.max( 309 | torch.abs(normalize_inner), torch.ones_like(normalize_inner) 310 | ) 311 | h_tilda = torch.einsum("bkj,bj -> bk", C_t, q_t) / divisor.view(-1, 1) 312 | h_t = o_t * h_tilda 313 | 314 | return h_t, (C_t, n_t, m_t) 315 | 316 | def init_hidden( 317 | self, batch_size: int, **kwargs 318 | ) -> Tuple[torch.Tensor, torch.Tensor]: 319 | return ( 320 | torch.zeros(batch_size, self.hidden_size, self.hidden_size, **kwargs), 321 | torch.zeros(batch_size, self.hidden_size, **kwargs), 322 | torch.zeros(batch_size, self.hidden_size, **kwargs), 323 | ) 324 | 325 | 326 | class mLSTM(nn.Module): 327 | def __init__( 328 | self, 329 | input_size: int, 330 | hidden_size: int, 331 | num_layers: int, 332 | bias: bool = True, 333 | batch_first: bool = False, 334 | ) -> None: 335 | super().__init__() 336 | self.input_size = input_size 337 | self.hidden_size = hidden_size 338 | self.num_layers = num_layers 339 | self.bias = bias 340 | self.batch_first = batch_first 341 | 342 | self.cells = nn.ModuleList( 343 | [ 344 | mLSTMCell(input_size if layer == 0 else hidden_size, hidden_size, bias) 345 | for layer in range(num_layers) 346 | ] 347 | ) 348 | 349 | def forward( 350 | self, 351 | x: torch.Tensor, 352 | hidden_states: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, 353 | ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 354 | # Permute the input tensor if batch_first is True 355 | if self.batch_first: 356 | x = x.permute(1, 0, 2) 357 | 358 | if hidden_states is None: 359 | hidden_states = self.init_hidden(x.size(1), device=x.device, dtype=x.dtype) 360 | else: 361 | # Check if the hidden states are of the correct length 362 | if len(hidden_states) != self.num_layers: 363 | raise ValueError( 364 | f"Expected hidden states of length {self.num_layers}, but got {len(hidden_states)}" 365 | ) 366 | if any(state[0].size(0) != x.size(1) for state in hidden_states): 367 | raise ValueError( 368 | f"Expected hidden states of batch size {x.size(1)}, but got {hidden_states[0][0].size(0)}" 369 | ) 370 | 371 | H, C, N, M = [], [], [], [] 372 | 373 | for layer, cell in enumerate(self.cells): 374 | lh, lc, ln, lm = [], [], [], [] 375 | for t in range(x.size(0)): 376 | h_t, hidden_states[layer] = ( 377 | cell(x[t], hidden_states[layer]) 378 | if layer == 0 379 | else cell(H[layer - 1][t], hidden_states[layer]) 380 | ) 381 | lh.append(h_t) 382 | lc.append(hidden_states[layer][0]) 383 | ln.append(hidden_states[layer][1]) 384 | lm.append(hidden_states[layer][2]) 385 | 386 | H.append(torch.stack(lh, dim=0)) 387 | C.append(torch.stack(lc, dim=0)) 388 | N.append(torch.stack(ln, dim=0)) 389 | M.append(torch.stack(lm, dim=0)) 390 | 391 | H = torch.stack(H, dim=0) 392 | C = torch.stack(C, dim=0) 393 | N = torch.stack(N, dim=0) 394 | M = torch.stack(M, dim=0) 395 | 396 | return H[-1], (H, C, N, M) 397 | 398 | def init_hidden( 399 | self, batch_size: int, **kwargs 400 | ) -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: 401 | return [cell.init_hidden(batch_size, **kwargs) for cell in self.cells] 402 | 403 | 404 | if __name__ == '__main__': 405 | # 定义输入张量的参数 406 | input_size = 128 407 | hidden_size = 128 408 | num_layers = 2 409 | seq_length = 10 410 | batch_size = 32 411 | dropout = 0.1 412 | 413 | # 初始化 mLSTM 模块 414 | block = mLSTM(input_size, hidden_size, num_layers) 415 | 416 | # 随机生成输入张量 417 | input_seq = torch.rand(batch_size, seq_length, input_size) 418 | 419 | # 运行前向传递 420 | output, hidden_state = block(input_seq) 421 | 422 | # 输出输入张量和输出张量的形状 423 | print(" mLSTM.nput size:", input_seq.size()) 424 | print("mLSTM.Output size:", output.size()) -------------------------------------------------------------------------------- /MambaIR(CV二维图像).py: -------------------------------------------------------------------------------- 1 | # Code Implementation of the MambaIR Model 2 | import warnings 3 | warnings.filterwarnings("ignore") 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from functools import partial 9 | from typing import Optional, Callable 10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 11 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref 12 | from einops import rearrange, repeat 13 | 14 | 15 | """ 16 | 最近,选择性结构化状态空间模型,特别是改进版本的Mamba,在线性复杂度的远程依赖建模方面表现出了巨大的潜力。 17 | 然而,标准Mamba在低级视觉方面仍然面临一定的挑战,例如局部像素遗忘和通道冗余。在这项工作中,我们引入了局部增强和通道注意力来改进普通 Mamba。 18 | 通过这种方式,我们利用了局部像素相似性并减少了通道冗余。大量的实验证明了我们方法的优越性。 19 | """ 20 | 21 | 22 | NEG_INF = -1000000 23 | 24 | 25 | class ChannelAttention(nn.Module): 26 | """Channel attention used in RCAN. 27 | Args: 28 | num_feat (int): Channel number of intermediate features. 29 | squeeze_factor (int): Channel squeeze factor. Default: 16. 30 | """ 31 | 32 | def __init__(self, num_feat, squeeze_factor=16): 33 | super(ChannelAttention, self).__init__() 34 | self.attention = nn.Sequential( 35 | nn.AdaptiveAvgPool2d(1), 36 | nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0), 37 | nn.ReLU(inplace=True), 38 | nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), 39 | nn.Sigmoid()) 40 | 41 | def forward(self, x): 42 | y = self.attention(x) 43 | return x * y 44 | 45 | 46 | class CAB(nn.Module): 47 | def __init__(self, num_feat, is_light_sr= False, compress_ratio=3,squeeze_factor=30): 48 | super(CAB, self).__init__() 49 | if is_light_sr: # we use depth-wise conv for light-SR to achieve more efficient 50 | self.cab = nn.Sequential( 51 | nn.Conv2d(num_feat, num_feat, 3, 1, 1, groups=num_feat), 52 | ChannelAttention(num_feat, squeeze_factor) 53 | ) 54 | else: # for classic SR 55 | self.cab = nn.Sequential( 56 | nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1), 57 | nn.GELU(), 58 | nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1), 59 | ChannelAttention(num_feat, squeeze_factor) 60 | ) 61 | 62 | def forward(self, x): 63 | return self.cab(x) 64 | 65 | 66 | class Mlp(nn.Module): 67 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 68 | super().__init__() 69 | out_features = out_features or in_features 70 | hidden_features = hidden_features or in_features 71 | self.fc1 = nn.Linear(in_features, hidden_features) 72 | self.act = act_layer() 73 | self.fc2 = nn.Linear(hidden_features, out_features) 74 | self.drop = nn.Dropout(drop) 75 | 76 | def forward(self, x): 77 | x = self.fc1(x) 78 | x = self.act(x) 79 | x = self.drop(x) 80 | x = self.fc2(x) 81 | x = self.drop(x) 82 | return x 83 | 84 | 85 | class DynamicPosBias(nn.Module): 86 | def __init__(self, dim, num_heads): 87 | super().__init__() 88 | self.num_heads = num_heads 89 | self.pos_dim = dim // 4 90 | self.pos_proj = nn.Linear(2, self.pos_dim) 91 | self.pos1 = nn.Sequential( 92 | nn.LayerNorm(self.pos_dim), 93 | nn.ReLU(inplace=True), 94 | nn.Linear(self.pos_dim, self.pos_dim), 95 | ) 96 | self.pos2 = nn.Sequential( 97 | nn.LayerNorm(self.pos_dim), 98 | nn.ReLU(inplace=True), 99 | nn.Linear(self.pos_dim, self.pos_dim) 100 | ) 101 | self.pos3 = nn.Sequential( 102 | nn.LayerNorm(self.pos_dim), 103 | nn.ReLU(inplace=True), 104 | nn.Linear(self.pos_dim, self.num_heads) 105 | ) 106 | 107 | def forward(self, biases): 108 | pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) 109 | return pos 110 | 111 | def flops(self, N): 112 | flops = N * 2 * self.pos_dim 113 | flops += N * self.pos_dim * self.pos_dim 114 | flops += N * self.pos_dim * self.pos_dim 115 | flops += N * self.pos_dim * self.num_heads 116 | return flops 117 | 118 | 119 | class Attention(nn.Module): 120 | r""" Multi-head self attention module with dynamic position bias. 121 | 122 | Args: 123 | dim (int): Number of input channels. 124 | num_heads (int): Number of attention heads. 125 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 126 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 127 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 128 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 129 | """ 130 | 131 | def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., 132 | position_bias=True): 133 | 134 | super().__init__() 135 | self.dim = dim 136 | self.num_heads = num_heads 137 | head_dim = dim // num_heads 138 | self.scale = qk_scale or head_dim ** -0.5 139 | self.position_bias = position_bias 140 | if self.position_bias: 141 | self.pos = DynamicPosBias(self.dim // 4, self.num_heads) 142 | 143 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 144 | self.attn_drop = nn.Dropout(attn_drop) 145 | self.proj = nn.Linear(dim, dim) 146 | self.proj_drop = nn.Dropout(proj_drop) 147 | 148 | self.softmax = nn.Softmax(dim=-1) 149 | 150 | def forward(self, x, H, W, mask=None): 151 | """ 152 | Args: 153 | x: input features with shape of (num_groups*B, N, C) 154 | mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None 155 | H: height of each group 156 | W: width of each group 157 | """ 158 | group_size = (H, W) 159 | B_, N, C = x.shape 160 | assert H * W == N 161 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() 162 | q, k, v = qkv[0], qkv[1], qkv[2] 163 | 164 | q = q * self.scale 165 | attn = (q @ k.transpose(-2, -1)) # (B_, self.num_heads, N, N), N = H*W 166 | 167 | if self.position_bias: 168 | # generate mother-set 169 | position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) 170 | position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) 171 | biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 172 | biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 173 | 174 | # get pair-wise relative position index for each token inside the window 175 | coords_h = torch.arange(group_size[0], device=attn.device) 176 | coords_w = torch.arange(group_size[1], device=attn.device) 177 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw 178 | coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw 179 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw 180 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 181 | relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 182 | relative_coords[:, :, 1] += group_size[1] - 1 183 | relative_coords[:, :, 0] *= 2 * group_size[1] - 1 184 | relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw 185 | 186 | pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads 187 | # select position bias 188 | relative_position_bias = pos[relative_position_index.view(-1)].view( 189 | group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH 190 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw 191 | attn = attn + relative_position_bias.unsqueeze(0) 192 | 193 | if mask is not None: 194 | nP = mask.shape[0] 195 | attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( 196 | 0) # (B, nP, nHead, N, N) 197 | attn = attn.view(-1, self.num_heads, N, N) 198 | attn = self.softmax(attn) 199 | else: 200 | attn = self.softmax(attn) 201 | 202 | attn = self.attn_drop(attn) 203 | 204 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 205 | x = self.proj(x) 206 | x = self.proj_drop(x) 207 | return x 208 | 209 | 210 | class SS2D(nn.Module): 211 | def __init__( 212 | self, 213 | d_model, 214 | d_state=16, 215 | d_conv=3, 216 | expand=2., 217 | dt_rank="auto", 218 | dt_min=0.001, 219 | dt_max=0.1, 220 | dt_init="random", 221 | dt_scale=1.0, 222 | dt_init_floor=1e-4, 223 | dropout=0., 224 | conv_bias=True, 225 | bias=False, 226 | device=None, 227 | dtype=None, 228 | **kwargs, 229 | ): 230 | factory_kwargs = {"device": device, "dtype": dtype} 231 | super().__init__() 232 | self.d_model = d_model 233 | self.d_state = d_state 234 | self.d_conv = d_conv 235 | self.expand = expand 236 | self.d_inner = int(self.expand * self.d_model) 237 | self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank 238 | 239 | self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) 240 | self.conv2d = nn.Conv2d( 241 | in_channels=self.d_inner, 242 | out_channels=self.d_inner, 243 | groups=self.d_inner, 244 | bias=conv_bias, 245 | kernel_size=d_conv, 246 | padding=(d_conv - 1) // 2, 247 | **factory_kwargs, 248 | ) 249 | self.act = nn.SiLU() 250 | 251 | self.x_proj = ( 252 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 253 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 254 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 255 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 256 | ) 257 | self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner) 258 | del self.x_proj 259 | 260 | self.dt_projs = ( 261 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, 262 | **factory_kwargs), 263 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, 264 | **factory_kwargs), 265 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, 266 | **factory_kwargs), 267 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, 268 | **factory_kwargs), 269 | ) 270 | self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank) 271 | self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner) 272 | del self.dt_projs 273 | 274 | self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N) 275 | self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N) 276 | 277 | self.selective_scan = selective_scan_fn 278 | 279 | self.out_norm = nn.LayerNorm(self.d_inner) 280 | self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) 281 | self.dropout = nn.Dropout(dropout) if dropout > 0. else None 282 | 283 | @staticmethod 284 | def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, 285 | **factory_kwargs): 286 | dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) 287 | 288 | # Initialize special dt projection to preserve variance at initialization 289 | dt_init_std = dt_rank ** -0.5 * dt_scale 290 | if dt_init == "constant": 291 | nn.init.constant_(dt_proj.weight, dt_init_std) 292 | elif dt_init == "random": 293 | nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) 294 | else: 295 | raise NotImplementedError 296 | 297 | # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max 298 | dt = torch.exp( 299 | torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) 300 | + math.log(dt_min) 301 | ).clamp(min=dt_init_floor) 302 | # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 303 | inv_dt = dt + torch.log(-torch.expm1(-dt)) 304 | with torch.no_grad(): 305 | dt_proj.bias.copy_(inv_dt) 306 | # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit 307 | dt_proj.bias._no_reinit = True 308 | 309 | return dt_proj 310 | 311 | @staticmethod 312 | def A_log_init(d_state, d_inner, copies=1, device=None, merge=True): 313 | # S4D real initialization 314 | A = repeat( 315 | torch.arange(1, d_state + 1, dtype=torch.float32, device=device), 316 | "n -> d n", 317 | d=d_inner, 318 | ).contiguous() 319 | A_log = torch.log(A) # Keep A_log in fp32 320 | if copies > 1: 321 | A_log = repeat(A_log, "d n -> r d n", r=copies) 322 | if merge: 323 | A_log = A_log.flatten(0, 1) 324 | A_log = nn.Parameter(A_log) 325 | A_log._no_weight_decay = True 326 | return A_log 327 | 328 | @staticmethod 329 | def D_init(d_inner, copies=1, device=None, merge=True): 330 | # D "skip" parameter 331 | D = torch.ones(d_inner, device=device) 332 | if copies > 1: 333 | D = repeat(D, "n1 -> r n1", r=copies) 334 | if merge: 335 | D = D.flatten(0, 1) 336 | D = nn.Parameter(D) # Keep in fp32 337 | D._no_weight_decay = True 338 | return D 339 | 340 | def forward_core(self, x: torch.Tensor): 341 | B, C, H, W = x.shape 342 | L = H * W 343 | K = 4 344 | x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) 345 | xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (1, 4, 192, 3136) 346 | 347 | x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight) 348 | dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) 349 | dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight) 350 | xs = xs.float().view(B, -1, L) 351 | dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) 352 | Bs = Bs.float().view(B, K, -1, L) 353 | Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l) 354 | Ds = self.Ds.float().view(-1) 355 | As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) 356 | dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) 357 | out_y = self.selective_scan( 358 | xs, dts, 359 | As, Bs, Cs, Ds, z=None, 360 | delta_bias=dt_projs_bias, 361 | delta_softplus=True, 362 | return_last_state=False, 363 | ).view(B, K, -1, L) 364 | assert out_y.dtype == torch.float 365 | 366 | inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) 367 | wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 368 | invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 369 | 370 | return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y 371 | 372 | def forward(self, x: torch.Tensor, **kwargs): 373 | B, H, W, C = x.shape 374 | 375 | xz = self.in_proj(x) 376 | x, z = xz.chunk(2, dim=-1) 377 | 378 | x = x.permute(0, 3, 1, 2).contiguous() 379 | x = self.act(self.conv2d(x)) 380 | y1, y2, y3, y4 = self.forward_core(x) 381 | assert y1.dtype == torch.float32 382 | y = y1 + y2 + y3 + y4 383 | y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) 384 | y = self.out_norm(y) 385 | y = y * F.silu(z) 386 | out = self.out_proj(y) 387 | if self.dropout is not None: 388 | out = self.dropout(out) 389 | return out 390 | 391 | 392 | class VSSBlock(nn.Module): 393 | def __init__( 394 | self, 395 | hidden_dim: int = 0, 396 | drop_path: float = 0, 397 | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), 398 | attn_drop_rate: float = 0, 399 | d_state: int = 16, 400 | expand: float = 2., 401 | is_light_sr: bool = False, 402 | **kwargs, 403 | ): 404 | super().__init__() 405 | self.ln_1 = norm_layer(hidden_dim) 406 | self.self_attention = SS2D(d_model=hidden_dim, d_state=d_state,expand=expand,dropout=attn_drop_rate, **kwargs) 407 | self.drop_path = DropPath(drop_path) 408 | self.skip_scale= nn.Parameter(torch.ones(hidden_dim)) 409 | self.conv_blk = CAB(hidden_dim,is_light_sr) 410 | self.ln_2 = nn.LayerNorm(hidden_dim) 411 | self.skip_scale2 = nn.Parameter(torch.ones(hidden_dim)) 412 | 413 | 414 | 415 | def forward(self, input, x_size): 416 | # x [B,HW,C] 417 | B, L, C = input.shape 418 | input = input.view(B, *x_size, C).contiguous() # [B,H,W,C] 419 | x = self.ln_1(input) 420 | x = input*self.skip_scale + self.drop_path(self.self_attention(x)) 421 | x = x*self.skip_scale2 + self.conv_blk(self.ln_2(x).permute(0, 3, 1, 2).contiguous()).permute(0, 2, 3, 1).contiguous() 422 | x = x.view(B, -1, C).contiguous() 423 | return x 424 | 425 | 426 | if __name__ == '__main__': 427 | # 初始化VSSBlock模块,hidden_dim为128 428 | block = VSSBlock(hidden_dim=128, drop_path=0.1, attn_drop_rate=0.1, d_state=16, expand=2.0, is_light_sr=False) 429 | 430 | # 将模块转移到合适的设备上 431 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 432 | block = block.to(device) 433 | 434 | # 生成随机输入张量,尺寸为[B, H*W, C],这里模拟的是批次大小为4,每个图像的尺寸是32x32,通道数为128 435 | B, H, W, C = 4, 32, 32, 128 436 | input_tensor = torch.rand(B, H * W, C).to(device) 437 | 438 | # 计算输出 439 | output_tensor = block(input_tensor, (H, W)) 440 | 441 | # 打印输入和输出张量的尺寸 442 | print("Input tensor size:", input_tensor.size()) 443 | print("Output tensor size:", output_tensor.size()) 444 | 445 | --------------------------------------------------------------------------------